In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
import torch
import sys
sys.path.append("../src")

from nanodiffusion.config.diffusion_training_config import DiffusionTrainingConfig

In [15]:
config = DiffusionTrainingConfig(
    dataset="valhalla/emoji-dataset",  # 256x256 is the original resolution
    caption_column="text",
    # Conditioning
    conditional=False,
    cond_embed_dim=768,
    cond_drop_prob=0.2,
    guidance_scale=4.5,
    # Model
    net="unet",
    # Training loop
    batch_size=128,
    resolution=64,  # resize to 64x64
    logger="wandb",  # None,
    sample_every=1000,
    validate_every=1000,
    fid_every=-1,  # disable FID logging
    total_steps=100000,
    num_samples_for_logging=8,
    num_samples_for_fid=1000,
    num_real_samples_for_fid=10000,
)

## Add text embeddings to the dataset and cache them

In [None]:
import os
from torchvision import transforms
from datasets import load_dataset
from nanodiffusion.models.text_encoder import TextEncoder


resize_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

text_encoder = TextEncoder("openai/clip-vit-large-patch14", device="cuda:0")
text_encoder.eval()
def get_text_embeddings(text: str):
    with torch.no_grad():
        return text_encoder([text])[0]

for split in ["train"]:
    dst_path = f"data/emoji_w_text_emb_{split}"
    if os.path.exists(dst_path):
        continue
    captioned_emoji_train = load_dataset("valhalla/emoji-dataset", split=split)
    captioned_emoji_train_w_text_emb = captioned_emoji_train.map(
        lambda x: {
            "text": x["text"],
            "text_emb": get_text_embeddings(x["text"]),
            "image": x["image"].resize((config.resolution, config.resolution))
        },
        batched=False,
    )
    # save to disk
    captioned_emoji_train_w_text_emb.save_to_disk(dst_path)


In [17]:
from datasets import load_from_disk
from torch.utils.data import DataLoader
import torch
from torchvision import transforms

ds = load_from_disk("data/emoji_w_text_emb_train")
# Normalize image pixels
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
normalize_op = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
normalize_img = lambda x: {"image": normalize_op(x["image"]).float(), 
                      "text": x["text"],
                      "text_emb": x["text_emb"]}

# randomly shuffle and split into train and val
ds = ds.map(normalize_img).with_format("torch")
ds = ds.shuffle()
ds_train = ds.select(range(int(len(ds) * 0.8)))
ds_val = ds.select(range(int(len(ds) * 0.2)))

# create dataloader
train_loader = DataLoader(ds_train, batch_size=16, shuffle=True)
val_loader = DataLoader(ds_val, batch_size=16, shuffle=False)





In [18]:
batch = next(iter(train_loader))
batch

