In [1]:
import torch
import torch.nn as nn

# import from guided-diffusion folder
from attention_layer import ModifiedMultiheadAttention
from relational_gcn import RelationalRGCN
from time_embedding import TimeEmbedding
from ddpm_scheduler import DDPMScheduler

# Putting everything together

In this section we put everything together to have a complete diffusion model.

The network will go as follows:

1. The first part of the network works is the unconditional denoising/diffusion. For that, we represent 3D scenes as a matrix X of dimension [B, N, D] with B batch size, N number of nodes (objects) and D Dimension of each object storing its information (location, size, ...). In this part of the network we feed X through a custom multihead attention layer (see module #1) and get as output data of dimensions [B, N, E] with E hidden Dimension.

2. Parallel to the first part, we also receive a scene graph in form of a triple (N, C, C) with N nodes storing string description of objects ("chair", "table", etc.) and C edges as list of tuples (id_1, id_2) indicating an outgoing connection from node id_1 to node id_2 as well as C connection types as ints. The string description of each node is embedded using FastText (see module #2) and an relational GCN (see module #3) is built with nodes storing these embeddings, and edges as well as edge types extracted from the scene graph following the aforementioned structure. This RGCN block has no hidden layers and outputs dimension F.

3. Outputs from the first and second step [B, N, E] and [B, N, F] are concatenated in the last axis to form [B, N, E + F]. Time embedding (see module #4) of matching dimension is produced and added to that result. This matrix is then fed into another relational GCN that consists of N nodes (each row becomes a node) and reuses the edges and edge types from the first RGCN. This new RGCN has a few hidden layers and results in dimension [B, N, D], finishing the forward pass.


In [2]:
class GuidedDiffusionNetwork(nn.Module):
    def __init__(
        self,
        # Attention block
        attention_in_dim, 
        attention_out_dim, 
        # Encoder RGCN block
        encoder_in_dim, 
        encoder_out_dim, 
        encoder_num_relations,
        encoder_hidden_dim_list=[],
        encoder_num_bases=None,
        encoder_aggr='mean',
        encoder_activation=nn.LeakyReLU(negative_slope=0.2, inplace=True), # TODO: hyperparam vs. hardcode?
        encoder_dp_rate=0.1,
        encoder_bias=True,
        # Fusion block
        fusion_hidden_dim_list=[],
    ):
        super(GuidedDiffusionNetwork, self).__init__()
        
        self.attention_module = ModifiedMultiheadAttention(
            input_dim=attention_in_dim, 
            embed_dim=attention_out_dim, 
            num_heads=attention_out_dim // 4 # TODO: hyperparam vs. hardcode?
        )
        
        self.encoder_module = RelationalRGCN(
            in_channels=encoder_in_dim, 
            h_channels_list=encoder_hidden_dim_list,
            out_channels=encoder_out_dim,
            num_relations=encoder_num_relations, 
            num_bases=encoder_num_bases, 
            aggr=encoder_aggr,
            activation=encoder_activation,
            dp_rate=encoder_dp_rate, 
            bias=encoder_bias
        )
        
        self.time_embedding_module = TimeEmbedding(dim=attention_out_dim+encoder_out_dim)
        
        self.fused_rgcn_module = RelationalRGCN(
            in_channels=attention_out_dim + encoder_out_dim,
            h_channels_list=fusion_hidden_dim_list,
            out_channels=attention_in_dim,
            num_relations=encoder_num_relations,
            num_bases=encoder_num_bases,
            # TODO: mirror encoder params?
            aggr=encoder_aggr,
            activation=encoder_activation,
            dp_rate=encoder_dp_rate,
            bias=encoder_bias
        )
    
    # This forward method should return the output prediction of noise of the final relational GCN in shape [B, N, D]
    def forward(self, x, t, obj_cond, edge_cond, relation_cond):
        # Step 1: Unconditional denoising/diffusion
        x = self.attention_module(x)

        # Step 2: Scene graph processing
        graph_output = self.encoder_module(obj_cond, edge_cond, relation_cond)
        
        # Note: instead of stacking [B, N, ...], RGCN uses [B*N, ...] approach, so we need to unstack them to match the shape of x
        B, N, _ = x.shape
        graph_output = torch.stack(torch.split(graph_output, [N] * B, dim=0), dim=0)

        # Step 3: Concatenation and time embedding
        fused_output = torch.cat([x, graph_output], dim=-1)
        # adapt the time embedding shape ([B, F] -> [B, 1, F]) to use broadcasting when adding to fused_output [B, N, F]
        time_embedded = self.time_embedding_module(t)[:, None, :]
        fused_output += time_embedded

        # Step 4: Final relational GCN
        # Note: to feed the data back to RGCN, we need to reshape the data back to [B*N, ...]
        output = self.fused_rgcn_module(
            fused_output.view(-1, fused_output.size(-1)), 
            edge_cond.view(-1, edge_cond.size(-1)), 
            relation_cond.view(-1)
        )
        
        # Step 5: Reshape the output back to [B, N, ...]
        output = output.view(B, N, -1)
        return output

    # TODO: implement forward_with_cond_scale

## Mock Data

In this section, we will generate an example scene graph and a corresponding matrix X to feed into the network.

In [3]:
B = 1 # num of graphs in batch

N = 20 # num of nodes

# RGCN hyperparams
C = 300 # dim of node features
E = 22 # num of edges
R = 8 # num of edge types

# Attention hyperparams
D = 100 # dim of attention output

In [4]:
# Instantiate the model
model = GuidedDiffusionNetwork(
    attention_in_dim=D,
    attention_out_dim=D,
    encoder_in_dim=C,
    encoder_out_dim=C,
    encoder_num_relations=R,
    encoder_hidden_dim_list=[],
    encoder_num_bases=None,
    encoder_aggr='mean',
    encoder_dp_rate=0.1,
    encoder_bias=True,
    fusion_hidden_dim_list=[C+D, C+D]
)

print(model)

GuidedDiffusionNetwork(
  (attention_module): ModifiedMultiheadAttention(
    (qkv_proj): Linear(in_features=100, out_features=300, bias=False)
    (o_proj): Linear(in_features=100, out_features=100, bias=False)
  )
  (encoder_module): RelationalRGCN(
    (layers): ModuleList(
      (0): RGCNConv(300, 300, num_relations=8)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
  (time_embedding_module): TimeEmbedding()
  (fused_rgcn_module): RelationalRGCN(
    (layers): ModuleList(
      (0): RGCNConv(400, 400, num_relations=8)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
      (2): Dropout(p=0.1, inplace=False)
      (3): RGCNConv(400, 400, num_relations=8)
      (4): LeakyReLU(negative_slope=0.2, inplace=True)
      (5): Dropout(p=0.1, inplace=False)
      (6): RGCNConv(400, 400, num_relations=8)
      (7): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
)


In [5]:
from torch_geometric.data import Data

# Scene Graphs for conditioning
def generate_random_graph():
    # --- Initialize nodes ---
    nodes = torch.randn(N, C) # creates N x D tensor of (random) node features

    # --- Initialize edges --- 
    edges = torch.randint(N, (2, E)) # creates 2 x E tensor of (random) edges

    # --- Introduce different types of edges ---
    rels = torch.randint(R, (E,)) # creates E x 1 tensor of (random) edge types

    # --- Create a graph ---
    graph = Data(x=nodes, edge_index=edges, edge_attr=rels)
    
    return graph

# --- Initialize batch ---
graphs = [generate_random_graph() for _ in range(B)]

print(f"Batch size: {len(graphs)}, dimensions of each graph: {graphs[0]}")

Batch size: 1, dimensions of each graph: Data(x=[20, 300], edge_index=[2, 22], edge_attr=[22])


In [6]:
# Scenes for denoising
X = torch.randn(B, N, D) # creates B x N x D tensor of (random) node features
print(f"Dimensions: X={X.shape}")

Dimensions: X=torch.Size([1, 20, 100])


In [7]:
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, X, graphs):
        super(CustomDataset, self).__init__()
        self.X = X
        self.graphs = graphs

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        x = self.X[index]
        graph = self.graphs[index]
        return {
            'x': x,
            'obj_cond': graph.x,
            'edge_cond': graph.edge_index,
            'relation_cond': graph.edge_attr
        }

    def collate_fn(self, batch):
        x_batch = torch.stack([item['x'] for item in batch], dim=0)
        obj_cond_batch = torch.cat([item['obj_cond'] for item in batch], dim=0)
        edge_cond_batch = torch.cat([item['edge_cond'] for item in batch], dim=1)
        relation_cond_batch = torch.cat([item['relation_cond'] for item in batch], dim=0)

        return {
            'x': x_batch,
            'obj_cond': obj_cond_batch,
            'edge_cond': edge_cond_batch,
            'relation_cond': relation_cond_batch
        }


In [8]:
from torch.utils.data import DataLoader

# Example usage
dataset = CustomDataset(X, graphs)
dataloader = DataLoader(dataset, batch_size=B, shuffle=True, collate_fn=dataset.collate_fn)

# Time embedding (should be coming from DDPM scheduler)
t = torch.randint(100, (B,)) # creates B x 1 tensor of (random) time indices

# Iterate over the dataloader
for batch in dataloader:
    x_batch = batch['x']
    obj_cond_batch = batch['obj_cond']
    edge_cond_batch = batch['edge_cond']
    relation_cond_batch = batch['relation_cond']
    print(f"Input Dimensions:\n\tX={x_batch.shape}\n\tt={t.shape}\n\tobj_cond={obj_cond_batch.shape}\n\tedge_cond={edge_cond_batch.shape}\n\trelation_cond={relation_cond_batch.shape}")
    # Forward pass through the model
    output = model(x_batch, t, obj_cond_batch, edge_cond_batch, relation_cond_batch)
    print(f"Output Dimensions:\n\toutput={output.shape}")


Input Dimensions:
	X=torch.Size([1, 20, 100])
	t=torch.Size([1])
	obj_cond=torch.Size([20, 300])
	edge_cond=torch.Size([2, 22])
	relation_cond=torch.Size([22])
Output Dimensions:
	output=torch.Size([1, 20, 400])


## Connecting the GuidedDiffusionModel to a DDPM Scheduler

In [28]:
hparams = {
    'B': B, # num of graphs in batch
    'N': N, # num of objects in each graph
    'D': D, # dim of each object in the scene

    # RGCN hyperparams
    'C': C, # dim of node features
    'E': E, # num of edges
    'R': R, # num of edge types
    
    'encoder_num_bases': None,
    'encoder_aggr': 'mean',
    'encoder_dp_rate': 0,
    'encoder_bias': True,
    
    # Scheduler hyperparams
    'scheduler_timesteps': 1000,
    'scheduler_sampling_timesteps': None,
    'scheduler_loss': 'l1',
    "scheduler_objective": 'pred_noise',
    'scheduler_beta_schedule': 'cosine',
    'scheduler_ddim_sampling_eta': 1.0,
    'scheduler_min_snr_loss_weight': False,
    'scheduler_min_snr_gamma': 5,
    
    
    # Training and optimizer hyperparams
    'epochs': 2500,
    'optimizer_lr': 1e-3,
    'optimizer_weight_decay': 5e-4,
    'lr_scheduler_factor': 0.7,
    'lr_scheduler_patience': 100,
    'lr_scheduler_minlr': 0.00001,
}

architecture_hparams = {
    # Encoder RGCN hyperparams
    'encoder_activation': nn.LeakyReLU(negative_slope=0.2, inplace=True),
    'encoder_hidden_dim_list': [],
    # Fusion RGCN hyperparams
    'fusion_hidden_dim_list': [], # [C+D, C+D, D],
}

In [29]:
from torch.utils.tensorboard import SummaryWriter


# --- Load the (mocked) data
dataloader = DataLoader(dataset, batch_size=B, shuffle=True, collate_fn=dataset.collate_fn)


# --- Instantiate the model
model = GuidedDiffusionNetwork(
    attention_in_dim=hparams['D'],
    attention_out_dim=hparams['D'],
    encoder_in_dim=hparams['C'],
    encoder_out_dim=hparams['C'],
    encoder_num_relations=hparams['R'],
    encoder_num_bases=hparams['encoder_num_bases'],
    encoder_hidden_dim_list=architecture_hparams['encoder_hidden_dim_list'],
    encoder_aggr=hparams['encoder_aggr'],
    encoder_activation=architecture_hparams['encoder_activation'],
    encoder_dp_rate=hparams['encoder_dp_rate'],
    encoder_bias=hparams['encoder_bias'],
    fusion_hidden_dim_list=architecture_hparams['fusion_hidden_dim_list']
)

print(f"Model:\n{model}")

scheduler = DDPMScheduler(
    model=model,
    N=hparams['N'],
    D=hparams['D'],
    timesteps=hparams['scheduler_timesteps'],
    sampling_timesteps=hparams['scheduler_sampling_timesteps'],
    loss_type=hparams['scheduler_loss'],
    objective=hparams['scheduler_objective'],
    beta_schedule=hparams['scheduler_beta_schedule'],
    ddim_sampling_eta=hparams['scheduler_ddim_sampling_eta'],
    min_snr_loss_weight=hparams['scheduler_min_snr_loss_weight'],
    min_snr_gamma=hparams['scheduler_min_snr_gamma']
)

print(f"DDPM Scheduler:\n{scheduler}")


# --- Setup training loop ---
from tqdm import tqdm

optimizer = torch.optim.Adam(
    scheduler.parameters(), 
    lr=hparams['optimizer_lr'], 
    weight_decay=hparams['optimizer_weight_decay']
)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=hparams['lr_scheduler_factor'], 
    patience=hparams['lr_scheduler_patience'], 
    min_lr=hparams['lr_scheduler_minlr']
)


# --- Initialize tensorboard ---
# use timestamp to avoid overwriting previous runs
from datetime import datetime
now = datetime.now()
writer = SummaryWriter(log_dir=f'runs/full-DDPM/overfit-B:{hparams["B"]}-time:{now.strftime("%Y-%m-%d-%H:%M:%S")}')

# --- Training loop ---
for epoch in tqdm(range(hparams['epochs'])):
    epoch_loss = 0
    for batch in dataloader:
        x_batch = batch['x']
        obj_cond_batch = batch['obj_cond']
        edge_cond_batch = batch['edge_cond']
        relation_cond_batch = batch['relation_cond']
        
        loss = scheduler(x_batch, obj_cond_batch, edge_cond_batch, relation_cond_batch)
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
    epoch_loss /= len(dataloader)
        
    lr_scheduler.step(epoch_loss)
    writer.add_scalar('Loss/train', epoch_loss, epoch)
    writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)

# log all the hyperparameters and final loss
writer.add_hparams(hparams, {'Final loss': epoch_loss})

writer.close()

Model:
GuidedDiffusionNetwork(
  (attention_module): ModifiedMultiheadAttention(
    (qkv_proj): Linear(in_features=100, out_features=300, bias=False)
    (o_proj): Linear(in_features=100, out_features=100, bias=False)
  )
  (encoder_module): RelationalRGCN(
    (layers): ModuleList(
      (0): RGCNConv(300, 300, num_relations=8)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
  (time_embedding_module): TimeEmbedding()
  (fused_rgcn_module): RelationalRGCN(
    (layers): ModuleList(
      (0): RGCNConv(400, 100, num_relations=8)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
)
DDPM Scheduler:
DDPMScheduler(
  (model): GuidedDiffusionNetwork(
    (attention_module): ModifiedMultiheadAttention(
      (qkv_proj): Linear(in_features=100, out_features=300, bias=False)
      (o_proj): Linear(in_features=100, out_features=100, bias=False)
    )
    (encoder_module): RelationalRGCN(
      (layers): ModuleList(
        (0): RGCNConv(300, 300, num_relations=8)


100%|██████████| 2500/2500 [00:23<00:00, 107.26it/s]


## Overfitting to a batch of mocked scenes
In the following, we have a self-contained example of the network overfitting to a batch of mocked scenes. We will generate a batch of B scenes, each with N objects. The network will be trained to reconstruct the same batch of scenes.

In [None]:
# Try overfitting the model on a single batch for a reconstruction task
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

# --- Reset the model
activation = nn.LeakyReLU(negative_slope=0.2, inplace=True)

model = GuidedDiffusionNetwork(
    attention_in_dim=D,
    attention_out_dim=D,
    encoder_in_dim=C,
    encoder_out_dim=C,
    encoder_num_relations=R,
    encoder_num_bases=None,
    encoder_aggr='mean',
    encoder_activation=activation,
    encoder_dp_rate=0.1,
    encoder_bias=True,
)

# Reset the dataloader
dataloader = DataLoader(dataset, batch_size=B, shuffle=True, collate_fn=dataset.collate_fn)

# --- Setup training loop ---
from tqdm import tqdm

epochs = 1000

lr = 1e-3
weight_decay = 5e-4 # weight decay
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=25, min_lr=0.00001)
# loss for reconstruction (we use X as the ground truth)
recon_loss_fn = torch.nn.MSELoss()

# --- Initialize tensorboard ---
# use timestamp to avoid overwriting previous runs
from datetime import datetime
now = datetime.now()
writer = SummaryWriter(log_dir=f'runs/full-model/overfit-B-{B}-lr-{lr}-time-{now.strftime("%Y-%m-%d-%H-%M-%S")}')

# --- Training loop ---
for epoch in tqdm(range(epochs)):
    epoch_loss = 0
    for batch in dataloader:
        x_batch = batch['x']
        obj_cond_batch = batch['obj_cond']
        edge_cond_batch = batch['edge_cond']
        relation_cond_batch = batch['relation_cond']
        # Forward pass through the model
        output = model(x_batch, t, obj_cond_batch, edge_cond_batch, relation_cond_batch)
        # Compute the loss
        loss = recon_loss_fn(output, x_batch)
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss /= len(dataloader)
    # Log the loss
    writer.add_scalar('Loss/train', epoch_loss, epoch)
    
    # Update the learning rate
    scheduler.step(epoch_loss)