In [4]:
import torch
import torch.nn as nn

# import from guided-diffusion folder
from attention_layer import ModifiedMultiheadAttention
from encoder_rgcn import EncoderRGCN
from text_encoder import FastTextEncoder
from time_embedding import TimeEmbedding
from ddpm_scheduler import DDPMScheduler

# Putting everything together

In this section we put everything together to have a complete diffusion model.

The network will go as follows:

1. The first part of the network works is the unconditional denoising/diffusion. For that, we represent 3D scenes as a matrix X of dimension [B, N, D] with B batch size, N number of nodes (objects) and D dimension of each object storing its information (location, size, ...). In this part of the network we feed X through a custom multihead attention layer (see module #1) and get as output data of dimensions [B, N, E] with E hidden dimension.

2. Parallel to the first part, we also receive a scene graph in form of a triple (N, C, C) with N nodes storing string description of objects ("chair", "table", etc.) and C edges as list of tuples (id_1, id_2) indicating an outgoing connection from node id_1 to node id_2 as well as C connection types as ints. The string description of each node is embedded using FastText (see module #2) and an relational GCN (see module #3) is built with nodes storing these embeddings, and edges as well as edge types extracted from the scene graph following the aforementioned structure. This RGCN block has no hidden layers and outputs dimension F.

3. Outputs from the first and second step [B, N, E] and [B, N, F] are concatenated in the last axis to form [B, N, E + F]. Time embedding (see module #4) of matching dimension is produced and added to that result. This matrix is then fed into another relational GCN that consists of N nodes (each row becomes a node) and reuses the edges and edge types from the first RGCN. This new RGCN has a few hidden layers and results in dimension [B, N, D], finishing the forward pass.


In [None]:
class GuidedDiffusionNetwork(nn.Module):
    def __init__(
        self,
        # Attention block
        attention_in_dim, 
        attention_out_dim, 
        # Encoder RGCN block
        encoder_in_dim, 
        encoder_out_dim, 
        encoder_num_relations,
        encoder_num_bases=None,
        encodeer_aggr='mean',
        encoder_dp_rate=0.1,
        encoder_bias=True,
    ):
        super(GuidedDiffusionNetwork, self).__init__()
        
        self.attention_module = ModifiedMultiheadAttention(
            input_dim=attention_in_dim, 
            target_dim=attention_out_dim, 
            num_heads=attention_in_dim // 4 # TODO: hyperparam vs. hardcode?
        )
        
        self.encoder_module = EncoderRGCN(
            in_channels=encoder_in_dim, 
            h_channels_list = [], # TODO: no hidden layers?
            out_channels = encoder_out_dim,
            num_relations=encoder_num_relations, 
            num_bases=encoder_num_bases, 
            aggr=encodeer_aggr, 
            dp_rate=encoder_dp_rate, 
            bias=encoder_bias
        )
        
        self.time_embedding_module = TimeEmbedding(dim=attention_out_dim+encoder_out_dim)
    
    # TODO: adapt the forward method to match DDPM scheduler: model.forward(x, t, obj_cond, edge_cond, relation_cond)
    def forward(self, x, scene_graph, time):
        # Step 1: Unconditional denoising/diffusion
        x = self.attention_module(x)

        # Step 2: Scene graph processing
        embedded_graph = torch.stack([FastTextEncoder.encode(word) for word in scene_graph])
        embedded_graph = embedded_graph.squeeze(1)
        graph_output = self.encoder_module(embedded_graph, edge_index, edge_type)

        # Step 3: Concatenation and time embedding
        fused_output = torch.cat([x, graph_output], dim=-1)
        time_embedded = self.time_embedding_module(time)
        fused_output += time_embedded

        # TODO: Step 4: Final relational GCN
        output = fused_output # TODO: placeholder

        return output
