In [18]:
import torch
import torch.nn as nn

# import from guided-diffusion folder
from model import GuidedDiffusionNetwork
from ddpm_scheduler import DDPMScheduler

## Mock data

In this section, we will generate an example scene graph and a corresponding matrix X to feed into the network.

In [19]:
# Training data hyperparams
B = 32 # num of graphs in batch

# Scene hyperparams
N = 20 # num of nodes
D = 16 # dim of objects from the scene

# Scene Condition hyperparams
C = 30 # dim of node features
E = 22 # num of edges
R = 8 + 1 # num of edge types (including 'unknown' type)

In [20]:
from torch_geometric.data import Data

# Scene Graphs for conditioning
def generate_random_graph(is_one_hot=None):
    # --- Initialize nodes ---
    if is_one_hot is not None:
        nodes = torch.zeros(N, C) # creates N x D tensor of (random) node features
        # Initialize nodes with one-hot encoding
        for i in range(N):
            nodes[i, torch.randint(C, (1,))] = 1
    else:
        nodes = torch.randn(N, C) # creates N x D tensor of (random) node features
    # Initialize nodes with one-hot encoding

    # --- Initialize edges --- 
    edges = torch.randint(N, (2, E)) # creates 2 x E tensor of (random) edges

    # --- Introduce different types of edges ---
    rels = torch.randint(R - 1, (E,)) + 1 # creates E x 1 tensor of (random) edge types excluding 'unknown' type with value 0

    # --- Create a graph ---
    graph = Data(x=nodes, edge_index=edges, edge_attr=rels)
    
    return graph

# --- Initialize batch ---
graphs = [generate_random_graph(is_one_hot=True) for _ in range(B)]

print(f"Batch size: {len(graphs)}, dimensions of each graph: {graphs[0]}")

# print(graphs[0].edge_attr)

Batch size: 32, dimensions of each graph: Data(x=[20, 30], edge_index=[2, 22], edge_attr=[22])


In [21]:
# Scenes for denoising
# X = torch.randn(B, N, D) # creates B x N x D tensor of (random) node features
# TODO: Create X as ones for testing now
Xs = [torch.ones(N, D) * (i / float(B)) for i in range(B)]
X = torch.stack(Xs, dim=0)
print(f"Dimensions: X={X.shape}")

Dimensions: X=torch.Size([32, 20, 16])


In [22]:
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, X, graphs):
        super(CustomDataset, self).__init__()
        self.X = X
        self.graphs = graphs

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        x = self.X[index]
        graph = self.graphs[index]
        return {
            'x': x,
            'obj_cond': graph.x,
            'edge_cond': graph.edge_index,
            'relation_cond': graph.edge_attr
        }

    def collate_fn(self, batch):
        x_batch = torch.stack([item['x'] for item in batch], dim=0)
        obj_cond_batch = torch.cat([item['obj_cond'] for item in batch], dim=0)
        edge_cond_batch = torch.cat([item['edge_cond'] for item in batch], dim=1)
        relation_cond_batch = torch.cat([item['relation_cond'] for item in batch], dim=0)

        return {
            'x': x_batch,
            'obj_cond': obj_cond_batch,
            'edge_cond': edge_cond_batch,
            'relation_cond': relation_cond_batch
        }

## Hyperparameter tuning

In this section, we will tune the hyperparameters of the network.

In [23]:
hparams = {
    'batch_size': B, # num of graphs in batch
    
    # --- Attention hyperparams ---
    'attention_out_dim': D,
    'attention_num_heads': 5, # must be a divisor of D
    
    # --- Encoder RGCN hyperparams ---
    'encoder_out_dim': C,
    'encoder_hidden_dims': f"{()}", # (C, C),
    'encoder_num_bases': None,
    'encoder_aggr': 'mean',
    'encoder_activation': 'leakyrelu',
    'encoder_dp_rate': 0.,
    'encoder_bias': True,
    
    # --- Fusion RGCN hyperparams ---
    'fusion_hidden_dims': f"{()}", # (C+D, C+D, D),
    'fusion_num_bases': None,
    'fusion_aggr': 'mean',
    'fusion_activation': 'leakyrelu',
    'fusion_dp_rate': 0.,
    'fusion_bias': True,
    
    # Scheduler hyperparams
    'scheduler_timesteps': 1000,
    'scheduler_loss': 'l1',
    'scheduler_beta_schedule': 'cosine',
    # Note: not needed for now
    # 'scheduler_sampling_timesteps': None,
    # "scheduler_objective": 'pred_noise',
    # 'scheduler_ddim_sampling_eta': 1.0,
    # 'scheduler_min_snr_loss_weight': False,
    # 'scheduler_min_snr_gamma': 5,
    
    # Classifier-free guidance parameters
    'cfg_cond_drop_prob': 0.1,
    
    # Training and optimizer hyperparams
    'epochs': 3000,
    'optimizer_lr': 1e-3,
    'optimizer_weight_decay': 5e-4,
    'lr_scheduler_factor': 0.7,
    'lr_scheduler_patience': 100,
    'lr_scheduler_minlr': 0.00001,
}

