In [None]:
import torch
import torch.nn as nn

# import from guided-diffusion folder
from model import GuidedDiffusionNetwork
from ddpm_scheduler import DDPMScheduler
from scenes_dataset import ScenesDataset, DatasetConstants

In [None]:
import json

# Load data from JSON file
with open('datasets/data/train.json', 'r') as file:
    train_data = json.load(file)['scenes']

with open('datasets/data/val.json', 'r') as file:
    val_data = json.load(file)['scenes']

## Model and Train setup

In [None]:
B = 64 # num of scenes in batch

# Scene hyperparams
N = 20 # num of nodes

# RGCN hyperparams
C = 300 # dim of node features
R = 23+1 # num of relations

# Attention hyperparams
D = 315 # dim of objects from the scene

hparams = {
    'batch_size': B, # num of graphs in batch
    
    # --- Attention hyperparams ---
    'attention_out_dim': D,
    'attention_num_heads': 5, # must be a divisor of D
    
    # --- Encoder RGCN hyperparams ---
    'encoder_out_dim': C,
    'encoder_hidden_dims': f"{()}", # (C, C),
    'encoder_num_bases': None,
    'encoder_aggr': 'mean',
    'encoder_activation': 'tanh',
    'encoder_dp_rate': 0.,
    'encoder_bias': True,
    
    # --- Fusion RGCN hyperparams ---
    'fusion_hidden_dims': f"{()}", # (C+D, C+D, D),
    'fusion_num_bases': None,
    'fusion_aggr': 'mean',
    'fusion_activation': 'tanh',
    'fusion_dp_rate': 0.,
    'fusion_bias': True,
    
    # Scheduler hyperparams
    'scheduler_timesteps': 1000,
    'scheduler_loss': 'l1',
    'scheduler_beta_schedule': 'cosine',
    # Note: not needed for now
    # 'scheduler_sampling_timesteps': None,
    # "scheduler_objective": 'pred_noise',
    # 'scheduler_ddim_sampling_eta': 1.0,
    # 'scheduler_min_snr_loss_weight': False,
    # 'scheduler_min_snr_gamma': 5,
    
    # Classifier-free guidance parameters
    'cfg_cond_drop_prob': 0.1,
    
    # Training and optimizer hyperparams
    'epochs': 5000,
    'optimizer_lr': 1e-3,
    'optimizer_weight_decay': 5e-5,
    'lr_scheduler_factor': 0.8,
    'lr_scheduler_patience': 20,
    'lr_scheduler_minlr': 8e-5,
}

In [None]:
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.loader import DataLoader

if torch.cuda.is_available():
    device = torch.device('cuda')
# Not all operations support MPS yet so this option is not available for now
# elif torch.has_mps:
#     device = torch.device('mps')
else:
    device = torch.device('cpu')


# --- Load the data
range_matrix = DatasetConstants.get_range_matrix().to(device)

train_dataset = ScenesDataset(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True)

val_dataset = ScenesDataset(val_data)
val_dataloader = DataLoader(val_dataset, batch_size=hparams['batch_size'], shuffle=True)


# --- Instantiate the model
model = GuidedDiffusionNetwork(
    attention_N=N,
    attention_D=D,
    attention_out_dim=hparams['attention_out_dim'],
    attention_num_heads=hparams['attention_num_heads'],
    
    rgcn_num_relations=R,
    
    encoder_in_dim=C,
    encoder_out_dim=hparams['encoder_out_dim'],
    encoder_hidden_dims=hparams['encoder_hidden_dims'],
    encoder_num_bases=hparams['encoder_num_bases'],
    encoder_aggr=hparams['encoder_aggr'],
    encoder_activation=hparams['encoder_activation'],
    encoder_dp_rate=hparams['encoder_dp_rate'],
    encoder_bias=hparams['encoder_bias'],
    
    fusion_hidden_dims=hparams['fusion_hidden_dims'],
    fusion_num_bases=hparams['fusion_num_bases'],
    fusion_aggr=hparams['fusion_aggr'],
    fusion_activation=hparams['fusion_activation'],
    fusion_dp_rate=hparams['fusion_dp_rate'],
    fusion_bias=hparams['fusion_bias'],
    
    cond_drop_prob=hparams['cfg_cond_drop_prob']
)

print(f"Model:\n{model}")

scheduler = DDPMScheduler(
    model=model,
    N=N,
    D=D,
    range_matrix = range_matrix,
    timesteps=hparams['scheduler_timesteps'],
    sampling_timesteps=None,
    loss_type=hparams['scheduler_loss'],
    objective='pred_noise',
    beta_schedule=hparams['scheduler_beta_schedule'],
    ddim_sampling_eta=1.0,
    min_snr_loss_weight=False,
    min_snr_gamma=5
)

print(f"DDPM Scheduler:\n{scheduler}")

