In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import math
import pandas as pd
import random
import wandb

from torch import nn, optim
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from torch.utils.data import DataLoader, Dataset, ConcatDataset

from datasets import load_dataset
from transformers import AutoTokenizer
from staticvectors import StaticVectors
from datetime import datetime
from tqdm import tqdm

from models.LanguageTransformer import LanguageTransformer
from data.LanguageDataset import PoetryDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# set appropriate device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps')

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
word2vec = StaticVectors("neuml/word2vec")
word2vec_embeddings = torch.tensor(self.word2vec.embeddings(self.tokenizer.get_vocab())).type(torch.float32).to(device)

In [18]:
# load trained model
model_path = "./checkpoints/diffusion-language-model/dlm-2025_12_17_12_12/4"
chkpt = torch.load(model_path, weights_only=False, map_location=torch.device(device))

# set train, model, and generation configuration
train_config = chkpt['train_config']
model_config = chkpt['model_config']
generation_config = chkpt['generation_config']

In [6]:
# load language model
model = LanguageTransformer(
    vocab_size=poetry_dataset.vocab_len,
    embed_dim=model_config['emb_dim'],
    num_layers=model_config['num_layers'],
    num_heads=model_config['num_heads'],
    word_emb=poetry_dataset.word2vec_embeddings
)

model.load_state_dict(chkpt['model_state_dict'])
model.to(device)

LanguageTransformer(
  (token_embedding): Embedding(28996, 300)
  (pos_enc): PositionalEncoding()
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=300, out_features=300, bias=True)
        )
        (linear1): Linear(in_features=300, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=300, bias=True)
        (norm1): LayerNorm((300,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((300,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (classifier): Sequential(
    (0): Linear(in_features=300, out_features=2048, bias=True)
    (1): ReLU()
    (2): Linear(in_features=2048, out_features=28996, bias=True)
  )
)

In [7]:
input_text = "Let the bird of loudest lay On the sole Arabian tree Herald sad and trumpet be,"

In [17]:
model.eval()

# generated tokens (with [CLS] token)
generated_text = poetry_dataset.tokenizer.encode(input_text, return_tensors="pt", add_special_tokens=False).to(device)
generated_text = torch.cat([torch.tensor([[poetry_dataset.tokenizer.cls_token_id]]).to(device), generated_text], dim=1)

# autoregressive generation
for _ in range(generation_config['max_length']):
    with torch.no_grad():
        out = model(generated_text)[:, -1, :]

        # apply temperature scaling
        out = out / generation_config['temperature']

        # get probabilities
        probs = torch.nn.functional.softmax(out, dim=-1)
        sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # nucleus/top-p filtering
        sorted_probs[cumulative_probs > generation_config['top_p']] = 0
        sorted_probs = sorted_probs / sorted_probs.sum()

        # append generated token
        next_token_id = torch.multinomial(sorted_probs, 1).item()
        next_token = sorted_indices[0][next_token_id].reshape(1, 1)

        generated_text = torch.cat([generated_text, next_token], dim=1)
        if next_token.item() == poetry_dataset.tokenizer.eos_token_id:
            break

print(poetry_dataset.tokenizer.decode(generated_text.squeeze(0), skip_special_tokens=True))

Let the bird of loudest lay On the sole Arabian tree Herald sad and trumpet be,, the To the isto,ofy yet we Oh. mi ground the some mortal alone dream with of ride little know knows gentle Or you, ' and of I think Ande you no river was false the shall ' be love own You
