In [1]:
import pandas as pd
import numpy as np
import random
import tqdm
import gc
gc.enable()

%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import download_url, extract_zip, HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.nn.conv import MessagePassing
import torch_geometric.transforms as T
from torch_geometric import EdgeIndex
from torch_geometric.utils import add_self_loops, spmm, is_sparse
from torch_geometric.typing import Adj, OptPairTensor, SparseTensor

### Environment parameters

In [12]:
"""Immutable parameters"""
schedule_size = 100     # size of scheduling window: count discrete blocks of time

# Possible qualifications for a thief
id2qual = {
    0: 'Safe cracking',
    1: 'Getaway driving',
    2: 'Masterminding',
    3: 'Burglary',
    4: 'Smooth Operator',
    5: 'Hacking',
    6: 'Token woman',
    7: 'Flirting with detective',
    8: 'Wisecracking',
    9: 'Distraction'
}

# Possible factors for a thief when considering to take on a job 
id2factors = {
    0: 'Risk',
    1: 'Workload',
    2: 'Payoff',
    3: 'Complexity',
    4: 'Travel'
}

num_quals   = len(id2qual)
num_factors = len(id2factors)
thief_size  = num_quals + num_factors + schedule_size # Cannot change because dimensionality of nodes cannot change
heist_size  = 5                                       # Cannot change because dimensionality of nodes cannot change
slot_size   = num_quals + num_factors + 1             # Cannot change because dimensionality of edges cannot change

In [3]:
"""Mutable parameters"""
fname_pool = ['Amy','Brandon','Carli','Dante','Eleanor','Frank']
lname_pool = ['Adams','Boxer','Charles','Darwin','Egan','Fallon']
heist_pool = ['Vermeer','Rembrandt','Banksy','Faberge','Monet','Michelangelo','Boticelli','Van Gogh']
museum_pool = ['British Museum', 'Smithsonian', 'Louvre', 'Billionaire\'s Private Collection']

max_heist_time = 5 # Max length of heist
max_heist_num  = 5 # Max num of heists a thief can go on (due to union limits)

qual_min = 0   # Minimum qualification level
qual_max = 3

factor_min = 0  # Minimum rating for a job factor 
factor_max = 3

n_slots_min = 2 # Minimum number of slots on a job
n_slots_max = 8

### Define generation functions

In [4]:
def generate_schedule(n_heists):
    """ Generate schedule for a thief """

    # initialize free (0) schedule
    sched = np.zeros((schedule_size,))
    if n_heists == 0:
        return sched
        
    intvl = schedule_size // n_heists # prevents us from squashing all heists at end
    max_idx = intvl
    min_idx = 0
    for i in range(n_heists):
        start_idx = random.randint(min_idx, max_idx)
        end_idx = random.randint(start_idx, max_idx)
        sched[start_idx:end_idx] = 1

        # Update intvl
        max_idx += intvl
        min_idx = end_idx
    
    return sched

In [5]:
class Thief():
    def __init__(self,
                Id: int):
        self.id = Id
        self.name = f"{random.choice(fname_pool)} {random.choice(lname_pool)}"
        self.schedule = generate_schedule(random.randint(0,max_heist_num))
        self.qualifications = np.random.randint(qual_min, qual_max, size=len(id2qual))
        currency_mask = np.random.randint(0,1,size=len(id2factors))
        self.currencies = np.random.uniform(low=factor_min, high=factor_max, size=len(id2factors))
        self.currencies *= currency_mask

    def get_data(self):
        # Return tensor of all data
        data = torch.zeros((thief_size + heist_size,))
        data[:thief_size] = torch.from_numpy(np.concatenate([self.schedule, self.qualifications, self.currencies])).to(torch.float)
        return data

