# Train Custom Diffusion Priors

In [1]:
%load_ext autoreload
%autoreload 2


import sys
sys.path.append('../')

In [2]:
from diffusers import DDPMScheduler
import pandas as pd
import torch
from torch.utils.data import DataLoader

from Datasets import RecommenderUserSampler, EmbeddingsDataset
from grid_search import run_grid_search
from prior_models import TransformerEmbeddingDiffusionModelv2
from train_priors import train_diffusion_prior
from utils import map_embeddings_to_ratings, split_recommender_data, set_seeds

## Load Data

Load the data in its corresponding (sub)directory and map image embeddings to observations.
The data in ratings.csv will constitute our observations, and for our purposes, it will 
consist of the triplets $(U_i, S_j, I_k)$, where $U_i$ corresponds user $i$, $S_j$ encodes wheter user likes $(\text{ score}\geq 4)$ or dislikes the image $(\text{ score}< 4)$ and $I_k$ is the $k$-th image.

In [3]:
image_features = torch.load("../data/flickr/processed/ip-adapters/SD15/sd15_image_embeddings.pt", weights_only=True)
ratings_df = pd.read_csv("../data/flickr/processed/ratings.csv")
expanded_features = map_embeddings_to_ratings(image_features, ratings_df)
device = "cuda"

In [4]:
expanded_features.shape

torch.Size([193208, 1024])

In [5]:
usr_threshold = 100

liked_counts = (
    ratings_df[ratings_df["score"] >= 4]
    .groupby("worker_id")["score"]
    .count()
    .reset_index(name="liked_count")
)
valid_users = liked_counts[liked_counts["liked_count"] >= usr_threshold]["worker_id"].unique()
valid_worker_id = liked_counts[liked_counts["liked_count"] >= usr_threshold]["worker_id"].unique()
filtered_ratings_df = ratings_df[ratings_df["worker_id"].isin(valid_users)].copy()
print(f"User loss: {210-len(valid_users)}")
print(f"Data loss: {100*(1 - filtered_ratings_df.shape[0]/ratings_df.shape[0])}%")

User loss: 116
Data loss: 7.281789573930686%


In [6]:
210-116

94

In [8]:
worker_mapping = {old_id: new_id for new_id, old_id in enumerate(valid_worker_id)}
filtered_ratings_df.rename(columns={"worker_id": "old_worker_id"}, inplace=True)
filtered_ratings_df["worker_id"] = filtered_ratings_df["old_worker_id"].map(worker_mapping)
#filtered_ratings_df = filtered_ratings_df.reset_index(drop=True)
worker_mapping_df = pd.DataFrame(list(worker_mapping.items()), columns=["old_worker_id", "worker_id"])
worker_mapping_df.to_csv(f"../data/flickr/processed/worker_id_mapping_usrthr_{usr_threshold}.csv", index=False)
filtered_ratings_df.to_csv(f"../data/flickr/processed/filtered_ratings_df_usrthrs_{usr_threshold}.csv", index=False)

In [9]:
train_df, val_df, test_df = split_recommender_data(
    ratings_df=filtered_ratings_df,
    val_spu=10,
    test_spu=10,
    seed=42
)

Train set size: 177278
Validation set size: 928
Evaluation set size: 933


In [9]:
train_df['worker_id'].value_counts(ascending=True)


worker_id
40      201
36      208
52      208
72      210
67      258
      ...  
49     8064
20    11064
22    11343
87    17320
28    17875
Name: count, Length: 94, dtype: int64

In [12]:
train_df.to_csv(f"../data/flickr/processed/train/train_usrthrs_{usr_threshold}.csv", index=False)
val_df.to_csv(f"../data/flickr/processed/train/validation_usrthrs_{usr_threshold}.csv", index=False)
test_df.to_csv(f"../data/flickr/processed/test/test_usrthrs_{usr_threshold}.csv", index=False)



