In [1]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer

# Load the TinyStories dataset
dataset = load_dataset("roneneldan/TinyStories")
tokenizer = AutoTokenizer.from_pretrained("gpt2")  # or any suitable tokenizer
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Preprocess the dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])  # Remove original text, keep only tokens
tokenized_datasets.set_format("torch")  # Set format to PyTorch tensors

vocab_size=tokenizer.vocab_size+1


Repo card metadata block was not found. Setting CardData to empty.


In [2]:
import math
from continous_diffusion.diffusion import Diffusion
from continous_diffusion.model import TransformerModel
from continous_diffusion.loss import Loss
from continous_diffusion.embedding import Embedder
from continous_diffusion.scheduling import CauchySchedule
from continous_diffusion.conditioning import TimeConditioning

embed_dim=256
num_heads=4
cond_dim=16
n_blocks=4

dit=TransformerModel(embed_dim,num_heads,cond_dim,n_blocks)
embedder=Embedder(vocab_size,embed_dim)
schedule=CauchySchedule(0.01,20,1,1,math.log(vocab_size),0)
loss=Loss(embedder,schedule)
conditioning=TimeConditioning(cond_dim,cond_dim)
model=Diffusion(dit,loss,conditioning)

In [3]:
# DataLoader
train_loader = DataLoader(tokenized_datasets["train"], batch_size=8, shuffle=True)
optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4)

# Training loop
model.train()
for epoch in range(1):  # Define num_epochs
    for batch in train_loader:
        tokens = batch['input_ids']
        x,sigma=model.make_sample(tokens)
        print(x.shape)
        prediction=model(x,sigma)
        # Forward pass
        loss = model.loss(tokens,prediction,sigma)

        # Backward pass and optimization
        optimizer.zero_grad()  # Define your optimizer
        loss.backward()
        optimizer.step()

        # Log, print, or save as needed
        print(f"Loss: {loss.item()}")


torch.Size([8, 512, 256])
Loss: 1.9408994913101196
torch.Size([8, 512, 256])
Loss: 3.1746482849121094
torch.Size([8, 512, 256])
Loss: 2.605680465698242
torch.Size([8, 512, 256])
Loss: 3.8398118019104004
torch.Size([8, 512, 256])
Loss: 3.7289814949035645
torch.Size([8, 512, 256])
Loss: 2.976043701171875
torch.Size([8, 512, 256])
Loss: 2.5450997352600098
torch.Size([8, 512, 256])
Loss: 4.529530048370361
torch.Size([8, 512, 256])
Loss: 1.8267449140548706
torch.Size([8, 512, 256])
Loss: 1.8130924701690674
torch.Size([8, 512, 256])
Loss: 3.344564199447632
torch.Size([8, 512, 256])
Loss: 3.671360969543457
torch.Size([8, 512, 256])
Loss: 1.5351699590682983
torch.Size([8, 512, 256])
Loss: 1.085735559463501
torch.Size([8, 512, 256])
Loss: 2.6575260162353516
torch.Size([8, 512, 256])
Loss: 3.197154998779297
torch.Size([8, 512, 256])