In [6]:
class Heist():
    def __init__(self,
                Id: int):
        self.id = Id
        self.name = random.choice(heist_pool)
        self.start_time = random.randint(0, schedule_size)
        self.end_time = random.randint(self.start_time, self.start_time + max_heist_time)
        self.crew = {}
        self.n_slots = random.randint(n_slots_min, n_slots_max) # num spots available  
        self.n_slots_left = self.n_slots                        # num spots left
        self.n_slots_required = random.randint(0,self.n_slots)  # num spots left that are required
    
    def get_data(self):
                # Return tensor of all data
        data = torch.zeros((thief_size + heist_size,))
        data[thief_size:] = torch.from_numpy(np.array([
            self.start_time, self.end_time, self.n_slots, self.n_slots_left, self.n_slots_required
        ])).to(torch.float)
        return data

In [7]:
class Slot():
    def __init__(self,
                Id,
                thief_id,
                heist_id):
        self.id = Id
        self.thief_id = thief_id
        self.heist_id = heist_id
        self.required = np.expand_dims(np.array(random.randint(0,1)), axis=0)
        self.qualifications = np.random.randint(qual_min, qual_max, size=len(id2qual))
        currency_mask = np.random.randint(0,1,size=len(id2factors))
        self.currencies = np.random.uniform(low=factor_min, high=factor_max, size=len(id2factors))
        self.currencies *= currency_mask

    def get_data(self):
        return torch.from_numpy(np.concatenate([
            self.required, self.qualifications, self.currencies
        ]))        

### Debugging MessagePassing

In [143]:
# Generate node dfs
num_heists = 20
num_thieves = 50

heists_df = pd.DataFrame()
for i in range(num_heists):
    heist_data = Heist(i).get_data()
    tmp_df = pd.DataFrame(heist_data.numpy()).T
    tmp_df.index = [i]
    # tmp_df.columns = ['start_time','end_time','n_slots','n_slots_left','n_slots_required']
    heists_df = pd.concat([heists_df, tmp_df])
heists_df.index.rename('heistId', inplace=True)

thieves_df = pd.DataFrame()
for i in range(num_thieves):
    thief_data = Thief(i).get_data()
    tmp_df = pd.DataFrame(thief_data.numpy()).T
    tmp_df.index = [i]
    # tmp_df.columns = ['start_time','end_time','n_slots','n_slots_left','n_slots_required']
    thieves_df = pd.concat([thieves_df, tmp_df])
thieves_df.index.rename('thiefId', inplace=True)

In [144]:
# Generate slots (future edge_attr) and edge index (future edge_index)
slots_df = pd.DataFrame()
h_lst = []  # list of heist endpoints for edges
t_lst = []  # list of thief endpoints for edges

s_idx = 0
for h_idx, h in heists_df.iterrows():
    for t_idx, t in thieves_df.iterrows():
        for i in range(int(h[heist_size + thief_size - 3])):            # 112 contains n_slots data
            # TODO: currently assign randomly
            if random.uniform(0,1) < 0.5:
                slot_data = Slot(s_idx, t_idx, h_idx).get_data()
                tmp_df = pd.DataFrame(slot_data.numpy()).T
                tmp_df.index = [s_idx]
                slots_df = pd.concat([slots_df, tmp_df])
                s_idx += 1

                # Add endpoints
                h_lst.append(h_idx)
                t_lst.append(t_idx)
slots_df.index.rename('slotId', inplace=True)

# Create edge index
edge_index_thief_to_heist = torch.stack([torch.tensor(t_lst), torch.tensor(h_lst)], dim=0)

In [145]:
# NOTE: set all values to -100 for ease of identification
thieves_df.iloc[:,:] = -100

In [146]:
data = HeteroData()

# Add node indices
data['thief'].node_id = torch.tensor(thieves_df.index)
data['heist'].node_id = torch.tensor(heists_df.index)

# Add node features
data["thief"].x = torch.tensor(thieves_df.values).to(torch.float)
data["heist"].x = torch.tensor(heists_df.values).to(torch.float)

# Add edge indices
data["thief", "slot", "heist"].edge_index = edge_index_thief_to_heist # has shape (2, num_edges)

# Add edge features
data["thief", "slot", "heist"].edge_attr = torch.tensor(slots_df.values).to(torch.float)

