In [1]:
import os
import torch # was 20230913, now 20231013
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from typing import Callable, Union

In [2]:
from fms.models.llama import LLaMAConfig, LLaMA

In [3]:
modelc = LLaMAConfig(32000, 4096, 1e-6, 32, 0, 32)
model = LLaMA(modelc)
model.eval()

LLaMA(
  (shared): WordEmbedding(
    (emb): Embedding(32000, 4096)
    (head): Linear(in_features=4096, out_features=32000, bias=False)
  )
  (layers): ModuleList(
    (0-31): 32 x LLaMABlock(
      (ln): LayerNormParameterized()
      (ff_ln): LayerNormParameterized()
      (attn): MultiHeadAttention(
        (query): Linear(in_features=4096, out_features=4096, bias=False)
        (key): Linear(in_features=4096, out_features=4096, bias=False)
        (value): Linear(in_features=4096, out_features=4096, bias=False)
        (dense): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (ff_sub_layer): GatedLinearUnit(
        (w1): Linear(in_features=4096, out_features=11008, bias=False)
        (wg): Linear(in_features=4096, out_features=11008, bias=False)
        (a): SiLU()
        (w2): Linear(in_features=11008, out_features=4096, bias=False)
      )
    )
  )
  (dec_norm): LayerNormParameterized()
)

In [4]:
d = torch.load("../../../llama_7b_ckp.pth")['model_state']

In [5]:
keylist = list(d.keys())
for key in keylist:
    if "dec_process" in key:
        value = d.pop(key)
        fields = key.split(".")
        fields[0] = "layers"
        d[".".join(fields)] = value

In [6]:
model.load_state_dict(d, strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['rope.freqs'])

In [7]:
from fms.utils.generation import generate
from transformers import AutoTokenizer
t = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
vinv = {v:k for k,v in t.vocab.items()}

In [8]:
inp = t("Hello, how are you today?")["input_ids"]
print([vinv[x] for x in inp])
inp = torch.IntTensor(inp)

['<s>', '▁Hello', ',', '▁how', '▁are', '▁you', '▁today', '?']


In [9]:
oracle = generate(model, inp, 20, 20, do_sample=False, use_cache=True)
t.decode(oracle.tolist())

'<s> Hello, how are you today? I hope you are doing well. I am doing well. I am happy to be here with you'

In [11]:
test = Speculator(n_heads=3)
test.load_state_dict(torch.load("../../../specu_dist.pth", map_location="cpu")["model_state"])

<All keys matched successfully>

In [12]:
sum(p.numel() for p in test.parameters())

844107776

In [14]:
out, steps = speculative_generate(model, inp, test, 30, 30, top_k=25)
print()
print("Steps:", steps)

torch.Size([1, 1, 4096])
Topk@1: ['▁I', '<0x0A>', '▁We', '▁My', '▁This']
Topk@2: ['’', "'", '▁you', '▁hope', '▁is']
Topk@3: ['▁you', '▁is', 'm', 's', '▁I']
Verification: ['?', '▁I', '▁hope', '▁you'] ['▁I', '▁hope', '▁you', '▁are'] 3
Updated output: ['<s>', '▁Hello', ',', '▁how', '▁are', '▁you', '▁today', '?', '▁I', '▁hope', '▁you', '▁are']

torch.Size([1, 1, 4096])
Topk@1: ['re', '▁doing', '▁well', '▁having', '▁all']
Topk@2: ['▁well', '▁doing', '▁having', '▁great', '.']
Topk@3: ['▁well', '.', '▁great', '▁a', '▁good']
Verification: ['▁are', '▁doing', '▁well', '.'] ['▁doing', '▁well', '.', '▁I'] 3
Updated output: ['<s>', '▁Hello', ',', '▁how', '▁are', '▁you', '▁today', '?', '▁I', '▁hope', '▁you', '▁are', '▁doing', '▁well', '.', '▁I']

torch.Size([1, 1, 4096])
Topk@1: ['▁you', '’', ',', "'", '▁am']
Topk@2: ['▁is', ',', 'm', '▁I', '▁you']
Topk@3: ['▁I', ',', '▁a', '▁you', '▁to']
Verification: ['▁I', '▁you', '▁is', '▁I'] ['▁am', '▁are', '▁a', '▁am'] 0
Updated output: ['<s>', '▁Hello', ',', 

In [10]:
from fms.modules.layernorm import LayerNormParameterized


class Speculator(nn.Module):
    def __init__(self, emb_dim=4096, vocab_size=32000, n_heads=4):
        super().__init__()
        self.nheads = n_heads
        self.emb_dim = emb_dim
        self.vsize = vocab_size
        self.w_in = nn.Parameter(torch.empty(emb_dim, int((emb_dim * 2.6875 * 2) // 256) * 256 * 2))  # d 2z
        self.a = nn.GELU()
        self.w_out = nn.Parameter(torch.empty(int((emb_dim * 2.6875 * 2) // 256) * 256, emb_dim * n_heads))  # z hd
        self.ln = LayerNormParameterized(emb_dim, elementwise_shift=False, elementwise_scale=True)
        self.head = nn.Parameter(torch.empty(n_heads, emb_dim, vocab_size))  # h d v
        self.reset_params()

    def reset_params(self):
        nn.init.trunc_normal_(self.w_in, 0, (1 / 2.6875) ** (1 / 6) / self.emb_dim**0.5)
        nn.init.trunc_normal_(self.w_out, 0, (1 / 2.6875) ** (1 / 6) / self.emb_dim**0.5)
        nn.init.trunc_normal_(self.head, 0, 1 / self.emb_dim**0.5)

    def forward(self, x):
        # x: b n d
        print(x.shape)
        z, g = x.matmul(self.w_in).chunk(2, dim=2)
        z = z * self.a(g)
        z = z.matmul(self.w_out).view(x.size(0), x.size(1), self.nheads, self.emb_dim)  # b n h d
        z = z + x.unsqueeze(2)
        z = self.ln(z)
        z = torch.einsum("bnhd,hdv->bnhv", z, self.head)
        return z # b n h v
    
    
def get_topk_tree(logits, k=50):
    # probs: b h v
    n_adds = logits.size(1)
    probs = logits.softmax(2)
    # Generate probabilities for all combos of predictions
    probtable = [
        probs[:,i].view(
            *([-1] + [1]*i + [probs.size(2)] + [1]*(n_adds-i-1))
        ).expand(
            *([-1] + [probs.size(2)]*n_adds)
        )
        for i in range(n_adds)
    ]
    probtable = torch.stack(probtable, 0).prod(0) # b v v v...
    psize = probtable.size()
    probtable = probtable.view(psize[0],-1)
    # Fetch top-k most probable tree nodes
    v,i = probtable.topk(k, dim=1) # b k
    i = torch.stack(torch.unravel_index(i, psize[1:]), 1)
    # v: b k
    # i: b h k
    return v,i

    
def decode_obo(x):
    return [vinv[z] for z in x.squeeze().tolist()]


def speculative_generate(
    model: Union[Callable, torch.nn.Module],
    input_ids: torch.LongTensor,
    smallmodel: torch.nn.Module,
    max_seq_len: int = 2048,
    max_new_tokens: int = 256,
    temperature: float = 1.0,
    top_k: int = 25,
    num_beams: int = 1,
):
    do_sample = False
    batched = False
    if num_beams != 1:
        raise NotImplementedError("generate() does yet not support beam search")
    if type(input_ids) == torch.Tensor:
        if input_ids.dim() != 1:
            batched = True
    else:
        raise RuntimeError("generate() requires a tensor of token ids as the prefix")

    if not batched:
        input_ids = input_ids.unsqueeze(0)

    result = input_ids
    next_input = input_ids
    kwargs = dict()
    kwargs["past_key_value_states"] = None
    kwargs["use_cache"] = True

    output = model(input_ids[:,:-1], include_embeds=True, **kwargs)
    _, past_key_value_states, embeds = output
    kwargs["past_key_value_states"] = past_key_value_states
    next_input = next_input[:,-1:]
    
    n_gen = 0
    n_steps = 0
    n_kv_s = past_key_value_states
    while n_gen < max_new_tokens:
        n_steps += 1
        input_ids = next_input[:, -max_seq_len:]
        
        probs = smallmodel(embeds[:,-1].unsqueeze(1)).squeeze(1) # b h v
        probs, topk = probs.topk(5, dim=2) # b h 5
        n_adds = smallmodel.nheads
        for i in range(n_adds):
            print(f"Topk@{i+1}:", decode_obo(topk[0,i]))
        
        # Build probability table
        topk_v, topk_i = get_topk_tree(probs, top_k)
        
        # Assemble batch of tree branches
        adds = topk.gather(2, topk_i).transpose(1,2) # b k h
        adds = adds[0] # For now, non-batching and take only first b entry
        input_ids = torch.cat([input_ids.expand(top_k,1), adds], dim=-1) 
#         print("Speculations:")
#         for i in range(top_k):
#             print(decode_obo(input_ids[i]))
        
        mask = torch.ones(input_ids.size(1),input_ids.size(1)+n_kv_s[0][0].size(2))
        mask = mask.tril(diagonal=mask.size(1)-mask.size(0))
        mask = mask.unsqueeze(0).unsqueeze(0).log()
        
#         input_ids = input_ids[0].unsqueeze(0).expand(25,-1)
        
        output = model.forward(input_ids, include_embeds=True, mask=mask, **kwargs)
        
        logits, past_key_value_states, embeds = output
        logits = logits[:, -n_adds-1:, :]

        if do_sample:
            # get logits from last value in sequence and scale
#             logits = logits / temperature
#             if top_k:
#                 v, _ = torch.topk(logits, top_k)
#                 logits[logits < v[:, [-1]]] = -float("inf")

#             probs = F.softmax(logits, dim=-1)
#             next_val = torch.multinomial(probs, num_samples=1)
            assert False
        else:
            next_vals = torch.argmax(logits, dim=-1)
        
        # Check correctness of smallmodel predictions
        test = input_ids.roll(-1, 1).eq(next_vals).cumprod(1)
        
        n_correct = test.sum(1).clamp(0,n_adds)
        best_guess = n_correct.argmax()
        
#         for i in range(top_k):
#             print(decode_obo(input_ids[i]), decode_obo(next_vals[i]), test[i].tolist(), n_correct[i].item())
        
        next_vals = next_vals[best_guess].unsqueeze(0)
        n_correct = n_correct[best_guess]
        embeds = embeds[best_guess].unsqueeze(0)
        
        print("Verification:", decode_obo(input_ids[best_guess]), decode_obo(next_vals), n_correct.item())
        
        # Toss any wrong smallmodel outputs
        next_vals = next_vals[:,:n_correct+1]
        n_gen += n_correct+1
        embeds = embeds[:,:n_correct+1]
            
        n_wrong = n_adds - n_correct
        # kv updates are required for torch.compile with
        # mode='reduce-overhead'
        n_kv_s = []
        for layer_idx in range(len(past_key_value_states)):
            n_kv_s.append([])
            for tensor_idx in range(2):
                base = past_key_value_states[layer_idx][tensor_idx]
                new = past_key_value_states[layer_idx][tensor_idx+2][best_guess].unsqueeze(0)
                if n_wrong > 0:
                    new = new[:,:,:-n_wrong]
                base = torch.cat([base, new], dim=2)
                n_kv_s[layer_idx].append(
                    base.clone(memory_format=torch.contiguous_format).detach()
                )
                # torch._dynamo.mark_dynamic(n_kv_s[layer_idx][tensor_idx], 2)
        kwargs["past_key_value_states"] = n_kv_s

        result = torch.cat((result, next_vals), dim=-1)
        print("Updated output:", decode_obo(result))
        print()

        next_input = next_vals[:,-1].unsqueeze(-1)

    if not batched:
        result = result[0]
    return result, n_steps

In [18]:
# ANY-LENGTH VERSION

def get_topk_tree(logits, k=50):
    # probs: b h v
    n_adds = logits.size(1)
    probs = logits.softmax(2)
    # Add a no-token option to each head
    probs = torch.cat([torch.ones(probs.size(0),probs.size(1),1), probs], dim=2) # b h 6
    probtable = torch.ones(*([probs.size(0)]+[probs.size(2)]*n_adds)) # b 6 6 6
    # Populate probability table
    for i in range(n_adds):
        dimlist = [-1]+[1]*n_adds
        dimlist[i+1] = probtable.size(i+1)
        probtable *= probs[:,i].view(dimlist)
    # Zero out impossible entries (i.e. nil nil token)
    psize = probtable.size()
    causal = torch.ones(psize[-1],psize[-1]) # 6 6
    causal[0,1:].zero_()
    for i in range(n_adds-1):
        probtable *= causal.view(*(list(causal.size()) + [1]*(i)))
    probtable = probtable.view(psize[0],-1)
    # Zero out all-nil option
    probtable[:,0].zero_()
    # Fetch top-k most probable tree nodes
    v,i = probtable.topk(k, dim=1) # b k
    i = torch.stack(torch.unravel_index(i, psize[1:]), 1).permute(0,2,1)
    # v: b k
    # i: b k h
    return v,i

get_topk_tree(torch.randn(2,3,5), 10)

(tensor([[0.7231, 0.2638, 0.2146, 0.1832, 0.1578, 0.1490, 0.0970, 0.0947, 0.0843,
          0.0789],
         [0.3288, 0.2745, 0.1755, 0.1384, 0.1326, 0.1156, 0.0887, 0.0745, 0.0739,
          0.0622]]),
 tensor([[[2, 0, 0],
          [2, 1, 0],
          [2, 1, 1],
          [2, 3, 0],
          [5, 0, 0],
          [2, 3, 1],
          [2, 4, 0],
          [2, 5, 0],
          [2, 2, 0],
          [2, 4, 1]],
 
         [[5, 0, 0],
          [1, 0, 0],
          [3, 0, 0],
          [5, 1, 0],
          [4, 0, 0],
          [1, 1, 0],
          [2, 0, 0],
          [5, 2, 0],
          [3, 1, 0],
          [1, 2, 0]]]))

In [10]:
start = time.time()
out, steps = speculative_generate(model, inp, 20, 20)
print(time.time()-start)
print("Steps:", steps)

Speculation: tensor([[ 0,  1,  2,  3,  4,  5,  6,  7, 52,  6, 24,  6]])
Verification: tensor([[ 52,   6,  24,   6, 108]]) 4
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108]])

Speculation: tensor([[108,   0,   0,   0,   0]])
Verification: tensor([[121,  22,  22,  22, 118]]) 0
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108, 121]])

Speculation: tensor([[121,  24,  35,  73,  60]])
Verification: tensor([[24, 35, 73, 60, 75]]) 4
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108, 121,
          24,  35,  73,  60,  75]])

Speculation: tensor([[75, 28, 24, 60,  0]])
Verification: tensor([[ 28,  24,  60, 112,  35]]) 3
Updated output: tensor([[  0,   1,   2,   3,   4,   5,   6,   7,  52,   6,  24,   6, 108, 121,
          24,  35,  73,  60,  75,  28,  24,  60, 112]])

Speculation: tensor([[112, 118,  73,  92,   0]])
Verification: tensor([[118,  73,  92, 124,  69]]) 3
Updated

In [8]:
for i in range(len(out)):
    assert out[0,i]==oracle[0,i], i
print("Exact match!")

Exact match!


In [6]:
def blackbox_oracle(inp):
    out = oracle[:,inp.size(1):inp.size(1)+4].clone()
    nzero = torch.randint(5,(1,))
    if nzero > 0:
        out[:,-nzero:] = 0
    return out, min(oracle.size(1), inp.size(1)+4) - inp.size(1)

