In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from typing import Callable, Union

In [2]:
from fms.models.llama import LLaMAConfig, LLaMA

In [3]:
modelc = LLaMAConfig(32000, 4096, 1e-6, 32, 0, 32)
model = LLaMA(modelc)

In [4]:
d = torch.load("../../../llama_7b_ckp.pth")['model_state']

In [5]:
keylist = list(d.keys())
for key in keylist:
    if "dec_process" in key:
        value = d.pop(key)
        fields = key.split(".")
        fields[0] = "layers"
        d[".".join(fields)] = value

In [6]:
model.load_state_dict(d, strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['rope.freqs'])

In [7]:
from fms.utils.generation import generate
from transformers import AutoTokenizer
t = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
vinv = {v:k for k,v in t.vocab.items()}

In [8]:
inp = t("Hello! How are you today?")["input_ids"]
print([vinv[x] for x in inp])
inp = torch.IntTensor(inp)

['<s>', '▁Hello', '!', '▁How', '▁are', '▁you', '▁today', '?']


In [9]:
oracle = generate(model, inp, 20, 20, do_sample=False, use_cache=True)
t.decode(oracle.tolist())

'<s> Hello! How are you today? I hope you are doing well. I am doing well. I am happy to be here with you'

In [10]:
trigram = torch.load("../../../cc123_trigram.pth")

In [15]:
out, steps = speculative_generate(model, inp, 20, 20)
print()
print("Steps:", steps)

Topk@1: ['<0x0A>', '▁What', '▁I', '▁How', '▁Call']
Topk@2: ['The', 'A', 'We', 'I', 'If']
Topk@3: ['▁', '▁first', '▁following', '▁best', '▁new']
Speculation: ['<s>', '▁Hello', '!', '▁How', '▁are', '▁you', '▁today', '?', '<0x0A>', 'The', '▁']
Verification: ['▁I', 'I', '▁weather', '2'] 0
Acc@5: 1
Updated output: ['<s>', '▁Hello', '!', '▁How', '▁are', '▁you', '▁today', '?', '▁I']

Topk@1: ['’', "'", '▁have', '▁am', '▁don']
Topk@2: ['m', 've', 'll', 'd', 'M']
Topk@3: ['▁not', '▁a', '▁going', '▁sure', '▁so']
Speculation: ['▁I', '’', 'm', '▁not']
Verification: ['▁hope', 'm', '▁doing', '▁sure'] 0
Acc@5: 0
Updated output: ['<s>', '▁Hello', '!', '▁How', '▁are', '▁you', '▁today', '?', '▁I', '▁hope']

