In [None]:
%env CUDA_VISIBLE_DEVICES=3
%env HF_HOME=/mnt/LLM
%env OMP_NUM_THREADS=16
import sys; sys.path.insert(0, '..')

import torch
import transformers
from tqdm import tqdm, trange
from IPython.display import clear_output

import shared_cache

MODEL_NAME = "Qwen/QwQ-32B" #"Qwen/QwQ-32B-AWQ"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype='auto', low_cpu_mem_usage=True, device_map=device)


# Hogwild parallelism

In [None]:
problem = """
Calculate x + x^2 + x^3 + x^4 + x^5 for x = 1..10
""".strip()

parallelism_prompt_common = """
I will collaborate this problem with another. We refer to each other as Alice and Bob. We are assistants.

We will reason together and try to collaborate.
I will take into account what the other assistant is doing and try to help them.

We will write our solutions concurrently. I will write my own thoughts at the bottom, and see the other's thoughts above.

I will not repeat the copy assistant's thoughts: I can already see them above.

The other assistant will continue writing their thoughts above while I am writing mine. They will add more text every time I check.

Since we both write our thoughts in parallel, I will initially see only partial (unfinished) thoughts of the other assistant.
I will use these partial thoughts to decide how best to help the other assistant without doing the same work twice.

When reasoning, we will five each other tasks to coordinate (e.g. if Alice writes: Bob, please do this, then Bob should take this into account).

Before doing anything, I will check the other assistant's workspace. If they have already done that or are currently doing it, I don't need to do that again. If so, I will stop (e.g. 'Wait, this is already done') and pivot to a different task.

""".strip()

SEP = '\n\n'
WORKER_PREFIXES = [SEP + "# Alice workspace" + SEP, SEP + "# Bob workspace" + SEP]

prompt_full_input = tokenizer.apply_chat_template(
    [dict(role='user', content=problem)],
    tokenize=False,
    add_generation_prompt=True
) + SEP + parallelism_prompt_common
prompt_split = " <the assistant will continue here>\n\n"

worker_prompts = [
    WORKER_PREFIXES[0] + """I am Alice. Let's solve this together, Bob. Here's how we should collaborate:""".strip(),
    WORKER_PREFIXES[1] + """I am Bob. Let's solve this together, Alice.""".strip()
]

forbidden_token_ix = [tokenizer.vocab[x] for x in ['#', '<|im_end|>', '</think>']]

In [None]:
cache_input, cache_split, cache_w1, cache_w2 = (shared_cache.CacheBlock(config=model.config) for _ in range(4))
cm = shared_cache.SharedCacheManager(cache_structure=[
    [cache_input, cache_w2, cache_split, cache_w1],
    [cache_input, cache_w1, cache_split, cache_w2],
], write_to=[cache_w1, cache_w2])

# pre-fill common parts
with torch.no_grad():
    model(**tokenizer(prompt_full_input, add_special_tokens=False, return_tensors='pt').to(device),
          use_cache=True, past_key_values=cache_input);  # <-- write to common prompt
    model(**tokenizer(prompt_split, add_special_tokens=False, return_tensors='pt').to(device),
          use_cache=True, past_key_values=cache_split);   # <-- write to common separator

# generate texts
next_inputs = tokenizer(worker_prompts, return_tensors='pt', padding=True, padding_side='left').to(device)
tokens_by_worker = tokenizer(worker_prompts)['input_ids']
for inference_step in range(1024):
    with torch.no_grad():
        logits = model(**cm.get_input_kwargs(**next_inputs)).logits[..., -1, :]
        logits[..., forbidden_token_ix] -= 100
        new_tokens = logits.argmax(-1)
        assert len(new_tokens) == len(cm.cache_structure)
        next_inputs = dict(input_ids=new_tokens.view(-1, 1))

    for worker_tokens, new_token in zip(tokens_by_worker, new_tokens.tolist()):
        worker_tokens.append(new_token)
    clear_output(True)
    for worker_index, worker_tokens in enumerate(tokens_by_worker):
        print(end=tokenizer.decode(worker_tokens))
    print(flush=True)
