In [1]:
import torch
import torch.nn as nn

from transformers import PreTrainedTokenizerFast

MAX_SEQ_LEN = 256
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="models/tokenizer.json",
    pad_token="[PAD]",
    unk_token="[UNK]",
    eos_token="<|endoftext|>",
    max_len = MAX_SEQ_LEN,
    add_prefix_space=False
)

VOCAB_SIZE = tokenizer.vocab_size
DEVICE="cuda"

In [2]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, n_heads):
        super(TransformerBlock, self).__init__()
        self.norm_layer_1 = nn.LayerNorm(embed_dim)
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dim, 
            num_heads=n_heads,
            bias=False,
            dropout=0.1,
            batch_first=True
        )

        self.norm_layer_2 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim*8//3,),
            nn.Linear(embed_dim*8//3, embed_dim,),
            nn.ReLU(),
        )
        self.dropout = nn.Dropout(0.1)

    def forward(self, x, ):
        # X --> (BATCH_SIZE, CONTEXT_LENGTH, EMBED_DIM) 
        x_norm = self.norm_layer_1(x)
        attn_mask = torch.triu(
            torch.zeros(
                (x_norm.size(1), x_norm.size(1))
            ), 
            diagonal=0
        ).to(x.device)
        attn_mask[attn_mask>0] = -torch.inf

        x_norm, _ = self.attention(
            x_norm, 
            x_norm, 
            x_norm, 
            attn_mask=attn_mask, 
            is_causal=True
        )
        x = x + x_norm

        x_norm = self.norm_layer_2(x)
        x_norm = self.ffn(x_norm)
        x_norm = self.dropout(x_norm)

        return x + x_norm

In [3]:
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_seq_len=MAX_SEQ_LEN, n_layers=5, n_heads=4):
        super(TransformerLM, self).__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.position_emb = nn.Embedding(max_seq_len, embed_dim)

        self.dropout = nn.Dropout(0.1)

        self.transfomers = nn.Sequential(
            *[
                TransformerBlock(
                    embed_dim=embed_dim,
                    n_heads=n_heads
                ) for _ in range(n_layers)
            ]
        )
        self.norm = nn.LayerNorm(embed_dim)
        self.out = nn.Linear(embed_dim, vocab_size)

    def forward(self, x:torch.Tensor):
        x = self.dropout(
            self.token_emb(x) + self.position_emb(torch.arange(x.size(1), device=x.device))
        )
        x = self.transfomers((x))
        x = self.norm(x)
        x = self.out(x)

        return x.reshape((x.shape[0], x.shape[2], x.shape[1]))

In [4]:
model = TransformerLM(
    vocab_size=VOCAB_SIZE,
    embed_dim=512,
    max_seq_len=MAX_SEQ_LEN,
    n_layers=4,
    n_heads=16,
).to(DEVICE)

model.load_state_dict(
    torch.load("models/model_0_1", weights_only=True)
)

model.eval();

In [5]:

def _get_next_token(
        model,
        tokens,
        temperature,
        k=10
    ):
    
    with torch.no_grad():
        pred_tokens = model(tokens)[:,:,-1]/temperature
        order = torch.argsort(pred_tokens, dim=1,descending=True)
        
        token_probs = torch.softmax(pred_tokens, dim=1)[0]
        token_probs[order[k:]] = 0
        token_probs/=token_probs.sum()
        
        next_token = token_probs.multinomial(1)[0]
        #next_token = torch.softmax(pred_tokens[:, :, next_token_pos], dim=1).argmax().reshape(-1,1)
    return next_token
    

In [6]:
def get_story_2(model, tokenizer, seed_text, max_generated,device="cuda", temperature=1, top_k=10):
    base = tokenizer(
        seed_text,
        return_tensors="pt",
    )
    
    n_generated = 0
    tokens = base["input_ids"].to(device)
    
    while n_generated < max_generated:
        next_token = _get_next_token(
            model,
            tokens,
            temperature=temperature,
            k=top_k
        )

        if (next_token == tokenizer.eos_token_id):
            print("Reached end of text!")
            break

        tokens = torch.cat([tokens, next_token.reshape(-1,1)], dim=1)
        n_generated += 1
    
    return tokenizer.decode(tokens.squeeze())

In [24]:
get_story_2(model, tokenizer, "johnny boy is a ", 15, top_k=5, temperature=30, device=DEVICE)

'johnny  boy  is  a    pad  pl  deter  mommy  wit  rhinocerier  extraiest  onto  package  jack - X  cra'

In [25]:
get_story_2(model, tokenizer, "jack was a little kid that", 15, top_k=5, temperature=30, device=DEVICE)

'jack  was  a  little  kid  that  had4  squir  raining  ad  sne  treweween  cats  poison  mint  buy ģ  dogs'

In [26]:
get_story_2(model, tokenizer, "once upon a time, ", 15, top_k=5, temperature=30, device=DEVICE)

