In [3]:
import os
import torch
import random
import numpy as np
HF_TOKEN = os.getenv("HF_TOKEN")
print(HF_TOKEN)



None


In [4]:
def set_seed(seed=42):
    # Set seed for Python's built-in random module
    random.seed(seed)
    
    # Set seed for NumPy
    np.random.seed(seed)
    
    # Set seed for PyTorch
    torch.manual_seed(seed)
    
    # Ensure deterministic behavior on CUDA (GPU)
    torch.cuda.manual_seed(seed)

In [23]:
checkpoint = "gpt2-xl"
checkpoint_assist = "gpt2"
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Select device (GPU or CPU)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_auth_token=HF_TOKEN)
main_model = AutoModelForCausalLM.from_pretrained(checkpoint, use_auth_token=HF_TOKEN).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(checkpoint_assist, use_auth_token=HF_TOKEN).to(device)




tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config.json:   0%|          | 0.00/689 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/6.43G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [66]:
import time
def vanilla_generation(model, tokenizer, prompt, max_tokens=50):
    start = time.time()
    input = tokenizer(prompt, return_tensors="pt")
    input = input.to(device)
    out = model.generate(**input, max_new_tokens=max_tokens)
    end = time.time()
    tok = out.size(1)
    print(f"\nTotal time of vanilla: {round((end-start), 2)} seconds")
    print(f"Output of vanilla: {tokenizer.decode(out[0], skip_special_tokens=True)}")
    print(f"Speed of vanilla (averaged): {round(tok/(end-start), 2)} tokens per second")

    return end-start
    
set_seed(42)

sec = vanilla_generation(main_model, tokenizer, "Hi")


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Total time of vanilla: 17.47 seconds
Output of vanilla: Hi, I'm a new member of the community. I'm a newbie to the forum, but I'm looking for a good place to start. I'm a newbie to the forum, but I'm looking for a good place to start.
Speed of vanilla (averaged): 2.92 tokens per second
17.46697473526001


In [72]:
def hf_implementation(model, tokenizer, prompt, assistant_model, van_time, max_tokens=50):
    start = time.time()
    input = tokenizer(prompt, return_tensors="pt")
    input = input.to(device)
    out = model.generate(**input, max_new_tokens=max_tokens, assistant_model=assistant_model)
    end = time.time()
    tok = out.size(1)
    print(van_time)

    print(f"\nTotal time of huggingface implementation: {end-start} seconds" 
      f"{f' ({(van_time)/(end-start):.2f}x speedup compared to vanilla!)' if van_time is not None else ''}")
    
    print(f"Output of huggingface implementation of spec dec: {tokenizer.decode(out[0], skip_special_tokens=True)}")
    print(f"Speed of huggingface implementation of spec dec (averaged): {round(tok/(end-start), 2)} tokens per second")


In [7]:
def check_models(assistant_model, main_model, tokenizer, prompt, max=8):
    input = tokenizer(prompt, return_tensors="pt")
    input = input.to(device)
    out_main = main_model.generate(**input, max_new_tokens=max)
    out_assist = assistant_model.generate(**input, max_new_tokens=max)
    print(f"Main model output: {tokenizer.decode(out_main[0])}")
    print(f"Asistant model output: {tokenizer.decode(out_assist[0])}")
check_models(assistant_model, main_model, tokenizer, "Hi")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Main model output: Hi, I'm a student at the University
Asistant model output: Hi. I'm sorry, but I'm


