In [1]:
import torch
import torch.nn as nn

# import from guided-diffusion folder
from model_alternative import GuidedDiffusionNetwork
from ddpm_scheduler import DDPMScheduler
from scenes_dataset import ScenesDataset, DatasetConstants

In [2]:
import json

# Load data from JSON file
with open('datasets/data/train.json', 'r') as file:
    train_data = json.load(file)['scenes']

with open('datasets/data/val.json', 'r') as file:
    val_data = json.load(file)['scenes']

In [3]:
B = 3 # num of scenes in batch

# Scene hyperparams
N = 20 # num of objects in scene
D = 15 # dim of objects from the scene

# Condition hyperparmas
C = 300 # dim of node features
R = 23+1 # num of relations

hparams = {
    'batch_size': B, # num of graphs in batch
    'layer_2_dim': 29, # must be a divisor of 300

    # --- RGCN hyperparams ---
    'rgc_hidden_dims': f"{()}", # (C+D, C+D, D),
    'rgc_num_bases': 5, # Alternative: None
    'rgc_aggr': 'mean',
    'rgc_activation': 'tanh',
    'rgc_dp_rate': 0.,
    'rgc_bias': True,
    
    # --- Attention hyperparams ---
    'attention_self_head_dims': 10,
    'attention_num_heads': 3, 
    'attention_cross_head_dims': 30,
    
    # Scheduler hyperparams
    'scheduler_timesteps': 1000,
    'scheduler_loss': 'l2',
    'scheduler_beta_schedule': 'cosine',
    # Note: not needed for now
    # 'scheduler_sampling_timesteps': None,
    # "scheduler_objective": 'pred_noise',
    # 'scheduler_ddim_sampling_eta': 1.0,
    # 'scheduler_min_snr_loss_weight': False,
    # 'scheduler_min_snr_gamma': 5,
    
    # Classifier-free guidance parameters
    'cfg_cond_drop_prob': 0.,
    
    # Training and optimizer hyperparams
    'epochs': 2000,
    'optimizer_lr': 1e-3,
    'optimizer_weight_decay': 5e-5,
    'lr_scheduler_factor': 0.8,
    'lr_scheduler_patience': 20,
    'lr_scheduler_minlr': 8e-5,
}


In [4]:
general_params = {
    "num_obj": N,
    "obj_cond_dim": C
}

attention_params = {
    "attention_self_head_dim": hparams['attention_self_head_dims'],
    "attention_num_heads": hparams['attention_num_heads'],
    "attention_cross_head_dim": hparams['attention_cross_head_dims']
}

rgc_params = {
    "rgc_hidden_dims": hparams['rgc_hidden_dims'],
    "rgc_num_relations": R,
    "rgc_num_bases": hparams['rgc_num_bases'],
    "rgc_aggr": hparams['rgc_aggr'],
    "rgc_activation": hparams['rgc_activation'],
    "rgc_dp_rate": hparams['rgc_dp_rate'],
    "rgc_bias": hparams['rgc_bias']
}

In [5]:
# Sample noise for every timestep T in the same shape as the scene [B, N, D]
noise = torch.randn(B, N, D)

In [7]:
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.loader import DataLoader

if torch.cuda.is_available():
    device = torch.device('cuda')
# Not all operations support MPS yet so this option is not available for now
# elif torch.has_mps:
#     device = torch.device('mps')
else:
    device = torch.device('cpu')


# --- Load the data
range_matrix = DatasetConstants.get_range_matrix().to(device)

train_dataset = ScenesDataset(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True)

val_dataset = ScenesDataset(val_data)
val_dataloader = DataLoader(val_dataset, batch_size=hparams['batch_size'], shuffle=True)

# --- Instantiate the model
model = GuidedDiffusionNetwork(
    layer_1_dim=D,
    layer_2_dim=hparams['layer_2_dim'],
    general_params=general_params,
    attention_params=attention_params,
    rgc_params=rgc_params,
    cond_drop_prob=hparams['cfg_cond_drop_prob']
)

print(f"Model:\n{model}")

scheduler = DDPMScheduler(
    model=model,
    N=N,
    D=D,
    range_matrix = range_matrix[:, C:],
    timesteps=hparams['scheduler_timesteps'],
    sampling_timesteps=None,
    loss_type=hparams['scheduler_loss'],
    objective='pred_noise',
    beta_schedule=hparams['scheduler_beta_schedule'],
    ddim_sampling_eta=1.0,
    min_snr_loss_weight=False,
    min_snr_gamma=5
)

