In [1]:
import os
import torch # was 20230913, now 20231013
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from typing import Callable, Union

In [2]:
from fms.models.llama import convert_hf_llama
from transformers import LlamaForCausalLM

In [3]:
from fms.models import get_model

In [5]:
model = get_model(
    "llama",
    "7b",
    model_path="../../../llama_weights/7B-F/",
    device_type="cpu",
    source="meta",
)

In [7]:
from fms.utils.generation import generate
from transformers import AutoTokenizer
t = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
vinv = {v:k for k,v in t.vocab.items()}

In [14]:
inp = t("Yesterday is history, tomorrow is a mystery, today")["input_ids"]
print([vinv[x] for x in inp])
inp = torch.IntTensor(inp)

['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁today']


In [15]:
oracle = generate(model, inp, 30, 30, do_sample=False, use_cache=True)
t.decode(oracle.tolist())

"<s> Yesterday is history, tomorrow is a mystery, today is a gift.\nThat's why they call it the present.\n\nThis quote is a reminder to appreciate the present moment and not"

In [11]:
test = Speculator(n_heads=3)
test.load_state_dict(torch.load("../../../specu_recur_n2.pth", map_location="cpu")["model_state"])

<All keys matched successfully>

In [12]:
sum(p.numel() for p in test.parameters())

887119872

In [16]:
out, steps = speculative_generate(model, inp, test, 30, 30, top_k=25, threshes=[10,3,2])
print()
print("Steps:", steps)

Verification: ['▁today', '▁is', '▁a', '▁day'] ['▁is', '▁a', '▁gift', '▁of'] 2
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁today', '▁is', '▁a', '▁gift']

Verification: ['▁gift', '.', '<0x0A>', 'This'] ['.', '<0x0A>', 'That', '▁quote'] 2
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁today', '▁is', '▁a', '▁gift', '.', '<0x0A>', 'That']

Verification: ['That', "'", 's', '▁why'] ["'", 's', '▁why', '▁they'] 3
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁today', '▁is', '▁a', '▁gift', '.', '<0x0A>', 'That', "'", 's', '▁why', '▁they']

Verification: ['▁they', '▁call', '▁it', '▁the'] ['▁call', '▁it', '▁the', '▁present'] 3
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁today', '▁is', '▁a', '▁gift', '.', '<0x0A>', 'That', "'

In [10]:
from fms.modules.layernorm import LayerNormParameterized


class Speculator(nn.Module):
    def __init__(self, emb_dim=4096, vocab_size=32000, n_heads=4):
        super().__init__()
        self.nheads = n_heads
        self.emb_dim = emb_dim
        self.vsize = vocab_size
        self.emb = nn.ModuleList([nn.Embedding(vocab_size, emb_dim) for _ in range(n_heads)])
        self.proj = nn.ModuleList([nn.Linear(emb_dim * 2, emb_dim, bias=False) for _ in range(n_heads)])
        self.head = nn.ModuleList([nn.Linear(emb_dim, vocab_size, bias=False) for _ in range(n_heads)])
        self.ln = nn.ModuleList(
            [LayerNormParameterized(emb_dim, elementwise_shift=True, elementwise_scale=True) for _ in range(n_heads)]
        )
        self.a = nn.GELU()
        self.reset_params()

    def reset_params(self):
        for m in self.modules():
            if isinstance(m, nn.Embedding) or isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, 0, 1 / self.emb_dim**0.5)
            elif isinstance(m, LayerNormParameterized):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
                
    def generate_tree(self, state, ind, topk=[5,4,3], k=25):
        # state: b 1 d
        # ind: b 1
        b = state.size(0)
        out = torch.LongTensor(b,1,0) # b k h
        log_probs = torch.zeros(b,1) # b k
        assert len(topk)==self.nheads
        for i in range(self.nheads):
            z = self.emb[i](ind) # b k d
            z = torch.cat([state, z], dim=2) # b k 2d
            state = self.a(self.ln[i](self.proj[i](z))) # b k d
            probs = F.log_softmax(self.head[i](state), dim=2) # b k v
            probs, preds = probs.topk(topk[i], dim=2) # b k k'
            out = out.unsqueeze(2).expand(-1,-1,topk[i],-1) # b k k' h
            out = torch.cat([out, preds.unsqueeze(3)], dim=3) # b k k' h+1
            
            # Prep for next round
            out = out.view(b, -1, i+1) # b kk' h+1
            state = state.unsqueeze(2).expand(-1,-1,topk[i],-1) # b k k' d
            state = state.reshape(b, -1, state.size(3)) # b kk' d
            ind = preds.view(b, -1) # b kk'
            log_probs = log_probs.unsqueeze(2).expand(b,-1,topk[i]) # b k k'
            log_probs = log_probs.add(probs).reshape(b, -1) # b kk'
            
        best_guesses = log_probs.topk(k, dim=1)[1] # b k
        
        return out.gather(1, best_guesses.unsqueeze(2).expand(-1,-1,self.nheads)) # b k h
            

    def forward(self, state, inds):
        # state: b n d
        # inds: b n+2 (..., pred token, n+2, n+3)
        out = []
        for i in range(self.nheads):
            h_inds = inds[:, i : i + state.size(1)]
            z = self.emb[i](h_inds)  # b n d
            z = torch.cat([state, z], dim=2)  # b n 2d
            state = self.a(self.ln[i](self.proj[i](z)))  # b n d
            out.append(self.head[i](state))  # b n v
        return torch.stack(out, dim=0)  # h b n v

    
