In [35]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import math
import pandas as pd
import random
import wandb

from torch import nn, optim
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from torch.utils.data import DataLoader, Dataset, ConcatDataset

from datasets import load_dataset
from transformers import AutoTokenizer
from staticvectors import StaticVectors
from datetime import datetime
from tqdm import tqdm

from models.LanguageDiffusion import LanguageDiffusion
from data.DiffuseLanguage import DiffuseLanguageDataset

In [36]:
# dynamically select device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [37]:
# load trained model
model_path = "./checkpoints/diffusion-language-model/dlm-2025_12_19_22_12_epoch9_end.pth"
chkpt = torch.load(model_path, weights_only=False, map_location=torch.device(device))

# get model configuration
model_config = chkpt['model_config']

# get vocabulary
vocab = chkpt['vocab']
print(chkpt['loss'])

tensor(2.6319, device='mps:0', requires_grad=True)


In [38]:
# load language model
model = LanguageDiffusion(
    vocab_size=len(vocab),
    embed_dim=model_config['emb_dim'],
    num_layers=model_config['num_layers'],
    num_heads=model_config['num_heads']
)

model.load_state_dict(chkpt['model_state_dict'])
model.to(device)

LanguageDiffusion(
  (token_embedding): Embedding(9944, 256)
  (pos_enc): PositionalEncoding()
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-15): 16 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (classifier): Linear(in_features=256, out_features=9944, bias=True)
)

In [67]:
generation_config = {
    'max_len': 200,
    'sampling_steps': 100,
    'temperature': 0.9,
    'top_p': 0.9
}

In [80]:
import re

model.eval()

generated_text = torch.ones((1, generation_config['max_len']), dtype=torch.long).to(device) * vocab.word2idx['<m>']

prev_masks = torch.ones_like(generated_text, dtype=torch.bool).to(device)
prev_probs = torch.ones_like(generated_text, dtype=torch.float).to(device)
sample_t = torch.linspace(0, 1, steps=100)

# diffuse generation
for i in tqdm(range(generation_config['sampling_steps'])):
    t = sample_t[generation_config['sampling_steps'] - i - 1]

    # potentially add other tokens to mask pool based on uncertainty of prediction
    prev_masks = prev_masks | (prev_probs < random.uniform(0, 1))

    # create mask only on masked tokens
    mask = (torch.rand(generated_text.shape) < t).to(device) & prev_masks
    prev_masks = prev_masks | mask.to(device)
    masks = generated_text.masked_fill(mask.to(device), vocab.word2idx['<m>']).to(device)

    out = model(masks)
    out /= generation_config['temperature']

    # probabilities
    probs = nn.functional.softmax(out, dim=-1)

    """
    # nucleus/top-p filtering
    sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    
    # nucleus coverage
    nucleus = cumulative_probs < generation_config['top_p']
    nucleus_size = torch.sum(nucleus, dim=-1)
    nucleus[0, nucleus_size] = True

    # apply coverage
    nucleus_probs = torch.zeros_like(probs)
    nucleus_probs = torch.where(nucleus, sorted_probs, sorted_probs)
    nucleus_probs /= torch.sum(nucleus_probs, dim=-1, keepdim=True)

    # sample from filtered distribution
    sampled_indices = torch.multinomial(nucleus_probs.view(-1, nucleus_probs.size(-1)), num_samples=1).view(nucleus_probs.size(0), nucleus_probs.size(1)).unsqueeze(1)
    sampled_tokens = torch.gather(sorted_indices, -1, sampled_indices).squeeze(-1)
    generated_text = torch.where(mask, sampled_tokens, generated_text).squeeze(0).to(device)
    """

    sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1).view(probs.size(0), probs.size(1)).to(device)
    generated_text = torch.where(mask.to(device), sampled_tokens, generated_text).to(device)

    # update previous probabilities as corresponding probabilities of sampled tokens
    prev_probs = torch.gather(probs, -1, sampled_tokens.unsqueeze(-1)).squeeze(-1).to(device)

# decode generated tokens
generated_string = vocab.idx2text(generated_text.squeeze().tolist())
generated_string = re.sub(r'\s+([.,!?])', r'\1', generated_string)
print(generated_string)

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:05<00:00, 19.20it/s]

lily and ben are in the park. they like to run and jump in the park. they see a big hill. they want to go inside. they climb up the hill. but it is very dangerous. it put on a big hole on the hill. they see a hole in the hole and a hole. ben puts and jump up. they move. then they shakes the hole.  but there is wrong. it runs out the cage.  " no, but the hole is faster. " ben says. " let me see them. we are brave. they look at them and say, " do n't worry, bird. we did not to hurt us. "  " i am a hole. lily is a bird. you are safe. how can we play with you? "  ben smiles and says, " you are not alert. okay. you are going to help you. we are dangerous. we hope the bird are very dangerous. " lily smile. they hug and





In [47]:
text_input = "mark was playing with his toys and his big red ball."
generated_text = torch.tensor([vocab.text2idx(text_input)], dtype=torch.long).to(device)
mask = (torch.rand(generated_text.shape) < 0.25).to(device)
masked_input = generated_text.masked_fill(mask.to(device), vocab.word2idx['<m>']).to(device)
print(vocab.idx2text(masked_input.squeeze().tolist()))

mark was playing with <m> <m> and his big red ball .


In [48]:
out = model(masked_input)
out /= 0.25
probs = nn.functional.softmax(out, dim=-1)
print(torch.max(probs, dim=-1).values)
sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1).view(probs.size(0), probs.size(1))
generated_text = torch.where(mask.to(device), sampled_tokens, generated_text)
generated_string = vocab.idx2text(generated_text.squeeze().tolist())
print(generated_string)

tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.6812, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000]], device='mps:0', grad_fn=<MaxBackward0>)
mark was playing with his ball and his big red ball .