'once  upon  a  time,    smileasses🍌  n æ留  task  fix  yourself O  buck  hugquest  freddy ғ'

In [16]:
import torch.functional as F
import numpy as np

def get_kth_most_likely_token(
        input_tokens,  
        model, 
        tokenizer, 
        k
    ):
    outputs = model(input_tokens["input_ids"])[0,:,-1]
    prob_over_tokens = torch.softmax(outputs, dim=0)

    order = torch.argsort(prob_over_tokens, dim=0, descending=True)
    next_token = order[k]

    output_tokens = input_tokens.copy()
    output_tokens["input_ids"] = torch.cat((output_tokens["input_ids"].to("cpu"), torch.tensor([[next_token]])), dim=1)
    output_tokens["attention_mask"] = torch.cat((output_tokens["attention_mask"].to("cpu"), torch.tensor([[1]])), dim=1)
    output_tokens["last_token_prob"] = prob_over_tokens[next_token]
    output_tokens["log_prob"] += torch.log(prob_over_tokens[next_token])

    return output_tokens

In [17]:
input_txt = "once upon a time"
input_tokens = tokenizer(
    input_txt,
    return_tensors="pt"
).to(DEVICE)
input_tokens["log_prob"] = 0.

for i in range(5):
    input_tokens = get_kth_most_likely_token(
        input_tokens.to(DEVICE),
        model,
        tokenizer,
        1
    )
    print(
        tokenizer.decode(
            input_tokens["input_ids"][0],
            skip_special_tokens=True
        )
    )

once  upon  a  time  clean
once  upon  a  time  clean  dad
once  upon  a  time  clean  dad  ran
once  upon  a  time  clean  dad  ran µ
once  upon  a  time  clean  dad  ran µ "


In [18]:
import numpy as np
def print_beams(beams, tokenizer):
    for i, b in enumerate(beams):
        print(f"Beam {i}, Prob {b["log_prob"]:.3f}:", tokenizer.decode(b["input_ids"][0], skip_special_tokens=True))
    print("------")

def do_search_beam(
        input_tokens_in,
        model,
        tokenizer,
        n_beam=5,
        beam_length=10
):
    input_tokens_in["log_prob"] = 0.

    beams = [None for _ in range(n_beam)]
    for c_k in range(n_beam):
        beams[c_k] = input_tokens_in
        beams[c_k] = get_kth_most_likely_token(beams[c_k].to(DEVICE), model, tokenizer, k=c_k)

    print_beams(beams, tokenizer)

    for c_pos in range(beam_length-1):
        beams_all = [None for _ in range(n_beam**2)]
        log_probs_all = np.zeros(n_beam**2)

        for c_beam in range(n_beam):
            for c_k in range(n_beam):
                beams_all[c_beam * n_beam + c_k] = get_kth_most_likely_token(beams[c_beam].to(DEVICE), model, tokenizer, c_k)
                log_probs_all[c_beam * n_beam + c_k] = beams_all[c_beam * n_beam + c_k]["log_prob"]
    
        sorted_index = np.argsort(np.array(log_probs_all)*-1)
        for c_k in range(n_beam):
            beams[c_k] = beams_all[sorted_index[c_k]]

        print_beams(beams, tokenizer)
    return beams[0]
            

In [21]:
input_txt = "jack was a little kid that"
input_tokens = tokenizer(
    input_txt,
    return_tensors="pt"
).to(DEVICE)

n_beams=5
best_beam = do_search_beam(input_tokens, model, tokenizer, n_beam=n_beams)
print()
print(
    tokenizer.decode(
        best_beam["input_ids"][0],
        skip_special_tokens=True
    )
)

Beam 0, Prob -1.379: jack  was  a  little  kid  that  lively
Beam 1, Prob -1.805: jack  was  a  little  kid  that  leop
Beam 2, Prob -2.023: jack  was  a  little  kid  that."
Beam 3, Prob -2.416: jack  was  a  little  kid  that  beet
Beam 4, Prob -2.663: jack  was  a  little  kid  that  wind
------
Beam 0, Prob -13.374: jack  was  a  little  kid  that  lively  dress
Beam 1, Prob -14.236: jack  was  a  little  kid  that  leop  small
Beam 2, Prob -12.842: jack  was  a  little  kid  that."  small
Beam 3, Prob -14.509: jack  was  a  little  kid  that  beet  free
Beam 4, Prob -12.498: jack  was  a  little  kid  that  wind  small
------
Beam 0, Prob -25.909: jack  was  a  little  kid  that  wind  smallor
Beam 1, Prob -26.715: jack  was  a  little  kid  that  lively  dress  amazed
Beam 2, Prob -25.919: jack  was  a  little  kid  that."  smallle
Beam 3, Prob -26.656: jack  was  a  little  kid  that  leop  small  amazed
Beam 4, Prob -29.166: jack  was  a  little  kid  that  beet  free  amazed
-