In [73]:
def speculative_decoding(tokenizer, model, assistant_model, prompt, van_time, max_len=50, speculative_len=8):
    # Generating tokens we will speculate on:
    start = time.time()
    cur_len = 0
    input = tokenizer(prompt, return_tensors="pt")
    input = input.to(device)
     
    while cur_len+1 < max_len:
        candidate_input_ids = input["input_ids"]
        attn_mask = input["attention_mask"]
        # main_attn_mask = attn_mask
        speculative_len = min(speculative_len, max_len-cur_len)
        for i in range(speculative_len):
            with torch.no_grad():
                out = assistant_model(candidate_input_ids, attention_mask=attn_mask)
                next_token = out.logits[:, -1, :].argmax(dim=-1)                
                candidate_input_ids = torch.cat((candidate_input_ids, next_token[:, None]), dim=-1)
                attn_mask = torch.cat((attn_mask, torch.ones_like(next_token[:, None])), dim=-1)

        with torch.no_grad():
            #verifying using main model:
            assistant_ids = candidate_input_ids[:, -speculative_len:]
            
            out_logits = model(input_ids=candidate_input_ids, attention_mask=attn_mask)
            last_logits = out_logits.logits[:, -speculative_len-1:, :]
            main_ids = torch.argmax(last_logits, dim=-1)
            main = torch.cat((input["input_ids"], main_ids), dim=-1)
            ass = torch.cat((input["input_ids"], assistant_ids), dim=-1)
            # print(f"OUTPUT FROM THE MAIN MODEL wit prompt: {tokenizer.decode(main[0])}")  
            # print(f"OUTPUT FROM THE ASSISTANT MODEL: {tokenizer.decode(ass[0])}")  

            match_mask = ~(assistant_ids == main_ids[:, :-1])

            match_mask = match_mask.cumsum(dim=-1)
            match_mask = match_mask < 1
            n_matches = match_mask.sum().item()
            valid_tokens = main_ids[:, :n_matches+1] # this is key, this ensures that even if n_matches are zero, we can always just come back to normal vanilla gen, because n_matches+1 is always true, because its still sampled from correct senstence it actually agreed with
            attn = torch.ones_like(valid_tokens)
            input["input_ids"] = torch.cat((input["input_ids"], valid_tokens), dim=-1)
            input["attention_mask"] = torch.cat((input["attention_mask"], attn), dim=-1)


            cur_len += n_matches+1
            # print(f"Current input after appending accepted: {tokenizer.decode(input['input_ids'][0])}")
            # print(input["input_ids"].shape)
            if n_matches+1 == speculative_len:
                speculative_len+=2
            else:
                speculative_len = max(1, speculative_len-1)
            
    end = time.time()
    tok = input["input_ids"].size(1)
    print(van_time)
    print(f"\nTotal time of speculative decoding: {end-start} seconds" 
      f"{f' ({(van_time)/(end-start):.2f}x speedup compared to vanilla!)' if van_time is not None else ''}")
    print(f"Speed of speculative decoding (averaged): {round(tok/(end-start), 2)} tokens per second")
    print(f"Output of the model (speculative decoding): {tokenizer.decode(input['input_ids'][0])}")
    
                
                






In [74]:
def main():
    prompt = "Hi"
    set_seed(42)
    
    van_time = vanilla_generation(main_model, tokenizer, prompt)

    hf_implementation(main_model, tokenizer, prompt, assistant_model, van_time=van_time)

    speculative_decoding(tokenizer, main_model, assistant_model, prompt, van_time=van_time)

    
if __name__ == "__main__":
    main()


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Total time of vanilla: 18.26 seconds
Output of vanilla: Hi, I'm a new member of the community. I'm a newbie to the forum, but I'm looking for a good place to start. I'm a newbie to the forum, but I'm looking for a good place to start.
Speed of vanilla (averaged): 2.79 tokens per second
18.261565685272217

Total time of huggingface implementation: 7.939016580581665 seconds (2.30x speedup compared to vanilla!)
Output of huggingface implementation of spec dec: Hi, I'm a new member of the community. I'm a newbie to the forum, but I'm looking for a good place to start. I'm a newbie to the forum, but I'm looking for a good place to start.
Speed of huggingface implementation of spec dec (averaged): 6.42 tokens per second
18.261565685272217

Total time of speculative decoding: 12.354034423828125 seconds (1.48x speedup compared to vanilla!)
Speed of speculative decoding (averaged): 4.05 tokens per second
Output of the model (speculative decoding): Hi, I'm a new member of the community. I'm a n