{'image': tensor([[[[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
           [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
           [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
           ...,
           [ 0.9922,  0.9843,  0.9843,  ...,  0.8980,  1.0000,  1.0000],
           [ 1.0000,  0.9922,  0.9843,  ...,  0.9843,  0.9686,  0.9686],
           [ 1.0000,  1.0000,  0.9922,  ...,  0.9529,  0.9529,  0.9529]],
 
          [[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
           [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
           [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
           ...,
           [ 1.0000,  1.0000,  1.0000,  ...,  0.8824,  1.0000,  1.0000],
           [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
           [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000]],
 
          [[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
    

In [19]:
from nanodiffusion.diffusion.diffusion_model_components import create_diffusion_model_components

model_components = create_diffusion_model_components(config)

Creating model unet with resolution 64 and in_channels 3 and cond_embed_dim 768
model params: 33.11 M


In [None]:

from nanodiffusion.diffusion.diffusion_training_loop import training_loop

num_examples_trained = training_loop(
    model_components, train_loader, val_loader, config
)

Training on cuda:0
Creating checkpoint directory: logs/train/2025-04-15_16-37-39
Setting up logger: wandb
Logging to Weights & Biases project: nano-diffusion


VBox(children=(Label(value='126.542 MB of 126.542 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
learning_rate,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██████████
loss,███▇▆▆▅▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
num_batches_trained,▁▁▁▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇▇██
num_examples_trained,▁▁▁▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇▇██
test_samples_step,▁█
val_loss,▁

0,1
learning_rate,0.0001
loss,0.03747
num_batches_trained,1400.0
num_examples_trained,179328.0
test_samples_step,1000.0
val_loss,0.03146


Step: 0, Examples: 128, Loss: 0.9972, LR: 0.000000
Sampling a torch.Size([8, 3, 64, 64]) array in 1000 steps. Initial avg: -0.0024878503754734993


100%|██████████| 1000/1000 [00:33<00:00, 29.61it/s, std=1.55]


Model saved at logs/train/2025-04-15_16-37-39/model_checkpoint_step_0.pth
Step: 50, Examples: 6528, Loss: 0.9923, LR: 0.000005
Step: 100, Examples: 12928, Loss: 0.9566, LR: 0.000010
Step: 150, Examples: 19328, Loss: 0.8870, LR: 0.000015
Step: 200, Examples: 25728, Loss: 0.7636, LR: 0.000020
Step: 250, Examples: 32128, Loss: 0.6393, LR: 0.000025
Step: 300, Examples: 38528, Loss: 0.5131, LR: 0.000030
Step: 350, Examples: 44928, Loss: 0.3998, LR: 0.000035
Step: 400, Examples: 51328, Loss: 0.3162, LR: 0.000040
Step: 450, Examples: 57728, Loss: 0.2282, LR: 0.000045
Step: 500, Examples: 64128, Loss: 0.1461, LR: 0.000050
Step: 550, Examples: 70528, Loss: 0.0916, LR: 0.000055
Step: 600, Examples: 76928, Loss: 0.0680, LR: 0.000060
Step: 650, Examples: 83328, Loss: 0.0468, LR: 0.000065
Step: 700, Examples: 89728, Loss: 0.0412, LR: 0.000070
Step: 750, Examples: 96128, Loss: 0.0564, LR: 0.000075
Step: 800, Examples: 102528, Loss: 0.0238, LR: 0.000080
Step: 850, Examples: 108928, Loss: 0.0476, LR: 

100%|██████████| 1000/1000 [00:33<00:00, 29.42it/s, std=0.602]


Step: 1050, Examples: 134528, Loss: 0.0219, LR: 0.000100
Step: 1100, Examples: 140928, Loss: 0.0357, LR: 0.000100
Step: 1150, Examples: 147328, Loss: 0.0672, LR: 0.000100
Step: 1200, Examples: 153728, Loss: 0.0329, LR: 0.000100
Step: 1250, Examples: 160128, Loss: 0.0244, LR: 0.000100
Step: 1300, Examples: 166528, Loss: 0.0143, LR: 0.000100
Step: 1350, Examples: 172928, Loss: 0.0178, LR: 0.000100
Step: 1400, Examples: 179328, Loss: 0.0160, LR: 0.000100
Step: 1450, Examples: 185728, Loss: 0.0282, LR: 0.000100
Step: 1500, Examples: 192128, Loss: 0.0147, LR: 0.000100
Step: 1550, Examples: 198528, Loss: 0.0355, LR: 0.000100
Step: 1600, Examples: 204928, Loss: 0.0219, LR: 0.000100
Step: 1650, Examples: 211328, Loss: 0.0159, LR: 0.000100
Step: 1700, Examples: 217728, Loss: 0.0219, LR: 0.000100
Step: 1750, Examples: 224128, Loss: 0.0255, LR: 0.000100
Step: 1800, Examples: 230528, Loss: 0.0347, LR: 0.000100
Step: 1850, Examples: 236928, Loss: 0.0235, LR: 0.000100
Step: 1900, Examples: 243328, L

100%|██████████| 1000/1000 [00:33<00:00, 29.57it/s, std=0.679]


Step: 2050, Examples: 262528, Loss: 0.0145, LR: 0.000100
Step: 2100, Examples: 268928, Loss: 0.0523, LR: 0.000100
Step: 2150, Examples: 275328, Loss: 0.0459, LR: 0.000100
Step: 2200, Examples: 281728, Loss: 0.0178, LR: 0.000100
Step: 2250, Examples: 288128, Loss: 0.0230, LR: 0.000100
Step: 2300, Examples: 294528, Loss: 0.0101, LR: 0.000100
Step: 2350, Examples: 300928, Loss: 0.0243, LR: 0.000100
Step: 2400, Examples: 307328, Loss: 0.0118, LR: 0.000100
Step: 2450, Examples: 313728, Loss: 0.0142, LR: 0.000100
Step: 2500, Examples: 320128, Loss: 0.0169, LR: 0.000100
Step: 2550, Examples: 326528, Loss: 0.0355, LR: 0.000100
Step: 2600, Examples: 332928, Loss: 0.0181, LR: 0.000100
Step: 2650, Examples: 339328, Loss: 0.0241, LR: 0.000100
Step: 2700, Examples: 345728, Loss: 0.0176, LR: 0.000100
Step: 2750, Examples: 352128, Loss: 0.0198, LR: 0.000100
Step: 2800, Examples: 358528, Loss: 0.0245, LR: 0.000100
Step: 2850, Examples: 364928, Loss: 0.0168, LR: 0.000100
Step: 2900, Examples: 371328, L

100%|██████████| 1000/1000 [00:33<00:00, 29.53it/s, std=1.01]


Step: 3050, Examples: 390528, Loss: 0.0156, LR: 0.000100
Step: 3100, Examples: 396928, Loss: 0.0419, LR: 0.000100
Step: 3150, Examples: 403328, Loss: 0.0517, LR: 0.000100
Step: 3200, Examples: 409728, Loss: 0.0305, LR: 0.000100
Step: 3250, Examples: 416128, Loss: 0.0156, LR: 0.000100
Step: 3300, Examples: 422528, Loss: 0.0151, LR: 0.000100
