In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import math
import pandas as pd
import random
import wandb

from torch import nn, optim
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from torch.utils.data import DataLoader, Dataset, ConcatDataset

from datasets import load_dataset
from transformers import AutoTokenizer
from staticvectors import StaticVectors
from datetime import datetime
from tqdm import tqdm

from models.LanguageTransformer import LanguageTransformer
from data.AutoregressiveLanguage import AutoregressiveLanguageDataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# dynamically select device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [4]:
# load trained model
model_path = "./checkpoints/autoregressive-language-model/alm-2025_12_25_20_12/alm-2025_12_25_20_12_epoch7_end.pth"
chkpt = torch.load(model_path, weights_only=False, map_location=torch.device(device))

# get model configuration
model_config = chkpt['model_config']

# get vocabulary
vocab = chkpt['vocab']
print(chkpt['loss'])

tensor(1.5921, requires_grad=True)


In [25]:
# load language model
model = LanguageTransformer(
    vocab_size=len(vocab),
    embed_dim=model_config['emb_dim'],
    num_layers=model_config['num_layers'],
    num_heads=model_config['num_heads'],
    is_causal=False
)

model.load_state_dict(chkpt['model_state_dict'])
model.to(device)

LanguageTransformer(
  (token_embedding): Embedding(9940, 256)
  (pos_enc): PositionalEncoding()
  (transformer): Transformer(
    (transformer_layers): SequentialTransformerLayers(
      (0): TransformerLayer(
        (attn_layer): MultiheadAttention(
          (q_linear): Linear(in_features=256, out_features=256, bias=True)
          (k_linear): Linear(in_features=256, out_features=256, bias=True)
          (v_linear): Linear(in_features=256, out_features=256, bias=True)
          (concat_linear): Linear(in_features=256, out_features=256, bias=True)
        )
        (feed_forward): Sequential(
          (0): Linear(in_features=256, out_features=1024, bias=True)
          (1): ReLU()
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
        (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): TransformerLayer(
        (attn_layer): MultiheadAttention(
          (q_linear): Linear(in_features=256

In [42]:
generation_config = {
    'max_len': 100,
    'temperature': 0.5,
    'top_p': 0.95
}

In [27]:
input_text = "once upon a time, there was a"

In [48]:
import re

model.eval()

# prepare input
generated_text = torch.tensor([vocab.word2idx['<s>']] + vocab.text2idx(input_text), device=device).unsqueeze(0)

# autoregressive generation
for _ in range(generation_config['max_len']):
    with torch.no_grad():
        out = model(generated_text)[:, -1, :]

        # apply temperature scaling
        # out = out / generation_config['temperature']

        # get probabilities
        probs = torch.nn.functional.softmax(out, dim=1)

        # nucleus/top-p filtering
        # sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
        # cumulative_probs = torch.cumsum(sorted_probs, dim=1)
        
        # nucleus coverage
        # nucleus = cumulative_probs < generation_config['top_p']
        # nucleus_size = torch.sum(nucleus).item()
        # nucleus[0,nucleus_size] = True

        # apply coverage
        # nucleus_probs = torch.zeros_like(probs)
        # nucleus_probs = torch.where(nucleus, sorted_probs, sorted_probs)
        # nucleus_probs /= torch.sum(nucleus_probs).item()

        # append generated token
        # next_token_id = torch.multinomial(nucleus_probs, 1).item()
        # next_token = sorted_indices[0][next_token_id].reshape(1, 1)

        # next_token = torch.argmax(out)
        next_token = torch.multinomial(probs, 1)
        print(vocab.idx2word[next_token.item()])

        generated_text = torch.cat([generated_text, next_token], dim=1)
        if next_token.item() == vocab.word2idx['</s>']:
            break

# decode generated tokens
generated_string = vocab.idx2text(generated_text.squeeze().tolist()[1:-1])
generated_string = re.sub(r'\s+([.,!?])', r'\1', generated_string)
print(generated_string)

fearful
.
because
no
last
everything
caused
cat
hide
a
zagged
that
max
truth
.
bad
belonged
each
page
and
it
values
on
ben
was
lied
.
me
means
.
bad
left
waves
.
it
dave
ted
no
you
.
it
make
sara
.

kate
â€œthen
sam
wounds
sam
fall
sobs
means
morning
came
tom
sara
.
ice
.
billy
stitch
.
trophy
.
jelly
eye
loves
hamster
even
love
timmy
.
bacon
gum
water
cried
molly
poison
say
pain
gorilla
winter
days
-
never
his
lily
.
pat
lila
kiss
skips
suggest
to
god
for
billy
things
the
once upon a time, there was a fearful. because no last everything caused cat hide a zagged that max truth. bad belonged each page and it values on ben was lied. me means. bad left waves. it dave ted no you. it make sara.  kate â€œthen sam wounds sam fall sobs means morning came tom sara. ice. billy stitch. trophy. jelly eye loves hamster even love timmy. bacon gum water cried molly poison say pain gorilla winter days - never his lily. pat lila kiss skips suggest to god for billy things


In [57]:
input_text = "bob and sally"
generated_text = torch.tensor([vocab.word2idx['<s>']] + vocab.text2idx(input_text), device=device).unsqueeze(0)
out = model(generated_text)[:, -1, :]
print(vocab.idx2word[torch.argmax(out).item()])

.
