In [None]:
import torch
import torch.nn as nn

# import from guided-diffusion folder
from model import GuidedDiffusionNetwork
from ddpm_scheduler import DDPMScheduler
from scenes_dataset import ScenesDataset, DatasetConstants

In [None]:
import json

# Load data from JSON file
with open('datasets/data/train.json', 'r') as file:
    train_data = json.load(file)['scenes']

with open('datasets/data/val.json', 'r') as file:
    val_data = json.load(file)['scenes']

## Hyperparameter tuning

In this section, we will tune the hyperparameters of the network.

In [None]:
B = 3 # num of scenes in batch

# Scene hyperparams
N = 20 # num of objects in scene
D = 15 # dim of objects from the scene

# Time hyperparams
T = 14

# Condition hyperparmas
C = 300 # dim of node features
R = 23+1 # num of relations

hparams = {
    'batch_size': B,
    'time_dim': T,

    # --- RGCN hyperparams ---
    'rgc_hidden_dims': f"{()}", # (C+D, C+D, D),
    'rgc_num_bases': 5, # Alternative: None
    'rgc_aggr': 'mean',
    'rgc_activation': 'tanh',
    'rgc_dp_rate': 0.,
    'rgc_bias': True,
    
    # --- Attention hyperparams ---
    'attention_self_head_dims': 10,
    'attention_num_heads': 3, 
    'attention_cross_head_dims': 30,
    
    # Scheduler hyperparams
    'scheduler_timesteps': 1000,
    'scheduler_loss': 'l2',
    'scheduler_beta_schedule': 'cosine',
    # Note: not needed for now
    # 'scheduler_sampling_timesteps': None,
    # "scheduler_objective": 'pred_noise',
    # 'scheduler_ddim_sampling_eta': 1.0,
    # 'scheduler_min_snr_loss_weight': False,
    # 'scheduler_min_snr_gamma': 5,
    
    # Classifier-free guidance parameters
    'cfg_cond_drop_prob': 0.3,
    
    # Training and optimizer hyperparams
    'epochs': 2000,
    'optimizer_lr': 1e-3,
    'optimizer_weight_decay': 5e-5,
    'lr_scheduler_factor': 0.8,
    'lr_scheduler_patience': 20,
    'lr_scheduler_minlr': 8e-5,
}


In [None]:
import optuna
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.loader import DataLoader
import uuid

if torch.cuda.is_available():
    device = torch.device('cuda')
# Not all operations support MPS yet so this option is not available for now
# elif torch.has_mps:
#     device = torch.device('mps')
else:
    device = torch.device('cpu')


# --- Load the data
range_matrix = DatasetConstants.get_range_matrix().to(device)

train_dataset = ScenesDataset(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True)

val_dataset = ScenesDataset(val_data)
val_dataloader = DataLoader(val_dataset, batch_size=hparams['batch_size'], shuffle=True)


def train_scheduler(hparams):
    general_params = {
        "num_obj": N,
        "obj_cond_dim": C,
        'layer_1_dim': D,
        'layer_2_dim': D + hparams['time_dim'],
        "time_dim": hparams['time_dim'],
    }

    attention_params = {
        "attention_self_head_dim": hparams['attention_self_head_dims'],
        "attention_num_heads": hparams['attention_num_heads'],
        "attention_cross_head_dim": hparams['attention_cross_head_dims']
    }

    rgc_params = {
        "rgc_hidden_dims": hparams['rgc_hidden_dims'],
        "rgc_num_relations": R,
        "rgc_num_bases": hparams['rgc_num_bases'],
        "rgc_aggr": hparams['rgc_aggr'],
        "rgc_activation": hparams['rgc_activation'],
        "rgc_dp_rate": hparams['rgc_dp_rate'],
        "rgc_bias": hparams['rgc_bias']
    }
    
    # --- Instantiate the model
    model = GuidedDiffusionNetwork(
        general_params=general_params,
        attention_params=attention_params,
        rgc_params=rgc_params,
        cond_drop_prob=hparams['cfg_cond_drop_prob']
    )

    scheduler = DDPMScheduler(
        model=model,
        N=N,
        D=D,
        range_matrix = range_matrix,
        timesteps=hparams['scheduler_timesteps'],
        sampling_timesteps=None,
        loss_type=hparams['scheduler_loss'],
        objective='pred_noise',
        beta_schedule=hparams['scheduler_beta_schedule'],
        ddim_sampling_eta=1.0,
        min_snr_loss_weight=False,
        min_snr_gamma=5
    )

    # Move to device
    model = model.to(device)
    scheduler = scheduler.to(device)


    # --- Setup training loop ---
    from tqdm import tqdm

    optimizer = torch.optim.Adam(
        scheduler.parameters(), 
        lr=hparams['optimizer_lr'], 
        weight_decay=hparams['optimizer_weight_decay']
    )

    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=hparams['lr_scheduler_factor'], 
        patience=hparams['lr_scheduler_patience'], 
        min_lr=hparams['lr_scheduler_minlr']
    )


    # --- Initialize tensorboard ---
    # use timestamp to avoid overwriting previous runs
    trial_uuid = str(uuid.uuid4())
    writer = SummaryWriter(log_dir=f'runs/full-DDPM/hparamtuning-{trial_uuid}')

    best_loss = float('inf')
    for epoch in tqdm(range(hparams['epochs'])):
        scheduler.train()
        epoch_loss = 0
        
        # torch.autograd.set_detect_anomaly(True)
        # --- Training loop ---
        for batch in train_dataloader:
            x_batch = batch.x.to(device)
            obj_cond_batch = batch.cond.to(device)
            edge_cond_batch = batch.edge_index.to(device)
            relation_cond_batch = batch.edge_attr.to(device)
            
            # X is read as [B*N, D] and needs to be reshaped to [B, N, D]
            x_batch = x_batch.view(batch.num_graphs, N, D)
            # obj_cond is read as [B*N, C] and needs to be reshaped to [B, N, C]
            obj_cond_batch = obj_cond_batch.view(batch.num_graphs, N, C)
            
            # loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch, noise=noise[:batch.num_graphs, :, :])
            loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch)
            
            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            
        epoch_loss /= len(train_dataloader)
            
        lr_scheduler.step(epoch_loss)
        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)
        
        # --- Validation loop ---
        with torch.no_grad():
            scheduler.eval()
            epoch_loss = 0
            for batch in val_dataloader:
                x_batch = batch.x.to(device)
                obj_cond_batch = batch.cond.to(device)
                edge_cond_batch = batch.edge_index.to(device)
                relation_cond_batch = batch.edge_attr.to(device)
                
                # X is read as [B*N, D] and needs to be reshaped to [B, N, D]
                x_batch = x_batch.view(batch.num_graphs, N, D)
                # obj_cond is read as [B*N, C] and needs to be reshaped to [B, N, C]
                obj_cond_batch = obj_cond_batch.view(batch.num_graphs, N, C)
                
                loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch)
                epoch_loss += loss.item()
                
            epoch_loss /= len(val_dataloader)
            writer.add_scalar('Loss/val', epoch_loss, epoch)

            if epoch_loss < best_loss:
                best_loss = epoch_loss
                torch.save(model.state_dict(), f'models/hparamtuning/val-model-{trial_uuid}.pt')
                print(f"Saved best model with val loss {best_loss}")
        
    # log all the hyperparameters and final loss
    writer.add_hparams(hparams, {'Best loss': best_loss})
    writer.close()
    
    return best_loss