# Add reverse edge
data = T.ToUndirected()(data)

In [147]:
# Encoder sizes
message_hidden_channels = 64
out_channels = 1


In [188]:
class GNN(nn.Module):
    def __init__(self, node1_in_channels = heist_size + thief_size,
                       node2_in_channels = heist_size + thief_size,
                       edge_in_channels  = slot_size,
                       hidden_channels   = message_hidden_channels,
                       out_channels      =out_channels):
        super(GNN, self).__init__()
        
        # Define message passing layers for each type of node and edge
        self.node1_message_passing = CustomMessagePassing(node1_in_channels)
        self.node2_message_passing = CustomMessagePassing(node2_in_channels)
    
    def forward(self, data):

        # Extract node features and edge attributes
        node1_x = data['thief'].x
        node2_x = data['heist'].x
        edge_attr = data['thief','slot','heist'].edge_attr
        edge_index = data['thief', 'slot', 'heist'].edge_index
        edge_rev_index = data['heist', 'rev_slot', 'thief'].edge_index

        # Embed node and edge features
        opt_x: OptPairTensor = (node1_x, node2_x)
        opt_x2: OptPairTensor = (node2_x, node1_x)

        # Perform message passing for both node types
        node1_messages = self.node1_message_passing(opt_x2, edge_rev_index, edge_attr)
        # NOTE: we need to pass in OptPairTensor(neighbor_x, central_x) and the reverse edges
        # node2_messages = self.node2_message_passing(opt_x, edge_index)
        print("node1: ", node1_x.shape, " --> ", node1_messages.shape)
        # print("node2: ", node2_x.shape, " --> ", node2_messages.shape)

        return node1_messages

class CustomMessagePassing(MessagePassing):
    def __init__(self, node_channels):
        super(CustomMessagePassing, self).__init__(aggr='mean')

    def propagate(
        self,
        edge_index: Adj,
        size=None,
        **kwargs,
    ):
        decomposed_layers = 1 if self.explain else self.decomposed_layers

        for hook in self._propagate_forward_pre_hooks.values():
            res = hook(self, (edge_index, size, kwargs))
            if res is not None:
                edge_index, size, kwargs = res

        mutable_size = self._check_input(edge_index, size)

        # Run "fused" message and aggregation (if applicable).
        fuse = False
        if self.fuse and not self.explain:
            if is_sparse(edge_index):
                fuse = True
            elif (not torch.jit.is_scripting()
                  and isinstance(edge_index, EdgeIndex)):
                if (self.SUPPORTS_FUSED_EDGE_INDEX
                        and edge_index.is_sorted_by_col):
                    fuse = True

        if fuse:
            coll_dict = self._collect(self._fused_user_args, edge_index,
                                      mutable_size, kwargs)

            msg_aggr_kwargs = self.inspector.collect_param_data(
                'message_and_aggregate', coll_dict)
            for hook in self._message_and_aggregate_forward_pre_hooks.values():
                res = hook(self, (edge_index, msg_aggr_kwargs))
                if res is not None:
                    edge_index, msg_aggr_kwargs = res
            out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)
            for hook in self._message_and_aggregate_forward_hooks.values():
                res = hook(self, (edge_index, msg_aggr_kwargs), out)
                if res is not None:
                    out = res

            update_kwargs = self.inspector.collect_param_data(
                'update', coll_dict)
            out = self.update(out, **update_kwargs)
        else:  # Otherwise, run both functions in separation.
            if decomposed_layers > 1:
                user_args = self._user_args
                decomp_args = {a[:-2] for a in user_args if a[-2:] == '_j'}
                decomp_kwargs = {
                    a: kwargs[a].chunk(decomposed_layers, -1)
                    for a in decomp_args
                }
                decomp_out = []

            for i in range(decomposed_layers):
                if decomposed_layers > 1:
                    for arg in decomp_args:
                        kwargs[arg] = decomp_kwargs[arg][i]

                coll_dict = self._collect(self._user_args, edge_index,
                                          mutable_size, kwargs)

                msg_kwargs = self.inspector.collect_param_data(
                    'message', coll_dict)
                for hook in self._message_forward_pre_hooks.values():
                    res = hook(self, (msg_kwargs, ))
                    if res is not None:
                        msg_kwargs = res[0] if isinstance(res, tuple) else res
                out = self.message(**msg_kwargs)
                for hook in self._message_forward_hooks.values():
                    res = hook(self, (msg_kwargs, ), out)
                    if res is not None:
                        out = res

                if self.explain:
                    explain_msg_kwargs = self.inspector.collect_param_data(
                        'explain_message', coll_dict)
                    out = self.explain_message(out, **explain_msg_kwargs)

                aggr_kwargs = self.inspector.collect_param_data(
                    'aggregate', coll_dict)
                for hook in self._aggregate_forward_pre_hooks.values():
                    res = hook(self, (aggr_kwargs, ))
                    if res is not None:
                        aggr_kwargs = res[0] if isinstance(res, tuple) else res
                out = self.aggregate(out, **aggr_kwargs)

                for hook in self._aggregate_forward_hooks.values():
                    res = hook(self, (aggr_kwargs, ), out)
                    if res is not None:
                        out = res

                update_kwargs = self.inspector.collect_param_data(
                    'update', coll_dict)
                out = self.update(out, **update_kwargs)

                if decomposed_layers > 1:
                    decomp_out.append(out)

            if decomposed_layers > 1:
                out = torch.cat(decomp_out, dim=-1)

        for hook in self._propagate_forward_hooks.values():
            res = hook(self, (edge_index, mutable_size, kwargs), out)
            if res is not None:
                out = res

        print("Message kwargs: ", msg_kwargs)
        print("Aggregation kwargs: ", aggr_kwargs)

        return out

    def forward(self, x, edge_index, edge_attr):
        # Perform message passing for both types of nodes simultaneously
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        return x_j

    def aggregate(self, inputs, index, ptr = None, dim_size = None):
        # inputs has shape [E, hidden]
        # index has shape [E]
        # NOTE: output has shape [N1, H] where N1 = num of nodes of opposite type
        print("Aggregation inputs: dim_size=", dim_size, ", inputs=", inputs.shape)

        return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,
                                dim=self.node_dim)

