In [78]:
import os
import torch
import random
import numpy as np
HF_TOKEN = os.getenv("HF_TOKEN")
print(HF_TOKEN)



None


In [79]:
def set_seed(seed=42):
    # Set seed for Python's built-in random module
    random.seed(seed)
    
    # Set seed for NumPy
    np.random.seed(seed)
    
    # Set seed for PyTorch
    torch.manual_seed(seed)
    
    # Ensure deterministic behavior on CUDA (GPU)
    torch.cuda.manual_seed(seed)

In [2]:
checkpoint = "gpt2"
checkpoint_assist = "distilgpt2"
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Select device (GPU or CPU)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_auth_token=HF_TOKEN)
main_model = AutoModelForCausalLM.from_pretrained(checkpoint, use_auth_token=HF_TOKEN).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(checkpoint_assist, use_auth_token=HF_TOKEN).to(device)




In [130]:
import time
def vanilla_generation(model, tokenizer, prompt, max_tokens=8):
    start = time.time()
    input = tokenizer(prompt, return_tensors="pt")
    input = input.to(device)
    out = model.generate(**input, max_new_tokens=max_tokens)
    end = time.time()
    print(tokenizer.decode(out[0], skip_special_tokens=True))
    return end - start
set_seed(42)

vanilla_generation(main_model, tokenizer, "Hi, how are")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Hi, how are you doing?

I'm doing


0.13299942016601562

In [173]:
def speculative_decoding(tokenizer, model, assistant_model, prompt, max_len=10, speculative_len=5, vocab_size=50257):
    # Generating tokens we will speculate on:
    start = time.time()
    cur_len = 0
    input = tokenizer(prompt, return_tensors="pt")
    input = input.to(device)
   
    while cur_len < max_len:
        
        
        for _ in range(speculative_len):
            with torch.no_grad():

                out = assistant_model(**input)
                next_token = out.logits[:, -1, :].argmax(dim=-1)
                input["input_ids"] = torch.cat((input["input_ids"], next_token[:, None]), dim=-1)
                input["attention_mask"] = torch.cat((input["attention_mask"], torch.ones_like(next_token[:, None])), dim=-1)
        speculative_ids = input["input_ids"][:, -speculative_len:]
        with torch.no_grad():
            #verifying using main model:
            speculative_ids = input["input_ids"][:, -speculative_len:]
            if speculative_len > 0:
                out_logits = model(**input)
                last_logits = out_logits.logits[:, -speculative_len:, :]
                verify_ids = torch.argmax(last_logits, dim=-1)
                match_mask = ~(speculative_ids == verify_ids)
                match_mask = match_mask.cumsum(dim=-1)
                match_mask = match_mask < 1
                n_matches = match_mask.sum().item()

                if n_matches != speculative_len:
                    input["input_ids"] = input["input_ids"][:, :-speculative_len + n_matches]
                    input["attention_mask"] = input["attention_mask"][:, :-speculative_len + n_matches]

                print(tokenizer.decode(input["input_ids"][0]))
                cur_len += n_matches

                # Adjust speculative_len dynamically
                if n_matches == speculative_len and speculative_len != 0:
                    speculative_len += 2
                elif n_matches < speculative_len / 5:
                    speculative_len = max(0, speculative_len - 1)  # Ensure speculative_len is not negative

                print(speculative_len)
            else:
                # Fallback to vanilla generation when speculative_len becomes 0
                vanilla_time = vanilla_generation(model, tokenizer, prompt, max_tokens=max_len)
                cur_len = max_len
    end = time.time()
    return end - start, end - start - vanilla_time, vanilla_time
                
                



prompt = "Hi how are"
set_seed(42)
total, spec, vanilla = speculative_decoding(tokenizer, main_model, assistant_model, prompt)
print(f"Speculative loop time: {spec}")
print(f"Total time of vanilla if it was used: {vanilla}")
print(f"Total time of function: {total}")


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Hi how are
4
Hi how are
3
Hi how are
2
Hi how are
1
Hi how are
0
Hi how are you doing?

I'm doing great.
Speculative loop time: 0.5059974193572998
Total time of vanilla if it was used: 0.20399999618530273
Total time of function: 0.7099974155426025