torch.save(expanded_features[train_df.original_index], f"../data/flickr/processed/train/train_ie_usrthrs_{usr_threshold}.pt")
torch.save(expanded_features[val_df.original_index], f"../data/flickr/processed/train/validation_ie_usrthrs_{usr_threshold}.pt")
torch.save(expanded_features[test_df.original_index], f"../data/flickr/processed/test/test_ie_usrthrs_{usr_threshold}.pt")


In [13]:
expanded_features[train_df.original_index].shape

torch.Size([177278, 1024])

In [14]:
train_dataset = EmbeddingsDataset(
    train_df,
    image_embeddings=expanded_features[train_df.original_index]
)

val_dataset = EmbeddingsDataset(
    val_df,
    image_embeddings=expanded_features[val_df.original_index]
)

In [47]:
from prior_models import CrossAttentionDiffusionPrior

model = CrossAttentionDiffusionPrior(
    img_embed_dim=1024,
    num_users=94,
    num_tokens=16,
    n_heads=16,
    num_layers=16,
    dim_feedforward=2048,
).to(device)

In [48]:
set_seeds(0)
batch_size = 64
samples_per_user = 500
learning_rate = 1e-5
unique_users = filtered_ratings_df["worker_id"].unique()
train_user_sampler = RecommenderUserSampler(train_df, num_users=len(unique_users), samples_per_user=samples_per_user)

train_dataloader = DataLoader(train_dataset, sampler=train_user_sampler, batch_size=batch_size)
test_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

diffusion_optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
noise_scheduler = DDPMScheduler(num_train_timesteps=6000, beta_schedule="laplace")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(diffusion_optimizer, 'min', patience=5, factor=0.5)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")


savepath = f"../data/flickr/evaluation/diffusion_priors/models/weights/test_xattn_v4.pth"


Total parameters: 4806912
Trainable parameters: 4806912


In [49]:
train_loss, val_loss = train_diffusion_prior(
                model=model,
                noise_scheduler=noise_scheduler,
                train_dataloader=train_dataloader,
                val_dataloader=test_dataloader,
                optimizer=diffusion_optimizer,
                scheduler=scheduler,
                num_unique_users=len(unique_users),
                objective="v_prediction",
                device=device,
                num_epochs=2001,      # Ensure config.num_epochs is defined
                patience=20,
                savepath=savepath,
                return_losses=True,
                verbose=True
            )

Epoch 1/2001, Time Elapsed: 8.70s, Train Loss: 0.7554, Val Loss: 0.8231, Grad Norm: 1.1109
Epoch 2/2001, Time Elapsed: 17.46s, Train Loss: 0.7332, Val Loss: 0.7943, Grad Norm: 0.8083
Epoch 3/2001, Time Elapsed: 26.06s, Train Loss: 0.7047, Val Loss: 0.7688, Grad Norm: 1.0431
Epoch 4/2001, Time Elapsed: 34.69s, Train Loss: 0.6919, Val Loss: 0.7561, Grad Norm: 1.0565
Epoch 5/2001, Time Elapsed: 43.27s, Train Loss: 0.6842, Val Loss: 0.7558, Grad Norm: 1.0783
Epoch 6/2001, Time Elapsed: 51.90s, Train Loss: 0.6779, Val Loss: 0.7389, Grad Norm: 1.0854
Epoch 7/2001, Time Elapsed: 60.46s, Train Loss: 0.6724, Val Loss: 0.7344, Grad Norm: 1.1183
Epoch 8/2001, Time Elapsed: 69.15s, Train Loss: 0.6673, Val Loss: 0.7306, Grad Norm: 1.1258
Epoch 9/2001, Time Elapsed: 77.83s, Train Loss: 0.6631, Val Loss: 0.7272, Grad Norm: 1.1120
Epoch 10/2001, Time Elapsed: 86.47s, Train Loss: 0.6608, Val Loss: 0.7264, Grad Norm: 1.1082
Epoch 11/2001, Time Elapsed: 94.98s, Train Loss: 0.6574, Val Loss: 0.7187, Grad 