In [189]:
# Instantiate the GNN model
model = GNN()

# Forward pass with the data
output = model(data)

Aggregation inputs: dim_size= 50 , inputs= torch.Size([2560, 120])
Message kwargs:  {'x_i': tensor([[-100., -100., -100.,  ..., -100., -100., -100.],
        [-100., -100., -100.,  ..., -100., -100., -100.],
        [-100., -100., -100.,  ..., -100., -100., -100.],
        ...,
        [-100., -100., -100.,  ..., -100., -100., -100.],
        [-100., -100., -100.,  ..., -100., -100., -100.],
        [-100., -100., -100.,  ..., -100., -100., -100.]]), 'x_j': tensor([[0., 0., 0.,  ..., 4., 4., 0.],
        [0., 0., 0.,  ..., 4., 4., 0.],
        [0., 0., 0.,  ..., 4., 4., 0.],
        ...,
        [0., 0., 0.,  ..., 5., 5., 2.],
        [0., 0., 0.,  ..., 5., 5., 2.],
        [0., 0., 0.,  ..., 5., 5., 2.]]), 'edge_attr': tensor([[1., 1., 2.,  ..., 0., 0., 0.],
        [1., 2., 2.,  ..., 0., 0., 0.],
        [1., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [1., 2., 1.,  ..., 0., 0., 0.],
        [1., 1., 0.,  ..., 0., 0., 0.],
        [0., 2., 2.,  ..., 0., 0., 0.]])}
Aggregation kwa

### Manual verification

In [156]:
# Message passing result for model above for node1 (thief)
print(output[0])

""" Can we get to the same result manually? 
Remember:
- thief = node1
- heist = node2
- edge_index = data['thief','slot','heist].edge_index

Message method calculates messages for EVERY edge
Agg method determines which ones to take for agg
- 'index' kwarg is edge_index[1] and has values from 0-49
"""

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.00

In [157]:
# Get reverse edge index
edge_rev_index = data['heist','rev_slot','thief'].edge_index[1]

# Get data
node1_x = data['thief'].x
node2_x = data['heist'].x

In [158]:
# Get indices where the edge is connected to thief 0
idx = (edge_rev_index==0).nonzero()

# Get the other endpoints of edges connected to thief 0 -> heists that thief 0 is connected to
heist_idx = data['heist','rev_slot','thief'].edge_index[0][idx]

# Calculate mean values of these heists (since agg method = mean)
heist_sum = 0
for i in heist_idx:
    heist_sum += node2_x[i]

In [159]:
print(heist_sum / len(heist_idx))

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  

### Building up GNN

In [183]:
# Encoder sizes
encoder_node1_channels = [heist_size + thief_size, 128, 32]
encoder_node2_channels = [heist_size + thief_size, 128, 32]
encoder_edge_channels = [slot_size, 16, 8]

# Message passing sizes
message_hidden_channels = 64
out_channels = 1

In [235]:
class GNN(nn.Module):
    def __init__(self, encoder_node1_channels  : list = encoder_node1_channels,
                       encoder_node2_channels  : list = encoder_node2_channels,
                       encoder_edge_channels   : list = encoder_edge_channels,
                       message_hidden_channels : int  = message_hidden_channels, # Latent space message passing happens in
                       out_channels            : int  = out_channels,
                       n_passes                : int  = 3):
        super(GNN, self).__init__()

        node1_in, node1_h, node1_out = encoder_node1_channels
        node2_in, node2_h, node2_out = encoder_node2_channels
        edge_in,  edge_h,  edge_out  = encoder_edge_channels
        self.node1_encoder = Encoder(node1_in, node1_h, node1_out)
        self.node2_encoder = Encoder(node2_in, node2_h, node2_out)
        self.edge_encoder  = Encoder(edge_in, edge_h, edge_out)

        # Define message passing layers -> take in the output sizes of the encoders
        self.n_passes = n_passes
        self.node1_message_passing = CustomMessagePassing(node1_out, message_hidden_channels, edge_size=edge_out)
        self.node2_message_passing = CustomMessagePassing(node2_out, message_hidden_channels, edge_size=edge_out)
        self.edge_message_passing  = EdgeMessagePassing(edge_out, node1_out, node2_out, message_hidden_channels)

        # Define final linear transformation layer
        self.decoder = nn.Sequential(
            nn.Linear(edge_out, 2*edge_out),
            nn.ReLU(),
            nn.Linear(2*edge_out, out_channels)
        )

        self.decoder2 = nn.Sequential(
            nn.Linear(edge_out + node1_out + node2_out, 2*edge_out),
            nn.ReLU(),
            nn.Linear(2*edge_out, out_channels)
        )
    
    def forward(self, data):

        # Extract node features and edge attributes
        node1_x = data['thief'].x
        node2_x = data['heist'].x
        edge_attr = data['thief','slot','heist'].edge_attr
        edge_index = data['thief', 'slot', 'heist'].edge_index
        edge_rev_index = data['heist', 'rev_slot', 'thief'].edge_index

        # Embed node and edge features
        node1_x = self.node1_encoder(node1_x)
        node2_x = self.node2_encoder(node2_x)
        edge_x = self.edge_encoder(edge_attr)
        # print(node1_x.shape) # 50, 32
        # print(node2_x.shape) # 20, 32
        # print(edge_x.shape) # E, 8

        # Preserve original embeddings
        node1_x_original = node1_x
        node2_x_original = node2_x
        edge_x_original  = edge_x

        # Perform n message passing
        for i in range(self.n_passes):
            opt_x: OptPairTensor = (node1_x, node2_x)
            opt_x2: OptPairTensor = (node2_x, node1_x)

            node1_x = self.node1_message_passing(edge_rev_index, opt_x2, edge_x)
            node2_x = self.node2_message_passing(edge_index, opt_x, edge_x)
            edge_x  = self.edge_message_passing(edge_index, edge_x, node1_x, node2_x)

        print("Original --> Embedding --> Message passing")
        print("node1: ", data['thief'].x.shape, " --> ", node1_x_original.shape, " --> ", node1_x.shape)
        print("node2: ", data['heist'].x.shape, " --> ", node2_x_original.shape, " --> ", node2_x.shape)
        print("edge: ", edge_attr.shape, " --> ", edge_x_original.shape, " --> ", edge_x.shape)

        # Combine with preserved embeddings
        node1_x += node1_x_original
        node2_x += node2_x_original
        edge_x += edge_x_original

        # Final decoding, version 0 -> just pass edge through mlp
        out = self.decoder(edge_x)
        print(out.shape)

        # Final decoding, version 1 -> dot product endpoints, then pass through mlp
        src, dst = edge_index
        out = (node1_x[src] * node2_x[dst]).sum(dim=-1)
        print(out.shape)

        # Final decoding, version 2 -> pass edge + endpoints through mlp
        out = self.decoder2(torch.cat([edge_x, node1_x[src], node2_x[dst]], dim=1))
        print(out.shape)

        return out
    
class Encoder(nn.Module):
    """ 2 layer MLP """
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    
    def forward(self, x):
        return self.mlp(x)

class CustomMessagePassing(MessagePassing):
    def __init__(self, node_channels,
                       message_hidden_channels,
                       edge_size=None):
        super(CustomMessagePassing, self).__init__(aggr='mean')
        
        """ 
        Message passing layer shape: node_channels + edge_size -> message_hidden_channels -> node_channels
        - outputs same size for iterated passing
        """

        # Define linear transformations for message passing
        in_size = node_channels + (edge_size if edge_size is not None else 0)
        self.neighbor_linear = nn.Linear(in_size, message_hidden_channels)
        self.update_linear = nn.Linear(message_hidden_channels, node_channels) # TODO: consider changing dimensions

    def forward(self, edge_index, x, edge_attr=None):
        # Perform message passing for both types of nodes simultaneously
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr=None):
        # x_i has shape [E, central_node_dim]
        # x_j has shape [E, neighbor_node_dim]
        # edge_attr has shape [E, slot_dim]

        # Apply linear transformation for edge messages
        if edge_attr is not None:
            tmp = torch.cat([x_j,edge_attr],dim=1)
        else: tmp = x_j
        return self.neighbor_linear(tmp)
    
    def update(self, aggr_out):
        # Apply linear transformation for node messages
        return self.update_linear(aggr_out)
    
