In [2]:
import os
import torch
import random
import numpy as np
HF_TOKEN = os.getenv("HF_TOKEN")
print(HF_TOKEN)



None


In [3]:
def set_seed(seed=42):
    # Set seed for Python's built-in random module
    random.seed(seed)
    
    # Set seed for NumPy
    np.random.seed(seed)
    
    # Set seed for PyTorch
    torch.manual_seed(seed)
    
    # Ensure deterministic behavior on CUDA (GPU)
    torch.cuda.manual_seed(seed)

In [74]:
checkpoint = "gpt2-large"
checkpoint_assist = "distilgpt2"
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Select device (GPU or CPU)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_auth_token=HF_TOKEN)
main_model = AutoModelForCausalLM.from_pretrained(checkpoint, use_auth_token=HF_TOKEN).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(checkpoint_assist, use_auth_token=HF_TOKEN).to(device)


In [75]:
import time
def vanilla_generation(model, tokenizer, prompt, max_tokens=79):
    start = time.time()
    input = tokenizer(prompt, return_tensors="pt")
    input = input.to(device)
    for i in range(max_tokens):
            with torch.no_grad():
                out = model(**input)
                next_token = out.logits[:, -1, :].argmax(dim=-1)                
                input["input_ids"] = torch.cat((input["input_ids"], next_token[:, None]), dim=-1)
                input["attention_mask"] = torch.cat((input["attention_mask"], torch.ones_like(next_token[:, None])), dim=-1)
    end = time.time()
    print(tokenizer.decode(input["input_ids"][0], skip_special_tokens=True))
    return end - start
set_seed(42)

vanilla_generation(main_model, tokenizer, "Hi")

KeyboardInterrupt: 

In [18]:
def check_models(assistant_model, main_model, tokenizer, prompt, max=8):
    input = tokenizer(prompt, return_tensors="pt")
    input = input.to(device)
    out_main = main_model.generate(**input, max_new_tokens=max)
    out_assist = assistant_model.generate(**input, max_new_tokens=max)
    print(f"Main model output: {tokenizer.decode(out_main[0])}")
    print(f"Asistant model output: {tokenizer.decode(out_assist[0])}")
check_models(assistant_model, main_model, tokenizer, "Hi")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Main model output: Hi, I'm a new user. I
Asistant model output: Hi, I'm a student at the University


In [7]:
pad_token_id = tokenizer.eos_token_id
main_model.config.pad_token_id = main_model.config.eos_token_id

In [71]:
def speculative_decoding(tokenizer, model, assistant_model, prompt, max_len=50, speculative_len=5, vocab_size=50257):
    # Generating tokens we will speculate on:
    start = time.time()
    cur_len = 0
    input = tokenizer(prompt, return_tensors="pt")
    input = input.to(device)
     
    while cur_len < max_len:
        candidate_input_ids = input["input_ids"]
        attn_mask = input["attention_mask"]
        # main_attn_mask = attn_mask
        for i in range(speculative_len):
            with torch.no_grad():
                out = assistant_model(candidate_input_ids, attention_mask=attn_mask)
                next_token = out.logits[:, -1, :].argmax(dim=-1)                
                candidate_input_ids = torch.cat((candidate_input_ids, next_token[:, None]), dim=-1)
                attn_mask = torch.cat((attn_mask, torch.ones_like(next_token[:, None])), dim=-1)

        with torch.no_grad():
            #verifying using main model:
            assistant_ids = candidate_input_ids[:, -speculative_len:]
            if speculative_len > 0:
                out_logits = model(input_ids=candidate_input_ids, attention_mask=attn_mask)
                last_logits = out_logits.logits[:, -speculative_len-1:, :]
                main_ids = torch.argmax(last_logits, dim=-1)
                main = torch.cat((input["input_ids"], main_ids), dim=-1)
                ass = torch.cat((input["input_ids"], assistant_ids), dim=-1)
                # print(f"OUTPUT FROM THE MAIN MODEL wit prompt: {tokenizer.decode(main[0])}")  
                # print(f"OUTPUT FROM THE ASSISTANT MODEL: {tokenizer.decode(ass[0])}")  

                match_mask = ~(assistant_ids == main_ids[:, :-1])

                match_mask = match_mask.cumsum(dim=-1)
                match_mask = match_mask < 1
                n_matches = match_mask.sum().item()
                valid_tokens = main_ids[:, :n_matches+1] # this is key, this ensures that even if n_matches are zero, we can always just come back to normal vanilla gen, because n_matches+1 is always true, because its still sampled from correct senstence it actually agreed with
                attn = torch.ones_like(valid_tokens)
                input["input_ids"] = torch.cat((input["input_ids"], valid_tokens), dim=-1)
                input["attention_mask"] = torch.cat((input["attention_mask"], attn), dim=-1)


                # print(f" Number of matches = {n_matches}")
                cur_len += n_matches
                # print(f"Current input after appending accepted: {tokenizer.decode(input['input_ids'][0])}")
                # print(input["input_ids"].shape)
                if n_matches == speculative_len:
                    speculative_len+=2
                else:
                    speculative_len = max(1, speculative_len-1)
            # else:
            #     # Fallback to vanilla generation when speculative_len becomes 0
            #     vanilla_time = vanilla_generation(model, tokenizer, prompt, max_tokens=max_len)
            #     cur_len = max_len
    end = time.time()
    print(f"Current input after appending accepted: {tokenizer.decode(input['input_ids'][0])}")
    return end - start
                
                



prompt = "Hi"
set_seed(42)
total= speculative_decoding(tokenizer, main_model, assistant_model, prompt)
# print(f"Speculative loop time: {spec}")
# print(f"Total time of vanilla if it was used: {vanilla}")
print(f"Total time of function: {total}")


Current input after appending accepted: Hi, I'm a new user. I'm looking for a way to get my account back. I've been using the site for a few months now and I've been able to get my account back. I'm not sure what happened, but I've been unable to log in. I've tried to contact support, but they've never responded. I've tried to contact the site owner
Total time of function: 4.224996328353882
