In [1]:
import os
import torch
import torch.nn.functional as F
import numpy as np
import time
from typing import Callable, Union

In [2]:
from fms.models.llama import LLaMAConfig, LLaMA

In [3]:
modelc = LLaMAConfig(128, 512, 1e-6, 16, 0, 8)
model = LLaMA(modelc)

In [11]:
from fms.utils.generation import generate

inp = torch.arange(8).unsqueeze(0)

start = time.time()
oracle = generate(model, inp, 20, 20, do_sample=False, use_cache=True)
print(time.time()-start)

print(oracle.squeeze())

0.23375678062438965
tensor([  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108, 121,
         24,  35,  73,  60,  75,  28,  24,  60, 112, 118,  73,  92, 124,  96])


In [10]:
start = time.time()
out, steps = speculative_generate(model, inp, 20, 20)
print(time.time()-start)
print("Steps:", steps)

Speculation: tensor([[ 0,  1,  2,  3,  4,  5,  6,  7, 52,  6, 24,  6]])
Verification: tensor([[ 52,   6,  24,   6, 108]]) 4
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108]])

Speculation: tensor([[108,   0,   0,   0,   0]])
Verification: tensor([[121,  22,  22,  22, 118]]) 0
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108, 121]])

Speculation: tensor([[121,  24,  35,  73,  60]])
Verification: tensor([[24, 35, 73, 60, 75]]) 4
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108, 121,
          24,  35,  73,  60,  75]])

Speculation: tensor([[75, 28, 24, 60,  0]])
Verification: tensor([[ 28,  24,  60, 112,  35]]) 3
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108, 121,
          24,  35,  73,  60,  75,  28,  24,  60, 112]])

Speculation: tensor([[112, 118,  73,  92,   0]])
Verification: tensor([[118,  73,  92, 124,  69]]) 3
Updated

In [8]:
for i in range(len(out)):
    assert out[0,i]==oracle[0,i], i
print("Exact match!")

Exact match!


In [6]:
def blackbox_oracle(inp):
    out = oracle[:,inp.size(1):inp.size(1)+4].clone()
    nzero = torch.randint(5,(1,))
    if nzero > 0:
        out[:,-nzero:] = 0
    return out, min(oracle.size(1), inp.size(1)+4) - inp.size(1)

def speculative_generate(
    model: Union[Callable, torch.nn.Module],
    input_ids: torch.LongTensor,
    max_seq_len: int = 2048,
    max_new_tokens: int = 256,
    temperature: float = 1.0,
    top_k: int = 10,
    num_beams: int = 1,
    smallmodel: Callable = blackbox_oracle,
):
    do_sample = False
    use_cache = True
    batched = False
    if num_beams != 1:
        raise NotImplementedError("generate() does yet not support beam search")
    if type(input_ids) == torch.Tensor:
        if input_ids.dim() != 1:
            batched = True
    else:
        raise RuntimeError("generate() requires a tensor of token ids as the prefix")

    if not batched:
        input_ids = input_ids.unsqueeze(0)

    result = input_ids
    next_input = input_ids
    kwargs = dict()
    kwargs["past_key_value_states"] = None
    kwargs["use_cache"] = use_cache

    n_gen = 0
    n_steps = 0
    while n_gen < max_new_tokens:
        n_steps += 1
        input_ids = next_input[:, -max_seq_len:]
        adds, n_adds = smallmodel(result)
        input_ids = torch.cat([input_ids, adds], dim=-1)
        print("Speculation:", input_ids)
        output = model.forward(input_ids, **kwargs)
        if use_cache:
            logits, past_key_value_states = output
        else:
            logits = output
        logits = logits[:, -n_adds-1:, :]

        if do_sample:
            # get logits from last value in sequence and scale
#             logits = logits / temperature
#             if top_k:
#                 v, _ = torch.topk(logits, top_k)
#                 logits[logits < v[:, [-1]]] = -float("inf")

#             probs = F.softmax(logits, dim=-1)
#             next_val = torch.multinomial(probs, num_samples=1)
            pass
        else:
            next_vals = torch.argmax(logits, dim=-1)
        
        # Check correctness of smallmodel predictions
        n_correct = 0
        while n_correct < n_adds and next_vals[0,n_correct] == input_ids[0,-n_adds+n_correct]:
            n_correct += 1
        print("Verification:", next_vals, n_correct)
        
        # Toss any wrong smallmodel outputs
        next_vals = next_vals[:,:n_correct+1]
        n_gen += n_correct+1
            
        n_wrong = n_adds - n_correct
        if use_cache:
            # kv updates are required for torch.compile with
            # mode='reduce-overhead'
            n_kv_s = []
            for layer_idx in range(len(past_key_value_states)):
                n_kv_s.append([])
                for tensor_idx in range(len(past_key_value_states[layer_idx])):
                    base = past_key_value_states[layer_idx][tensor_idx]
                    if n_wrong > 0:
                        base = base[:,:,:-n_wrong]
                    n_kv_s[layer_idx].append(
                        base.clone(memory_format=torch.contiguous_format).detach()
                    )
                    # torch._dynamo.mark_dynamic(n_kv_s[layer_idx][tensor_idx], 2)
            kwargs["past_key_value_states"] = n_kv_s

        result = torch.cat((result, next_vals), dim=-1)
        print("Updated output:", result)
        print()

        if use_cache:
            next_input = next_vals[:,-1].unsqueeze(-1)
        else:
            next_input = result

    if not batched:
        result = result[0]
    return result, n_steps