In [26]:
from torch.utils.data import DataLoader
from tqdm import tqdm
import uuid
from torch.utils.tensorboard import SummaryWriter
import optuna

# --- Load the (mocked) data
dataset = CustomDataset(X, graphs)
dataloader = DataLoader(dataset, batch_size=hparams['batch_size'], shuffle=True, collate_fn=dataset.collate_fn)

def train_scheduler(hparams):
    model = GuidedDiffusionNetwork(
        attention_in_dim=D,
        attention_out_dim=hparams['attention_out_dim'],
        attention_num_heads=hparams['attention_num_heads'],
        
        rgcn_num_relations=R,
        
        encoder_in_dim=C,
        encoder_out_dim=hparams['encoder_out_dim'],
        encoder_hidden_dims=hparams['encoder_hidden_dims'],
        encoder_num_bases=hparams['encoder_num_bases'],
        encoder_aggr=hparams['encoder_aggr'],
        encoder_activation=hparams['encoder_activation'],
        encoder_dp_rate=hparams['encoder_dp_rate'],
        encoder_bias=hparams['encoder_bias'],
        
        fusion_hidden_dims=hparams['fusion_hidden_dims'],
        fusion_num_bases=hparams['fusion_num_bases'],
        fusion_aggr=hparams['fusion_aggr'],
        fusion_activation=hparams['fusion_activation'],
        fusion_dp_rate=hparams['fusion_dp_rate'],
        fusion_bias=hparams['fusion_bias'],
        
        cond_drop_prob=hparams['cfg_cond_drop_prob']
    )

    scheduler = DDPMScheduler(
        model=model,
        N=N,
        D=D,
        timesteps=hparams['scheduler_timesteps'],
        sampling_timesteps=None,
        loss_type=hparams['scheduler_loss'],
        objective='pred_noise',
        beta_schedule=hparams['scheduler_beta_schedule'],
        ddim_sampling_eta=1.0,
        min_snr_loss_weight=False,
        min_snr_gamma=5
    )
    
    optimizer = torch.optim.Adam(
        scheduler.parameters(), 
        lr=hparams['optimizer_lr'], 
        weight_decay=hparams['optimizer_weight_decay']
    )
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=hparams['lr_scheduler_factor'], 
        patience=hparams['lr_scheduler_patience'], 
        min_lr=hparams['lr_scheduler_minlr']
    )
    
    # Generate a unique id for each trial
    trial_uuid = str(uuid.uuid4())
    writer = SummaryWriter(log_dir=f'runs/full-DDPM/hparamtuning-{trial_uuid}')

    # --- Training loop ---
    best_loss = float('inf')
    for epoch in tqdm(range(hparams['epochs'])):
        epoch_loss = 0
        for batch in dataloader:
            x_batch = batch['x']
            obj_cond_batch = batch['obj_cond']
            edge_cond_batch = batch['edge_cond']
            relation_cond_batch = batch['relation_cond']
            
            loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch)
            
            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            
        epoch_loss /= len(dataloader)
        
        best_loss = min(best_loss, epoch_loss)
            
        lr_scheduler.step(epoch_loss)
        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)

    # log all the hyperparameters and final loss
    writer.add_hparams(hparams, {'Best loss': best_loss})
    writer.close()
    
    return best_loss

