In [1]:
import os
import torch # was 20230913, now 20231013
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from typing import Callable, Union

In [None]:
from fms.models.llama import convert_hf_llama
from transformers import LlamaForCausalLM

In [3]:
from fms.models import get_model

In [4]:
model = get_model(
    "llama",
    "7b",
    model_path="../../../llama_weights/7B-F/",
    device_type="cpu",
    source="meta",
)

In [5]:
from fms.utils.generation import generate
from transformers import AutoTokenizer
t = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
vinv = {v:k for k,v in t.vocab.items()}

In [6]:
inp = t("Yesterday is history, tomorrow is a mystery, today")["input_ids"]
print([vinv[x] for x in inp])
inp = torch.IntTensor(inp)

['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁today']


In [12]:
oracle = generate(model, inp, 30, 30, do_sample=False, use_cache=True)
t.decode(oracle.tolist())

"<s> Yesterday is history, tomorrow is a mystery, today is a gift.\nThat's why they call it the present.\n\nThis quote is a reminder to appreciate the present moment and not"

In [8]:
from fms.modules.speculator import Speculator

test = Speculator(n_heads=3)
test.load_state_dict(torch.load("../../../specu_recur_n2.pth", map_location="cpu")["model_state"])
sum(p.numel() for p in test.parameters())

887119872

In [15]:
from fms.utils.generation import speculative_generate

out, steps = speculative_generate(model, inp, test, 30, 30, top_k=5, threshes=[10,3,2], verbose_dict=t.vocab)
print()
print("Steps:", steps)

Speculation: ['▁today', '▁is', '▁the', '▁day'] n_correct: 1
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁today', '▁is', '▁a']

Speculation: ['▁a', '▁gift', '.', '<0x0A>'] n_correct: 3
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁today', '▁is', '▁a', '▁gift', '.', '<0x0A>', 'That']

Speculation: ['That', "'", 's', '▁why'] n_correct: 3
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁today', '▁is', '▁a', '▁gift', '.', '<0x0A>', 'That', "'", 's', '▁why', '▁they']

Speculation: ['▁they', '▁call', '▁it', '▁the'] n_correct: 3
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁today', '▁is', '▁a', '▁gift', '.', '<0x0A>', 'That', "'", 's', '▁why', '▁they', '▁call', '▁it', '▁the', '▁present']

Speculation: ['▁present', '▁day', '▁