# Train Custom SDXL Diffusion Priors

In [1]:
%load_ext autoreload
%autoreload 2


import sys
sys.path.append('../')

In [2]:
from diffusers import DDPMScheduler
import pandas as pd
import torch
from torch.utils.data import DataLoader

from Datasets import RecommenderUserSampler, EmbeddingsDataset
from grid_search import run_grid_search
from prior_models import TransformerEmbeddingDiffusionModelv2
from train_priors import train_diffusion_prior
from utils import map_embeddings_to_ratings, split_recommender_data, set_seeds

## Load Data

Load the data in its corresponding (sub)directory and map image embeddings to observations.
The data in ratings.csv will constitute our observations, and for our purposes, it will 
consist of the triplets $(U_i, S_j, I_k)$, where $U_i$ corresponds user $i$, $S_j$ encodes wheter user likes $(\text{ score}\geq 4)$ or dislikes the image $(\text{ score}< 4)$ and $I_k$ is the $k$-th image.

In [3]:
image_features = torch.load("../data/flickr/processed/ip-adapters/SDXL/sdxl_image_embeddings.pt", weights_only=True)
ratings_df = pd.read_csv("../data/flickr/processed/ratings.csv")
expanded_features = map_embeddings_to_ratings(image_features, ratings_df)
device = "cuda"

In [7]:
liked_counts = (
    ratings_df[ratings_df["score"] >= 4]
    .groupby("worker_id")["score"]
    .count()
    .reset_index(name="liked_count")
)
valid_users = liked_counts[liked_counts["liked_count"] >= 20]["worker_id"].unique()
valid_worker_id = liked_counts[liked_counts["liked_count"] >= 20]["worker_id"].unique()
filtered_ratings_df = ratings_df[ratings_df["worker_id"].isin(valid_users)].copy()
print(f"User loss: {210-len(valid_users)}")
print(f"Data loss: {1 - filtered_ratings_df.shape[0]/ratings_df.shape[0]}%")

User loss: 22
Data loss: 0.00664051178005054%


In [8]:
worker_mapping = {old_id: new_id for new_id, old_id in enumerate(valid_worker_id)}
filtered_ratings_df.rename(columns={"worker_id": "old_worker_id"}, inplace=True)
filtered_ratings_df["worker_id"] = filtered_ratings_df["old_worker_id"].map(worker_mapping)
#filtered_ratings_df = filtered_ratings_df.reset_index(drop=True)
worker_mapping_df = pd.DataFrame(list(worker_mapping.items()), columns=["old_worker_id", "worker_id"])
worker_mapping_df.to_csv("../data/flickr/processed/worker_id_mapping.csv", index=False)
filtered_ratings_df.to_csv("../data/flickr/processed/filtered_ratings_df.csv", index=False)

In [9]:
train_df, val_df, test_df = split_recommender_data(
    ratings_df=filtered_ratings_df,
    val_spu=10,
    test_spu=10,
    seed=42
)

Train set size: 188273
Validation set size: 1823
Evaluation set size: 1829


In [21]:
train_df.to_csv("../data/flickr/processed/train/train.csv", index=False)
val_df.to_csv("../data/flickr/processed/train/validation.csv", index=False)
test_df.to_csv("../data/flickr/processed/test/test.csv", index=False)

train_ie = expanded_features[train_df.original_index]
val_ie = expanded_features[val_df.original_index]
test_ie = expanded_features[test_df.original_index]

torch.save(train_ie, "../data/flickr/processed/train/train_ie.pth")
torch.save(val_ie, "../data/flickr/processed/train/val_ie.pth")
torch.save(test_ie, "../data/flickr/processed/test/test_ie.pth")

In [11]:
train_dataset = EmbeddingsDataset(
    train_df,
    image_embeddings=expanded_features[train_df.original_index]
)

val_dataset = EmbeddingsDataset(
    val_df,
    image_embeddings=expanded_features[val_df.original_index]
)

In [12]:
diffusion_prior_model = TransformerEmbeddingDiffusionModelv2(
    img_embed_dim=1280,
    num_users=188,    # So user embedding covers your entire user set
    n_heads=16,
    num_tokens=1,
    num_user_tokens=4,
    num_layers=8,
    dim_feedforward=2048,
    whether_use_user_embeddings=True
).to(device)



