In [1]:
import torch
import torch.nn.functional as F

import numpy as np

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer

In [3]:
def generate(model, input_ids, n_tokens, seq_lens, pad_token_id, temperature=1):
    """ 
    model: transformers model
    input_ids: (N, T) right padded 
    n_tokens: numm tokens to generate
    sample_fn: takes as input logits and returns as output a sample from them 

    ---------------------------------------
    anything you'd want to implement diff can be implemented as a logits processing fn

    TODO should use KVCache. 
        - there's a past_key_values elem in the out_dict that could help
    """    
    # init output
    N, T = input_ids.shape
    
    max_len = torch.amax(seq_lens)
    min_len = torch.amin(seq_lens)
    
    out_ids = torch.ones(N, max_len + n_tokens, dtype=input_ids.dtype) * pad_token_id
    out_ids[:, :T] = input_ids
    is_pad_mask = out_ids == pad_token_id
        
    i = min_len
    # num generated tokens < n_tokens 
    is_unfinished = i - seq_lens < n_tokens

    while torch.any(is_unfinished):
        out_dict = model(input_ids=out_ids[:, :i])
        probs = F.softmax(out_dict["logits"] / temperature, dim=-1)
        pred_ids = torch.multinomial(probs[:, -1], num_samples=1)
        # next token is set to pad token if we've finished generating for this sequence
        # this is a waste b/c we're doing the forward pass anyway but we'd need variable n_tokens to support
        next_token = is_unfinished * pred_ids.view(-1) + ~is_unfinished * pad_token_id
        # updates the values that are set to padding and keeps the values that are not padding 
        out_ids[:, i] = is_pad_mask[:, i] * next_token + ~is_pad_mask[:, i] * out_ids[:, i]
        
        i += 1
        is_unfinished = i - seq_lens < n_tokens
    
    out_dict = model(out_ids)
    probs = F.softmax(out_dict["logits"] / temperature, dim=-1)
    return out_ids, probs

