In [None]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer

# Load the TinyStories dataset
dataset = load_dataset("roneneldan/TinyStories")
tokenizer = AutoTokenizer.from_pretrained("gpt2")  # or any suitable tokenizer
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Preprocess the dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=64)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])  # Remove original text, keep only tokens
tokenized_datasets.set_format("torch")  # Set format to PyTorch tensors

vocab_size=tokenizer.vocab_size+1
device="cuda" if torch.cuda.is_available() else "cpu"

In [None]:
import math
from continous_diffusion import Diffusion,DiffusionTransformer,Loss,Embedder,CauchySchedule

embed_dim=128
qkv_dim=1024
num_heads=8
cond_dim=16
n_blocks=8

dit=DiffusionTransformer(embed_dim,qkv_dim,num_heads,cond_dim,n_blocks)
embedder=Embedder(vocab_size,embed_dim)
schedule=CauchySchedule(0.01,200,0,0.3,math.log(vocab_size),0)
loss=Loss(embedder,schedule)
model=Diffusion(dit,loss).to(device)

print(model.n_parameters)

4461088


In [7]:
# DataLoader
train_loader = DataLoader(tokenized_datasets["train"], batch_size=256, shuffle=True)
optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4)

# Learning-rate scheduling
from torch.optim.lr_scheduler import SequentialLR, ConstantLR, LinearLR, ExponentialLR
warmup=LinearLR(optimizer,1e-3,1,5)
decay=ExponentialLR(optimizer,gamma=0.9)
scheduler=SequentialLR(optimizer,[warmup,decay],milestones=[6])

In [4]:
schedule_update_frequency=2000

from IPython.display import clear_output
# Training loop
model.train()
for epoch in range(1):  # Define num_epochs
    for i,batch in enumerate(train_loader):

        optimizer.zero_grad()  
        tokens = batch['input_ids'].to(device)
        x,sigma,attn_mask=model.make_sample(tokens)
        prediction=model(x,sigma,attn_mask)
        # Forward pass
        loss = model.loss(tokens,prediction,sigma)

        loss.backward()

        optimizer.step()
        # Log, print, or save as needed

        if i%schedule_update_frequency==0 and i!=0:
            schedule.update_optimal_parameters()

        if i%50==0 and i!=0:
            scheduler.step()
            clear_output(wait=True) 
            schedule.plot_entropy_time_curve()
            print(f"lr: {scheduler.get_last_lr()}")

        print(f"Step: {i},  Loss: {loss.item()}")


KeyboardInterrupt: 

In [None]:
model.generate(1,64,1000,device=device)

In [8]:
model.schedule.optimal_parameters

Parameter containing:
tensor([ 0.1884,  0.1834,  3.7468, -1.0581])