def decode_obo(x):
    return [vinv[z] for z in x.squeeze().tolist()]


def speculative_generate(
    model: Union[Callable, torch.nn.Module],
    input_ids: torch.LongTensor,
    smallmodel: torch.nn.Module,
    max_seq_len: int = 2048,
    max_new_tokens: int = 256,
    temperature: float = 1.0,
    top_k: int = 24,
    num_beams: int = 1,
    threshes = [4,3,2]
):
    do_sample = False
    batched = False
    if num_beams != 1:
        raise NotImplementedError("generate() does yet not support beam search")
    if type(input_ids) == torch.Tensor:
        if input_ids.dim() != 1:
            batched = True
    else:
        raise RuntimeError("generate() requires a tensor of token ids as the prefix")

    if not batched:
        input_ids = input_ids.unsqueeze(0)

    result = input_ids
    next_input = input_ids
    kwargs = dict()
    kwargs["past_key_value_states"] = None
    kwargs["use_cache"] = True

    output = model(input_ids[:,:-1], include_embeds=True, **kwargs)
    _, past_key_value_states, embeds = output
    embeds = embeds[:,-1:]
    kwargs["past_key_value_states"] = past_key_value_states
    next_input = next_input[:,-1:]
    
    n_gen = 0
    n_steps = 0
    n_kv_s = past_key_value_states
    while n_gen < max_new_tokens:
        n_steps += 1
        input_ids = next_input[:, -max_seq_len:]
        
        adds = smallmodel.generate_tree(embeds, input_ids, threshes, top_k)
#         for i in range(adds.size(1)):
#             print(decode_obo(adds[0,i]))
        
        n_adds = smallmodel.nheads
#         probs = smallmodel(embeds, input_ids).squeeze(1) # b h v
#         probs, topk = probs.topk(max(threshes), dim=2) # b h 5
#         for i in range(n_adds):
#             print(f"Topk@{i+1}:", decode_obo(topk[0,i]))
        
#         # Build probability table
#         topk_v, topk_i = get_topk_tree(probs, top_k, threshes)
        
#         # Assemble batch of tree branches
#         adds = topk.gather(2, topk_i).transpose(1,2) # b k h
        adds = adds[0] # For now, non-batching and take only first b entry
        input_ids = torch.cat([input_ids.expand(top_k,1), adds], dim=-1) 
#         print("Speculations:")
#         for i in range(top_k):
#             print(decode_obo(input_ids[i]))
        
        mask = torch.ones(input_ids.size(1),input_ids.size(1)+n_kv_s[0][0].size(2))
        mask = mask.tril(diagonal=mask.size(1)-mask.size(0))
        mask = mask.unsqueeze(0).unsqueeze(0).log()
        
#         input_ids = input_ids[0].unsqueeze(0).expand(25,-1)
        
        output = model.forward(input_ids, include_embeds=True, mask=mask, **kwargs)
        
        logits, past_key_value_states, embeds = output
        logits = logits[:, -n_adds-1:, :]

        if do_sample:
            # get logits from last value in sequence and scale
#             logits = logits / temperature
#             if top_k:
#                 v, _ = torch.topk(logits, top_k)
#                 logits[logits < v[:, [-1]]] = -float("inf")

#             probs = F.softmax(logits, dim=-1)
#             next_val = torch.multinomial(probs, num_samples=1)
            assert False
        else:
            next_vals = torch.argmax(logits, dim=-1)
        
        # Check correctness of smallmodel predictions
        test = input_ids.roll(-1, 1).eq(next_vals).cumprod(1)
        
        n_correct = test.sum(1).clamp(0,n_adds)
        best_guess = n_correct.argmax()
        