In [127]:
def speculative_step(m_target, m_draft, prefix_ids, pad_token_id, gamma=10, temperature=1, eps=1e-8):
    """
    Implements a speculative decoding step

    m_target: 
        model we want to sample from 
    m_draft: 
        draft model we use to speculate on tokens to decode
    prefix_ids: 
        (N, T) prefix ids -- TODO figure out if we need to pad those some way 
    pad_token_id: 
        int id of padding token
    gamma: 
        int number of draft tokens to generate
    """
    N_batch, T_prefix = prefix_ids.shape
    seq_lens = torch.sum(prefix_ids != pad_token_id, dim=1)
    
    # draft ids contain the prefix as well 
    # 
    # draft_ids: (N, T_prefix+gamma)
    # draft_probs: (N, T_prefix+gamma, V)
    draft_ids, draft_probs = generate(
        model=m_draft, input_ids=prefix_ids, n_tokens=gamma, seq_lens=seq_lens, 
        pad_token_id=pad_token_id, temperature=temperature
    )
    
    # target_out: (N, T_prefix+gamma)
    target_out = m_target(draft_ids)
    # target_probs: (N, T_prefix+gamma, V)
    target_probs = F.softmax(target_out["logits"] / temperature, dim=-1)
    
    # [1:]  b/c there's no probability predicted for the first input id 
    # [:-1] b/c we haven't sampled a token for position T_prefix + gamma + 1
    # (N,  T_prefix+gamma-1)
    m_draft_prob_of_draft = draft_probs[:, :-1].gather(2, draft_ids[:, 1:, None]).squeeze(2)    
    m_target_prob_of_draft = target_probs[:, :-1].gather(2, draft_ids[:, 1:, None]).squeeze(2)
    
    draft_idxs = seq_lens[:, None] + torch.arange(gamma).expand(N_batch, gamma) - 1
    batch_idxs = torch.arange(N_batch)[:, None]
    
    # take only the last gamma target probs which are for the generated tokens
    # (N, gamma) 
    m_draft_prob_of_draft = m_draft_prob_of_draft[batch_idxs, draft_idxs]
    m_target_prob_of_draft = m_target_prob_of_draft[batch_idxs, draft_idxs]
    
    # get uniform probabilities to make the accept / reject decision
    uniform = torch.zeros_like(m_target_prob_of_draft, dtype=torch.float).uniform_(0, 1)
    do_reject = uniform > (m_target_prob_of_draft / (m_draft_prob_of_draft + eps))
    
    # get the indicies where we first reject the draft models preds
    # set to -1 if we don't reject any 
    first_true_idxs = torch.argmax(do_reject.int(), dim=1)
    no_true_mask = ~do_reject.any(dim=1)
    # we set this to gamma b/c that's 1+max_value
    first_true_idxs[no_true_mask] = gamma

    # create tensor for output
    max_out_len = torch.amax(seq_lens + first_true_idxs + 1)
    out_ids = draft_ids.clone()[:, :max_out_len]
    # overwrite out_ids from first_true_idx onwards to pad token    
    do_pad = torch.arange(max_out_len).expand(N_batch, max_out_len) >= seq_lens[:, None] + first_true_idxs[:, None]
    out_ids = do_pad * pad_token_id + ~do_pad * out_ids

    # (N, gamma)
    adjusted_probs = F.relu(target_probs[batch_idxs, draft_idxs] - draft_probs[batch_idxs, draft_idxs])
    # normalize probs to sum to 1
    adjusted_probs = adjusted_probs / torch.sum(adjusted_probs, dim=-1, keepdim=True)
    
    # (N, gamma+1)
    # add probs from target on to adjusted so that if first true idx is gamma they are chosen
    first_true_idx_offset = (seq_lens + first_true_idxs)[:, None]
    adjusted_probs = torch.cat([adjusted_probs, target_probs[batch_idxs, first_true_idx_offset]], dim=1)
    
    # (N, 1)    
    adjusted_probs = adjusted_probs[batch_idxs, first_true_idxs[:, None]]
    adjusted_ids = torch.multinomial(adjusted_probs.squeeze(1), num_samples=1)
    out_ids[batch_idxs, first_true_idx_offset] = adjusted_ids

    return out_ids

In [56]:
tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M")
tokenizer.add_special_tokens({"pad_token": "<|pad|>"})

1

In [57]:
m_q = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M")
m_p = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-33M")

m_q.resize_token_embeddings(len(tokenizer))
m_p.resize_token_embeddings(len(tokenizer))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Embedding(50258, 768)

In [124]:
input_dict = tokenizer(
    ["once upon", "once", "once upon a time there was"], 
    padding="longest", 
    padding_side="right",
    return_tensors="pt"
)

In [125]:
pre_ids = input_dict["input_ids"]

In [128]:
n_iters = 0
while pre_ids.shape[1] < 100:
    n_iters += 1
    pre_ids = speculative_step(m_p, m_q, pre_ids, tokenizer.pad_token_id, gamma=10, temperature=1)

In [130]:
tokenizer.batch_decode(pre_ids)

['once upon a time, there was a beautiful chest. In the chest there were two special things that made them happy. One special thingful was a little lion cub. It was given a big happy roar.\n\nThe other cub was called Leo. Leo always liked to be stuck, especially in the dark frame. One day, Leo wanted to open the frame, but momma gave him the frame, and made an even bigger roar.\n\nBut Leo was feeling weak and tired. He tried',
 'once upon a time there was an ugly bench. A lived in a park and one day a patch of grassy mud appeared on the bench. Two frogs stepped out of the mud. One of them was very colourful and the other was made out of bright stones.\n\nThe jumped onto the bench and they bowed to each other. The seventh frog bowed back and then swam away. The bench smiled at the birds who sang in<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|>',
 "once upon a time there was a lamp. Every day it would soak in the sun's rays. \n\nOne day the la