# Define the objective function for hyperparameter optimization
def objective(trial):
    # Define the hyperparameters to tune and their search spaces
    search_space = {
        'batch_size': trial.suggest_categorical('batch_size', [4, 8, 16, 32]),
        'time_dim': trial.suggest_categorical('time_dim', [6, 14, 30, 44]),
    
        'rgc_hidden_dims': trial.suggest_categorical('rgc_hidden_dims', [f'{()}', f'{(30,)}', f'{(15,)}', f'{(60,)}', f'{(30, 30)}']),
        'rgc_num_bases': trial.suggest_categorical('rgc_num_bases', [None, 2, 4, 8, 16]),
        'rgc_aggr': trial.suggest_categorical('rgc_aggr', ['mean', 'sum', 'max']),
        'rgc_activation': 'tanh',
        'rgc_dp_rate': trial.suggest_float('rgc_dp_rate', 0.0, 0.5),
        'rgc_bias': trial.suggest_categorical('rgc_bias', [True, False]),
    
        'attention_self_head_dims': trial.suggest_categorical('attention_self_head_dims', [10, 20, 30, 40]),
        'attention_num_heads': trial.suggest_categorical('attention_num_heads', [1, 2, 3, 4, 5]),
        'attention_cross_head_dims': trial.suggest_categorical('attention_cross_head_dims', [15, 20, 30, 40, 45]),
    
        'scheduler_timesteps': trial.suggest_categorical('scheduler_timesteps', [1000, 2000, 5000]),
        'scheduler_loss': 'l2', # trial.suggest_categorical('scheduler_loss', ['l1', 'l2']),
        'scheduler_beta_schedule': trial.suggest_categorical('scheduler_beta_schedule', ['cosine', 'linear']),
    
        'cfg_cond_drop_prob': trial.suggest_float('cfg_cond_drop_prob', 0.1, 0.5),
    
        'epochs': 300,
        'optimizer_lr': trial.suggest_float('optimizer_lr', 1e-5, 5e-3, log=True),
        'optimizer_weight_decay': trial.suggest_float('optimizer_weight_decay', 1e-6, 1e-3, log=True),
        'lr_scheduler_factor': trial.suggest_float('lr_scheduler_factor', 0.5, 0.9),
        'lr_scheduler_patience': trial.suggest_int('lr_scheduler_patience', 20, 60),
        'lr_scheduler_minlr': trial.suggest_float('lr_scheduler_minlr', 0.00001, 0.001, log=True),
    }
    
    hparams.update(search_space)
    
    return train_scheduler(hparams)

def run_hparam_tuning():
    study = optuna.create_study(direction='minimize')
    study.optimize(objective, n_trials=100) # , n_jobs=-1 for all CPUs
    
    return study

result_study = run_hparam_tuning()

In [None]:
best_hparams = result_study.best_params
best_loss = result_study.best_value

print(f"Best loss: {best_loss}")
print(f"Best hyperparameters: {best_hparams}")
# print every hyperparameter and its value in a separate line
for key, value in best_hparams.items():
    print(f"{key}: {value}")