print(f"DDPM Scheduler:\n{scheduler}")

# Move to device
model = model.to(device)
scheduler = scheduler.to(device)


# --- Setup training loop ---
from tqdm import tqdm

optimizer = torch.optim.Adam(
    scheduler.parameters(), 
    lr=hparams['optimizer_lr'], 
    weight_decay=hparams['optimizer_weight_decay']
)

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=hparams['lr_scheduler_factor'], 
    patience=hparams['lr_scheduler_patience'], 
    min_lr=hparams['lr_scheduler_minlr']
)


# --- Initialize tensorboard ---
# use timestamp to avoid overwriting previous runs
from datetime import datetime
now = datetime.now()
writer = SummaryWriter(log_dir=f'runs/full-DDPM/train-time:{now.strftime("%Y-%m-%d-%H:%M:%S")}')

best_loss = float('inf')

for epoch in tqdm(range(hparams['epochs'])):
    scheduler.train()
    epoch_loss = 0
    
    # torch.autograd.set_detect_anomaly(True)
    # --- Training loop ---
    for batch in train_dataloader:
        x_batch = batch.x.to(device)[:, C:]
        obj_cond_batch = batch.cond.to(device)
        edge_cond_batch = batch.edge_index.to(device)
        relation_cond_batch = batch.edge_attr.to(device)
        
        # X is read as [B*N, D] and needs to be reshaped to [B, N, D]
        x_batch = x_batch.view(batch.num_graphs, N, D)
        # obj_cond is read as [B*N, C] and needs to be reshaped to [B, N, C]
        obj_cond_batch = obj_cond_batch.view(batch.num_graphs, N, C)
        
        loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch, noise=noise[:batch.num_graphs, :, :])
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
    epoch_loss /= len(train_dataloader)
        
    lr_scheduler.step(epoch_loss)
    writer.add_scalar('Loss/train', epoch_loss, epoch)
    writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)
    
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), f'models/overfit-model.pt')
        print(f"Saved best model with train loss {best_loss}")
    
    # --- Validation loop ---
    with torch.no_grad():
        scheduler.eval()
        epoch_loss = 0
        for batch in val_dataloader:
            x_batch = batch.x.to(device)[:, C:]
            obj_cond_batch = batch.cond.to(device)
            edge_cond_batch = batch.edge_index.to(device)
            relation_cond_batch = batch.edge_attr.to(device)
            
            # X is read as [B*N, D] and needs to be reshaped to [B, N, D]
            x_batch = x_batch.view(batch.num_graphs, N, D)
            # obj_cond is read as [B*N, C] and needs to be reshaped to [B, N, C]
            obj_cond_batch = obj_cond_batch.view(batch.num_graphs, N, C)
            
            loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch)
            epoch_loss += loss.item()
            
        epoch_loss /= len(val_dataloader)
        writer.add_scalar('Loss/val', epoch_loss, epoch)
    

# log all the hyperparameters and final loss
writer.add_hparams(hparams, {'Final loss': epoch_loss})

writer.close()