In [13]:
set_seeds(0)
batch_size = 64
samples_per_user = 50
learning_rate = 1e-4
unique_users = filtered_ratings_df["worker_id"].unique()
train_user_sampler = RecommenderUserSampler(train_df, num_users=len(unique_users), samples_per_user=samples_per_user)

train_dataloader = DataLoader(train_dataset, sampler=train_user_sampler, batch_size=batch_size)
test_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

diffusion_optimizer = torch.optim.AdamW(diffusion_prior_model.parameters(), lr=learning_rate, weight_decay=1e-5)
noise_scheduler = DDPMScheduler(num_train_timesteps=6000)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(diffusion_optimizer, 'min', patience=5, factor=0.5)

total_params = sum(p.numel() for p in diffusion_prior_model.parameters())
trainable_params = sum(p.numel() for p in diffusion_prior_model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")


savepath = f"../data/flickr/evaluation/diffusion_priors/sdxl_ied1024_nu188_nh16_nit1_nut4_nl8_dff1024_uetrue.pth"

Total parameters: 97092864
Trainable parameters: 97092864


In [14]:
train_loss, val_loss = train_diffusion_prior(
                model=diffusion_prior_model,
                noise_scheduler=noise_scheduler,
                train_dataloader=train_dataloader,
                val_dataloader=test_dataloader,
                optimizer=diffusion_optimizer,
                scheduler=scheduler,
                num_unique_users=len(unique_users),
                objective="noise-pred",
                device=device,
                num_epochs=2001,      # Ensure config.num_epochs is defined
                patience=50,
                savepath=savepath,
                return_losses=True,
                verbose=True
            )

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Epoch 1/2001, Time Elapsed: 6.15s, Train Loss: 1.4693, Val Loss: 1.2755, Grad Norm: 5.3657
Epoch 2/2001, Time Elapsed: 12.23s, Train Loss: 1.1523, Val Loss: 1.0927, Grad Norm: 3.1856
Epoch 3/2001, Time Elapsed: 18.29s, Train Loss: 0.9786, Val Loss: 0.9156, Grad Norm: 2.5721
Epoch 4/2001, Time Elapsed: 24.34s, Train Loss: 0.8328, Val Loss: 0.7829, Grad Norm: 2.3149
Epoch 5/2001, Time Elapsed: 30.41s, Train Loss: 0.7254, Val Loss: 0.6746, Grad Norm: 2.1344
Epoch 6/2001, Time Elapsed: 36.35s, Train Loss: 0.6461, Val Loss: 0.5914, Grad Norm: 2.0181
Epoch 7/2001, Time Elapsed: 42.20s, Train Loss: 0.5841, Val Loss: 0.5280, Grad Norm: 1.9133
Epoch 8/2001, Time Elapsed: 48.01s, Train Loss: 0.5336, Val Loss: 0.4853, Grad Norm: 1.8305
Epoch 9/2001, Time Elapsed: 53.88s, Train Loss: 0.4940, Val Loss: 0.4433, Grad Norm: 1.7693
Epoch 10/2001, Time Elapsed: 59.85s, Train Loss: 0.4577, Val Loss: 0.4084, Grad Norm: 1.7067
Epoch 11/2001, Time Elapsed: 65.75s, Train Loss: 0.4280, Val Loss: 0.3674, Grad 

## Or we may run large-scale experiments

In [22]:
param_grid = {
    'timesteps': [6000],
    'layers': [8],
    'heads': [32],
    'dim_feedforward':[1024, 2048],
    'num_image_tokens': [1],
    'num_user_tokens': [4],
    'learning_rate': [1e-4],
    #'optimizers': ['adamw', 'sgd'],
    'optimizers': ['adamw'],
    #'schedulers': ['reduce_on_plateau', 'cosine'],
    'schedulers': ['reduce_on_plateau'],
    'batch_size': [64, 128],
    #'noise_schedule': ['linear', "squaredcos_cap_v2"],
    'noise_schedule': ['linear'],
    'samples_per_user': [50, 80, 110, 140],
    'objective':["noise-pred"],
    'use_ue': [True],
    'img_embed_dim': [1280]
}

savedir = "../data/flickr/evaluation/diffusion_priors/models"

In [23]:
run_grid_search(
    train_df=train_df,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    param_grid=param_grid,
    savedir=savedir,
    unique_users = len(unique_users)
                )



Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=64, noise_schedule=linear, samples_per_user=50, objective=noise-pred, use_ue=True


Hyperparameter combinations:   6%|▋         | 1/16 [13:51<3:27:56, 831.78s/it]

Early stopping with best val loss: 0.07189387866649134!
Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=64, noise_schedule=linear, samples_per_user=80, objective=noise-pred, use_ue=True


Hyperparameter combinations:  12%|█▎        | 2/16 [41:16<5:05:36, 1309.77s/it]

Early stopping with best val loss: 0.06602152019482234!
Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=64, noise_schedule=linear, samples_per_user=110, objective=noise-pred, use_ue=True


Hyperparameter combinations:  19%|█▉        | 3/16 [1:14:29<5:51:23, 1621.80s/it]

Early stopping with best val loss: 0.0629817265374907!
Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=64, noise_schedule=linear, samples_per_user=140, objective=noise-pred, use_ue=True


Hyperparameter combinations:  25%|██▌       | 4/16 [1:44:05<5:36:31, 1682.67s/it]

Early stopping with best val loss: 0.0573450588332168!
Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=128, noise_schedule=linear, samples_per_user=50, objective=noise-pred, use_ue=True


Hyperparameter combinations:  31%|███▏      | 5/16 [2:00:04<4:20:40, 1421.83s/it]

Early stopping with best val loss: 0.07419471666216851!
Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=128, noise_schedule=linear, samples_per_user=80, objective=noise-pred, use_ue=True


Hyperparameter combinations:  38%|███▊      | 6/16 [2:15:31<3:28:56, 1253.64s/it]

Early stopping with best val loss: 0.068526175369819!
Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=128, noise_schedule=linear, samples_per_user=110, objective=noise-pred, use_ue=True


Hyperparameter combinations:  44%|████▍     | 7/16 [2:41:18<3:22:26, 1349.57s/it]

Early stopping with best val loss: 0.062235010663668315!
Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=128, noise_schedule=linear, samples_per_user=140, objective=noise-pred, use_ue=True


Hyperparameter combinations:  50%|█████     | 8/16 [3:08:31<3:11:57, 1439.74s/it]

Early stopping with best val loss: 0.061932186037302016!
Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=64, noise_schedule=linear, samples_per_user=50, objective=noise-pred, use_ue=True


Hyperparameter combinations:  56%|█████▋    | 9/16 [3:24:25<2:30:15, 1287.87s/it]

Early stopping with best val loss: 0.0725581044780797!
Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=64, noise_schedule=linear, samples_per_user=80, objective=noise-pred, use_ue=True


Hyperparameter combinations:  62%|██████▎   | 10/16 [3:55:37<2:26:49, 1468.22s/it]

Early stopping with best val loss: 0.06175178456409224!
Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=64, noise_schedule=linear, samples_per_user=110, objective=noise-pred, use_ue=True


Hyperparameter combinations:  69%|██████▉   | 11/16 [4:23:55<2:08:12, 1538.50s/it]

Early stopping with best val loss: 0.05620144847138175!
Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=64, noise_schedule=linear, samples_per_user=140, objective=noise-pred, use_ue=True


Hyperparameter combinations:  75%|███████▌  | 12/16 [5:14:44<2:13:12, 1998.17s/it]

Early stopping with best val loss: 0.057925735311261536!
Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=128, noise_schedule=linear, samples_per_user=50, objective=noise-pred, use_ue=True


Hyperparameter combinations:  81%|████████▏ | 13/16 [5:33:13<1:26:25, 1728.65s/it]

Early stopping with best val loss: 0.07841355130076408!
Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=128, noise_schedule=linear, samples_per_user=80, objective=noise-pred, use_ue=True


Hyperparameter combinations:  88%|████████▊ | 14/16 [5:50:00<50:21, 1510.80s/it]  

Early stopping with best val loss: 0.06790840278069178!
Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=128, noise_schedule=linear, samples_per_user=110, objective=noise-pred, use_ue=True


Hyperparameter combinations:  94%|█████████▍| 15/16 [6:17:10<25:46, 1546.57s/it]

Early stopping with best val loss: 0.0736171322564284!
Running configuration: timesteps=6000, layers=8, heads=32, image_tokens=1, user_tokens=4, learning_rate=0.0001, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=128, noise_schedule=linear, samples_per_user=140, objective=noise-pred, use_ue=True


Hyperparameter combinations: 100%|██████████| 16/16 [6:48:01<00:00, 1530.07s/it]

Early stopping with best val loss: 0.062047504385312396!
Experimentation complete. Results saved to results.csv at ../data/flickr/evaluation/diffusion_priors/models



