In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from typing import Callable, Union

In [2]:
from fms.models.llama import LLaMAConfig, LLaMA

In [3]:
modelc = LLaMAConfig(32000, 4096, 1e-6, 32, 0, 32)
model = LLaMA(modelc)

In [4]:
d = torch.load("../../../llama_7b_ckp.pth")['model_state']

In [5]:
keylist = list(d.keys())
for key in keylist:
    if "dec_process" in key:
        value = d.pop(key)
        fields = key.split(".")
        fields[0] = "layers"
        d[".".join(fields)] = value

In [6]:
model.load_state_dict(d, strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['rope.freqs'])

In [7]:
# from fms.utils.generation import generate
from transformers import AutoTokenizer
t = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
vinv = {v:k for k,v in t.vocab.items()}

In [17]:
inp = t("Hello! How are you today?")["input_ids"]
inp2 = t("Where will you be tomorrow?")["input_ids"]
inp = torch.stack([torch.IntTensor(inp),torch.IntTensor(inp2)])
for line in inp:
    print([vinv[x.item()] for x in line])

['<s>', '▁Hello', '!', '▁How', '▁are', '▁you', '▁today', '?']
['<s>', '▁Where', '▁will', '▁you', '▁be', '▁tom', 'orrow', '?']


In [45]:
out, kv, embeds = generate(model, inp, 8, 8, do_sample=True, use_cache=True)
for line in out:
    print([vinv[x] for x in line.tolist()])

torch.Size([2, 8, 4096]) <s> Hello! How are you today? How
torch.Size([2, 9, 4096]) <s> Hello! How are you today? How is
torch.Size([2, 10, 4096]) <s> Hello! How are you today? How is life
torch.Size([2, 11, 4096]) <s> Hello! How are you today? How is life?
torch.Size([2, 12, 4096]) <s> Hello! How are you today? How is life?

torch.Size([2, 13, 4096]) <s> Hello! How are you today? How is life?
I
torch.Size([2, 14, 4096]) <s> Hello! How are you today? How is life?
I hope
torch.Size([2, 15, 4096]) <s> Hello! How are you today? How is life?
I hope you
['<s>', '▁Hello', '!', '▁How', '▁are', '▁you', '▁today', '?', '▁How', '▁is', '▁life', '?', '<0x0A>', 'I', '▁hope', '▁you']
['<s>', '▁Where', '▁will', '▁you', '▁be', '▁tom', 'orrow', '?', '<0x0A>', 'In', '▁the', '▁', '1', '2', '▁days', '▁of']


In [52]:
line = [vinv[x] for x in out[0].tolist()]
line

['<s>',
 '▁Hello',
 '!',
 '▁How',
 '▁are',
 '▁you',
 '▁today',
 '?',
 '▁How',
 '▁is',
 '▁life',
 '?',
 '<0x0A>',
 'I',
 '▁hope',
 '▁you']

In [55]:
e = line[:-1]
t = line[1:]

print(e[:-1])
print(t[1:])

['<s>', '▁Hello', '!', '▁How', '▁are', '▁you', '▁today', '?', '▁How', '▁is', '▁life', '?', '<0x0A>', 'I']
['!', '▁How', '▁are', '▁you', '▁today', '?', '▁How', '▁is', '▁life', '?', '<0x0A>', 'I', '▁hope', '▁you']


In [43]:
out.size(), embeds.size(), kv[0][0].size()

(torch.Size([2, 16]), torch.Size([2, 15, 4096]), torch.Size([2, 32, 15, 128]))

In [51]:
new_kv = torch.stack([torch.cat([x.transpose(1,2) for x in kv_], dim=2) for kv_ in kv], dim=2).flatten(2)
new_kv.size()

torch.Size([2, 15, 262144])

In [26]:
t.decode(out[1].tolist())

"<s> Where will you be tomorrow? This year, I will be at the National Conference on Volunteer Engagement in Indianapolis, Indiana.\nIf you're looking for ways to engage more volunteers on campus for your non-profit or for-profit business, you'll want to be there as well. Here are five reasons to attend:\n1. The 2017 NCVAE is a great place to meet people.\nIf you want to meet people from across the country, this is the perfect conference for you. You'll have the opportunity to meet other non-profits in person and learn what they'"

In [44]:
def generate(
    model: Union[Callable, torch.nn.Module],
    input_ids: torch.LongTensor,
    max_seq_len: int = 2048,
    max_new_tokens: int = 256,
    temperature: float = 1.0,
    top_k: int = 10,
    do_sample: bool = True,
    num_beams: int = 1,
    use_cache: bool = False,
):
    """
    A trivial generate function that can be used for validation/testing in
    cases where HF is not available.
    We could add implementations for other types of generation, but this is
    enough for making sure a model is working.
    Does not implement batching nor beam search, but those could be added.

    Args:
        model: A function or nn.Module that takes a batch of input_ids and
            returns logits
        prefix: A tensor of token IDs.
        max_seq_len: the sequence length of the model
        max_new_tokens: max tokens to generate
        temperature: temperature of softmax when sampling
        top_k: only search among top k tokens
        do_sample: multinomial sampling. False for greedy.
        num_beams: TODO: support beam search
        use_cache: requires that the model accept use_cache and
            past_key_value_states args in forward method.
    """
    batched = False
    if num_beams != 1:
        raise NotImplementedError("generate() does yet not support beam search")
    if type(input_ids) == torch.Tensor:
        if input_ids.dim() != 1:
            batched = True
    else:
        raise RuntimeError("generate() requires a tensor of token ids as the prefix")

    if not batched:
        input_ids = input_ids.unsqueeze(0)

    embeds = None
    result = input_ids
    next_input = input_ids
    kwargs = dict()
    kwargs["past_key_value_states"] = None
    kwargs["use_cache"] = use_cache

    for _ in range(max_new_tokens):
        input_ids = next_input[:, -max_seq_len:]
        output = model.forward(input_ids, include_embeds=True, **kwargs)
        if use_cache:
            logits, past_key_value_states, z = output
            # kv updates are required for torch.compile with
            # mode='reduce-overhead'
            n_kv_s = []
            for layer_idx in range(len(past_key_value_states)):
                n_kv_s.append([])
                for tensor_idx in range(len(past_key_value_states[layer_idx])):
                    n_kv_s[layer_idx].append(
                        past_key_value_states[layer_idx][tensor_idx]
                        .clone(memory_format=torch.contiguous_format)
                        .detach()
                    )
                    # torch._dynamo.mark_dynamic(n_kv_s[layer_idx][tensor_idx], 2)
            kwargs["past_key_value_states"] = n_kv_s
        else:
            logits, z = output
        logits = logits[:, -1, :]

        if do_sample:
            # get logits from last value in sequence nad scale
            logits = logits / temperature
            if top_k:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float("inf")

            probs = F.softmax(logits, dim=-1)
            next_val = torch.multinomial(probs, num_samples=1)
        else:
            next_val = torch.argmax(logits, dim=-1).unsqueeze(0).t()

        result = torch.cat((result, next_val), dim=-1)
        if embeds is None:
            embeds = z
        else:
            embeds = torch.cat((embeds, z), dim=-2)
        print(embeds.size(), t.decode(result[0].tolist()))

        if use_cache:
            next_input = next_val
        else:
            next_input = result

    if not batched:
        result = result[0]
    return result, n_kv_s, embeds