In [1]:
import os
import torch
import random
import numpy as np
HF_TOKEN = os.getenv("HF_TOKEN")
print(HF_TOKEN)



None


In [2]:
def set_seed(seed=42):
    # Set seed for Python's built-in random module
    random.seed(seed)
    
    # Set seed for NumPy
    np.random.seed(seed)
    
    # Set seed for PyTorch
    torch.manual_seed(seed)
    
    # Ensure deterministic behavior on CUDA (GPU)
    torch.cuda.manual_seed(seed)

In [3]:
checkpoint = "gpt2-large"
checkpoint_assist = "gpt2-medium"
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Select device (GPU or CPU)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_auth_token=HF_TOKEN)
main_model = AutoModelForCausalLM.from_pretrained(checkpoint, use_auth_token=HF_TOKEN).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(checkpoint_assist, use_auth_token=HF_TOKEN).to(device)




In [4]:
import time
def vanilla_generation(model, tokenizer, prompt, max_tokens=8):
    start = time.time()
    input = tokenizer(prompt, return_tensors="pt")
    input = input.to(device)
    out = model.generate(**input, max_new_tokens=max_tokens)
    end = time.time()
    print(tokenizer.decode(out[0], skip_special_tokens=True))
    return end - start
set_seed(42)

vanilla_generation(main_model, tokenizer, "Hi, how")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Hi, how are you? I'm fine. I


2.510523557662964

In [5]:
def check_models(assistant_model, main_model, tokenizer, prompt, max=8):
    input = tokenizer(prompt, return_tensors="pt")
    input = input.to(device)
    out_main = main_model.generate(**input, max_new_tokens=max)
    out_assist = assistant_model.generate(**input, max_new_tokens=max)
    print(f"Main model output: {tokenizer.decode(out_main[0])}")
    print(f"Asistant model output: {tokenizer.decode(out_assist[0])}")
check_models(assistant_model, main_model, tokenizer, "Hi, how")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Main model output: Hi, how are you? I'm fine. I
Asistant model output: Hi, how are you doing?

I'm


In [34]:
def speculative_decoding(tokenizer, model, assistant_model, prompt, max_len=10, speculative_len=6, vocab_size=50257):
    # Generating tokens we will speculate on:
    start = time.time()
    cur_len = 0
    input = tokenizer(prompt, return_tensors="pt")
    input = input.to(device)
   
    while cur_len < max_len:
        candidate_input_ids = input["input_ids"]
        attn_mask = input["attention_mask"]
        for _ in range(speculative_len):
            with torch.no_grad():
                assistant_model.eval()
                out = assistant_model(candidate_input_ids, attention_mask=attn_mask)
                next_token = out.logits[:, -1, :].argmax(dim=-1)
                candidate_input_ids = torch.cat((candidate_input_ids, next_token[:, None]), dim=-1)
                attn_mask = torch.cat((attn_mask, torch.ones_like(next_token[:, None])), dim=-1)
               
        with torch.no_grad():
            #verifying using main model:
            speculative_ids = candidate_input_ids[:, -speculative_len:]
            print(f"What has assitant model predicted: {tokenizer.decode(speculative_ids[0])}")
            if speculative_len > 0:
                print("INPUT THAT GOES TO THE MAIN MODEL", tokenizer.decode(input["input_ids"][0]))
                model.eval()
                out_logits = model(candidate_input_ids, attention_mask=attn_mask)
                last_logits = out_logits.logits[:, -speculative_len:, :]
                verify_ids = torch.argmax(last_logits, dim=-1)
                print(f" MAIN MODEL IDS {verify_ids[0]}")
                print(f" Asistant IDS {speculative_ids[0]}")

                print(f"How main model thinks it should be: {tokenizer.decode(verify_ids[0])}")

                match_mask = ~(speculative_ids == verify_ids)
                print(match_mask)

                match_mask = match_mask.cumsum(dim=-1)
                match_mask = match_mask < 1
                print(match_mask)
                n_matches = match_mask.sum().item()
                if n_matches != speculative_len:
                    input["input_ids"] = candidate_input_ids[:, :-speculative_len + n_matches]
                    input["attention_mask"] = attn_mask[:, :-speculative_len + n_matches]
                else: 
                    input["input_ids"] = candidate_input_ids 
                    input["attention_mask"] = attn_mask

                print(tokenizer.decode(input["input_ids"][0]))
                print(f" Number of matches = {n_matches}")
                cur_len += n_matches

                # Adjust speculative_len dynamically
                # if n_matches == speculative_len and speculative_len != 0:
                #     speculative_len += 2
                # elif n_matches < speculative_len / 4:
                #     speculative_len = max(0, speculative_len - 1)  # Ensure speculative_len is not negative

            # else:
            #     # Fallback to vanilla generation when speculative_len becomes 0
            #     vanilla_time = vanilla_generation(model, tokenizer, prompt, max_tokens=max_len)
            #     cur_len = max_len
    end = time.time()
    return end - start
                
                



prompt = "Hi, how"
set_seed(42)
total= speculative_decoding(tokenizer, main_model, assistant_model, prompt)
# print(f"Speculative loop time: {spec}")
# print(f"Total time of vanilla if it was used: {vanilla}")
print(f"Total time of function: {total}")


What has assitant model predicted:  are you doing?


INPUT THAT GOES TO THE MAIN MODEL Hi, how
 MAIN MODEL IDS tensor([345,  30,  30, 198, 198,  40], device='cuda:0')
 Asistant IDS tensor([ 389,  345, 1804,   30,  198,  198], device='cuda:0')
How main model thinks it should be:  you??

I
tensor([[ True,  True,  True,  True, False,  True]], device='cuda:0')
tensor([[False, False, False, False, False, False]], device='cuda:0')
Hi, how
 Number of matches = 0
What has assitant model predicted:  are you doing?


INPUT THAT GOES TO THE MAIN MODEL Hi, how
 MAIN MODEL IDS tensor([345,  30,  30, 198, 198,  40], device='cuda:0')
 Asistant IDS tensor([ 389,  345, 1804,   30,  198,  198], device='cuda:0')
How main model thinks it should be:  you??

I
tensor([[ True,  True,  True,  True, False,  True]], device='cuda:0')
tensor([[False, False, False, False, False, False]], device='cuda:0')
Hi, how
 Number of matches = 0
What has assitant model predicted:  are you doing?


INPUT THAT GOES TO THE MAIN

KeyboardInterrupt: 