In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from typing import Callable, Union

In [2]:
from fms.models.llama import LLaMAConfig, LLaMA

In [3]:
modelc = LLaMAConfig(32000, 4096, 1e-6, 32, 0, 32)
model = LLaMA(modelc)

In [4]:
d = torch.load("../../../llama_7b_ckp.pth")['model_state']

In [5]:
keylist = list(d.keys())
for key in keylist:
    if "dec_process" in key:
        value = d.pop(key)
        fields = key.split(".")
        fields[0] = "layers"
        d[".".join(fields)] = value

In [6]:
model.load_state_dict(d, strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['rope.freqs'])

In [7]:
from fms.utils.generation import generate
from transformers import AutoTokenizer
t = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
vinv = {v:k for k,v in t.vocab.items()}

In [53]:
inp = t("Hello! How are you today?")["input_ids"]
print([vinv[x] for x in inp])
inp = torch.IntTensor(inp)

['<s>', '▁Hello', '!', '▁How', '▁are', '▁you', '▁today', '?']


In [54]:
oracle = generate(model, inp, 20, 20, do_sample=False, use_cache=True)
t.decode(oracle.tolist())

'<s> Hello! How are you today? I hope you are doing well. I am doing well. I am happy to be here with you'

In [12]:
test = Speculator(n_heads=3)
test.load_state_dict(torch.load("../../../specu_greedy.pth", map_location="cpu")["model_state"])

<All keys matched successfully>

In [13]:
sum(p.numel() for p in test.parameters())

844107776

In [55]:
out, steps = speculative_generate(model, inp, test, 20, 20)
print()
print("Steps:", steps)

Topk@1: ['▁I', '<0x0A>', '!', '▁We', '▁my']
Topk@2: ['▁hope', '▁you', 'I', "'", '’']
Topk@3: ['▁you', 'm', '▁I', '▁is', '▁a']
Speculation: ['?', '▁I', '▁hope', '▁you']
Verification: ['▁I', '▁hope', '▁you', '▁are'] 3
Acc@5: 3
Updated output: ['<s>', '▁Hello', '!', '▁How', '▁are', '▁you', '▁today', '?', '▁I', '▁hope', '▁you', '▁are']

Topk@1: ['re', '▁having', '▁doing', '▁a', '▁had']
Topk@2: ['▁having', '▁well', '▁great', '▁doing', '.']
Topk@3: ['▁great', '.', '▁well', '▁a', 'ying']
Speculation: ['▁are', 're', '▁having', '▁great']
Verification: ['▁doing', 'ally', '▁a', '▁day'] 0
Acc@5: 3
Updated output: ['<s>', '▁Hello', '!', '▁How', '▁are', '▁you', '▁today', '?', '▁I', '▁hope', '▁you', '▁are', '▁doing']

