# Hogwild! Parallelism: Fast Kernels

This example demonstrates Hogwild! inference on a single problem using 2 workers and fast custom kernels. Please ensure that you have already installed the `hogwild` module by navigating to the `inference_lib` folder and running:

```bash
pip install -e . # ensure you have nvcc cuda compiler in PATH or export CUDACXX=/TODO/path/to/nvcc
```

Currently, the fast kernels only work with QwQ-32B and its quantized versions.


In [None]:
import sys;
import random
from copy import deepcopy
from typing import Dict, NamedTuple, Sequence, Optional

import numpy as np
import torch
import transformers
import torch
import transformers
from hogwild.generation import MathFormatting, get_logits_processor
from hogwild.attention import model_surgery, HogwildCache, merge_caches
from hogwild.formatting import FormattingBase, MathFormatting
from IPython.display import display, Markdown, clear_output

In [None]:
MODEL_NAME = "Qwen/QwQ-32B-AWQ"  # for 48GB GPUs use AWQ => for 80GB GPUs use QWQ

assert "QwQ" in MODEL_NAME, "Reference implementation only supports QwQ"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype='auto', low_cpu_mem_usage=True, device_map=device)

max_steps=8196
print_every_steps=1
insert_s1_collab_message_every_tokens=1024

In [None]:
problem = """
Calculate 3x +x^2 for x= 4, 9. Alice must return all answers in \\boxed{ }.""".strip()

print_every_steps = 1
insert_s1_prompt_every_tokens = 256
tokens_since_last_wait = 0

workers = ["Alice", "Bob"]
fmt = MathFormatting(
    tokenizer, workers,
)  # ^-- prompts and optional few-shot examples; has options for different model types - see formatting.py


worker_prompts = [
    f"""{fmt.get_step_prefix(workers[0], 1)}Hi, I'm {workers[0]}. Here's how we can collaborate""",
    f"""{fmt.get_step_prefix(workers[1], 1)}Hi, I'm {workers[1]}."""
]

In [None]:
logits_processor = get_logits_processor(model)
device = next(model.parameters()).device
tokenizer_kwargs = dict(return_tensors='pt', padding=True, padding_side='left', add_special_tokens=False)

In [None]:
tokens_since_last_wait = 0
cache_common, cache_current_step_header, cache_own_header, cache_w1, cache_w2 = (transformers.DynamicCache() for _ in range(5))
cm = HogwildCache(cache_structure=[
    [cache_common, cache_current_step_header, cache_w2, cache_own_header, cache_w1],
    [cache_common, cache_current_step_header, cache_w1, cache_own_header, cache_w2],
], write_to=[cache_w1, cache_w2], model=model)

w_prompt_caches = {
    0: HogwildCache(cache_structure=[
        [cache_common, cache_current_step_header, cache_w2, cache_own_header, cache_w1],
    ], write_to=[cache_w1], model=model),
    1: HogwildCache(cache_structure=[
        [cache_common, cache_current_step_header, cache_w1, cache_own_header, cache_w2],
    ], write_to=[cache_w2], model=model)
}


In [None]:
model_surgery(model)
model = torch.compile(model)
model.eval()

In [None]:
fmt = MathFormatting(tokenizer, extract_result=lambda box: int("".join(x for x in box if x.isdigit())))

# pre-fill common cache parts
with torch.inference_mode():
    model(**tokenizer(fmt.apply_chat_template(problem), **tokenizer_kwargs).to(device),
          use_cache=True, past_key_values=HogwildCache([[cache_common]], model=model))  # <-- write to common prompt

    model(**tokenizer(fmt.current_step_header, **tokenizer_kwargs).to(device),
          use_cache=True, past_key_values=HogwildCache([[cache_current_step_header]], model=model))  # <-- write to the separator after history

    model(**tokenizer(fmt.current_worker_header, **tokenizer_kwargs).to(device),
          use_cache=True, past_key_values=HogwildCache([[cache_own_header]], model=model))  # <-- write to separator between incomplete steps

In [None]:
# generate interdependent reasoning chains in parallel
current_step_index_by_worker = [1, 1]
current_step_tokens_by_worker = tokenizer(list(fmt.worker_prompts), add_special_tokens=False)['input_ids']
history = list()
next_inputs = tokenizer(list(fmt.worker_prompts), **tokenizer_kwargs).to(device)
output_parts_history=[]

In [None]:
for inference_step in range(max_steps):
    # run model with a shared cache (batched inference)
    with torch.inference_mode():
        logits = model(**cm.get_input_kwargs(**next_inputs)).logits[..., -1, :]
        logits = logits_processor(next_inputs['input_ids'], logits)
        new_tokens = torch.multinomial(logits.softmax(dim=-1), 1).flatten(
            ) if model.generation_config.do_sample else logits.argmax(-1)
        assert len(new_tokens) == len(fmt.workers)

    # process generated tokens for printing; handle step change, update next_inputs
    next_input_tokens = new_tokens.unsqueeze(-1).tolist()
    for worker_index, (worker_name, worker_tokens, new_token) in enumerate(
            zip(fmt.workers, current_step_tokens_by_worker, new_tokens.tolist())):
        worker_tokens.append(new_token)
        if fmt.is_end_of_step(worker_tokens):
            # worker just finished their step - add it to common history and start a new step
            current_step_index_by_worker[worker_index] += 1
            history.extend(worker_tokens)
            worker_tokens.clear()
            start_msg = fmt.get_step_prefix(worker_name, current_step_index_by_worker[worker_index])
            if tokens_since_last_wait > insert_s1_collab_message_every_tokens:
                start_msg += fmt.s1_collab_message
                tokens_since_last_wait = 0
            worker_tokens.extend(tokenizer.encode(start_msg, add_special_tokens=False))
            merge_caches(cache_common, cm.cache_structure[worker_index][-1], model.model)
            cm.cache_structure[worker_index][-1].crop(0)
            next_input_tokens[worker_index] = [new_token] + worker_tokens
        tokens_since_last_wait += len(next_input_tokens[worker_index])
    next_inputs = tokenizer.pad(
        dict(input_ids=next_input_tokens), padding_side='left', return_tensors='pt').to(device)
    if inference_step % print_every_steps == 0:
        output_parts = [f"[**Problem:** {problem}]\n\n"]
        output_parts.append(fmt.history_header + fmt.sep + tokenizer.decode(history))
        output_parts.append(fmt.current_step_header)
        for worker_index, worker_tokens in enumerate(current_step_tokens_by_worker):
            output_parts.append(tokenizer.decode(worker_tokens) + fmt.incomplete_step + fmt.sep)
        output_parts_history.append(output_parts)
        clear_output(True)  # display current progress
        display(Markdown(''.join(output_parts)))
    
    if torch.any(new_tokens == tokenizer.eos_token_id).item():
        break  # at least one worker generated the end-of-sequence token, finish early
