In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
import logging

import torch


class Target:
    def __init__(self, model_name_or_path):
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path, device_map="auto", torch_dtype=torch.bfloat16
        )
        self.pkv = None

        # batch, head, seq, dim/head

    def fix_pkv(self, start_id):
        self.pkv = tuple(
            (k[:, :, :start_id, :], v[:, :, :start_id, :]) for k, v in self.pkv
        )

    def verify(self, input_ids, start_id):
        """
        start_id: absolute index
        """
        if self.pkv is not None:
            current_pkv_len = self.pkv[0][0].shape[2]

            if start_id < current_pkv_len:
                self.fix_pkv(start_id)
            else:
                start_id = current_pkv_len

        input_ids_check = input_ids[:, start_id:]

        outputs = self.model(
            input_ids_check, past_key_values=self.pkv, return_dict=True
        )
        logits = outputs.logits

        self.pkv = outputs.past_key_values

        target_pred = logits.argmax(dim=-1)
        target_pred = target_pred[:, :-1]
        input_ids_check_continue = input_ids_check[:, 1:]

        logging.error(f"{input_ids_check_continue=}, {target_pred=}")

        draft_target_agree = input_ids_check != target_pred

        logging.error(f"{draft_target_agree=}")

        first_mistake_index_list = draft_target_agree.nonzero()

        if len(first_mistake_index_list) == 0:
            # no mistakes
            mistake_index = None
            correct_token = None
        else:
            logging.error(f"{first_mistake_index_list=}")
            mistake_index = first_mistake_index_list[0][1]
            correct_token = target_pred[0, mistake_index]
        logging.error(f"{correct_token=}, {mistake_index=}")

        return correct_token, mistake_index

In [None]:
target = Target("gpt2")

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [None]:
prompt = "Danny has a hat and a suit. What does Danny have? Danny has a"

In [None]:
inps = tokenizer(prompt, return_tensors="pt")

In [None]:
outs = target.model.generate(**inps, do_sample=False, max_new_tokens=30)

In [None]:
out_ids = tokenizer.convert_ids_to_tokens(outs[0])

In [None]:
from torch import tensor

In [None]:
new_ids = outs[:, :22]

In [None]:
new_ids

In [None]:
new_ids = tensor(
    [
        [
            45478,
            468,
            257,
            6877,
            290,
            257,
            6050,
            13,
            1867,
            857,
            15105,
            423,
            30,
            15105,
            468,
            257,
            6877,
            290,
            257,
            1219,
            1113,
            1919,
        ]
    ]
)

In [None]:
(new_ids == 1219).nonzero()

In [None]:
target.pkv = None
res = target.verify(new_ids, inps["input_ids"].shape[1] - 1)

In [None]:
tokenizer.batch_decode(new_ids)

In [None]:
[
    ("Danny", tensor(45478)),
    ("Ġhas", tensor(468)),
    ("Ġa", tensor(257)),
    ("Ġhat", tensor(6877)),
    ("Ġand", tensor(290)),
    ("Ġa", tensor(257)),
    ("Ġsuit", tensor(6050)),
    (".", tensor(13)),
    ("ĠWhat", tensor(1867)),
    ("Ġdoes", tensor(857)),
    ("ĠDanny", tensor(15105)),
    ("Ġhave", tensor(423)),
    ("?", tensor(30)),
    ("ĠDanny", tensor(15105)),
    ("Ġhas", tensor(468)),
    ("Ġa", tensor(257)),
    ("Ġhat", tensor(6877)),
    ("Ġand", tensor(290)),
    ("Ġa", tensor(257)),
    ("Ġsuit", tensor(12893)),
    (".", tensor(13)),
    ("ĠWhat", tensor(1867)),
    ("Ġdoes", tensor(857)),
    ("ĠDanny", tensor(15105)),
    ("Ġhave", tensor(423)),
    ("?", tensor(30)),
    ("Ċ", tensor(198)),
    ("Ċ", tensor(198)),
    ("The", tensor(464)),
    ("Ġfirst", tensor(717)),
    ("Ġtime", tensor(640)),
    ("ĠI", tensor(314)),
    ("Ġsaw", tensor(2497)),
    ("Ġhim", tensor(683)),
    (",", tensor(11)),
    ("ĠI", tensor(314)),
    ("Ġwas", tensor(373)),
    ("Ġlike", tensor(588)),
    (",", tensor(11)),
    ('Ġ"', tensor(366)),
    ("Oh", tensor(5812)),
    (",", tensor(11)),
    ("Ġhe", tensor(339)),
    ("'s", tensor(338)),
    ("Ġa", tensor(257)),
    ("Ġguy", tensor(3516)),
]

In [None]:
list(zip(out_ids, outs[0]))

In [None]:
len([2, 31414, 6, 127, 766, 16, 344, 4, 347, 4, 8, 38, 524, 10, 1294, 23, 5, 589, 9])

# Multiprocessing

In [None]:
import os

from torch.multiprocessing import Process, Queue

os.environ["TOKENIZERS_PARALLELISM"] = "false"

model = AutoModelForCausalLM.from_pretrained("gpt2")
tok_ids = torch.tensor(
    [[15205, 541, 305, 919, 278, 351, 12905, 2667, 15399, 714, 307, 281, 220]]
)


def fwd(model, tok_ids, queue):
    print("Starting process")
    print(f"{os.environ['TOKENIZERS_PARALLELISM']=}")
    print(f"{type(model)=}")
    print(f"{tok_ids=}")
    try:
        outs = model(tok_ids)
    except Exception as e:
        print(f"Error: {e}")
    print(f"{outs=}")
    queue.put(outs)


queue = Queue()
pr = Process(target=fwd, args=(model, tok_ids, queue))
pr.start()
pr.join()
outs = queue.get()
print(outs)

### Workaround with async-await

In [None]:
# a list of all the prefixes of `tok_ids`
all_tok_ids = [tok_ids[:, :i] for i in range(1, tok_ids.shape[1] + 1)]
all_tok_ids

In [None]:
import asyncio


async def fwd(model, tok_ids):
    print("Starting process")
    print(f"{os.environ['TOKENIZERS_PARALLELISM']=}")
    print(f"{type(model)=}")
    print(f"{tok_ids=}")
    try:
        outs = model(tok_ids)
    except Exception as e:
        print(f"Error: {e}")
    print(f"{outs=}")
    return outs


async def main():
    print("Starting main")
    tasks = [fwd(model, t) for t in all_tok_ids]
    print("Running tasks")
    outs = await asyncio.gather(*tasks)
    print(outs)


# asyncio.run(main())
await main()