In [1]:
import torch
import torch.nn as nn

# import from guided-diffusion folder
from model import GuidedDiffusionNetwork
from ddpm_scheduler import DDPMScheduler
from scenes_dataset import ScenesDataset, DatasetConstants

In [2]:
import json

# Load data from JSON file
with open('datasets/data/train.json', 'r') as file:
    train_data = json.load(file)['scenes']

with open('datasets/data/val.json', 'r') as file:
    val_data = json.load(file)['scenes']

In [3]:
B = 3 # num of scenes in batch

# Scene hyperparams
N = 20 # num of objects in scene
D = 15 # dim of objects from the scene

# Condition hyperparmas
C = 300 # dim of node features
R = 23+1 # num of relations

hparams = {
    'batch_size': B,
    'layer_2_dim': 29,

    # --- RGCN hyperparams ---
    'rgc_hidden_dims': f"{()}", # (C+D, C+D, D),
    'rgc_num_bases': 5, # Alternative: None
    'rgc_aggr': 'mean',
    'rgc_activation': 'tanh',
    'rgc_dp_rate': 0.,
    'rgc_bias': True,
    
    # --- Attention hyperparams ---
    'attention_self_head_dims': 10,
    'attention_num_heads': 3, 
    'attention_cross_head_dims': 30,
    
    # Scheduler hyperparams
    'scheduler_timesteps': 1000,
    'scheduler_loss': 'l2',
    'scheduler_beta_schedule': 'cosine',
    # Note: not needed for now
    # 'scheduler_sampling_timesteps': None,
    # "scheduler_objective": 'pred_noise',
    # 'scheduler_ddim_sampling_eta': 1.0,
    # 'scheduler_min_snr_loss_weight': False,
    # 'scheduler_min_snr_gamma': 5,
    
    # Classifier-free guidance parameters
    'cfg_cond_drop_prob': 0.7,
    
    # Training and optimizer hyperparams
    'epochs': 2000,
    'optimizer_lr': 1e-3,
    'optimizer_weight_decay': 5e-5,
    'lr_scheduler_factor': 0.8,
    'lr_scheduler_patience': 20,
    'lr_scheduler_minlr': 8e-5,
}


In [4]:
general_params = {
    "num_obj": N,
    "obj_cond_dim": C
}

attention_params = {
    "attention_self_head_dim": hparams['attention_self_head_dims'],
    "attention_num_heads": hparams['attention_num_heads'],
    "attention_cross_head_dim": hparams['attention_cross_head_dims']
}

rgc_params = {
    "rgc_hidden_dims": hparams['rgc_hidden_dims'],
    "rgc_num_relations": R,
    "rgc_num_bases": hparams['rgc_num_bases'],
    "rgc_aggr": hparams['rgc_aggr'],
    "rgc_activation": hparams['rgc_activation'],
    "rgc_dp_rate": hparams['rgc_dp_rate'],
    "rgc_bias": hparams['rgc_bias']
}

In [5]:
# Used only for debugging (!)

# Sample noise for every timestep T in the same shape as the scene [B, N, D]
noise = torch.randn(B, N, D)

In [6]:
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.loader import DataLoader

if torch.cuda.is_available():
    device = torch.device('cuda')
# Not all operations support MPS yet so this option is not available for now
# elif torch.has_mps:
#     device = torch.device('mps')
else:
    device = torch.device('cpu')


# --- Load the data
range_matrix = DatasetConstants.get_range_matrix().to(device)

train_dataset = ScenesDataset(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True)

val_dataset = ScenesDataset(val_data)
val_dataloader = DataLoader(val_dataset, batch_size=hparams['batch_size'], shuffle=True)

# --- Instantiate the model
model = GuidedDiffusionNetwork(
    layer_1_dim=D,
    layer_2_dim=hparams['layer_2_dim'],
    general_params=general_params,
    attention_params=attention_params,
    rgc_params=rgc_params,
    cond_drop_prob=hparams['cfg_cond_drop_prob']
)

print(f"Model:\n{model}")

scheduler = DDPMScheduler(
    model=model,
    N=N,
    D=D,
    range_matrix = range_matrix,
    timesteps=hparams['scheduler_timesteps'],
    sampling_timesteps=None,
    loss_type=hparams['scheduler_loss'],
    objective='pred_noise',
    beta_schedule=hparams['scheduler_beta_schedule'],
    ddim_sampling_eta=1.0,
    min_snr_loss_weight=False,
    min_snr_gamma=5
)

print(f"DDPM Scheduler:\n{scheduler}")

# Move to device
model = model.to(device)
scheduler = scheduler.to(device)


# --- Setup training loop ---
from tqdm import tqdm

