In [1]:
import os
import torch # was 20230913, now 20231013
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from typing import Callable, Union

In [2]:
from fms.models.llama import convert_hf_llama
from transformers import LlamaForCausalLM

In [3]:
from fms.models import get_model

In [4]:
model = get_model(
    "llama",
    "7b",
    model_path="../../../llama_weights/7B-F/",
    device_type="cpu",
    source="meta",
)

In [5]:
from fms.utils.generation import generate
from transformers import AutoTokenizer
t = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
vinv = {v:k for k,v in t.vocab.items()}

In [6]:
inp1 = t("Yesterday is history, tomorrow is a mystery, today")["input_ids"]
print([vinv[x] for x in inp1])
inp1 = torch.IntTensor(inp1)

['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁today']


In [7]:
oracle = generate(model, inp1, 4096, 30, do_sample=False, use_cache=True)
t.decode(oracle.tolist())

"<s> Yesterday is history, tomorrow is a mystery, today is a gift.\nThat's why they call it the present.\n\nThis quote is a reminder to appreciate the present moment and not"

In [8]:
text2 = "Hello, how are you today?"
inp2 = t(text2)["input_ids"]
# inp = t("Yesterday is history, tomorrow is a mystery, today")["input_ids"]
print([vinv[x] for x in inp2])
inp2 = torch.IntTensor(inp2)

['<s>', '▁Hello', ',', '▁how', '▁are', '▁you', '▁today', '?']


In [9]:
oracle = generate(model, inp2, 4096, 30, do_sample=False, use_cache=True)
t.decode(oracle.tolist())

"<s> Hello, how are you today? I'm doing well, thanks for asking! I'm excited to be here and share some of my thoughts and experiences with you.\n\n"

In [11]:
inp = [inp1,inp2]

In [12]:
from fms.modules.speculator import Speculator

test = Speculator(n_heads=3)
test.load_state_dict(torch.load("../../../specu_recur_n2.pth", map_location="cpu")["model_state"])
sum(p.numel() for p in test.parameters())

887119872

In [13]:
from fms.utils.generation import speculative_generate

out, steps = speculative_generate(model, inp, test, 4096, 30, top_k=5, threshes=[5,3,2], verbose_dict=t.vocab)
print()
print("Steps:", steps)

Speculation: ['▁today', '▁is', '▁the', '▁day'] n_correct: 1
Speculation: ['?', '▁I', "'", 'm'] n_correct: 3
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁today', '▁is', '▁a']
Updated output: ['<s>', '▁Hello', ',', '▁how', '▁are', '▁you', '▁today', '?', '▁I', "'", 'm', '▁doing']

Speculation: ['▁a', '▁gift', '.', '<0x0A>'] n_correct: 3
Speculation: ['▁doing', '▁a', '▁great', '▁job'] n_correct: 0
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁today', '▁is', '▁a', '▁gift', '.', '<0x0A>', 'That']
Updated output: ['<s>', '▁Hello', ',', '▁how', '▁are', '▁you', '▁today', '?', '▁I', "'", 'm', '▁doing', '▁well']

Speculation: ['That', "'", 's', '▁why'] n_correct: 3
Speculation: ['▁well', '.', '▁I', "'"] n_correct: 0
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁today', '▁is', '▁a', '▁gift', '


Steps: 14
