In [2]:
import os
import torch
import random
import numpy as np
HF_TOKEN = os.getenv("HF_TOKEN")
print(HF_TOKEN)

None


In [3]:
def set_seed(seed=42):
    np.random.seed(seed)

    torch.manual_seed(seed)

    random.seed(seed)
    
    torch.cuda.manual_seed(seed)



In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = "gpt2-xl"
model = AutoModelForCausalLM.from_pretrained(checkpoint, use_auth_token = HF_TOKEN).to(device)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_auth_token = HF_TOKEN)



In [5]:
import time
def vanilla_generation(model, tokenizer, prompt, max_tokens=20):
    start = time.time()
    input = tokenizer(prompt, return_tensors="pt")
    input = input.to(device)
    out = model.generate(**input, max_new_tokens=max_tokens)
    end = time.time()
    tok = out.size(1)
    print(f"\nTotal time of vanilla: {round((end-start), 2)} seconds")
    print(f"Output of vanilla: {tokenizer.decode(out[0], skip_special_tokens=True)}")
    print(f"Speed of vanilla (averaged): {round(tok/(end-start), 2)} tokens per second")

    return end-start
set_seed(42)
prompt = "Add a ? sign in a random place in this text: Tom is playing voleyball. Version with # added: Tom"
vanilla_generation(model, tokenizer, prompt)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Total time of vanilla: 7.02 seconds
Output of vanilla: Add a? sign in a random place in this text: Tom is playing voleyball. Version with # added: Tom is playing voleyball. Version without # added: Tom is playing voleyball.


Speed of vanilla (averaged): 6.41 tokens per second


7.023592948913574

In [28]:
import time
def speculative_editing(model, tokenizer, prompt, draft, max_tokens=20, speculative_len=2):
    start = time.time()
    cur_len=0
    draft = tokenizer(draft, return_tensors="pt")
    draft = draft.to(device)
    prompt = tokenizer(prompt, return_tensors="pt")
    prompt = prompt.to(device)
    
    
    whole_input = torch.cat((prompt["input_ids"], draft["input_ids"]), dim=-1)
    attn = torch.cat((prompt["attention_mask"], draft["attention_mask"]), dim=-1)
    main_ids = model(input_ids=whole_input, attention_mask=attn).logits.argmax(dim=-1)
    main_ids = main_ids[:, -speculative_len:]
    draft_ids = draft["input_ids"][:, cur_len:cur_len+speculative_len]
    
    print(draft_ids.shape)
    print(main_ids.shape)



    end = time.time()

set_seed(42)
prompt = "Hi how are"
draft = "you? I am great, by the way"
speculative_editing(model, tokenizer, prompt, draft)


torch.Size([1, 2])
torch.Size([1, 2])
