# Hogwild! Parallelism: Minimal Example (Colab Edition) [![Code](https://img.shields.io/badge/GitHub-%23121011.svg?logo=github&logoColor=white)](https://github.com/eqimp/hogwild_llm) [![arxiv](https://camo.githubusercontent.com/ce27cdf7b9627a67089c7bec66b101c90c6bbf21c2452a067a5bb5a4eac40d58/68747470733a2f2f696d672e736869656c64732e696f2f62616467652f41725869762d5044462d726564)](https://arxiv.org/abs/2504.06261)

This notebook demonstrates Hogwild! inference on a single problem with 2 workers and **using a small model to fit into colab's T4 GPU**. The smaller model can, to some extent, collaborate, but not as well as the larger reasoning-tuned LLMs.

In [1]:
!git clone https://github.com/eqimp/hogwild_llm && cp -r hogwild_llm/* .
%env HOGWILD_USE_TRITON=0
import torch
import transformers
import shared_cache
from generation import MathFormatting, get_logits_processor
from IPython.display import display, Markdown, clear_output
# load the smaller model to fit in colab; If you have a larger GPU, load QwQ-32B or R1 for more reliable collaboration
MODEL_NAME = "unsloth/Llama-3.2-3B-Instruct"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype='auto', low_cpu_mem_usage=True, device_map=device)

__Playground:__ you can define a problem and see if the workers collaborate. As we state earlier, small models like this one often fail in silly ways when they try to interact. Use QwQ-32B similar for better effect. Though, they clearly *try* to colaborate.

In [3]:
problem = """Calculate x - x^2 + x * (1 - x) for x = 4,5,6,7.""".strip()

print_every_steps = 3
insert_s1_prompt_every_tokens = 512
tokens_since_last_wait = 0

workers = ["Alice", "Bob"]
fmt = MathFormatting(
    tokenizer, workers, pass_system_prompt_as_user_message=True,
)  # ^-- prompts; has options for different model types - see formatting.py
worker_prompts = [
    f"""{fmt.get_step_prefix(workers[0], 1)}Hi, I'm {workers[0]}. Here's how we should collaborate:""",
    f"""{fmt.get_step_prefix(workers[1], 1)}Hi, I'm {workers[1]}."""
]

# define cache structure for the combined layout
cache_common, cache_current_step_header, cache_separator, cache_w1, cache_w2 = (
    shared_cache.CacheBlock(config=model.config) for _ in range(5))
cm = shared_cache.SharedCacheManager(cache_structure=[
    [cache_common, cache_current_step_header, cache_w2, cache_separator, cache_w1],
    [cache_common, cache_current_step_header, cache_w1, cache_separator, cache_w2],
])

logits_processor = get_logits_processor(model)
tokenizer_kwargs = dict(return_tensors='pt', padding=True, padding_side='left', add_special_tokens=False)

# initialize generation state for printing
history = []
current_step_index_by_worker = [1, 1]
current_step_tokens_by_worker = [tokenizer.encode(p, add_special_tokens=False) for p in worker_prompts]


# pre-fill common parts
with torch.inference_mode():
    model(**tokenizer([fmt.apply_chat_template(problem)], **tokenizer_kwargs).to(device),
          use_cache=True, past_key_values=cache_common);  # <-- write to common prompt
    model(**tokenizer(fmt.current_step_header, **tokenizer_kwargs).to(device),
          use_cache=True, past_key_values=cache_current_step_header);   # <-- write to separator
    model(**tokenizer(fmt.current_worker_header, **tokenizer_kwargs).to(device),
          use_cache=True, past_key_values=cache_separator);   # <-- write to separator between incomplete steps
    
next_inputs = tokenizer(worker_prompts, **tokenizer_kwargs).to(device)

In [5]:
for inference_step in range(128):  # <-- modify the number of generation steps here
    # run model with shared cache
    with torch.inference_mode():
        logits = model(**cm.get_input_kwargs(**next_inputs)).logits[..., -1, :]
        logits = logits_processor(next_inputs['input_ids'], logits)
        new_tokens = torch.multinomial(logits.softmax(dim=-1), 1).flatten(
        ) if model.generation_config.do_sample else logits.argmax(-1)

    # process generated tokens for printing; handle step change, update next_inputs
    assert len(new_tokens) == len(fmt.workers)
    next_input_tokens = new_tokens.unsqueeze(-1).tolist()
    for worker_index, (worker_name, worker_tokens, new_token) in enumerate(
            zip(fmt.workers, current_step_tokens_by_worker, new_tokens.tolist())):
        worker_tokens.append(new_token)
        if fmt.is_end_of_step(worker_tokens):
            # worker just finished their step - add it to common history and start a new step
            current_step_index_by_worker[worker_index] += 1
            history.extend(worker_tokens)
            worker_tokens.clear()
            start_msg = fmt.get_step_prefix(worker_name, current_step_index_by_worker[worker_index])
            if tokens_since_last_wait > insert_s1_prompt_every_tokens:
                start_msg += fmt.s1_collab_message   # <-- insert "Wait, am I doing redundant work?"
                tokens_since_last_wait = 0
            worker_tokens.extend(tokenizer.encode(start_msg, add_special_tokens=False))
            cache_common.append_from(cm.cache_structure[worker_index][-1])
            cm.cache_structure[worker_index][-1].clear()
            next_input_tokens[worker_index] = [new_token] + worker_tokens
        tokens_since_last_wait += len(next_input_tokens[worker_index])
    next_inputs = tokenizer.pad(dict(input_ids=next_input_tokens), padding_side='left', return_tensors='pt').to(device)

    if inference_step % print_every_steps == 0:
        clear_output(True)  # # display current progress
        output_parts = [f"[**Problem:** {problem}]\n\n"]
        output_parts.append(fmt.history_header + fmt.sep + tokenizer.decode(history))
        output_parts.append(fmt.current_step_header)
        for worker_index, worker_tokens in enumerate(current_step_tokens_by_worker):
            output_parts.append(tokenizer.decode(worker_tokens) + fmt.incomplete_step + fmt.sep)
        display(Markdown(''.join(output_parts)))

[**Problem:** Calculate x - x^2 + x * (1 - x) for x = 4,5,6,7.]

### Past steps

**Bob [1]:** Hi, I'm Bob. Let me see what we need to do.

**Alice [1]:** Hi, I'm Alice. Here's how we should collaborate: Let's split the values of x between us. Since there are four values (4,5,6,7), we can each do two. Let me take x=4 and x=6. You can do 5 and 7. That way we cover all without overlap.

**Bob [2]:**  Okay, Alice, that sounds good. Let me suggest we each take two of them. Since there are four numbers, splitting them into pairs would make sense. Maybe I can take 5 and 7, and you do 4 and 6. That way we can cross-verify each other's results too. 

**Alice [2]:**  Perfect, Bob! Let's do that. I'll start with x=4 first. Let me write down the expression: x - x² + x*(1 - x). Let's substitute x=4. 

### Work in progress (others)

**Alice [3]:**  For x=4: First calculate each term separately to avoid mistakes.<...>

**Bob [3]:**  Alright, I'll handle x=5 and x=7. Let me first compute for x=5. Plugging into the expression: 5 - 5² + 5*(1 - 5<...>



KeyboardInterrupt: 

**Disclaimer: small models are poor collaborators** and may be incapable of more complex interactions required for LIMO - and sometimes fail to keep their own promises, doing redundant work despite agreeing not to. We recommend using larger models such as [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) when possible - they work significantly better together.