# imports

In [None]:
import os
import torch
from sentencepiece import SentencePieceProcessor
from model import *
import matplotlib.pyplot as plt
import torch.nn.functional as F

# paths

In [6]:
tokenizer_path = './tokenizer.model'
# this is the same tokenizer as found in llama2

In [7]:
checkpoint_path = './out/batch128.pt'
checkpoint_path = './instruct_out/ckpt.pt'
checkpoint_path = './instruct_out/ckpt.pt'

In [8]:
device = 'cuda:0'

# load tokenizer

In [9]:
tokenizer = SentencePieceProcessor(model_file=tokenizer_path)

# load model

In [10]:
def load_model(checkpoint_path, device, unwanted_prefix='_orig_mod', add_lora=False):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    if isinstance(checkpoint['model_args'], ModelArgs):
        config = checkpoint['model_args']
    else:
        config = ModelArgs(**checkpoint['model_args'])
    model = Transformer(config)
    if add_lora:
        lora_rank = 2
        lora_dropout = 0.1
        lora_alpha = 1.0
        lora_targets = ['wk', 'wq', 'wo', 'wv']
        apply_lora(
            model, 
            targets=lora_targets,
            rank=lora_rank,
            dropout=lora_dropout,
            alpha=lora_alpha
        )
    print(f"Number of parameters: {sum([p.nelement() for p in model.parameters()])}")
    state_dict = checkpoint['model']
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict, strict=True)
    model.eval()
    model.to(device)
    return model, checkpoint

In [12]:
instruct_model, ckpt = load_model(
    checkpoint_path='./fine_tuning_instruct/ckpt.pt',
    device=device,
    unwanted_prefix='',
    add_lora=True
)

add lora to layers.0.attention.wq
add lora to layers.0.attention.wk
add lora to layers.0.attention.wv
add lora to layers.0.attention.wo
add lora to layers.1.attention.wq
add lora to layers.1.attention.wk
add lora to layers.1.attention.wv
add lora to layers.1.attention.wo
add lora to layers.2.attention.wq
add lora to layers.2.attention.wk
add lora to layers.2.attention.wv
add lora to layers.2.attention.wo
add lora to layers.3.attention.wq
add lora to layers.3.attention.wk
add lora to layers.3.attention.wv
add lora to layers.3.attention.wo
add lora to layers.4.attention.wq
add lora to layers.4.attention.wk
add lora to layers.4.attention.wv
add lora to layers.4.attention.wo
add lora to layers.5.attention.wq
add lora to layers.5.attention.wk
add lora to layers.5.attention.wv
add lora to layers.5.attention.wo
add lora to layers.6.attention.wq
add lora to layers.6.attention.wk
add lora to layers.6.attention.wv
add lora to layers.6.attention.wo
add lora to layers.7.attention.wq
add lora to la

# sampling

### generation

In [78]:

max_new_tokens = 400 # number of tokens generated in each sample
temperature = 0.3 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 10 # retain only the top_k most likely tokens, clamp others to have 0 probability


In [85]:
def generate_paragraph(model, prompt):
    tokenized_prompt = [tokenizer.bos_id()] + tokenizer.encode(prompt)# bos=True, eos=False)
    tokenized_prompt = (torch.tensor(tokenized_prompt, dtype=torch.long, device=device)[None, ...])
    #paragraph = tokenized_prompt.flatten().tolist()

    paragraph = []
    context_tokens = tokenized_prompt
    for _ in range(max_new_tokens):
        context_tokens = context_tokens[:, -min(model.params.max_seq_len, context_tokens.size(1)):]
        output = model(context_tokens)
        
        logits = output[:, -1, :]
        
        temp_scaled_logits = logits / temperature
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        logits[logits < v[:, [-1]]] = -float('Inf')
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        context_tokens = torch.cat((context_tokens, next_token), dim=1)
        #print(, type(next_token))
        #paragraph.extend(next_token.flatten().tolist())
        paragraph.append(next_token.item())
        if next_token.item() == tokenizer.eos_id():
            print('\n eos \n')
            break
        if tokenizer.decode(paragraph[-3:]) == 'The end.':
            print('\n The end \n')
            break
    return context_tokens, paragraph, tokenizer.decode(paragraph)

In [86]:
prompt = 'Write a short story (3-5 paragraphs) which only uses very simple words that a 3 year old child would understand. In the story, try to at some point use the verb "hope", the noun "search" and the adjective "comfortable". Remember to only use simple words!\n'

In [87]:
_, tokens, paragraph = generate_paragraph(instruct_model, prompt)
print(paragraph)

Once writing the story, the child asked their mommy for help. Their mommy said she would be happy to help and she took a pen and paper. She showed the little one how to make words with their hands. The child was so excited to learn!
They practiced writing their own words and soon enough they had made a whole story. They showed it to their mommy and she was so proud of them. The story had a lot of new words that the child loved to learn. Everyone was so happy!
The story ended and the child was very proud of their accomplishment. They were so glad that they had tried to write a story. And now that the story was finished, they could go to bed and have sweet dreams. Once upon a time, in a small town, there was a little girl named Lily. Lily lived in a comfortable little house made of wood. She loved to play outside with her friends and explore the woods near her house.
One sunny day, Lily and her friends were playing hide and seek. Lily found a dark cave and decided to hide inside. As she 