In [1]:
import os
import torch # was 20230913, now 20231013
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from typing import Callable, Union

In [2]:
from fms.models.llama import convert_hf_llama
from transformers import LlamaForCausalLM

In [3]:
from fms.models import get_model

In [4]:
model = get_model(
    "llama",
    "7b",
    model_path="../../../llama_weights/7B-F/",
    device_type="cpu",
    source="meta",
)

In [5]:
from fms.utils.generation import generate
from transformers import AutoTokenizer
t = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
vinv = {v:k for k,v in t.vocab.items()}

In [6]:
inp = t("Yesterday is history, tomorrow")["input_ids"]
print([vinv[x] for x in inp])
inp = torch.IntTensor(inp)

['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow']


In [7]:
inp2 = t("Hello, how are you today?")["input_ids"]
print([vinv[x] for x in inp2])
inp2 = torch.IntTensor(inp2)

['<s>', '▁Hello', ',', '▁how', '▁are', '▁you', '▁today', '?']


In [8]:
inp3 = t("The largest land dwelling animal is")["input_ids"]
print([vinv[x] for x in inp3])
inp3 = torch.IntTensor(inp3)

['<s>', '▁The', '▁largest', '▁land', '▁dwell', 'ing', '▁animal', '▁is']


In [9]:
oracle = generate(model, inp, 30, 30, do_sample=False, use_cache=True)
t.decode(oracle.tolist())

"<s> Yesterday is history, tomorrow is a mystery, but today is a gift.\nThat's why they call it the present.\n\nThis quote is a reminder to"

In [10]:
oracle = generate(model, inp2, 30, 30, do_sample=False, use_cache=True)
t.decode(oracle.tolist())

"<s> Hello, how are you today? I'm doing well, thanks for asking! I'm excited to be here and share some of my thoughts and experiences with you.\n\n"

In [11]:
oracle = generate(model, inp3, 30, 30, do_sample=False, use_cache=True)
t.decode(oracle.tolist())

'<s> The largest land dwelling animal is the African elephant, which can weigh up to 6 tons (12,000 lbs or 5,40'

In [12]:
print(inp, inp2, inp3)

tensor([    1,   612, 18358,   338,  4955, 29892,  6454, 22396],
       dtype=torch.int32) tensor([    1, 15043, 29892,   920,   526,   366,  9826, 29973],
       dtype=torch.int32) tensor([    1,   450, 10150,  2982, 24013,   292, 13019,   338],
       dtype=torch.int32)


In [13]:
from fms.modules.speculator import Speculator

test = Speculator(n_heads=3)
test.load_state_dict(torch.load("../../../specu_recur_n2.pth", map_location="cpu")["model_state"])
sum(p.numel() for p in test.parameters())

887119872

In [14]:
from fms.utils.generation import speculative_generate

out, steps = speculative_generate(model, inp, test, 30, 30, top_k=5, threshes=[10,3,2], verbose_dict=t.vocab)
print()
print("Steps:", steps)

Speculation: ['orrow', '▁is', '▁the', '▁day'] n_correct: 1
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a']

Speculation: ['▁a', '▁mystery', ',', '▁but'] n_correct: 3
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁but', '▁today']

Speculation: ['▁today', '▁is', '▁a', '▁day'] n_correct: 2
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁but', '▁today', '▁is', '▁a', '▁gift']

Speculation: ['▁gift', '.', '<0x0A>', 'I'] n_correct: 2
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁but', '▁today', '▁is', '▁a', '▁gift', '.', '<0x0A>', 'That']

Speculation: ['That', "'", 's', '▁why'] n_correct: 3
Updated output: ['<s>', '▁Y', 'esterday', '▁is', '▁history', ',', '▁tom', 'orrow', '▁is', '▁a', '▁mystery', ',', '▁but', '▁today', '▁is', '▁a', '▁gift', '.'

In [15]:
from fms.utils.generation import speculative_generate

out, steps = speculative_generate(model, inp2, test, 30, 30, top_k=5, threshes=[10,3,2], verbose_dict=t.vocab)
print()
print("Steps:", steps)

Speculation: ['?', '▁I', "'", 'm'] n_correct: 3
Updated output: ['<s>', '▁Hello', ',', '▁how', '▁are', '▁you', '▁today', '?', '▁I', "'", 'm', '▁doing']

Speculation: ['▁doing', '▁a', '▁great', '▁job'] n_correct: 0
Updated output: ['<s>', '▁Hello', ',', '▁how', '▁are', '▁you', '▁today', '?', '▁I', "'", 'm', '▁doing', '▁well']

Speculation: ['▁well', '.', '▁I', "'"] n_correct: 0
Updated output: ['<s>', '▁Hello', ',', '▁how', '▁are', '▁you', '▁today', '?', '▁I', "'", 'm', '▁doing', '▁well', ',']

Speculation: [',', '▁thanks', '▁for', '▁the'] n_correct: 2
Updated output: ['<s>', '▁Hello', ',', '▁how', '▁are', '▁you', '▁today', '?', '▁I', "'", 'm', '▁doing', '▁well', ',', '▁thanks', '▁for', '▁asking']

Speculation: ['▁asking', '!', '▁I', "'"] n_correct: 3
Updated output: ['<s>', '▁Hello', ',', '▁how', '▁are', '▁you', '▁today', '?', '▁I', "'", 'm', '▁doing', '▁well', ',', '▁thanks', '▁for', '▁asking', '!', '▁I', "'", 'm']

Speculation: ['m', '▁not', '▁sure', '▁how'] n_correct: 0
Updated outp

In [16]:
from fms.utils.generation import speculative_generate

out, steps = speculative_generate(model, inp3, test, 30, 30, top_k=5, threshes=[10,3,2], verbose_dict=t.vocab)
print()
print("Steps:", steps)

Speculation: ['▁is', '▁the', '▁ele', 'ph'] n_correct: 1
Updated output: ['<s>', '▁The', '▁largest', '▁land', '▁dwell', 'ing', '▁animal', '▁is', '▁the', '▁African']

Speculation: ['▁African', '▁ele', 'ph', 'ant'] n_correct: 3
Updated output: ['<s>', '▁The', '▁largest', '▁land', '▁dwell', 'ing', '▁animal', '▁is', '▁the', '▁African', '▁ele', 'ph', 'ant', ',']

Speculation: [',', '▁which', '▁is', '▁the'] n_correct: 1
Updated output: ['<s>', '▁The', '▁largest', '▁land', '▁dwell', 'ing', '▁animal', '▁is', '▁the', '▁African', '▁ele', 'ph', 'ant', ',', '▁which', '▁can']

Speculation: ['▁can', '▁reach', '▁up', '▁to'] n_correct: 0
Updated output: ['<s>', '▁The', '▁largest', '▁land', '▁dwell', 'ing', '▁animal', '▁is', '▁the', '▁African', '▁ele', 'ph', 'ant', ',', '▁which', '▁can', '▁we']

Speculation: ['▁we', 'igh', '▁up', '▁to'] n_correct: 3
Updated output: ['<s>', '▁The', '▁largest', '▁land', '▁dwell', 'ing', '▁animal', '▁is', '▁the', '▁African', '▁ele', 'ph', 'ant', ',', '▁which', '▁can', '▁we