# Define the objective function for hyperparameter optimization
def objective(trial):
    # Define the hyperparameters to tune and their search spaces
    search_space = {
        'batch_size': trial.suggest_categorical('batch_size', [4, 16, 32]),
    
        'attention_out_dim': trial.suggest_categorical('attention_out_dim', [16, 32, 48]),
        'attention_num_heads': trial.suggest_categorical('attention_num_heads', [2, 4, 8, 16]),
    
        'encoder_out_dim': trial.suggest_categorical('encoder_out_dim', [2, 10, 20, 30]),
        'encoder_hidden_dims': trial.suggest_categorical('encoder_hidden_dims', [f'{()}', f'{(30,)}', f'{(15,)}', f'{(60,)}', f'{(30, 30)}']),
        'encoder_num_bases': trial.suggest_categorical('encoder_num_bases', [None, 2, 4, 8]),
        'encoder_aggr': trial.suggest_categorical('encoder_aggr', ['mean', 'sum', 'max']),
        'encoder_activation': trial.suggest_categorical('encoder_activation', ['leakyrelu', 'relu', 'silu']),
        'encoder_dp_rate': trial.suggest_float('encoder_dp_rate', 0.0, 0.5),
        'encoder_bias': trial.suggest_categorical('encoder_bias', [True, False]),
    
        'fusion_hidden_dims': trial.suggest_categorical('fusion_hidden_dims', [f'{()}', f'{(8,)}', f'{(18,)}', f'{(24,)}', f'{(48,)}', f'{(48, 24)}', f'{(48, 32)}', f"{(48, 32, 24)}"]),
        'fusion_num_bases': trial.suggest_categorical('fusion_num_bases', [None, 2, 4, 8]),
        'fusion_aggr': trial.suggest_categorical('fusion_aggr', ['mean', 'sum', 'max']),
        'fusion_activation': trial.suggest_categorical('fusion_activation', ['leakyrelu', 'relu', 'silu']),
        'fusion_dp_rate': trial.suggest_float('fusion_dp_rate', 0.0, 0.5),
        'fusion_bias': trial.suggest_categorical('fusion_bias', [True, False]),
    
        'scheduler_timesteps': trial.suggest_categorical('scheduler_timesteps', [1000, 2000, 5000]),
        'scheduler_loss': trial.suggest_categorical('scheduler_loss', ['l1', 'l2']),
        'scheduler_beta_schedule': trial.suggest_categorical('scheduler_beta_schedule', ['cosine', 'linear']),
    
        'cfg_cond_drop_prob': trial.suggest_float('cfg_cond_drop_prob', 0.0, 0.5),
    
        'epochs': 1250,
        'optimizer_lr': trial.suggest_float('optimizer_lr', 1e-5, 5e-3, log=True),
        'optimizer_weight_decay': trial.suggest_float('optimizer_weight_decay', 1e-6, 1e-3, log=True),
        'lr_scheduler_factor': trial.suggest_float('lr_scheduler_factor', 0.5, 0.9),
        'lr_scheduler_patience': trial.suggest_int('lr_scheduler_patience', 50, 200),
        'lr_scheduler_minlr': trial.suggest_float('lr_scheduler_minlr', 0.00001, 0.001, log=True),
    }
    
    hparams.update(search_space)
    
    return train_scheduler(hparams)

def run_hparam_tuning():
    study = optuna.create_study(direction='minimize')
    study.optimize(objective, n_trials=100, n_jobs=-1)
    
    return study

result_study = run_hparam_tuning()

[I 2023-06-04 14:55:19,205] A new study created in memory with name: no-name-09139770-614f-4449-a0ff-4f79afee5019
  0%|          | 0/1250 [00:00<?, ?it/s]
[A


[A[A[A

[A[A





[A[A[A[A[A[A




[A[A[A[A[A



  0%|          | 1/1250 [00:00<03:10,  6.57it/s]

[A[A




[A[A[A[A[A



[A[A[A[A
[A


[A[A[A





[A[A[A[A[A[A




[A[A[A[A[A



[A[A[A[A

  0%|          | 2/1250 [00:00<03:14,  6.42it/s]
[A


[A[A[A





[A[A[A[A[A[A




[A[A[A[A[A



[A[A[A[A

  0%|          | 3/1250 [00:00<03:14,  6.42it/s]
[A




[A[A[A[A[A



[A[A[A[A


[A[A[A

[A[A





  0%|          | 4/1250 [00:00<03:05,  6.73it/s]




[A[A[A[A[A
[A



[A[A[A[A

[A[A


  0%|          | 5/1250 [00:00<02:49,  7.33it/s]




[A[A[A[A[A





[A[A[A[A[A[A
[A

[A[A




[A[A[A[A[A


  0%|          | 6/1250 [00:00<02:54,  7.15it/s]



[A[A[A[A





[A[A[A[A[A[A

[A[A




[A[A[A[A[A
[A


  1%|          

In [41]:
best_hparams = result_study.best_params
best_loss = result_study.best_value

print(f"Best loss: {best_loss}")
print(f"Best hyperparameters: {best_hparams}")
# print every hyperparameter and its value in a separate line
for key, value in best_hparams.items():
    print(f"{key}: {value}")

Best loss: 0.14199760556221008
Best hyperparameters: {'batch_size': 4, 'attention_out_dim': 48, 'attention_num_heads': 2, 'encoder_out_dim': 10, 'encoder_hidden_dims': '(15,)', 'encoder_num_bases': 4, 'encoder_aggr': 'sum', 'encoder_activation': 'silu', 'encoder_dp_rate': 0.07646450520670617, 'encoder_bias': True, 'fusion_hidden_dims': '(18,)', 'fusion_num_bases': None, 'fusion_aggr': 'max', 'fusion_activation': 'leakyrelu', 'fusion_dp_rate': 0.046351967280011154, 'fusion_bias': False, 'scheduler_timesteps': 2000, 'scheduler_loss': 'l1', 'scheduler_beta_schedule': 'linear', 'cfg_cond_drop_prob': 0.30141644746465146, 'optimizer_lr': 0.0016730837588311111, 'optimizer_weight_decay': 7.748251613968323e-06, 'lr_scheduler_factor': 0.6866331533478531, 'lr_scheduler_patience': 77, 'lr_scheduler_minlr': 0.00036801862369884345}
batch_size: 4
attention_out_dim: 48
attention_num_heads: 2
encoder_out_dim: 10
encoder_hidden_dims: (15,)
encoder_num_bases: 4
encoder_aggr: sum
encoder_activation: silu
