In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from typing import Optional

In [2]:
import os
os.getcwd()
import sys
sys.path.append("/Users/htkumar/llms")

In [3]:
from typing import Optional
import torch
import time
from pathlib import Path
import json
from sentencepiece import SentencePieceProcessor
from tqdm import tqdm

from model import ModelArgs, Transformer

In [4]:
checkpoints_dir = 'llama-2-7b/'
tokenizer_path='tokenizer.model'

In [5]:
prev_time = time.time()
checkpoints = sorted(Path(checkpoints_dir).glob("*.pth"))
checkpoints
checkpoint = torch.load(checkpoints[0], map_location="cpu")
print(f"Loaded checkpoint in {time.time() - prev_time:.2f}s")

Loaded checkpoint in 5.56s


In [6]:
with open(Path(checkpoints_dir) / "params.json", "r") as f:
    params = json.loads(f.read())
    print(params)

{'dim': 4096, 'multiple_of': 256, 'n_heads': 32, 'n_layers': 32, 'norm_eps': 1e-05, 'vocab_size': -1}


In [7]:
model_args: ModelArgs = ModelArgs(
    max_seq_len=1024,
    max_batch_size=10,
    device="cpu",
    **params
)

In [8]:
tokenizer = SentencePieceProcessor()
tokenizer.load(tokenizer_path)
model_args.vocab_size = tokenizer.vocab_size()

In [9]:
torch.set_default_tensor_type(torch.BFloat16Tensor)

  _C._set_default_tensor_type(t)


In [10]:
model = Transformer(model_args).to("cpu")

In [11]:
del checkpoint['rope.freqs']
model.load_state_dict(checkpoint, strict=True)

<All keys matched successfully>

In [12]:
prompts = ['How are you doing', 'who is zuck']
temperature = 0.6
top_p = 0.9
max_gen_len = 64
max_seq_len = 1024
device = 'cpu'

In [13]:
prompt_tokens = [tokenizer.encode(prompt, out_type=int, add_bos=True, add_eos=False) for prompt in prompts]
prompt_tokens
batch_size = len(prompt_tokens)
max_prompt_len = max(len(prompt_token) for prompt_token in prompt_tokens)
max_prompt_len
type(prompt_tokens)

list

In [14]:
total_len = min(max_seq_len, max_gen_len + max_prompt_len)
total_len

69

In [15]:
pad_id = tokenizer.pad_id()
pad_id
tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.long, device=device)
tokens.shape
tokens

tensor([[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
         -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
         -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
         -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
         -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
         -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
         -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]])

In [16]:
for k, t in enumerate(prompt_tokens):
    tokens[k, :len(t)] = torch.tensor(t, dtype=torch.long, device=device)

tokens

tensor([[   1, 1128,  526,  366, 2599,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1],
        [   1, 1058,  338, 1729,  384,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]])

In [17]:
eos_reached = torch.tensor([False] *  batch_size, device=device)
prompt_token_mask = tokens != pad_id
prompt_token_mask

tensor([[ True,  True,  True,  True,  True, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, 

In [18]:
def _sample_top_p(probs, p):
    # (B, vocab_size)
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    # (B, vocab_size)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    # (B, vocab_size)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0

    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    # (B, 1)
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token

In [21]:
cur_iterator = tqdm(range(1, 2), desc='Generating tokens')

Generating tokens:   0%|          | 0/1 [00:00<?, ?it/s]

In [23]:
for cur_pos in cur_iterator:
    with torch.no_grad():
        logits = model(tokens[:, cur_pos-1:cur_pos], cur_pos)
        print(logits.shape)
        
    if temperature > 0:
        probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
        next_token = _sample_top_p(probs, top_p)
        print(next_token)
    else:
        next_token = torch.argmax(logits[:, -1], dim=-1)
        

torch.Size([2, 1, 32000])
tensor([[23795],
        [30322]])


In [32]:
print(next_token.shape)

torch.Size([2, 1])


In [33]:
next_token = next_token.reshape(-1)
next_token.shape

torch.Size([2])

In [37]:
cur_pos

1

In [36]:
next_token = torch.where(prompt_token_mask[:, cur_pos], tokens[:, cur_pos], next_token)
next_token
tokens[:, cur_pos] = next_token

tensor([1128, 1058])

In [38]:
eos_reached |= (~prompt_token_mask[:, cur_pos]) & (next_token == tokenizer.eos_id)
eos_reached

tensor([False, False])

In [39]:
tokens

tensor([[   1, 1128,  526,  366, 2599,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1],
        [   1, 1058,  338, 1729,  384,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
           -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]])

In [41]:
# tokens.tolist()

In [44]:
tokenizer.eos_id()

2

In [46]:
out_tokens = []
out_text = []

In [48]:
for prompt_index, current_prompt_tokens in enumerate(tokens.tolist()):
    # if tokenizer.eos_id in current_prompt_tokens:
    #     eos_idx = current_prompt_tokens.index(tokenizer.eos_id)
    #     print(eos_idx)
    #     current_prompt_tokens = current_prompt_tokens[:5]
    
    out_tokens.append(current_prompt_tokens)
    out_text.append(tokenizer.decode(current_prompt_tokens[:5]))

In [49]:
out_text

['How are you doing', 'who is zuck']

In [34]:
torch.where??

[0;31mDocstring:[0m
where(condition, input, other, *, out=None) -> Tensor

Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`.

The operation is defined as:

.. math::
    \text{out}_i = \begin{cases}
        \text{input}_i & \text{if } \text{condition}_i \\
        \text{other}_i & \text{otherwise} \\
    \end{cases}

.. note::
    The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable <broadcasting-semantics>`.

Arguments:
    condition (BoolTensor): When True (nonzero), yield input, otherwise yield other
    input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices
                          where :attr:`condition` is ``True``
    other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices
                          where :attr:`condition` is ``False``

Keyword args:
    out (Tensor, optional): the output tensor.

Returns:


In [28]:
logits.shape, logits[:, -1, :].shape

(torch.Size([2, 1, 32000]), torch.Size([2, 32000]))

In [30]:
a = torch.ones((2, 3, 5))
a.shape, a[:, -1].shape

(torch.Size([2, 3, 5]), torch.Size([2, 5]))