optimizer = torch.optim.Adam(
    scheduler.parameters(), 
    lr=hparams['optimizer_lr'], 
    weight_decay=hparams['optimizer_weight_decay']
)

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=hparams['lr_scheduler_factor'], 
    patience=hparams['lr_scheduler_patience'], 
    min_lr=hparams['lr_scheduler_minlr']
)


# --- Initialize tensorboard ---
# use timestamp to avoid overwriting previous runs
from datetime import datetime
now = datetime.now()
writer = SummaryWriter(log_dir=f'runs/full-DDPM/train-time:{now.strftime("%Y-%m-%d-%H:%M:%S")}')

best_loss = float('inf')

for epoch in tqdm(range(hparams['epochs'])):
    scheduler.train()
    epoch_loss = 0
    
    # torch.autograd.set_detect_anomaly(True)
    # --- Training loop ---
    for batch in train_dataloader:
        x_batch = batch.x.to(device)
        obj_cond_batch = batch.cond.to(device)
        edge_cond_batch = batch.edge_index.to(device)
        relation_cond_batch = batch.edge_attr.to(device)
        
        # X is read as [B*N, D] and needs to be reshaped to [B, N, D]
        x_batch = x_batch.view(batch.num_graphs, N, D)
        # obj_cond is read as [B*N, C] and needs to be reshaped to [B, N, C]
        obj_cond_batch = obj_cond_batch.view(batch.num_graphs, N, C)
        
        # loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch, noise=noise[:batch.num_graphs, :, :])
        loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch)
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
    epoch_loss /= len(train_dataloader)
        
    lr_scheduler.step(epoch_loss)
    writer.add_scalar('Loss/train', epoch_loss, epoch)
    writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)
    
    # if epoch_loss < best_loss:
    #     best_loss = epoch_loss
    #     torch.save(model.state_dict(), f'models/overfit-model.pt')
    #     print(f"Saved best model with train loss {best_loss}")
    
    # --- Validation loop ---
    with torch.no_grad():
        scheduler.eval()
        epoch_loss = 0
        for batch in val_dataloader:
            x_batch = batch.x.to(device)
            obj_cond_batch = batch.cond.to(device)
            edge_cond_batch = batch.edge_index.to(device)
            relation_cond_batch = batch.edge_attr.to(device)
            
            # X is read as [B*N, D] and needs to be reshaped to [B, N, D]
            x_batch = x_batch.view(batch.num_graphs, N, D)
            # obj_cond is read as [B*N, C] and needs to be reshaped to [B, N, C]
            obj_cond_batch = obj_cond_batch.view(batch.num_graphs, N, C)
            
            loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch)
            epoch_loss += loss.item()
            
        epoch_loss /= len(val_dataloader)
        writer.add_scalar('Loss/val', epoch_loss, epoch)

        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(model.state_dict(), f'models/val-model.pt')
            print(f"Saved best model with val loss {best_loss}")
    
# log all the hyperparameters and final loss
writer.add_hparams(hparams, {'Final loss': epoch_loss})

writer.close()

Model:
GuidedDiffusionNetwork(
  (time): GuidedDiffusionTime(
    (time_embedding_module): TimeEmbedding()
  )
  (rgc1): GuidedDiffusionRGC(
    (rgc_module): RelationalRGCN(
      (layers): ModuleList(
        (0): RGCNConv(29, 29, num_relations=24)
        (1): Tanh()
      )
    )
  )
  (block1): GuidedDiffusionBlock(
    (max_pool): MaxPool1d(kernel_size=(12,), stride=(12,), padding=0, dilation=1, ceil_mode=False)
    (self_attention_module): SelfMultiheadAttention(
      (qkv_proj): Linear(in_features=29, out_features=90, bias=False)
      (o_proj): Linear(in_features=30, out_features=29, bias=False)
      (layer_norm): LayerNorm((20, 29), eps=1e-05, elementwise_affine=True)
    )
    (cross_attention_module): CrossMultiheadAttention(
      (q_proj): Linear(in_features=29, out_features=90, bias=False)
      (kv_proj): Linear(in_features=300, out_features=180, bias=False)
      (o_proj): Linear(in_features=90, out_features=29, bias=False)
      (layer_norm): LayerNorm((20, 29), eps

  0%|          | 0/2000 [00:00<?, ?it/s]

With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> False
With

  0%|          | 1/2000 [00:01<46:48,  1.40s/it]

With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
Saved best model with val loss 0.55282355149586

  0%|          | 2/2000 [00:02<46:30,  1.40s/it]

With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> False
Saved best model with val loss 0.30880971948305763
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> Fals

  0%|          | 2/2000 [00:03<52:25,  1.57s/it]

With conditional prob 0.7 we drop? -> False
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True
With conditional prob 0.7 we drop? -> True





KeyboardInterrupt: 