In [1]:
import torch
import torch.nn as nn

# import from guided-diffusion folder
from model import GuidedDiffusionNetwork
from ddpm_scheduler import DDPMScheduler

# Putting everything together

In this section we put everything together to have a complete diffusion model.

The network will go as follows:

TODO: explain the final version

<!-- 1. The first part of the network works is the unconditional denoising/diffusion. For that, we represent 3D scenes as a matrix X of dimension [B, N, D] with B batch size, N number of nodes (objects) and D Dimension of each object storing its information (location, size, ...). In this part of the network we feed X through a custom multihead attention layer (see module #1) and get as output data of dimensions [B, N, E] with E hidden Dimension.

2. Parallel to the first part, we also receive a scene graph in form of a triple (N, C, C) with N nodes storing string description of objects ("chair", "table", etc.) and C edges as list of tuples (id_1, id_2) indicating an outgoing connection from node id_1 to node id_2 as well as C connection types as ints. The string description of each node is embedded using FastText (see module #2) and an relational GCN (see module #3) is built with nodes storing these embeddings, and edges as well as edge types extracted from the scene graph following the aforementioned structure. This RGCN block has no hidden layers and outputs dimension F.

3. Outputs from the first and second step [B, N, E] and [B, N, F] are concatenated in the last axis to form [B, N, E + F]. Time embedding (see module #4) of matching dimension is produced and added to that result. This matrix is then fed into another relational GCN that consists of N nodes (each row becomes a node) and reuses the edges and edge types from the first RGCN. This new RGCN has a few hidden layers and results in dimension [B, N, D], finishing the forward pass. -->


## Mock Data

In this section, we will generate an example scene graph and a corresponding matrix X to feed into the network.

In [2]:
B = 4 # num of graphs in batch

N = 20 # num of nodes

# RGCN hyperparams
C = 6 # dim of node features
E = 22 # num of edges
R = 8 + 1 # num of edge types (including 'unknown' type)

# Attention hyperparams
D = 10 # dim of objects from the scene

In [3]:
# Instantiate the model
model = GuidedDiffusionNetwork(
    attention_in_dim=D,
    attention_out_dim=D,
    attention_num_heads=D // 2,
    
    rgcn_num_relations=R,
    
    encoder_in_dim=C,
    encoder_out_dim=C,
    encoder_hidden_dims=f"{()}",
    encoder_num_bases=None,
    encoder_aggr='mean',
    encoder_activation="leakyrelu",
    encoder_dp_rate=0.1,
    encoder_bias=True,
    
    fusion_hidden_dims=f"{(C+D, C+D)}",
    fusion_num_bases=None,
    fusion_aggr='mean',
    fusion_activation="leakyrelu",
    fusion_dp_rate=0.1,
    fusion_bias=True,
    
    cond_drop_prob=0.2
)

print(model)

GuidedDiffusionNetwork(
  (attention_module): ModifiedMultiheadAttention(
    (qkv_proj): Linear(in_features=10, out_features=30, bias=False)
    (o_proj): Linear(in_features=10, out_features=10, bias=False)
  )
  (encoder_module): RelationalRGCN(
    (layers): ModuleList(
      (0): RGCNConv(6, 6, num_relations=9)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
  (time_embedding_module): TimeEmbedding(
    (layers): Sequential(
      (0): Linear(in_features=16, out_features=16, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=16, out_features=16, bias=True)
    )
  )
  (fused_rgcn_module): RelationalRGCN(
    (layers): ModuleList(
      (0): RGCNConv(16, 16, num_relations=9)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
      (2): Dropout(p=0.1, inplace=False)
      (3): RGCNConv(16, 16, num_relations=9)
      (4): LeakyReLU(negative_slope=0.2, inplace=True)
      (5): Dropout(p=0.1, inplace=False)
      (6): RGCNConv(16, 10, num_r

In [4]:
from torch_geometric.data import Data

# Scene Graphs for conditioning
def generate_random_graph(is_one_hot=None):
    # --- Initialize nodes ---
    if is_one_hot is not None:
        nodes = torch.zeros(N, C) # creates N x D tensor of (random) node features
        # Initialize nodes with one-hot encoding
        for i in range(N):
            nodes[i, torch.randint(C, (1,))] = 1
    else:
        nodes = torch.randn(N, C) # creates N x D tensor of (random) node features
    # Initialize nodes with one-hot encoding

    # --- Initialize edges --- 
    edges = torch.randint(N, (2, E)) # creates 2 x E tensor of (random) edges

    # --- Introduce different types of edges ---
    rels = torch.randint(R - 1, (E,)) + 1 # creates E x 1 tensor of (random) edge types excluding 'unknown' type with value 0

    # --- Create a graph ---
    graph = Data(x=nodes, edge_index=edges, edge_attr=rels)
    
    return graph

# --- Initialize batch ---
graphs = [generate_random_graph(is_one_hot=True) for _ in range(B)]

print(f"Batch size: {len(graphs)}, dimensions of each graph: {graphs[0]}")

# print(graphs[0].edge_attr)

Batch size: 4, dimensions of each graph: Data(x=[20, 6], edge_index=[2, 22], edge_attr=[22])


In [5]:
# Scenes for denoising
# X = torch.randn(B, N, D) # creates B x N x D tensor of (random) node features
# TODO: Create X as ones for testing now
Xs = [torch.ones(N, D) * (i / float(B)) for i in range(B)]
X = torch.stack(Xs, dim=0)
print(f"Dimensions: X={X.shape}")

Dimensions: X=torch.Size([4, 20, 10])


In [6]:
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, X, graphs):
        super(CustomDataset, self).__init__()
        self.X = X
        self.graphs = graphs

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        x = self.X[index]
        graph = self.graphs[index]
        return {
            'x': x,
            'obj_cond': graph.x,
            'edge_cond': graph.edge_index,
            'relation_cond': graph.edge_attr
        }

    def collate_fn(self, batch):
        x_batch = torch.stack([item['x'] for item in batch], dim=0)
        obj_cond_batch = torch.cat([item['obj_cond'] for item in batch], dim=0)
        edge_cond_batch = torch.cat([item['edge_cond'] for item in batch], dim=1)
        relation_cond_batch = torch.cat([item['relation_cond'] for item in batch], dim=0)

        return {
            'x': x_batch,
            'obj_cond': obj_cond_batch,
            'edge_cond': edge_cond_batch,
            'relation_cond': relation_cond_batch
        }


In [7]:
from torch.utils.data import DataLoader

# Example usage
dataset = CustomDataset(X, graphs)
dataloader = DataLoader(dataset, batch_size=B, shuffle=True, collate_fn=dataset.collate_fn)

# Time embedding (should be coming from DDPM scheduler)
t = torch.randint(1000, (B,)) # creates B x 1 tensor of (random) time indices

# Iterate over the dataloader
for batch in dataloader:
    x_batch = batch['x']
    obj_cond_batch = batch['obj_cond']
    edge_cond_batch = batch['edge_cond']
    relation_cond_batch = batch['relation_cond']
    print(f"Input Dimensions:\n\tX={x_batch.shape}\n\tt={t.shape}\n\tobj_cond={obj_cond_batch.shape}\n\tedge_cond={edge_cond_batch.shape}\n\trelation_cond={relation_cond_batch.shape}")
    # Forward pass through the model
    print(f"Pass with cond_drop_prob=0.0")
    output = model(x_batch, t, obj_cond_batch, edge_cond_batch, relation_cond_batch, cond_drop_prob=0.)
    print(f"Pass with cond_drop_prob=1.0")
    output = model(x_batch, t, obj_cond_batch, edge_cond_batch, relation_cond_batch, cond_drop_prob=1.)
    print(f"Output Dimensions:\n\toutput={output.shape}")

Input Dimensions:
	X=torch.Size([4, 20, 10])
	t=torch.Size([4])
	obj_cond=torch.Size([80, 6])
	edge_cond=torch.Size([2, 88])
	relation_cond=torch.Size([88])
Pass with cond_drop_prob=0.0
Pass with cond_drop_prob=1.0
Output Dimensions:
	output=torch.Size([4, 20, 10])


## Connecting the GuidedDiffusionModel to a DDPM Scheduler

In [8]:
hparams = {
    'batch_size': B, # num of graphs in batch
    
    # --- Attention hyperparams ---
    'attention_out_dim': D,
    'attention_num_heads': 5, # must be a divisor of D
    
    # --- Encoder RGCN hyperparams ---
    'encoder_out_dim': C,
    'encoder_hidden_dims': f"{()}", # (C, C),
    'encoder_num_bases': None,
    'encoder_aggr': 'mean',
    'encoder_activation': 'leakyrelu',
    'encoder_dp_rate': 0.,
    'encoder_bias': True,
    
    # --- Fusion RGCN hyperparams ---
    'fusion_hidden_dims': f"{()}", # (C+D, C+D, D),
    'fusion_num_bases': None,
    'fusion_aggr': 'mean',
    'fusion_activation': 'leakyrelu',
    'fusion_dp_rate': 0.,
    'fusion_bias': True,
    
    # Scheduler hyperparams
    'scheduler_timesteps': 1000,
    'scheduler_loss': 'l1',
    'scheduler_beta_schedule': 'cosine',
    # Note: not needed for now
    # 'scheduler_sampling_timesteps': None,
    # "scheduler_objective": 'pred_noise',
    # 'scheduler_ddim_sampling_eta': 1.0,
    # 'scheduler_min_snr_loss_weight': False,
    # 'scheduler_min_snr_gamma': 5,
    
    # Classifier-free guidance parameters
    'cfg_cond_drop_prob': 0.1,
    
    # Training and optimizer hyperparams
    'epochs': 3000,
    'optimizer_lr': 1e-3,
    'optimizer_weight_decay': 5e-4,
    'lr_scheduler_factor': 0.7,
    'lr_scheduler_patience': 100,
    'lr_scheduler_minlr': 0.00001,
}

In [9]:
from torch.utils.tensorboard import SummaryWriter

if torch.cuda.is_available():
    device = torch.device('cuda')
# Not all operations support MPS yet so this option is not available for now
# elif torch.has_mps:
#     device = torch.device('mps')
else:
    device = torch.device('cpu')

# --- Load the (mocked) data
dataloader = DataLoader(dataset, batch_size=hparams['batch_size'], shuffle=True, collate_fn=dataset.collate_fn)


# --- Instantiate the model
model = GuidedDiffusionNetwork(
    attention_in_dim=D,
    attention_out_dim=hparams['attention_out_dim'],
    attention_num_heads=hparams['attention_num_heads'],
    
    rgcn_num_relations=R,
    
    encoder_in_dim=C,
    encoder_out_dim=hparams['encoder_out_dim'],
    encoder_hidden_dims=hparams['encoder_hidden_dims'],
    encoder_num_bases=hparams['encoder_num_bases'],
    encoder_aggr=hparams['encoder_aggr'],
    encoder_activation=hparams['encoder_activation'],
    encoder_dp_rate=hparams['encoder_dp_rate'],
    encoder_bias=hparams['encoder_bias'],
    
    fusion_hidden_dims=hparams['fusion_hidden_dims'],
    fusion_num_bases=hparams['fusion_num_bases'],
    fusion_aggr=hparams['fusion_aggr'],
    fusion_activation=hparams['fusion_activation'],
    fusion_dp_rate=hparams['fusion_dp_rate'],
    fusion_bias=hparams['fusion_bias'],
    
    cond_drop_prob=hparams['cfg_cond_drop_prob']
)

print(f"Model:\n{model}")

scheduler = DDPMScheduler(
    model=model,
    N=N,
    D=D,
    timesteps=hparams['scheduler_timesteps'],
    sampling_timesteps=None,
    loss_type=hparams['scheduler_loss'],
    objective='pred_noise',
    beta_schedule=hparams['scheduler_beta_schedule'],
    ddim_sampling_eta=1.0,
    min_snr_loss_weight=False,
    min_snr_gamma=5
)

print(f"DDPM Scheduler:\n{scheduler}")

# Move to device
model = model.to(device)
scheduler = scheduler.to(device)


# --- Setup training loop ---
from tqdm import tqdm

optimizer = torch.optim.Adam(
    scheduler.parameters(), 
    lr=hparams['optimizer_lr'], 
    weight_decay=hparams['optimizer_weight_decay']
)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=hparams['lr_scheduler_factor'], 
    patience=hparams['lr_scheduler_patience'], 
    min_lr=hparams['lr_scheduler_minlr']
)


# --- Initialize tensorboard ---
# use timestamp to avoid overwriting previous runs
from datetime import datetime
now = datetime.now()
writer = SummaryWriter(log_dir=f'runs/full-DDPM/overfit-time:{now.strftime("%Y-%m-%d-%H:%M:%S")}')

# --- Training loop ---
for epoch in tqdm(range(hparams['epochs'])):
    epoch_loss = 0
    for batch in dataloader:
        x_batch = batch['x'].to(device)
        obj_cond_batch = batch['obj_cond'].to(device)
        edge_cond_batch = batch['edge_cond'].to(device)
        relation_cond_batch = batch['relation_cond'].to(device)
        
        loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch)
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
    epoch_loss /= len(dataloader)
        
    lr_scheduler.step(epoch_loss)
    writer.add_scalar('Loss/train', epoch_loss, epoch)
    writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)

# log all the hyperparameters and final loss
writer.add_hparams(hparams, {'Final loss': epoch_loss})

writer.close()

Model:
GuidedDiffusionNetwork(
  (attention_module): ModifiedMultiheadAttention(
    (qkv_proj): Linear(in_features=10, out_features=30, bias=False)
    (o_proj): Linear(in_features=10, out_features=10, bias=False)
  )
  (encoder_module): RelationalRGCN(
    (layers): ModuleList(
      (0): RGCNConv(6, 6, num_relations=9)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
  (time_embedding_module): TimeEmbedding(
    (layers): Sequential(
      (0): Linear(in_features=16, out_features=16, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=16, out_features=16, bias=True)
    )
  )
  (fused_rgcn_module): RelationalRGCN(
    (layers): ModuleList(
      (0): RGCNConv(16, 10, num_relations=9)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
)
DDPM Scheduler:
DDPMScheduler(
  (model): GuidedDiffusionNetwork(
    (attention_module): ModifiedMultiheadAttention(
      (qkv_proj): Linear(in_features=10, out_features=30, bias=False)
      (o

100%|██████████| 3000/3000 [00:08<00:00, 345.57it/s]


In [10]:
# Sample from the model to see if it works with the same conditioning that was used for overfitting
# --- Load the (mocked) data xCOPY copies
COPY = 50
dataset = CustomDataset(X.repeat(COPY, 1, 1), graphs * COPY)
dataloader = DataLoader(dataset, batch_size=B * COPY, shuffle=True, collate_fn=dataset.collate_fn)

for batch in dataloader:
    x_batch = batch['x']
    obj_cond_batch = batch['obj_cond']
    edge_cond_batch = batch['edge_cond']
    relation_cond_batch = batch['relation_cond']
    
    # Sample from the model (use the same conditioning as the overfitting)
    output = scheduler.sample(obj_cond_batch, edge_cond_batch, relation_cond_batch, cond_scale=8.0)
    
    print(f"Original scene:\n{x_batch}")
    print(f"Sampled scene:\n{output}")
    
    # Measure MSE between the original and sampled scenes
    mse = torch.nn.functional.mse_loss(x_batch, output)
    print(f"MSE: {mse}")

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

## Overfitting to a batch of mocked scenes
In the following, we have a self-contained example of the network overfitting to a batch of mocked scenes. We will generate a batch of B scenes, each with N objects. The network will be trained to reconstruct the same batch of scenes.

In [None]:
# # Try overfitting the model on a single batch for a reconstruction task
# from torch.utils.tensorboard import SummaryWriter
# from torch.utils.data import DataLoader

# # --- Reset the model
# model = GuidedDiffusionNetwork(
#     attention_in_dim=hparams['D'],
#     attention_out_dim=hparams['D'],
#     attention_num_heads=architecture_hparams['attention_num_heads'],
#     encoder_in_dim=hparams['C'],
#     encoder_out_dim=hparams['C'],
#     encoder_num_relations=hparams['R'],
#     encoder_num_bases=hparams['encoder_num_bases'],
#     encoder_hidden_dim_list=architecture_hparams['encoder_hidden_dim_list'],
#     encoder_aggr=hparams['encoder_aggr'],
#     encoder_activation=architecture_hparams['encoder_activation'],
#     encoder_dp_rate=hparams['encoder_dp_rate'],
#     encoder_bias=hparams['encoder_bias'],
#     fusion_hidden_dim_list=architecture_hparams['fusion_hidden_dim_list'],
#     cond_drop_prob=hparams['cfg_cond_drop_prob']
# )

# # Reset the dataloader
# dataloader = DataLoader(dataset, batch_size=B, shuffle=True, collate_fn=dataset.collate_fn)

# # --- Setup training loop ---
# from tqdm import tqdm

# epochs = 1000

# lr = 1e-3
# weight_decay = 5e-4 # weight decay
# optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=25, min_lr=0.00001)
# # loss for reconstruction (we use X as the ground truth)
# recon_loss_fn = torch.nn.MSELoss()

# # --- Initialize tensorboard ---
# # use timestamp to avoid overwriting previous runs
# from datetime import datetime
# now = datetime.now()
# writer = SummaryWriter(log_dir=f'runs/full-model/overfit-B-{B}-lr-{lr}-time-{now.strftime("%Y-%m-%d-%H-%M-%S")}')

# # --- Training loop ---
# for epoch in tqdm(range(epochs)):
#     epoch_loss = 0
#     for batch in dataloader:
#         x_batch = batch['x']
#         obj_cond_batch = batch['obj_cond']
#         edge_cond_batch = batch['edge_cond']
#         relation_cond_batch = batch['relation_cond']
#         # Forward pass through the model
#         output = model(x_batch, t, obj_cond_batch, edge_cond_batch, relation_cond_batch)
#         # Compute the loss
#         loss = recon_loss_fn(output, x_batch)
#         # Backprop
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         epoch_loss += loss.item()

#     epoch_loss /= len(dataloader)
#     # Log the loss
#     writer.add_scalar('Loss/train', epoch_loss, epoch)
    
#     # Update the learning rate
#     scheduler.step(epoch_loss)

 29%|██▉       | 292/1000 [00:41<01:41,  6.96it/s]


KeyboardInterrupt: 