# Move to device
model = model.to(device)
scheduler = scheduler.to(device)


# --- Setup training loop ---
from tqdm import tqdm

optimizer = torch.optim.Adam(
    scheduler.parameters(), 
    lr=hparams['optimizer_lr'], 
    weight_decay=hparams['optimizer_weight_decay']
)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=hparams['lr_scheduler_factor'], 
    patience=hparams['lr_scheduler_patience'], 
    min_lr=hparams['lr_scheduler_minlr']
)


# --- Initialize tensorboard ---
# use timestamp to avoid overwriting previous runs
from datetime import datetime
now = datetime.now()
writer = SummaryWriter(log_dir=f'runs/full-DDPM/train-time:{now.strftime("%Y-%m-%d-%H:%M:%S")}')

best_loss = float('inf')

for epoch in tqdm(range(hparams['epochs'])):
    scheduler.train()
    epoch_loss = 0
    # --- Training loop ---
    for batch in train_dataloader:
        x_batch = batch.x.to(device)
        obj_cond_batch = batch.cond.to(device)
        edge_cond_batch = batch.edge_index.to(device)
        relation_cond_batch = batch.edge_attr.to(device)
        
        # X is read as [B*N, D] and needs to be reshaped to [B, N, D]
        x_batch = x_batch.view(batch.num_graphs, N, D)
        
        loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch)
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
    epoch_loss /= len(train_dataloader)
        
    lr_scheduler.step(epoch_loss)
    writer.add_scalar('Loss/train', epoch_loss, epoch)
    writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)
    
    # --- Validation loop ---
    with torch.no_grad():
        scheduler.eval()
        epoch_loss = 0
        for batch in val_dataloader:
            x_batch = batch.x.to(device)
            obj_cond_batch = batch.cond.to(device)
            edge_cond_batch = batch.edge_index.to(device)
            relation_cond_batch = batch.edge_attr.to(device)
            
            # X is read as [B*N, D] and needs to be reshaped to [B, N, D]
            x_batch = x_batch.view(batch.num_graphs, N, D)
            
            loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch)
            epoch_loss += loss.item()
            
        epoch_loss /= len(val_dataloader)
        writer.add_scalar('Loss/val', epoch_loss, epoch)
        
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), f'models/best-model.pt')
        print(f"Saved best model with val loss {best_loss}")
    

# log all the hyperparameters and final loss
writer.add_hparams(hparams, {'Final loss': epoch_loss})

writer.close()

## Hyperparameter tuning

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm
import uuid
from torch.utils.tensorboard import SummaryWriter
import optuna

if torch.cuda.is_available():
    device = torch.device('cuda')
# Not all operations support MPS yet so this option is not available for now
# elif torch.has_mps:
#     device = torch.device('mps')
else:
    device = torch.device('cpu')


def train_scheduler(hparams):
    # --- Load the data
    dataset = ScenesDataset(dataset)
    dataloader = DataLoader(dataset, batch_size=hparams['batch_size'], shuffle=True, collate_fn=dataset.collate_fn)
    
    model = GuidedDiffusionNetwork(
        attention_in_dim=D,
        attention_out_dim=hparams['attention_out_dim'],
        attention_num_heads=hparams['attention_num_heads'],
        
        rgcn_num_relations=R,
        
        encoder_in_dim=C,
        encoder_out_dim=hparams['encoder_out_dim'],
        encoder_hidden_dims=hparams['encoder_hidden_dims'],
        encoder_num_bases=hparams['encoder_num_bases'],
        encoder_aggr=hparams['encoder_aggr'],
        encoder_activation=hparams['encoder_activation'],
        encoder_dp_rate=hparams['encoder_dp_rate'],
        encoder_bias=hparams['encoder_bias'],
        
        fusion_hidden_dims=hparams['fusion_hidden_dims'],
        fusion_num_bases=hparams['fusion_num_bases'],
        fusion_aggr=hparams['fusion_aggr'],
        fusion_activation=hparams['fusion_activation'],
        fusion_dp_rate=hparams['fusion_dp_rate'],
        fusion_bias=hparams['fusion_bias'],
        
        cond_drop_prob=hparams['cfg_cond_drop_prob']
    )

    scheduler = DDPMScheduler(
        model=model,
        N=N,
        D=D,
        timesteps=hparams['scheduler_timesteps'],
        sampling_timesteps=None,
        loss_type=hparams['scheduler_loss'],
        objective='pred_noise',
        beta_schedule=hparams['scheduler_beta_schedule'],
        ddim_sampling_eta=1.0,
        min_snr_loss_weight=False,
        min_snr_gamma=5
    )
    
    optimizer = torch.optim.Adam(
        scheduler.parameters(), 
        lr=hparams['optimizer_lr'], 
        weight_decay=hparams['optimizer_weight_decay']
    )
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=hparams['lr_scheduler_factor'], 
        patience=hparams['lr_scheduler_patience'], 
        min_lr=hparams['lr_scheduler_minlr']
    )
    
    # Move to device
    model = model.to(device)
    scheduler = scheduler.to(device)
    
    # Generate a unique id for each trial
    trial_uuid = str(uuid.uuid4())
    writer = SummaryWriter(log_dir=f'runs/full-DDPM/hparamtuning-{trial_uuid}')

    # --- Training loop ---
    best_loss = float('inf')
    for epoch in tqdm(range(hparams['epochs'])):
        epoch_loss = 0
        for batch in dataloader:
            x_batch = batch.x.to(device)
            obj_cond_batch = batch.cond.to(device)
            edge_cond_batch = batch.edge_index.to(device)
            relation_cond_batch = batch.edge_attr.to(device)
            
            # X is read as [B*N, D] and needs to be reshaped to [B, N, D]
            x_batch = x_batch.view(batch.num_graphs, N, D)
            
            loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch)
            
            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            
        epoch_loss /= len(dataloader)
        
        best_loss = min(best_loss, epoch_loss)
            
        lr_scheduler.step(epoch_loss)
        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)

    # log all the hyperparameters and final loss
    writer.add_hparams(hparams, {'Best loss': best_loss})
    writer.close()
    
    return best_loss

