In [1]:
import torch
import torch.nn as nn

# import from guided-diffusion folder
from model import GuidedDiffusionNetwork
from ddpm_scheduler import DDPMScheduler

# Putting everything together

In this section we put everything together to have a complete diffusion model.

The network will go as follows:

TODO: explain the final version

<!-- 1. The first part of the network works is the unconditional denoising/diffusion. For that, we represent 3D scenes as a matrix X of dimension [B, N, D] with B batch size, N number of nodes (objects) and D Dimension of each object storing its information (location, size, ...). In this part of the network we feed X through a custom multihead attention layer (see module #1) and get as output data of dimensions [B, N, E] with E hidden Dimension.

2. Parallel to the first part, we also receive a scene graph in form of a triple (N, C, C) with N nodes storing string description of objects ("chair", "table", etc.) and C edges as list of tuples (id_1, id_2) indicating an outgoing connection from node id_1 to node id_2 as well as C connection types as ints. The string description of each node is embedded using FastText (see module #2) and an relational GCN (see module #3) is built with nodes storing these embeddings, and edges as well as edge types extracted from the scene graph following the aforementioned structure. This RGCN block has no hidden layers and outputs dimension F.

3. Outputs from the first and second step [B, N, E] and [B, N, F] are concatenated in the last axis to form [B, N, E + F]. Time embedding (see module #4) of matching dimension is produced and added to that result. This matrix is then fed into another relational GCN that consists of N nodes (each row becomes a node) and reuses the edges and edge types from the first RGCN. This new RGCN has a few hidden layers and results in dimension [B, N, D], finishing the forward pass. -->


In [6]:
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, X, graphs):
        super(CustomDataset, self).__init__()
        self.X = X
        self.graphs = graphs

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        x = self.X[index]
        graph = self.graphs[index]
        return {
            'x': x,
            'obj_cond': graph.x,
            'edge_cond': graph.edge_index,
            'relation_cond': graph.edge_attr
        }

    def collate_fn(self, batch):
        x_batch = torch.stack([item['x'] for item in batch], dim=0)
        obj_cond_batch = torch.cat([item['obj_cond'] for item in batch], dim=0)
        edge_cond_batch = torch.cat([item['edge_cond'] for item in batch], dim=1)
        relation_cond_batch = torch.cat([item['relation_cond'] for item in batch], dim=0)

        return {
            'x': x_batch,
            'obj_cond': obj_cond_batch,
            'edge_cond': edge_cond_batch,
            'relation_cond': relation_cond_batch
        }


In [8]:
# Read the data from the JSON file
import pandas as pd

def load_data(name):
    df = pd.read_json(f'datasets/data/{name}.json')
    return df['scenes'].values

df_val = load_data('val')

In [14]:
import json

# Load data from JSON file
with open('datasets/data/val.json', 'r') as file:
    data = json.load(file)['scenes']

data[0]

{'scene_id': '742e8f15-be0a-294e-9ebb-6c72dbcb9662',
 'scene_matrix': [[0.07241161167621613,
   -0.00645842170342803,
   -0.02096150815486908,
   0.11923903226852417,
   0.08361779153347015,
   0.035234011709690094,
   0.08475619554519653,
   -0.0008645294001325965,
   -0.07096676528453827,
   -0.0402718260884285,
   -0.09819656610488892,
   -0.1440766453742981,
   0.01132181566208601,
   -0.06274738907814026,
   0.10340331494808197,
   0.11818590015172958,
   -0.18197298049926758,
   0.08702360093593597,
   -0.05749018117785454,
   -0.07095494121313095,
   0.20678642392158508,
   0.1583821028470993,
   0.19579824805259705,
   0.06903882324695587,
   0.11693169176578522,
   0.11191493272781372,
   -0.007782831788063049,
   0.10218502581119537,
   0.1366354078054428,
   0.29891061782836914,
   0.12411278486251831,
   0.2010473608970642,
   -0.15790300071239471,
   -0.018323445692658424,
   0.08680214732885361,
   0.03868967294692993,
   -0.06567062437534332,
   0.12225886434316635,
   -

In [9]:
df_val

array([{'scene_id': '742e8f15-be0a-294e-9ebb-6c72dbcb9662', 'scene_matrix': [[0.072411611676216, -0.0064584217034280005, -0.020961508154869003, 0.119239032268524, 0.08361779153347, 0.035234011709690004, 0.084756195545196, -0.0008645294001320001, -0.070966765284538, -0.040271826088428005, -0.09819656610488801, -0.144076645374298, 0.011321815662086001, -0.06274738907814001, 0.10340331494808101, 0.11818590015172901, -0.18197298049926702, 0.087023600935935, -0.057490181177854004, -0.07095494121313001, 0.20678642392158503, 0.158382102847099, 0.19579824805259702, 0.06903882324695501, 0.11693169176578501, 0.11191493272781301, -0.007782831788063001, 0.10218502581119501, 0.136635407805442, 0.29891061782836903, 0.124112784862518, 0.20104736089706401, -0.15790300071239402, -0.018323445692658, 0.08680214732885301, 0.038689672946929, -0.065670624375343, 0.122258864343166, -0.18285796046257002, -0.07926345616579, 0.025466959923505002, -0.10095076262950801, -0.004142294637858, 0.021703977137804, -0.1

## Model and Train setup

In [8]:
B = 4 # num of scenes in batch

N = 20 # num of nodes

# RGCN hyperparams
C = 6 # dim of node features
E = 22 # num of edges
R = 8 + 1 # num of edge types (including 'unknown' type)

# Attention hyperparams
D = 10 # dim of objects from the scene

hparams = {
    'batch_size': B, # num of graphs in batch
    
    # --- Attention hyperparams ---
    'attention_out_dim': D,
    'attention_num_heads': 5, # must be a divisor of D
    
    # --- Encoder RGCN hyperparams ---
    'encoder_out_dim': C,
    'encoder_hidden_dims': f"{()}", # (C, C),
    'encoder_num_bases': None,
    'encoder_aggr': 'mean',
    'encoder_activation': 'leakyrelu',
    'encoder_dp_rate': 0.,
    'encoder_bias': True,
    
    # --- Fusion RGCN hyperparams ---
    'fusion_hidden_dims': f"{()}", # (C+D, C+D, D),
    'fusion_num_bases': None,
    'fusion_aggr': 'mean',
    'fusion_activation': 'leakyrelu',
    'fusion_dp_rate': 0.,
    'fusion_bias': True,
    
    # Scheduler hyperparams
    'scheduler_timesteps': 1000,
    'scheduler_loss': 'l1',
    'scheduler_beta_schedule': 'cosine',
    # Note: not needed for now
    # 'scheduler_sampling_timesteps': None,
    # "scheduler_objective": 'pred_noise',
    # 'scheduler_ddim_sampling_eta': 1.0,
    # 'scheduler_min_snr_loss_weight': False,
    # 'scheduler_min_snr_gamma': 5,
    
    # Classifier-free guidance parameters
    'cfg_cond_drop_prob': 0.1,
    
    # Training and optimizer hyperparams
    'epochs': 3000,
    'optimizer_lr': 1e-3,
    'optimizer_weight_decay': 5e-4,
    'lr_scheduler_factor': 0.7,
    'lr_scheduler_patience': 100,
    'lr_scheduler_minlr': 0.00001,
}

In [10]:
from torch.utils.tensorboard import SummaryWriter

if torch.cuda.is_available():
    device = torch.device('cuda')
# Not all operations support MPS yet so this option is not available for now
# elif torch.has_mps:
#     device = torch.device('mps')
else:
    device = torch.device('cpu')

# --- Load the (mocked) data
dataloader = DataLoader(dataset, batch_size=hparams['batch_size'], shuffle=True, collate_fn=dataset.collate_fn)


# --- Instantiate the model
model = GuidedDiffusionNetwork(
    attention_in_dim=D,
    attention_out_dim=hparams['attention_out_dim'],
    attention_num_heads=hparams['attention_num_heads'],
    
    rgcn_num_relations=R,
    
    encoder_in_dim=C,
    encoder_out_dim=hparams['encoder_out_dim'],
    encoder_hidden_dims=hparams['encoder_hidden_dims'],
    encoder_num_bases=hparams['encoder_num_bases'],
    encoder_aggr=hparams['encoder_aggr'],
    encoder_activation=hparams['encoder_activation'],
    encoder_dp_rate=hparams['encoder_dp_rate'],
    encoder_bias=hparams['encoder_bias'],
    
    fusion_hidden_dims=hparams['fusion_hidden_dims'],
    fusion_num_bases=hparams['fusion_num_bases'],
    fusion_aggr=hparams['fusion_aggr'],
    fusion_activation=hparams['fusion_activation'],
    fusion_dp_rate=hparams['fusion_dp_rate'],
    fusion_bias=hparams['fusion_bias'],
    
    cond_drop_prob=hparams['cfg_cond_drop_prob']
)

print(f"Model:\n{model}")

scheduler = DDPMScheduler(
    model=model,
    N=N,
    D=D,
    timesteps=hparams['scheduler_timesteps'],
    sampling_timesteps=None,
    loss_type=hparams['scheduler_loss'],
    objective='pred_noise',
    beta_schedule=hparams['scheduler_beta_schedule'],
    ddim_sampling_eta=1.0,
    min_snr_loss_weight=False,
    min_snr_gamma=5
)

print(f"DDPM Scheduler:\n{scheduler}")

# Move to device
model = model.to(device)
scheduler = scheduler.to(device)


# --- Setup training loop ---
from tqdm import tqdm

optimizer = torch.optim.Adam(
    scheduler.parameters(), 
    lr=hparams['optimizer_lr'], 
    weight_decay=hparams['optimizer_weight_decay']
)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=hparams['lr_scheduler_factor'], 
    patience=hparams['lr_scheduler_patience'], 
    min_lr=hparams['lr_scheduler_minlr']
)


# --- Initialize tensorboard ---
# use timestamp to avoid overwriting previous runs
from datetime import datetime
now = datetime.now()
writer = SummaryWriter(log_dir=f'runs/full-DDPM/overfit-time:{now.strftime("%Y-%m-%d-%H:%M:%S")}')

# --- Training loop ---
for epoch in tqdm(range(hparams['epochs'])):
    epoch_loss = 0
    for batch in dataloader:
        x_batch = batch['x'].to(device)
        obj_cond_batch = batch['obj_cond'].to(device)
        edge_cond_batch = batch['edge_cond'].to(device)
        relation_cond_batch = batch['relation_cond'].to(device)
        
        loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch)
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
    epoch_loss /= len(dataloader)
        
    lr_scheduler.step(epoch_loss)
    writer.add_scalar('Loss/train', epoch_loss, epoch)
    writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)

# log all the hyperparameters and final loss
writer.add_hparams(hparams, {'Final loss': epoch_loss})

writer.close()

Model:
GuidedDiffusionNetwork(
  (attention_module): ModifiedMultiheadAttention(
    (qkv_proj): Linear(in_features=10, out_features=30, bias=False)
    (o_proj): Linear(in_features=10, out_features=10, bias=False)
  )
  (encoder_module): RelationalRGCN(
    (layers): ModuleList(
      (0): RGCNConv(6, 6, num_relations=9)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
  (time_embedding_module): TimeEmbedding(
    (layers): Sequential(
      (0): Linear(in_features=16, out_features=16, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=16, out_features=16, bias=True)
    )
  )
  (fused_rgcn_module): RelationalRGCN(
    (layers): ModuleList(
      (0): RGCNConv(16, 10, num_relations=9)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
)
DDPM Scheduler:
DDPMScheduler(
  (model): GuidedDiffusionNetwork(
    (attention_module): ModifiedMultiheadAttention(
      (qkv_proj): Linear(in_features=10, out_features=30, bias=False)
      (o

100%|██████████| 3000/3000 [00:09<00:00, 327.50it/s]
