# Hogwild! Parallelism: example with interleaved cache and full prompt

This is a more advanced version of `basic_example.ipynb` that features a combined layout: interleaved steps with instant (token-level) synchronization.

In [1]:
import torch
import transformers
import shared_cache
from utils import get_math_input_prompts, get_logits_processor
from IPython.display import clear_output, display, Markdown

MODEL_NAME = "Qwen/QwQ-32B"  # for 48gb gpu, use "Qwen/QwQ-32B-AWQ"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype='auto', low_cpu_mem_usage=True, device_map=device)



Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loading checkpoint shards:   0%|          | 0/14 [00:00<?, ?it/s]

In [2]:
problem = """
Three vertices of a cube are $P=(7,12,10)$ , $Q=(8,8,1)$ , and $R=(11,3,9)$ . What is the surface area of the cube?
""".strip()

print_every_steps = 3
insert_s1_prompt_every_tokens = 256
tokens_since_last_wait = 0

workers = ["Alice", "Bob"]
Formatting = get_math_input_prompts(tokenizer, workers)  # <-- prompts are defined here
worker_prompts = [
    f"""{Formatting.get_step_prefix(workers[0], 1)}Hi, I'm {workers[0]}. Here's how we should do this:""",
    f"""{Formatting.get_step_prefix(workers[1], 1)}Hi, I'm {workers[1]}."""
]

# define cache structure for the combined layout
cache_common, cache_current_step_header, cache_separator, cache_w1, cache_w2 = (
    shared_cache.CacheBlock(config=model.config) for _ in range(5))
cm = shared_cache.SharedCacheManager(cache_structure=[
    [cache_common, cache_current_step_header, cache_w2, cache_separator, cache_w1],
    [cache_common, cache_current_step_header, cache_w1, cache_separator, cache_w2],
])

logits_processor = get_logits_processor(model, Formatting.forbidden_token_ix)
tokenizer_kwargs = dict(return_tensors='pt', padding=True, padding_side='left', add_special_tokens=False)

# initialize generation state for printing
history = []
current_step_index_by_worker = [1, 1]
current_step_tokens_by_worker = [tokenizer.encode(p, add_special_tokens=False) for p in worker_prompts]

# pre-fill common parts
with torch.inference_mode():
    model(**tokenizer([Formatting.get_full_prompt(problem)], **tokenizer_kwargs).to(device),
          use_cache=True, past_key_values=cache_common);  # <-- write to common prompt
    model(**tokenizer(Formatting.current_step_header, **tokenizer_kwargs).to(device),
          use_cache=True, past_key_values=cache_current_step_header);   # <-- write to separator
    model(**tokenizer(Formatting.current_worker_header, **tokenizer_kwargs).to(device),
          use_cache=True, past_key_values=cache_separator);   # <-- write to separator between incomplete steps
    
next_inputs = tokenizer(worker_prompts, **tokenizer_kwargs).to(device)

In [None]:
for inference_step in range(1024):  # <-- modify the number of generation steps here
    # run model with shared cache
    with torch.inference_mode():
        logits = model(**cm.get_input_kwargs(**next_inputs)).logits[..., -1, :]
        logits = logits_processor(next_inputs['input_ids'], logits)
        new_tokens = torch.multinomial(logits.softmax(dim=-1), 1).flatten(
        ) if model.generation_config.do_sample else logits.argmax(-1)
    
    # process generated tokens for printing; handle step change, update next_inputs
    assert len(new_tokens) == len(Formatting.workers)
    next_input_tokens = new_tokens.unsqueeze(-1).tolist()    
    for worker_index, (worker_name, worker_tokens, new_token) in enumerate(
            zip(Formatting.workers, current_step_tokens_by_worker, new_tokens.tolist())):
        worker_tokens.append(new_token)
        if Formatting.is_end_of_step(worker_tokens):
            # worker just finished their step - add it to common history and start a new step
            current_step_index_by_worker[worker_index] += 1
            history.extend(worker_tokens)
            worker_tokens.clear()
            start_msg = Formatting.get_step_prefix(worker_name, current_step_index_by_worker[worker_index])
            if tokens_since_last_wait > insert_s1_prompt_every_tokens:
                start_msg += Formatting.s1_collab_message   # <-- insert "Wait, am I doing redundant work?"ii
                tokens_since_last_wait = 0
            worker_tokens.extend(tokenizer.encode(start_msg, add_special_tokens=False))
            cache_common.append_from(cm.cache_structure[worker_index][-1])
            cm.cache_structure[worker_index][-1].clear()
            next_input_tokens[worker_index] = [new_token] + worker_tokens
        tokens_since_last_wait += len(next_input_tokens[worker_index])
    next_inputs = tokenizer.pad(dict(input_ids=next_input_tokens), padding_side='left', return_tensors='pt').to(device)

    if inference_step % print_every_steps == 0:
        clear_output(True)  # # display current progress
        output_parts = [f"[**Problem:** {problem}]\n\n"]
        output_parts.append(Formatting.history_header + Formatting.SEP + tokenizer.decode(history))
        output_parts.append(Formatting.current_step_header)
        for worker_index, worker_tokens in enumerate(current_step_tokens_by_worker):
            output_parts.append(tokenizer.decode(worker_tokens) + Formatting.pivot_message + Formatting.SEP)
        display(Markdown(''.join(output_parts)))

[**Problem:** Three vertices of a cube are $P=(7,12,10)$ , $Q=(8,8,1)$ , and $R=(11,3,9)$ . What is the surface area of the cube?]

### Past steps

**Bob [1]:** Hi, I'm Bob. Let me suggest that we first find the vectors between the points and check if they are edges or face diagonals of the cube. Since all edges are equal and angles between edges are 90 degrees, we can use vector dot products.

**Alice [1]:** Hi, I'm Alice. Here's how we should do this: Let's compute the distances between the three points to see if they can be edges, face diagonals, or space diagonals. Since in a cube, edges are equal, face diagonals are edge * sqrt(2), and space diagonals are edge * sqrt(3). So first, let me compute PQ, QR, and RP.



### Work in progress (others)

**Alice [2]:**  Wait, Bob is already calculating PQ. Maybe I'll compute QR instead. QR is between Q(8,8,1) and R(11,3,9). The differences are (3, -5, 8). Squared distance is 3² + (-5<...>

**Bob [2]:**  Let me start by calculating the distance between P and Q. P is (7,12,10), Q is (8,8,1). The differences in coordinates are (1, -4, -9). So the squared distance is (1)^2 + (-4)^2 + (-9)^2 = 1 + 16 + 81 = 9<...>