# Define the objective function for hyperparameter optimization
def objective(trial):
    # Define the hyperparameters to tune and their search spaces
    search_space = {
        'batch_size': trial.suggest_categorical('batch_size', [4, 16, 32, 64, 128]),
    
        'attention_out_dim': trial.suggest_categorical('attention_out_dim', [16, 32, 48]),
        'attention_num_heads': trial.suggest_categorical('attention_num_heads', [2, 4, 8, 16]),
    
        'encoder_out_dim': trial.suggest_categorical('encoder_out_dim', [2, 10, 20, 30]),
        'encoder_hidden_dims': trial.suggest_categorical('encoder_hidden_dims', [f'{()}', f'{(30,)}', f'{(15,)}', f'{(60,)}', f'{(30, 30)}']),
        'encoder_num_bases': trial.suggest_categorical('encoder_num_bases', [None, 2, 4, 8]),
        'encoder_aggr': trial.suggest_categorical('encoder_aggr', ['mean', 'sum', 'max']),
        'encoder_activation': trial.suggest_categorical('encoder_activation', ['leakyrelu', 'relu', 'silu']),
        'encoder_dp_rate': trial.suggest_float('encoder_dp_rate', 0.0, 0.5),
        'encoder_bias': trial.suggest_categorical('encoder_bias', [True, False]),
    
        'fusion_hidden_dims': trial.suggest_categorical('fusion_hidden_dims', [f'{()}', f'{(8,)}', f'{(18,)}', f'{(24,)}', f'{(48,)}', f'{(48, 24)}', f'{(48, 32)}', f"{(48, 32, 24)}"]),
        'fusion_num_bases': trial.suggest_categorical('fusion_num_bases', [None, 2, 4, 8]),
        'fusion_aggr': trial.suggest_categorical('fusion_aggr', ['mean', 'sum', 'max']),
        'fusion_activation': trial.suggest_categorical('fusion_activation', ['leakyrelu', 'relu', 'silu']),
        'fusion_dp_rate': trial.suggest_float('fusion_dp_rate', 0.0, 0.5),
        'fusion_bias': trial.suggest_categorical('fusion_bias', [True, False]),
    
        'scheduler_timesteps': trial.suggest_categorical('scheduler_timesteps', [1000, 2000, 5000]),
        'scheduler_loss': trial.suggest_categorical('scheduler_loss', ['l1', 'l2']),
        'scheduler_beta_schedule': trial.suggest_categorical('scheduler_beta_schedule', ['cosine', 'linear']),
    
        'cfg_cond_drop_prob': trial.suggest_float('cfg_cond_drop_prob', 0.0, 0.5),
    
        'epochs': 750,
        'optimizer_lr': trial.suggest_float('optimizer_lr', 1e-5, 5e-3, log=True),
        'optimizer_weight_decay': trial.suggest_float('optimizer_weight_decay', 1e-6, 1e-3, log=True),
        'lr_scheduler_factor': trial.suggest_float('lr_scheduler_factor', 0.5, 0.9),
        'lr_scheduler_patience': trial.suggest_int('lr_scheduler_patience', 50, 200),
        'lr_scheduler_minlr': trial.suggest_float('lr_scheduler_minlr', 0.00001, 0.001, log=True),
    }
    
    hparams.update(search_space)
    
    return train_scheduler(hparams)

def run_hparam_tuning():
    study = optuna.create_study(direction='minimize')
    study.optimize(objective, n_trials=50)
    
    return study

result_study = run_hparam_tuning()