Topk@1: ['▁you', '▁to', '▁that', '▁this', '▁it']
Topk@2: ['▁enjoy', '▁will', '’', '▁are', '▁have']
Topk@3: ['▁the', '▁your', '▁this', '▁a', '▁it']
Speculation: ['▁hope', '▁you', '▁enjoy', '▁the']
Verification: ['▁you', '▁are', '▁this', '▁video'] 1
Acc@5: 2
Updated output: ['<s>', '▁Hello', '!', '▁How

In [14]:
def decode_obo(x):
    return [vinv[z] for z in x.squeeze().tolist()]
    
def speculative_generate(
    model: Union[Callable, torch.nn.Module],
    input_ids: torch.LongTensor,
    max_seq_len: int = 2048,
    max_new_tokens: int = 256,
    temperature: float = 1.0,
    top_k: int = 10,
    num_beams: int = 1,
):
    do_sample = False
    use_cache = True
    batched = False
    if num_beams != 1:
        raise NotImplementedError("generate() does yet not support beam search")
    if type(input_ids) == torch.Tensor:
        if input_ids.dim() != 1:
            batched = True
    else:
        raise RuntimeError("generate() requires a tensor of token ids as the prefix")

    if not batched:
        input_ids = input_ids.unsqueeze(0)

    result = input_ids
    next_input = input_ids
    kwargs = dict()
    kwargs["past_key_value_states"] = None
    kwargs["use_cache"] = use_cache

    embeds = model(input_ids[:,:-1], include_embeds=True, **kwargs)
    embeds = embeds[2] if use_cache else embeds[1]
    n_gen = 0
    n_steps = 0
    while n_gen < max_new_tokens:
        n_steps += 1
        input_ids = next_input[:, -max_seq_len:]
        
        n_adds = 3
        adds = torch.FloatTensor(torch.zeros(1,1,n_adds,32000))
        tmp = result[0,-2:].tolist()
        for i in range(n_adds):
            pair = (tmp[0],tmp[1])
            if pair not in trigram:
                break
            probs = trigram[pair]
            imax = 0
            pmax = 0
            for ind,p in probs.items():
                adds[:,:,i,ind] = p
                if p > pmax:
                    imax = ind
                    pmax = p
            tmp = (tmp[1], imax)

#         adds = smallmodel(embeds[:,-1].unsqueeze(1))
#         n_adds = smallmodel.nheads
        
        topk = adds.topk(5, dim=3)[1]
        for i in range(n_adds):
            print(f"Topk@{i+1}:", decode_obo(topk[0,0,i]))
        adds = adds.argmax(3).squeeze(1) # b h
        input_ids = torch.cat([input_ids, adds], dim=-1)
        print("Speculation:", decode_obo(input_ids))
        output = model.forward(input_ids, include_embeds=True, **kwargs)
        if use_cache:
            logits, past_key_value_states, embeds = output
        else:
            logits, embeds = output
        logits = logits[:, -n_adds-1:, :]

        if do_sample:
            # get logits from last value in sequence and scale
#             logits = logits / temperature
#             if top_k:
#                 v, _ = torch.topk(logits, top_k)
#                 logits[logits < v[:, [-1]]] = -float("inf")

#             probs = F.softmax(logits, dim=-1)
#             next_val = torch.multinomial(probs, num_samples=1)
            assert False
        else:
            next_vals = torch.argmax(logits, dim=-1)
        
        # Check correctness of smallmodel predictions
        n_correct = 0
        while n_correct < n_adds and next_vals[0,n_correct] == input_ids[0,-n_adds+n_correct]:
            n_correct += 1
        print("Verification:", decode_obo(next_vals), n_correct)
        
        k_correct = 0
        while (
            result.size(1)+k_correct < len(oracle) and 
            k_correct < n_adds and
            oracle[result.size(1)+k_correct] in topk[0,0,k_correct]
        ):
            k_correct += 1
        print("Acc@5:", k_correct)
        
        # Toss any wrong smallmodel outputs
        next_vals = next_vals[:,:n_correct+1]
        n_gen += n_correct+1
        embeds = embeds[:,:n_correct+1]
            
        n_wrong = n_adds - n_correct
        if use_cache:
            # kv updates are required for torch.compile with
            # mode='reduce-overhead'
            n_kv_s = []
            for layer_idx in range(len(past_key_value_states)):
                n_kv_s.append([])
                for tensor_idx in range(len(past_key_value_states[layer_idx])):
                    base = past_key_value_states[layer_idx][tensor_idx]
                    if n_wrong > 0:
                        base = base[:,:,:-n_wrong]
                    n_kv_s[layer_idx].append(
                        base.clone(memory_format=torch.contiguous_format).detach()
                    )
                    # torch._dynamo.mark_dynamic(n_kv_s[layer_idx][tensor_idx], 2)
            kwargs["past_key_value_states"] = n_kv_s

        result = torch.cat((result, next_vals), dim=-1)
        print("Updated output:", decode_obo(result))
        print()

        if use_cache:
            next_input = next_vals[:,-1].unsqueeze(-1)
        else:
            next_input = result

    if not batched:
        result = result[0]
    return result, n_steps