Model:
GuidedDiffusionNetwork(
  (block1): GuidedDiffusionBlock(
    (time_embedding_module): TimeEmbedding()
    (max_pool): MaxPool1d(kernel_size=(20,), stride=(20,), padding=0, dilation=1, ceil_mode=False)
    (rgc_module): RelationalRGCN(
      (layers): ModuleList(
        (0): RGCNConv(15, 15, num_relations=24)
        (1): Tanh()
      )
    )
    (self_attention_module): SelfMultiheadAttention(
      (qkv_proj): Linear(in_features=15, out_features=90, bias=False)
      (o_proj): Linear(in_features=30, out_features=15, bias=False)
      (layer_norm): LayerNorm((20, 15), eps=1e-05, elementwise_affine=True)
    )
    (cross_attention_module): CrossMultiheadAttention(
      (q_proj): Linear(in_features=15, out_features=90, bias=False)
      (kv_proj): Linear(in_features=300, out_features=180, bias=False)
      (o_proj): Linear(in_features=90, out_features=15, bias=False)
      (layer_norm): LayerNorm((20, 15), eps=1e-05, elementwise_affine=True)
    )
  )
  (linear1): Linear(in_fea

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 1/2000 [00:00<10:37,  3.13it/s]

Saved best model with train loss 0.6869428684888792


  0%|          | 2/2000 [00:00<10:17,  3.23it/s]

Saved best model with train loss 0.3448775373091382


  0%|          | 3/2000 [00:00<10:42,  3.11it/s]

Saved best model with train loss 0.2589108509219382


  0%|          | 4/2000 [00:01<10:27,  3.18it/s]

Saved best model with train loss 0.227265632965348


  0%|          | 6/2000 [00:01<09:59,  3.33it/s]

Saved best model with train loss 0.22469961615510223


  0%|          | 7/2000 [00:02<10:07,  3.28it/s]

Saved best model with train loss 0.20778145813498616


  0%|          | 8/2000 [00:02<10:21,  3.21it/s]

Saved best model with train loss 0.2044878453016281


  0%|          | 9/2000 [00:02<10:21,  3.20it/s]

Saved best model with train loss 0.18026175008209283


  0%|          | 10/2000 [00:03<10:20,  3.21it/s]

Saved best model with train loss 0.16927399107616795


  1%|          | 12/2000 [00:03<10:12,  3.25it/s]

Saved best model with train loss 0.15753345323865078


  1%|          | 15/2000 [00:04<09:48,  3.37it/s]

Saved best model with train loss 0.14077337474123505


  1%|          | 16/2000 [00:04<09:49,  3.36it/s]

Saved best model with train loss 0.13741632130705128


  1%|          | 19/2000 [00:05<10:09,  3.25it/s]

Saved best model with train loss 0.12230118721230956


  1%|          | 21/2000 [00:06<09:57,  3.31it/s]

Saved best model with train loss 0.11615954916000612


  1%|          | 22/2000 [00:06<09:54,  3.33it/s]

Saved best model with train loss 0.11536219389724338


  1%|▏         | 29/2000 [00:08<09:34,  3.43it/s]

Saved best model with train loss 0.09076068698128392


  2%|▏         | 35/2000 [00:10<09:39,  3.39it/s]

Saved best model with train loss 0.09047558298719323


  2%|▏         | 43/2000 [00:12<09:37,  3.39it/s]

Saved best model with train loss 0.07517469388794554


  3%|▎         | 53/2000 [00:15<09:35,  3.38it/s]

Saved best model with train loss 0.07448521955323613


  3%|▎         | 55/2000 [00:16<09:34,  3.39it/s]

Saved best model with train loss 0.06409392599016428


  3%|▎         | 62/2000 [00:18<09:22,  3.44it/s]

Saved best model with train loss 0.06320816036005897


  3%|▎         | 67/2000 [00:19<09:24,  3.42it/s]

Saved best model with train loss 0.06299344278594926


  4%|▎         | 73/2000 [00:21<09:17,  3.46it/s]

Saved best model with train loss 0.06215042442329659


  4%|▍         | 76/2000 [00:22<09:20,  3.43it/s]

Saved best model with train loss 0.05618949331851168


  4%|▍         | 86/2000 [00:25<09:24,  3.39it/s]

Saved best model with train loss 0.052687199770911665


  4%|▍         | 90/2000 [00:26<09:20,  3.41it/s]

Saved best model with train loss 0.05228165175750359


  5%|▍         | 92/2000 [00:27<09:22,  3.39it/s]

Saved best model with train loss 0.04248432045386843


  5%|▍         | 98/2000 [00:28<09:17,  3.41it/s]

Saved best model with train loss 0.03727280735692456


  7%|▋         | 133/2000 [00:39<09:10,  3.39it/s]

Saved best model with train loss 0.03207480369813063


 11%|█         | 223/2000 [01:05<08:36,  3.44it/s]

Saved best model with train loss 0.025768770205149474


 13%|█▎        | 263/2000 [01:16<08:20,  3.47it/s]

Saved best model with train loss 0.024917346078994845


 16%|█▌        | 321/2000 [01:33<08:14,  3.40it/s]

Saved best model with train loss 0.024267007238885835


 19%|█▊        | 373/2000 [01:49<08:23,  3.23it/s]

Saved best model with train loss 0.02258915068920363


 26%|██▌       | 524/2000 [02:34<07:09,  3.44it/s]

Saved best model with train loss 0.021564048868873396


 29%|██▉       | 589/2000 [02:53<06:49,  3.44it/s]

Saved best model with train loss 0.021390371174697786


 36%|███▌      | 723/2000 [03:32<06:14,  3.41it/s]


KeyboardInterrupt: 