In [25]:
diffusion_prior_model = TransformerEmbeddingDiffusionModelv2(
    img_embed_dim=1024,
    num_users=122,    # So user embedding covers your entire user set
    n_heads=16,
    num_tokens=1,
    num_user_tokens=4,
    num_layers=8,
    dim_feedforward=2048,
    whether_use_user_embeddings=True
).to(device)



In [26]:
import math
set_seeds(0)

d = image_features.shape[-1]
norms = image_features.norm(dim=-1, keepdim=True)
norms = torch.clamp(norms, min=1e-8)
image_features_normed = image_features / norms * math.sqrt(d)
emb_final  = torch.clamp(image_features_normed, -3.2, 3.2) / 3.2   

expanded_features = map_embeddings_to_ratings(emb_final, ratings_df)
batch_size = 64
samples_per_user = 80
learning_rate = 1e-5
unique_users = filtered_ratings_df["worker_id"].unique()
train_dataset = EmbeddingsDataset(
        train_df,
        image_embeddings=expanded_features[train_df.original_index]
    )
val_dataset = EmbeddingsDataset(
        val_df,
        image_embeddings=expanded_features[val_df.original_index]
    )
train_user_sampler = RecommenderUserSampler(train_df, num_users=len(unique_users), samples_per_user=samples_per_user)

train_dataloader = DataLoader(train_dataset, sampler=train_user_sampler, batch_size=batch_size)
test_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

diffusion_optimizer = torch.optim.AdamW(diffusion_prior_model.parameters(), lr=learning_rate, weight_decay=1e-5)
noise_scheduler = DDPMScheduler(num_train_timesteps=6000, beta_schedule="laplace")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(diffusion_optimizer, 'min', patience=5, factor=0.5)