#         for i in range(top_k):
#             print(decode_obo(input_ids[i]), decode_obo(next_vals[i]), test[i].tolist(), n_correct[i].item())
        
        next_vals = next_vals[best_guess].unsqueeze(0)
        n_correct = n_correct[best_guess]
        embeds = embeds[best_guess].unsqueeze(0)
        
        print("Verification:", decode_obo(input_ids[best_guess]), decode_obo(next_vals), n_correct.item())
        
        # Toss any wrong smallmodel outputs
        next_vals = next_vals[:,:n_correct+1]
        n_gen += n_correct+1
        embeds = embeds[:,n_correct].unsqueeze(1)
            
        n_wrong = n_adds - n_correct
        # kv updates are required for torch.compile with
        # mode='reduce-overhead'
        n_kv_s = []
        for layer_idx in range(len(past_key_value_states)):
            n_kv_s.append([])
            for tensor_idx in range(2):
                base = past_key_value_states[layer_idx][tensor_idx]
                new = past_key_value_states[layer_idx][tensor_idx+2][best_guess].unsqueeze(0)
                if n_wrong > 0:
                    new = new[:,:,:-n_wrong]
                base = torch.cat([base, new], dim=2)
                n_kv_s[layer_idx].append(
                    base.clone(memory_format=torch.contiguous_format).detach()
                )
                # torch._dynamo.mark_dynamic(n_kv_s[layer_idx][tensor_idx], 2)
        kwargs["past_key_value_states"] = n_kv_s

        result = torch.cat((result, next_vals), dim=-1)
        print("Updated output:", decode_obo(result))
        print()

        next_input = next_vals[:,-1].unsqueeze(-1)

    if not batched:
        result = result[0]
    return result, n_steps

In [29]:
import torch
test = torch.load("../../../specu_recur_n2_shortgen_scores.pth")

for k in test.keys():
    steps = sum(test[k])/len(test[k])
    print(k, steps)

2 40.92
5 36.18
10 34.325
25 33.585


In [18]:
# ANY-LENGTH VERSION

def get_topk_tree(logits, k=50):
    # probs: b h v
    n_adds = logits.size(1)
    probs = logits.softmax(2)
    # Add a no-token option to each head
    probs = torch.cat([torch.ones(probs.size(0),probs.size(1),1), probs], dim=2) # b h 6
    probtable = torch.ones(*([probs.size(0)]+[probs.size(2)]*n_adds)) # b 6 6 6
    # Populate probability table
    for i in range(n_adds):
        dimlist = [-1]+[1]*n_adds
        dimlist[i+1] = probtable.size(i+1)
        probtable *= probs[:,i].view(dimlist)
    # Zero out impossible entries (i.e. nil nil token)
    psize = probtable.size()
    causal = torch.ones(psize[-1],psize[-1]) # 6 6
    causal[0,1:].zero_()
    for i in range(n_adds-1):
        probtable *= causal.view(*(list(causal.size()) + [1]*(i)))
    probtable = probtable.view(psize[0],-1)
    # Zero out all-nil option
    probtable[:,0].zero_()
    # Fetch top-k most probable tree nodes
    v,i = probtable.topk(k, dim=1) # b k
    i = torch.stack(torch.unravel_index(i, psize[1:]), 1).permute(0,2,1)
    # v: b k
    # i: b k h
    return v,i

get_topk_tree(torch.randn(2,3,5), 10)

(tensor([[0.7231, 0.2638, 0.2146, 0.1832, 0.1578, 0.1490, 0.0970, 0.0947, 0.0843,
          0.0789],
         [0.3288, 0.2745, 0.1755, 0.1384, 0.1326, 0.1156, 0.0887, 0.0745, 0.0739,
          0.0622]]),
 tensor([[[2, 0, 0],
          [2, 1, 0],
          [2, 1, 1],
          [2, 3, 0],
          [5, 0, 0],
          [2, 3, 1],
          [2, 4, 0],
          [2, 5, 0],
          [2, 2, 0],
          [2, 4, 1]],
 
         [[5, 0, 0],
          [1, 0, 0],
          [3, 0, 0],
          [5, 1, 0],
          [4, 0, 0],
          [1, 1, 0],
          [2, 0, 0],
          [5, 2, 0],
          [3, 1, 0],
          [1, 2, 0]]]))

In [10]:
start = time.time()
out, steps = speculative_generate(model, inp, 20, 20)
print(time.time()-start)
print("Steps:", steps)

Speculation: tensor([[ 0,  1,  2,  3,  4,  5,  6,  7, 52,  6, 24,  6]])
Verification: tensor([[ 52,   6,  24,   6, 108]]) 4
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108]])

Speculation: tensor([[108,   0,   0,   0,   0]])
Verification: tensor([[121,  22,  22,  22, 118]]) 0
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108, 121]])

Speculation: tensor([[121,  24,  35,  73,  60]])
Verification: tensor([[24, 35, 73, 60, 75]]) 4
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108, 121,
          24,  35,  73,  60,  75]])

Speculation: tensor([[75, 28, 24, 60,  0]])
Verification: tensor([[ 28,  24,  60, 112,  35]]) 3
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108, 121,
          24,  35,  73,  60,  75,  28,  24,  60, 112]])

Speculation: tensor([[112, 118,  73,  92,   0]])
Verification: tensor([[118,  73,  92, 124,  69]]) 3
Updated

In [8]:
for i in range(len(out)):
    assert out[0,i]==oracle[0,i], i
print("Exact match!")

Exact match!


In [6]:
def blackbox_oracle(inp):
    out = oracle[:,inp.size(1):inp.size(1)+4].clone()
    nzero = torch.randint(5,(1,))
    if nzero > 0:
        out[:,-nzero:] = 0
    return out, min(oracle.size(1), inp.size(1)+4) - inp.size(1)

