In [1]:
import torch
import torch.nn as nn

# import from guided-diffusion folder
from model_alternative import GuidedDiffusionNetwork
from ddpm_scheduler import DDPMScheduler

In [2]:
from torch_geometric.data import Data
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, scenes):
        super(CustomDataset, self).__init__()
        self.scenes = scenes

    def __len__(self):
        return len(self.scenes)

    def __getitem__(self, index):
        scene = self.scenes[index]
        
        scene_matrix = torch.tensor(scene["scene_matrix"], dtype=torch.float32)
        graph_objects = torch.tensor(scene["graph_objects"], dtype=torch.float32)
        graph_edges = torch.tensor(scene["graph_edges"], dtype=torch.long)
        graph_relationships = torch.tensor(scene["graph_relationships"], dtype=torch.long)

        return {
            'x': scene_matrix,
            'obj_cond': graph_objects,
            'edge_cond': graph_edges,
            'relation_cond': graph_relationships
        }


    def collate_fn(self, batch):
        x_batch = torch.stack([item['x'] for item in batch], dim=0)
        obj_cond_batch = torch.stack([item['obj_cond'] for item in batch], dim=0)
        edge_cond_batch = torch.cat([item['edge_cond'] for item in batch], dim=1)
        relation_cond_batch = torch.cat([item['relation_cond'] for item in batch], dim=0)

        return {
            'x': x_batch,
            'obj_cond': obj_cond_batch,
            'edge_cond': edge_cond_batch,
            'relation_cond': relation_cond_batch
        }



In [3]:
import json

# Load data from JSON file
with open('test_small.json', 'r') as file:
    train_data = json.load(file)['scenes']

In [4]:
# range matrix for real data
location_max = torch.tensor([3.285, 3.93, 0.879])
location_min = torch.tensor([-3.334, -2.619, -1.329])

normalized_axes_max = torch.ones(9)
normalized_axes_min = -torch.ones(9)

size_max = torch.tensor([4.878, 2.655, 2.305])
size_min = torch.tensor([0.232, 0.14, 0.094])

range_max = torch.cat((location_max, normalized_axes_max, size_max), dim=0)
range_min = torch.cat((location_min, normalized_axes_min, size_min), dim=0)

range_matrix = torch.cat((range_max.unsqueeze(0), range_min.unsqueeze(0)), dim=0)

In [5]:
B = 3 # num of scenes in batch

# Scene hyperparams
N = 20 # num of objects in scene
D = 15 # dim of objects from the scene

# Condition hyperparmas
C = 300 # dim of node features
R = 23+1 # num of relations

hparams = {
    'batch_size': B, # num of graphs in batch
    'layer_2_dim': 10,

    # --- RGCN hyperparams ---
    'rgc_hidden_dims': f"{()}", # (C+D, C+D, D),
    'rgc_num_bases': 5, # Alternative: None
    'rgc_aggr': 'mean',
    'rgc_activation': 'tanh',
    'rgc_dp_rate': 0.,
    'rgc_bias': True,
    
    # --- Attention hyperparams ---
    'attention_self_head_dims': 10,
    'attention_num_heads': 3, # must be a divisor of attention_hidden_dims
    'attention_cross_head_dims': 30,
    
    # Scheduler hyperparams
    'scheduler_timesteps': 1000,
    'scheduler_loss': 'l2',
    'scheduler_beta_schedule': 'cosine',
    # Note: not needed for now
    # 'scheduler_sampling_timesteps': None,
    # "scheduler_objective": 'pred_noise',
    # 'scheduler_ddim_sampling_eta': 1.0,
    # 'scheduler_min_snr_loss_weight': False,
    # 'scheduler_min_snr_gamma': 5,
    
    # Classifier-free guidance parameters
    'cfg_cond_drop_prob': 0.1,
    
    # Training and optimizer hyperparams
    'epochs': 2000,
    'optimizer_lr': 1e-3,
    'optimizer_weight_decay': 5e-5,
    'lr_scheduler_factor': 0.4,
    'lr_scheduler_patience': 30,
    'lr_scheduler_minlr': 2e-4,
}


In [6]:
general_params = {
    "num_obj": N,
    "obj_cond_dim": C
}

attention_params = {
    "attention_self_head_dim": hparams['attention_self_head_dims'],
    "attention_num_heads": hparams['attention_num_heads'],
    "attention_cross_head_dim": hparams['attention_cross_head_dims']
}

rgc_params = {
    "rgc_hidden_dims": hparams['rgc_hidden_dims'],
    "rgc_num_relations": R,
    "rgc_num_bases": hparams['rgc_num_bases'],
    "rgc_aggr": hparams['rgc_aggr'],
    "rgc_activation": hparams['rgc_activation'],
    "rgc_dp_rate": hparams['rgc_dp_rate'],
    "rgc_bias": hparams['rgc_bias']
}

In [7]:
from torch.utils.data import DataLoader

if torch.cuda.is_available():
    device = torch.device('cuda')
# Not all operations support MPS yet so this option is not available for now
# elif torch.has_mps:
#     device = torch.device('mps')
else:
    device = torch.device('cpu')

range_matrix = range_matrix.to(device)

train_dataset = CustomDataset(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True, collate_fn=train_dataset.collate_fn)

# --- Instantiate the model
model = GuidedDiffusionNetwork(
    layer_1_dim=D,
    layer_2_dim=hparams['layer_2_dim'],
    general_params=general_params,
    attention_params=attention_params,
    rgc_params=rgc_params,
    cond_drop_prob=hparams['cfg_cond_drop_prob']
)

print(f"Model:\n{model}")

scheduler = DDPMScheduler(
    model=model,
    N=N,
    D=D,
    range_matrix = range_matrix,
    timesteps=hparams['scheduler_timesteps'],
    sampling_timesteps=None,
    loss_type=hparams['scheduler_loss'],
    objective='pred_noise',
    beta_schedule=hparams['scheduler_beta_schedule'],
    ddim_sampling_eta=1.0,
    min_snr_loss_weight=False,
    min_snr_gamma=5
)

print(f"DDPM Scheduler:\n{scheduler}")

# Move to device
model = model.to(device)
scheduler = scheduler.to(device)

# --- Training loop ---
for batch in train_dataloader:
    x_batch = batch['x'].to(device)
    obj_cond_batch = batch['obj_cond'].to(device)
    edge_cond_batch = batch['edge_cond'].to(device)
    relation_cond_batch = batch['relation_cond'].to(device)

    loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch)
    print(loss)

Model:
GuidedDiffusionNetwork(
  (block1): GuidedDiffusionBlock(
    (time_embedding_module): TimeEmbedding(
      (layers): Sequential(
        (0): Linear(in_features=14, out_features=15, bias=True)
        (1): Tanh()
        (2): Linear(in_features=15, out_features=15, bias=True)
      )
    )
    (max_pool): MaxPool1d(kernel_size=(20,), stride=(20,), padding=0, dilation=1, ceil_mode=False)
    (rgc_module): RelationalRGCN(
      (layers): ModuleList(
        (0): RGCNConv(15, 15, num_relations=24)
        (1): Tanh()
      )
    )
    (self_attention_module): SelfMultiheadAttention(
      (qkv_proj): Linear(in_features=15, out_features=90, bias=False)
      (o_proj): Linear(in_features=30, out_features=15, bias=False)
      (layer_norm): LayerNorm((20, 15), eps=1e-05, elementwise_affine=True)
    )
    (cross_attention_module): CrossMultiheadAttention(
      (q_proj): Linear(in_features=15, out_features=90, bias=False)
      (kv_proj): Linear(in_features=300, out_features=180, bia