Topk@1: ['▁well', '.', '▁a', '▁good', 'ying']
Topk@2: ['▁great', '.', '▁and', '▁I', '▁good']
Topk@3: ['▁I', '▁day', '.', '<0x0A>', '▁great']
Speculation: ['▁doing', '▁well', '▁great', '▁I']
Verification: ['▁well', '.', '.', '▁am'] 1
Acc@5: 3
Updated output: ['<s>', '▁Hello', '!', '▁How'

In [45]:
from fms.modules.layernorm import LayerNormParameterized

class Speculator(nn.Module):
    def __init__(self, emb_dim=4096, vocab_size=32000, n_heads=4):
        super().__init__()
        self.nheads = n_heads
        self.emb_dim = emb_dim
        self.vsize = vocab_size
        self.w_in = nn.Parameter(torch.empty(emb_dim, int((emb_dim * 2.6875 * 2) // 256) * 256 * 2))  # d 2z
        self.a = nn.GELU()
        self.w_out = nn.Parameter(torch.empty(int((emb_dim * 2.6875 * 2) // 256) * 256, emb_dim * n_heads))  # z hd
        self.ln = LayerNormParameterized(emb_dim, elementwise_shift=False, elementwise_scale=True)
        self.head = nn.Parameter(torch.empty(n_heads, emb_dim, vocab_size))  # h d v
        self.reset_params()

    def reset_params(self):
        nn.init.trunc_normal_(self.w_in, 0, (1 / 2.6875) ** (1 / 6) / self.emb_dim**0.5)
        nn.init.trunc_normal_(self.w_out, 0, (1 / 2.6875) ** (1 / 6) / self.emb_dim**0.5)
        nn.init.trunc_normal_(self.head, 0, 1 / self.emb_dim**0.5)

    def forward(self, x):
        # x: b n d
        z, g = x.matmul(self.w_in).chunk(2, dim=2)
        z = z * self.a(g)
        z = z.matmul(self.w_out).view(x.size(0), x.size(1), self.nheads, self.emb_dim)  # b n h d
        z = z + x.unsqueeze(2)
        z = self.ln(z)
        z = torch.einsum("bnhd,hdv->bnhv", z, self.head)
        return z # b n h v
    
def decode_obo(x):
    return [vinv[z] for z in x.squeeze().tolist()]
    
def speculative_generate(
    model: Union[Callable, torch.nn.Module],
    input_ids: torch.LongTensor,
    smallmodel: torch.nn.Module,
    max_seq_len: int = 2048,
    max_new_tokens: int = 256,
    temperature: float = 1.0,
    top_k: int = 10,
    num_beams: int = 1,
):
    do_sample = False
    batched = False
    if num_beams != 1:
        raise NotImplementedError("generate() does yet not support beam search")
    if type(input_ids) == torch.Tensor:
        if input_ids.dim() != 1:
            batched = True
    else:
        raise RuntimeError("generate() requires a tensor of token ids as the prefix")

    if not batched:
        input_ids = input_ids.unsqueeze(0)

    result = input_ids
    next_input = input_ids
    kwargs = dict()
    kwargs["past_key_value_states"] = None
    kwargs["use_cache"] = True

    output = model(input_ids[:,:-1], include_embeds=True, **kwargs)
    _, past_key_value_states, embeds = output
    kwargs["past_key_value_states"] = past_key_value_states
    next_input = next_input[:,-1:]
    
    n_gen = 0
    n_steps = 0
    n_kv_s = past_key_value_states
    while n_gen < max_new_tokens:
        n_steps += 1
        input_ids = next_input[:, -max_seq_len:]
        
        adds = smallmodel(embeds[:,-1].unsqueeze(1))
        n_adds = smallmodel.nheads
        
        topk = adds.topk(5, dim=3)[1]
        for i in range(n_adds):
            print(f"Topk@{i+1}:", decode_obo(topk[0,0,i]))
        adds = adds.argmax(3).squeeze(1) # b h
        input_ids = torch.cat([input_ids, adds], dim=-1)
        mask = torch.ones(input_ids.size(1),input_ids.size(1)+n_kv_s[0][0].size(2))
        mask = mask.tril(diagonal=mask.size(1)-mask.size(0))
        mask = mask.unsqueeze(0).unsqueeze(0).log()
        print("Speculation:", decode_obo(input_ids))
        output = model.forward(input_ids, include_embeds=True, mask=mask, **kwargs)
        logits, past_key_value_states, embeds = output
        logits = logits[:, -n_adds-1:, :]

        if do_sample:
            # get logits from last value in sequence and scale
#             logits = logits / temperature
#             if top_k:
#                 v, _ = torch.topk(logits, top_k)
#                 logits[logits < v[:, [-1]]] = -float("inf")

#             probs = F.softmax(logits, dim=-1)
#             next_val = torch.multinomial(probs, num_samples=1)
            assert False
        else:
            next_vals = torch.argmax(logits, dim=-1)
        
        # Check correctness of smallmodel predictions
        n_correct = 0
        while n_correct < n_adds and next_vals[0,n_correct] == input_ids[0,-n_adds+n_correct]:
            n_correct += 1
        print("Verification:", decode_obo(next_vals), n_correct)
        
        k_correct = 0
        while (
            result.size(1)+k_correct < len(oracle) and 
            k_correct < n_adds and
            oracle[result.size(1)+k_correct] in topk[0,0,k_correct]
        ):
            k_correct += 1
        print("Acc@5:", k_correct)
        
        # Toss any wrong smallmodel outputs
        next_vals = next_vals[:,:n_correct+1]
        n_gen += n_correct+1
        embeds = embeds[:,:n_correct+1]
            
        n_wrong = n_adds - n_correct
        # kv updates are required for torch.compile with
        # mode='reduce-overhead'
        n_kv_s = []
        for layer_idx in range(len(past_key_value_states)):
            n_kv_s.append([])
            for tensor_idx in range(len(past_key_value_states[layer_idx])):
                base = past_key_value_states[layer_idx][tensor_idx]
                if n_wrong > 0:
                    base = base[:,:,:-n_wrong]
                n_kv_s[layer_idx].append(
                    base.clone(memory_format=torch.contiguous_format).detach()
                )
                # torch._dynamo.mark_dynamic(n_kv_s[layer_idx][tensor_idx], 2)
        kwargs["past_key_value_states"] = n_kv_s

        result = torch.cat((result, next_vals), dim=-1)
        print("Updated output:", decode_obo(result))
        print()

        next_input = next_vals[:,-1].unsqueeze(-1)

    if not batched:
        result = result[0]
    return result, n_steps

In [10]:
start = time.time()
out, steps = speculative_generate(model, inp, 20, 20)
print(time.time()-start)
print("Steps:", steps)

Speculation: tensor([[ 0,  1,  2,  3,  4,  5,  6,  7, 52,  6, 24,  6]])
Verification: tensor([[ 52,   6,  24,   6, 108]]) 4
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108]])

Speculation: tensor([[108,   0,   0,   0,   0]])
Verification: tensor([[121,  22,  22,  22, 118]]) 0
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108, 121]])

Speculation: tensor([[121,  24,  35,  73,  60]])
Verification: tensor([[24, 35, 73, 60, 75]]) 4
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108, 121,
          24,  35,  73,  60,  75]])

Speculation: tensor([[75, 28, 24, 60,  0]])
Verification: tensor([[ 28,  24,  60, 112,  35]]) 3
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108, 121,
          24,  35,  73,  60,  75,  28,  24,  60, 112]])

Speculation: tensor([[112, 118,  73,  92,   0]])
Verification: tensor([[118,  73,  92, 124,  69]]) 3
Updated

In [8]:
for i in range(len(out)):
    assert out[0,i]==oracle[0,i], i
print("Exact match!")

Exact match!


In [6]:
def blackbox_oracle(inp):
    out = oracle[:,inp.size(1):inp.size(1)+4].clone()
    nzero = torch.randint(5,(1,))
    if nzero > 0:
        out[:,-nzero:] = 0
    return out, min(oracle.size(1), inp.size(1)+4) - inp.size(1)

