In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import math
import pandas as pd
import random
import wandb

from torch import nn, optim
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from torch.utils.data import DataLoader, Dataset, ConcatDataset

from datasets import load_dataset
from transformers import AutoTokenizer
from staticvectors import StaticVectors
from datetime import datetime
from tqdm import tqdm

from models.LanguageTransformer import LanguageTransformer
from data.AutoregressiveLanguage import AutoregressiveLanguageDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# dynamically select device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [3]:
model = LanguageTransformer(
    vocab_size=100,
    embed_dim=256,
    num_layers=4,
    num_heads=4,
    is_causal=True
).to(device)

In [None]:
# load trained model
model_path = "/Users/josh/Documents/GitHub/diffusion-language-model/checkpoints/autoregressive-language-model/alm-2025_12_29_21_12_epoch1_end.pth"
chkpt = torch.load(model_path, weights_only=False, map_location=torch.device(device))

# get model configuration
model_config = chkpt['model_config']

# get vocabulary
vocab = chkpt['vocab']
model = LanguageTransformer(
vocab_size=len(vocab),
embed_dim=model_config['emb_dim'],
num_layers=model_config['num_layers'],
num_heads=model_config['num_heads'],
is_causal=True
)

model.load_state_dict(chkpt['model_state_dict'])
model.to(device)

In [5]:
seq = torch.randint(0, 100, (32, 10)).to(device)
out = model(seq)
assert out.shape == (32, 10, 100)

In [None]:
generation_config = {
    'max_len': 100,
    'temperature': 0.25,
    'top_p': 0.95
}

In [None]:
input_text = "once upon a time, there"

In [None]:
import re

model.eval()

# prepare input
generated_text = torch.tensor([vocab.word2idx['<s>']] + vocab.text2idx(input_text), device=device).unsqueeze(0)

# autoregressive generation
for _ in range(generation_config['max_len']):
    with torch.no_grad():
        out = model(generated_text)[:, -1, :]

        # apply temperature scaling
        out = out / generation_config['temperature']

        # get probabilities
        probs = torch.nn.functional.softmax(out, dim=1)

        # nucleus/top-p filtering
        sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
        cumulative_probs = torch.cumsum(sorted_probs, dim=1)
        
        # nucleus coverage
        nucleus = cumulative_probs < generation_config['top_p']
        nucleus_size = torch.sum(nucleus).item()
        nucleus[0,nucleus_size] = True

        # apply coverage
        nucleus_probs = torch.zeros_like(probs)
        nucleus_probs = torch.where(nucleus, sorted_probs, sorted_probs)
        nucleus_probs /= torch.sum(nucleus_probs).item()

        # append generated token
        next_token_id = torch.multinomial(nucleus_probs, 1).item()
        next_token = sorted_indices[0][next_token_id].reshape(1, 1)

        # next_token = torch.argmax(out, dim=1).reshape(1, 1)

        generated_text = torch.cat([generated_text, next_token], dim=1)
        if next_token.item() == vocab.word2idx['</s>']:
            break

# decode generated tokens
generated_string = vocab.idx2text(generated_text.squeeze().tolist()[1:-1])
generated_string = re.sub(r'\s+([.,!?])', r'\1', generated_string)
print(generated_string)

In [None]:
input_text = "once upon a time"
generated_text = torch.tensor([vocab.word2idx['<s>']] + vocab.text2idx(input_text), device=device).unsqueeze(0)
out = model(generated_text)[:, :-1, :]
print(vocab.idx2text(torch.argmax(out, dim=-1).squeeze().tolist()[1:]))
# print(vocab.idx2word[torch.argmax(out).item()])