In [115]:
"""
This prediction script is the Python version of run.c
"""

import torch
from torch.nn import functional as F
from model import Transformer, ModelArgs
from tokenizer import Tokenizer
from datetime import datetime
import time
import sys
from IPython.display import clear_output

In [9]:
checkpoint = torch.load('stories15M.pt', map_location='cpu')
model_args = checkpoint["model_args"]
# force these config attributes to be equal otherwise we can't even resume training
# the rest of the attributes (e.g. dropout) can stay as desired from command line
# create the model
gptconf = ModelArgs(**model_args)
model = Transformer(gptconf)
state_dict = checkpoint["model"]
# fix the keys of the state dictionary :(
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [37]:
model.vocab_size

32000

In [135]:
max_seq_len = model.params.max_seq_len
max_seq_len

256

In [86]:
tokenizer = Tokenizer('tokenizer.model')

In [127]:
prompt = ''
steps = 1024
temperature = 0.0   # 0.0 = greedy deterministic. 1.0 = original. don't set higher
topp = 1.0         # top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower

In [128]:
num_prompt_tokens = 0
prompt_tokens = []
if prompt:
    prompt_tokens = tokenizer.encode(prompt, False, False)
    num_prompt_tokens = len(prompt_tokens)
prompt_tokens

[]

In [136]:
BOS = 1
# start the main loop
start = 0   # used to time our code, only initialized after first iteration
nxt = 0     # will store the next token in the sequence
token = BOS # init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
pos = 0     # position in the sequence
token_str = ''
tokens = []
X = torch.tensor([[token]], dtype=torch.long)

while pos < steps:
    logits = model(X)
    if pos < num_prompt_tokens:
        nxt = prompt_tokens[pos]
    else:
        if temperature == 0.0:
            nxt = torch.argmax(logits).item()
        else:
            # focus only on the last time step, pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature # becomes (B, C)
            if 0 < topp < 1:
                top_k = int(logits.shape[-1] * (1-topp))
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            nxt = torch.multinomial(probs, num_samples=1) # (B, 1)
            nxt = nxt.item()
    pos += 1

    # data-dependent terminating condition: the BOS (1) token delimits sequences
    if nxt == BOS:
        break

    # init the timer here because the first iteration can be slower
    if start == 0:
        start = datetime.now()
    # following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
    tokens.append(nxt)
    token_str = tokenizer.decode(tokens)
    # print(f"{token_str}", end='\r', flush=True)
    # print("\r{}".format(token_str), end="")
    # sys.stdout.write('\r' + token_str)
    # time.sleep(0.05) # This line is to see if it's working or not
    # Clear the previous output
    clear_output(wait=True)
    print(token_str)
    # append sampled index to the running sequence
    X = torch.cat((X, torch.tensor([[nxt]], dtype=torch.long)), dim=1) # (B, T+1)
    # if the sequence context is growing too long we must crop it at block_size
    X = X if X.size(-1) <= max_seq_len else X[:, -max_seq_len:]
            

# report achieved tok/s (pos-1 because the timer starts after first iteration)
if pos > 1:
    end = datetime.now()
    print(f"\n\nachieved tok/s: {(pos-1)/(end-start).seconds:6f}\n")

Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big, red ball in the sky. It was the sun! She thought it was so pretty.
Lily wanted to play with the ball, but it was too high up in the sky. She tried to jump and reach it, but she couldn't. Then, she had an idea. She would use a stick to knock the ball down.
Lily found a stick and tried to hit the ball. But the stick was too short. She tried again and again, but she couldn't reach it. She felt sad.
Suddenly, a kind man came by and saw Lily. He asked her what was wrong. Lily told him about the ball. The man smiled and said, "I have a useful idea!" He took out a long stick and used it to knock the ball down. Lily was so happy! She thanked the man and they played together in the sunshine.


achieved tok/s: 27.625000

