# Hogwild! Parallelism: Basic Example

This example demonstrates Hogwild! inference on a single problem with 2 workers and minimal prompt defined below. There are no few-shot examples or prompt insertions, and the cache layout is the simplest one possible: two contiguous workspaces. This notebook is intended as a playground while the other notebooks present more advanced prompting and cache layout.

In [None]:
import torch
import transformers
import shared_cache
from IPython.display import display, Markdown, clear_output

MODEL_NAME = "Qwen/QwQ-32B"  # for 48GB gpus, use "Qwen/QwQ-32B-AWQ" instead
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype='auto', low_cpu_mem_usage=True, device_map=device)

parallelism_prompt_common = """
I will collaborate this problem with another. We refer to each other as Alice and Bob. We are assistants.

We will reason together and try to collaborate. I will take into account what the other assistant is doing and try to help them.

We will write our solutions concurrently. I will write my own thoughts at the bottom, and see the other's thoughts above.

I will not repeat the copy assistant's thoughts: I can already see them above.

The other assistant will continue writing their thoughts above while I am writing mine. They will add more text every time I check.

Since we both write our thoughts in parallel, I will initially see only partial (unfinished) thoughts of the other assistant.
I will use these partial thoughts to decide how best to help the other assistant without doing the same work twice.

When reasoning, we will give each other tasks to coordinate (e.g. if Alice writes: Bob, please do this, then Bob should take this into account).

Before doing anything, I will check the other assistant's workspace. If they have already done that or are currently doing it, I don't need to do that again. If so, I will stop (e.g. 'Wait, this is already done') and pivot to a different task.
""".strip()

worker_headers = ["\n\n# Alice workspace\n\n", "\n\n# Bob workspace\n\n"]
prompt_split = " <the assistant will continue here>\n\n"

forbidden_token_ix = [tokenizer.vocab[x] for x in ("#", "</think>")]
for x in tokenizer.special_tokens_map.values():
    forbidden_token_ix.extend([tokenizer.vocab[x]] if isinstance(x, str) else map(tokenizer.vocab.get, x))
tokenizer_kwargs = dict(add_special_tokens=False, return_tensors='pt', padding=True, padding_side='left')

__Playground:__ you can define a problem and see if the workers collaborate. With this simple setup, they do not always do that well out of the box, but this allows you to see how the prompt impacts their actions.

In [None]:
problem = """Calculate x - x^2 + x^3 for x = 5,6,7,8. Alice must return all 4 answers in \boxed{ }."""

prompt_full_input = tokenizer.apply_chat_template(
    [dict(role='user', content=problem)], tokenize=False, add_generation_prompt=True
) + "\n\n" + parallelism_prompt_common

worker_prompts = [
    f"""{worker_headers[0]}I am Alice. Let's solve this together, Bob. Here's how we should collaborate:""",
    f"""{worker_headers[1]}I am Bob. Let's solve this together, Alice."""
]

cache_input, cache_split, cache_w1, cache_w2 = (shared_cache.CacheBlock(config=model.config) for _ in range(4))
cm = shared_cache.SharedCacheManager(cache_structure=[
    [cache_input, cache_w2, cache_split, cache_w1],
    [cache_input, cache_w1, cache_split, cache_w2],
], write_to=[cache_w1, cache_w2])

# pre-fill common parts
with torch.inference_mode():
    model(**tokenizer(prompt_full_input, **tokenizer_kwargs).to(device),
          use_cache=True, past_key_values=cache_input);  # <-- write to common prompt
    model(**tokenizer(prompt_split, **tokenizer_kwargs).to(device),
          use_cache=True, past_key_values=cache_split);   # <-- write to common separator

# generate tokens in parallel with each worker
next_inputs = tokenizer(worker_prompts, **tokenizer_kwargs).to(device)
tokens_by_worker = tokenizer(worker_prompts, add_special_tokens=False)["input_ids"]
for inference_step in range(1024):       # <-- change max tokens here
    with torch.inference_mode():
        logits = model(**cm.get_input_kwargs(**next_inputs)).logits[..., -1, :]
        logits[..., forbidden_token_ix] -= 100
        new_tokens = logits.argmax(-1)   # <-- greedy generation
        next_inputs = dict(input_ids=new_tokens.view(-1, 1))
    
    for worker_tokens, new_token in zip(tokens_by_worker, new_tokens.tolist()):
        worker_tokens.append(new_token)
    clear_output(True)
    display(Markdown("".join(tokenizer.decode(seq) for seq in tokens_by_worker)))