total_params = sum(p.numel() for p in diffusion_prior_model.parameters())
trainable_params = sum(p.numel() for p in diffusion_prior_model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")


savepath = f"../data/flickr/evaluation/diffusion_priors/models/weights/test_rebecca_og_norm_v3.pth"

Total parameters: 68756480
Trainable parameters: 68756480


In [12]:
set_seeds(0)
batch_size = 64
samples_per_user = 80
learning_rate = 1e-4
unique_users = filtered_ratings_df["worker_id"].unique()
train_user_sampler = RecommenderUserSampler(train_df, num_users=len(unique_users), samples_per_user=samples_per_user)

train_dataloader = DataLoader(train_dataset, sampler=train_user_sampler, batch_size=batch_size)
test_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

diffusion_optimizer = torch.optim.AdamW(diffusion_prior_model.parameters(), lr=learning_rate, weight_decay=1e-5)
noise_scheduler = DDPMScheduler(num_train_timesteps=6000)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(diffusion_optimizer, 'min', patience=5, factor=0.5)

total_params = sum(p.numel() for p in diffusion_prior_model.parameters())
trainable_params = sum(p.numel() for p in diffusion_prior_model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")


savepath = f"../data/flickr/evaluation/diffusion_priors/models/weights/sd15_ied1024_nu122_nh16_nit1_nut4_nl8_dff2048.pth"


Total parameters: 68756480
Trainable parameters: 68756480


In [28]:
train_loss, val_loss = train_diffusion_prior(
                model=diffusion_prior_model,
                noise_scheduler=noise_scheduler,
                train_dataloader=train_dataloader,
                val_dataloader=test_dataloader,
                optimizer=diffusion_optimizer,
                scheduler=scheduler,
                num_unique_users=len(unique_users),
                objective="noise_pred",
                device=device,
                num_epochs=2001,      # Ensure config.num_epochs is defined
                patience=20,
                savepath=savepath,
                return_losses=True,
                verbose=True
            )

Epoch 1/2001, Time Elapsed: 2.26s, Train Loss: 1.6912, Val Loss: 1.0891, Grad Norm: 7.7146
Epoch 2/2001, Time Elapsed: 4.70s, Train Loss: 1.1028, Val Loss: 0.9466, Grad Norm: 5.0367
Epoch 3/2001, Time Elapsed: 7.23s, Train Loss: 1.0253, Val Loss: 0.9138, Grad Norm: 4.4878
Epoch 4/2001, Time Elapsed: 9.75s, Train Loss: 0.9836, Val Loss: 0.8918, Grad Norm: 3.9995
Epoch 5/2001, Time Elapsed: 12.26s, Train Loss: 0.9566, Val Loss: 0.8733, Grad Norm: 3.7429
Epoch 6/2001, Time Elapsed: 14.79s, Train Loss: 0.9345, Val Loss: 0.8610, Grad Norm: 3.5570
Epoch 7/2001, Time Elapsed: 17.42s, Train Loss: 0.9140, Val Loss: 0.8551, Grad Norm: 3.3351
Epoch 8/2001, Time Elapsed: 20.23s, Train Loss: 0.8980, Val Loss: 0.8401, Grad Norm: 3.1974
Epoch 9/2001, Time Elapsed: 22.74s, Train Loss: 0.8837, Val Loss: 0.8387, Grad Norm: 3.0305
Epoch 10/2001, Time Elapsed: 25.27s, Train Loss: 0.8731, Val Loss: 0.8288, Grad Norm: 2.9270
Epoch 11/2001, Time Elapsed: 27.77s, Train Loss: 0.8615, Val Loss: 0.8170, Grad Nor

In [21]:
train_loss, val_loss = train_diffusion_prior(
                model=diffusion_prior_model,
                noise_scheduler=noise_scheduler,
                train_dataloader=train_dataloader,
                val_dataloader=test_dataloader,
                optimizer=diffusion_optimizer,
                scheduler=scheduler,
                num_unique_users=len(unique_users),
                objective="noise-pred",
                device=device,
                num_epochs=2001,      # Ensure config.num_epochs is defined
                patience=20,
                savepath=savepath,
                return_losses=True,
                verbose=True
            )

Epoch 1/2001, Time Elapsed: 2.26s, Train Loss: 0.2101, Val Loss: 0.1565, Grad Norm: 1.3618
Epoch 2/2001, Time Elapsed: 4.65s, Train Loss: 0.2088, Val Loss: 0.1529, Grad Norm: 1.3454
Epoch 3/2001, Time Elapsed: 7.16s, Train Loss: 0.2068, Val Loss: 0.1546, Grad Norm: 1.3326
Epoch 4/2001, Time Elapsed: 9.42s, Train Loss: 0.2061, Val Loss: 0.1528, Grad Norm: 1.3260
Epoch 5/2001, Time Elapsed: 11.93s, Train Loss: 0.2065, Val Loss: 0.1483, Grad Norm: 1.3230
Epoch 6/2001, Time Elapsed: 14.45s, Train Loss: 0.2057, Val Loss: 0.1524, Grad Norm: 1.3188
Epoch 7/2001, Time Elapsed: 16.72s, Train Loss: 0.2037, Val Loss: 0.1507, Grad Norm: 1.3096
Epoch 8/2001, Time Elapsed: 18.99s, Train Loss: 0.2031, Val Loss: 0.1474, Grad Norm: 1.3023
Epoch 9/2001, Time Elapsed: 21.53s, Train Loss: 0.2038, Val Loss: 0.1533, Grad Norm: 1.3042
Epoch 10/2001, Time Elapsed: 23.79s, Train Loss: 0.2031, Val Loss: 0.1524, Grad Norm: 1.3004
Epoch 11/2001, Time Elapsed: 26.05s, Train Loss: 0.2023, Val Loss: 0.1473, Grad Nor

KeyboardInterrupt: 

## Or we may run large-scale experiments

In [10]:
import math
set_seeds(0)

d = image_features.shape[-1]
norms = image_features.norm(dim=-1, keepdim=True)
norms = torch.clamp(norms, min=1e-8)
image_features_normed = image_features / norms * math.sqrt(d)
emb_final  = torch.clamp(image_features_normed, -3.2, 3.2) / 3.2   

expanded_features = map_embeddings_to_ratings(emb_final, ratings_df)
batch_size = 64
samples_per_user = 80
learning_rate = 1e-5
unique_users = filtered_ratings_df["worker_id"].unique()
train_dataset = EmbeddingsDataset(
        train_df,
        image_embeddings=expanded_features[train_df.original_index]
    )
val_dataset = EmbeddingsDataset(
        val_df,
        image_embeddings=expanded_features[val_df.original_index]
    )

In [11]:
param_grid = {
    'timesteps': [1000,2000],
    'layers': [8, 16],
    'heads': [16, 32],
    'dim_feedforward':[2048],
    'num_image_tokens': [1, 2],
    'num_user_tokens': [4, 8],
    'learning_rate': [1e-4, 1e-5],
    #'optimizers': ['adamw', 'sgd'],
    'optimizers': ['adamw'],
    #'schedulers': ['reduce_on_plateau', 'cosine'],
    'schedulers': ['reduce_on_plateau'],
    'batch_size': [64],
    'noise_schedule': ["laplace", "linear", "squaredcos_cap_v2"],
    'samples_per_user': [80, 130, 200, 300, 400, 500],
    'clip_sample': [False, True],
    'rescale_betas': [False],
    'objective':["v_prediction", "noise-pred"],
    'use_ue': [True],
    'img_embed_dim': [1024]
}

savedir = "../data/flickr/evaluation/diffusion_priors/models/weights/experiment_old_architecture_new_norm"
#savedir = "../data/flickr/evaluation/diffusion_priors/models/weights/experiment_2"

In [14]:
import os
os.path.exists("../data/flickr/evaluation/diffusion_priors/models/weights/experiment_1")

True

In [12]:

run_grid_search(
    train_df=train_df,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    param_grid=param_grid,
    savedir=savedir,
    unique_users = len(unique_users)
                )

Hyperparameter combinations:   0%|          | 0/4608 [00:00<?, ?it/s]



Running configuration: timesteps=1000, layers=8, heads=16, image_tokens=1, user_tokens=4, learning_rate=0.0001, clip_sample=False, rescale_betas=False, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=64, noise_schedule=laplace, samples_per_user=80, objective=v_prediction, use_ue=True


Hyperparameter combinations:   0%|          | 1/4608 [08:21<642:12:12, 501.83s/it]

Early stopping with best val loss: 0.15440688530604044!
Running configuration: timesteps=1000, layers=8, heads=16, image_tokens=1, user_tokens=4, learning_rate=0.0001, clip_sample=False, rescale_betas=False, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=64, noise_schedule=laplace, samples_per_user=80, objective=noise-pred, use_ue=True


Hyperparameter combinations:   0%|          | 2/4608 [15:29<585:57:24, 457.98s/it]

Early stopping with best val loss: 0.11446110159158707!
Running configuration: timesteps=1000, layers=8, heads=16, image_tokens=1, user_tokens=4, learning_rate=0.0001, clip_sample=True, rescale_betas=False, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=64, noise_schedule=laplace, samples_per_user=80, objective=v_prediction, use_ue=True


Hyperparameter combinations:   0%|          | 3/4608 [24:53<648:07:22, 506.68s/it]

Early stopping with best val loss: 0.14889527161916097!
Running configuration: timesteps=1000, layers=8, heads=16, image_tokens=1, user_tokens=4, learning_rate=0.0001, clip_sample=True, rescale_betas=False, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=64, noise_schedule=laplace, samples_per_user=80, objective=noise-pred, use_ue=True


Hyperparameter combinations:   0%|          | 4/4608 [32:13<614:20:58, 480.38s/it]

Early stopping with best val loss: 0.11399615307648976!
Running configuration: timesteps=1000, layers=8, heads=16, image_tokens=1, user_tokens=4, learning_rate=0.0001, clip_sample=False, rescale_betas=False, optimizer=adamw, scheduler=reduce_on_plateau, batch_size=64, noise_schedule=laplace, samples_per_user=130, objective=v_prediction, use_ue=True


Hyperparameter combinations:   0%|          | 4/4608 [46:19<888:40:42, 694.88s/it]


KeyboardInterrupt: 