class EdgeMessagePassing(nn.Module):
    def __init__(self, edge_channels,
                       node1_channels,
                       node2_channels,
                       hidden_channels):
        super().__init__()

        self.mlp = nn.Sequential(
             nn.Linear(edge_channels + node1_channels + node2_channels, hidden_channels),
             nn.ReLU(),
             nn.Linear(hidden_channels, edge_channels) # outputs same size for iterated passing
        )

    def forward(self, edge_index, edge_attr, node1, node2):
        # No need for aggr, update since num nodes for each edge is constant at 2
        row, col  = edge_index
        new_edge_attr = self.mlp(torch.cat([node1[row], node2[col], edge_attr], dim=-1))
        return new_edge_attr

In [236]:
model = GNN()

output = model(data)

Original --> Embedding --> Message passing
node1:  torch.Size([50, 120])  -->  torch.Size([50, 32])  -->  torch.Size([50, 32])
node2:  torch.Size([20, 120])  -->  torch.Size([20, 32])  -->  torch.Size([20, 32])
edge:  torch.Size([2560, 16])  -->  torch.Size([2560, 8])  -->  torch.Size([2560, 8])
torch.Size([2560, 1])
torch.Size([50, 32]) tensor([ 0,  0,  1,  ..., 48, 48, 49])
torch.Size([2560])
torch.Size([2560, 1])
