# Hogwild! Parallelism: Minimal Example (Colab Edition) [![Code](https://img.shields.io/badge/GitHub-%23121011.svg?logo=github&logoColor=white)](https://github.com/eqimp/hogwild_llm) [![arxiv](https://camo.githubusercontent.com/ce27cdf7b9627a67089c7bec66b101c90c6bbf21c2452a067a5bb5a4eac40d58/68747470733a2f2f696d672e736869656c64732e696f2f62616467652f41725869762d5044462d726564)](https://arxiv.org/abs/2504.06261)

This notebook demonstrates Hogwild! inference on a single problem with 2 workers and **using a small model to fit into colab's T4 GPU**. The smaller model can, to some extent, collaborate, but not as well as the larger reasoning-tuned LLMs.

In [None]:
!git clone https://github.com/eqimp/hogwild_llm && cp -r hogwild_llm/* .
import torch
import transformers
import shared_cache
from utils import get_math_input_prompts, get_logits_processor
from IPython.display import display, Markdown, clear_output
# load the smaller model to fit in colab; If you have a larger GPU, load QwQ-32B or R1 for more reliable collaboration
MODEL_NAME = "unsloth/Llama-3.2-3B-Instruct"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype='auto', low_cpu_mem_usage=True, device_map=device)

Cloning into 'hogwild_llm'...
remote: Enumerating objects: 164, done.[K
remote: Counting objects: 100% (164/164), done.[K
remote: Compressing objects: 100% (121/121), done.[K
remote: Total 164 (delta 74), reused 119 (delta 41), pack-reused 0 (from 0)[K
Receiving objects: 100% (164/164), 1.75 MiB | 8.39 MiB/s, done.
Resolving deltas: 100% (74/74), done.


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/54.7k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/945 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/6.43G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/234 [00:00<?, ?B/s]

__Playground:__ you can define a problem and see if the workers collaborate. As we state earlier, small models like this one often fail in silly ways when they try to interact. Use QwQ-32B similar for better effect. Though, they clearly *try* to colaborate.

In [None]:
problem = """Calculate x - x^2 + x * (1 - x) for x = 4,5,6,7.""".strip()

print_every_steps = 3
insert_s1_prompt_every_tokens = 512
tokens_since_last_wait = 0

workers = ["Alice", "Bob"]
Formatting = get_math_input_prompts(tokenizer, workers)  # <-- prompts are defined here
worker_prompts = [
    f"""{Formatting.get_step_prefix(workers[0], 1)}Hi, I'm {workers[0]}. Here's how we should collaborate:""",
    f"""{Formatting.get_step_prefix(workers[1], 1)}Hi, I'm {workers[1]}."""
]

# define cache structure for the combined layout
cache_common, cache_current_step_header, cache_separator, cache_w1, cache_w2 = (
    shared_cache.CacheBlock(config=model.config) for _ in range(5))
cm = shared_cache.SharedCacheManager(cache_structure=[
    [cache_common, cache_current_step_header, cache_w2, cache_separator, cache_w1],
    [cache_common, cache_current_step_header, cache_w1, cache_separator, cache_w2],
])

logits_processor = get_logits_processor(model, Formatting.forbidden_token_ix)
tokenizer_kwargs = dict(return_tensors='pt', padding=True, padding_side='left', add_special_tokens=False)

# initialize generation state for printing
history = []
current_step_index_by_worker = [1, 1]
current_step_tokens_by_worker = [tokenizer.encode(p, add_special_tokens=False) for p in worker_prompts]

# pre-fill common parts
with torch.inference_mode():
    model(**tokenizer([Formatting.get_full_prompt(problem)], **tokenizer_kwargs).to(device),
          use_cache=True, past_key_values=cache_common);  # <-- write to common prompt
    model(**tokenizer(Formatting.current_step_header, **tokenizer_kwargs).to(device),
          use_cache=True, past_key_values=cache_current_step_header);   # <-- write to separator
    model(**tokenizer(Formatting.current_worker_header, **tokenizer_kwargs).to(device),
          use_cache=True, past_key_values=cache_separator);   # <-- write to separator between incomplete steps

next_inputs = tokenizer(worker_prompts, **tokenizer_kwargs).to(device)

In [None]:
for inference_step in range(128):  # <-- modify the number of generation steps here
    # run model with shared cache
    with torch.inference_mode():
        logits = model(**cm.get_input_kwargs(**next_inputs)).logits[..., -1, :]
        logits = logits_processor(next_inputs['input_ids'], logits)
        new_tokens = torch.multinomial(logits.softmax(dim=-1), 1).flatten(
        ) if model.generation_config.do_sample else logits.argmax(-1)

    # process generated tokens for printing; handle step change, update next_inputs
    assert len(new_tokens) == len(Formatting.workers)
    next_input_tokens = new_tokens.unsqueeze(-1).tolist()
    for worker_index, (worker_name, worker_tokens, new_token) in enumerate(
            zip(Formatting.workers, current_step_tokens_by_worker, new_tokens.tolist())):
        worker_tokens.append(new_token)
        if Formatting.is_end_of_step(worker_tokens):
            # worker just finished their step - add it to common history and start a new step
            current_step_index_by_worker[worker_index] += 1
            history.extend(worker_tokens)
            worker_tokens.clear()
            start_msg = Formatting.get_step_prefix(worker_name, current_step_index_by_worker[worker_index])
            if tokens_since_last_wait > insert_s1_prompt_every_tokens:
                start_msg += Formatting.s1_collab_message   # <-- insert "Wait, am I doing redundant work?"
                tokens_since_last_wait = 0
            worker_tokens.extend(tokenizer.encode(start_msg, add_special_tokens=False))
            cache_common.append_from(cm.cache_structure[worker_index][-1])
            cm.cache_structure[worker_index][-1].clear()
            next_input_tokens[worker_index] = [new_token] + worker_tokens
        tokens_since_last_wait += len(next_input_tokens[worker_index])
    next_inputs = tokenizer.pad(dict(input_ids=next_input_tokens), padding_side='left', return_tensors='pt').to(device)

    if inference_step % print_every_steps == 0:
        clear_output(True)  # # display current progress
        output_parts = [f"[**Problem:** {problem}]\n\n"]
        output_parts.append(Formatting.history_header + Formatting.SEP + tokenizer.decode(history))
        output_parts.append(Formatting.current_step_header)
        for worker_index, worker_tokens in enumerate(current_step_tokens_by_worker):
            output_parts.append(tokenizer.decode(worker_tokens) + Formatting.pivot_message + Formatting.SEP)
        display(Markdown(''.join(output_parts)))

[**Problem:** Calculate x - x^2 + x * (1 - x) for x = 4,5,6,7.]

### Past steps

**Alice [1]:** Hi, I'm Alice. Here's how we should collaborate: I'll calculate x - x² + x(1 - x) for x = 4, 5, 6, 7, and you can do the same for x = 3, 8, 9, 10. 

**Bob [1]:** Hi, I'm Bob. Let's start by calculating the expression for each x value: x - x² + x(1 - x) = x - x² + x - x². Simplifying, we get x - 2x². So we have two terms: x and -2x². 

**Alice [2]:** 4 - 2*4² = 4 - 32 = -28. 

**Bob [2]:** 5 - 2*5² = 5 - 50 = -45.

**Alice [3]:** 6 - 2*6² = 6 - 72 = -66.



### Work in progress (others)

**Alice [4]:** 8 - 2*<...>

**Bob [3]:** 7 - 2*7² = 7 - 98 =<...>



**Disclaimer: small models are poor collaborators** and may be incapable of more complex interactions required for LIMO - and sometimes fail to keep their own promises, doing redundant work despite agreeing not to. We recommend using larger models such as [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) when possible - they work significantly better together.