In [1]:
import pandas as pd
import numpy as np
import random
import tqdm
import gc
gc.enable()

%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import download_url, extract_zip, HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.nn.conv import MessagePassing
import torch_geometric.transforms as T
from torch_geometric.utils import add_self_loops

### Generate dummy data

In [2]:
fname_pool = ['Amy','Brandon','Carli','Dante','Eleanor','Frank']
lname_pool = ['Adams','Boxer','Charles','Darwin','Egan','Fallon']
heist_pool = ['Vermeer','Rembrandt','Banksy','Faberge Egg']

schedule_size = 100
max_heist_time = 5
num_heist = 5

qual_min = 0
qual_max = 3
curr_min = 0
curr_max = 3

n_slots_min = 2
n_slots_max = 8

id2qual = {
    0: 'Safe Cracker',
    1: 'Getaway Driver',
    2: 'Mastermind',
    3: 'Cat Burglar',
    4: 'Smooth Operator',
    5: 'Hacker'
}

id2currency = {
    0: 'Risk',
    1: 'Workload',
    2: 'Payoff',
    3: 'Artistic Enjoyment',
}

def generate_schedule(num_heists):
    sched = np.zeros((schedule_size,))
    if num_heists == 0:
        return sched
        
    intvl = schedule_size // num_heists # prevents us from squashing all heists at end
    max_idx = intvl
    min_idx = 0
    for i in range(num_heists):
        start_idx = random.randint(min_idx, max_idx)
        end_idx = random.randint(start_idx, max_idx)
        sched[start_idx:end_idx] = 1

        # Update intvl
        max_idx += intvl
        min_idx = end_idx
    
    return sched

In [3]:
thief_size = 110
heist_size = 5
slot_size = 11

In [4]:
class Thief():
    def __init__(self,
                Id: int):
        self.id = Id
        self.name = f"{random.choice(fname_pool)} {random.choice(lname_pool)}"
        self.schedule = generate_schedule(random.randint(0,num_heists))
        self.qualifications = np.random.randint(qual_min, qual_max, size=len(id2qual))
        currency_mask = np.random.randint(0,1,size=len(id2currency))
        self.currencies = np.random.uniform(low=curr_min, high=curr_max, size=len(id2currency))
        self.currencies *= currency_mask

    def get_data(self):
        # Return tensor of all data
        data = torch.zeros((thief_size + heist_size,))
        data[:thief_size] = torch.from_numpy(np.concatenate([self.schedule, self.qualifications, self.currencies])).to(torch.float)
        return data

In [5]:
class Heist():
    def __init__(self,
                Id: int):
        self.id = Id
        self.name = random.choice(heist_pool)
        self.start_time = random.randint(0, schedule_size)
        self.end_time = random.randint(self.start_time, self.start_time + max_heist_time)
        self.crew = {}
        self.n_slots = random.randint(n_slots_min, n_slots_max) # num spots available  
        self.n_slots_left = self.n_slots                        # num spots left
        self.n_slots_required = random.randint(0,self.n_slots)  # num spots left that are required
    
    def get_data(self):
                # Return tensor of all data
        data = torch.zeros((thief_size + heist_size,))
        data[thief_size:] = torch.from_numpy(np.array([
            self.start_time, self.end_time, self.n_slots, self.n_slots_left, self.n_slots_required
        ])).to(torch.float)
        return data

In [6]:
class Slot():
    def __init__(self,
                Id,
                thief_id,
                heist_id):
        self.id = Id
        self.thief_id = thief_id
        self.heist_id = heist_id
        self.required = np.expand_dims(np.array(random.randint(0,1)), axis=0)
        self.qualifications = np.random.randint(qual_min, qual_max, size=len(id2qual))
        currency_mask = np.random.randint(0,1,size=len(id2currency))
        self.currencies = np.random.uniform(low=curr_min, high=curr_max, size=len(id2currency))
        self.currencies *= currency_mask

    def get_data(self):
        return torch.from_numpy(np.concatenate([
            self.required, self.qualifications, self.currencies
        ]))        

In [7]:
# Generate node dfs
num_heists = 20
num_thieves = 50

heists_df = pd.DataFrame()
for i in range(num_heists):
    heist_data = Heist(i).get_data()
    tmp_df = pd.DataFrame(heist_data.numpy()).T
    tmp_df.index = [i]
    # tmp_df.columns = ['start_time','end_time','n_slots','n_slots_left','n_slots_required']
    heists_df = pd.concat([heists_df, tmp_df])
heists_df.index.rename('heistId', inplace=True)

thieves_df = pd.DataFrame()
for i in range(num_thieves):
    thief_data = Thief(i).get_data()
    tmp_df = pd.DataFrame(thief_data.numpy()).T
    tmp_df.index = [i]
    # tmp_df.columns = ['start_time','end_time','n_slots','n_slots_left','n_slots_required']
    thieves_df = pd.concat([thieves_df, tmp_df])
thieves_df.index.rename('thiefId', inplace=True)

In [8]:
# Generate slots (future edge_attr) and edge index (future edge_index)
slots_df = pd.DataFrame()
h_lst = []  # list of heist endpoints for edges
t_lst = []  # list of thief endpoints for edges

s_idx = 0
for h_idx, h in heists_df.iterrows():
    for t_idx, t in thieves_df.iterrows():
        for i in range(int(h[112])):            # 112 contains n_slots data
            # TODO: currently assign randomly
            if random.uniform(0,1) < 0.5:
                slot_data = Slot(s_idx, t_idx, h_idx).get_data()
                tmp_df = pd.DataFrame(slot_data.numpy()).T
                tmp_df.index = [s_idx]
                slots_df = pd.concat([slots_df, tmp_df])
                s_idx += 1

                # Add endpoints
                h_lst.append(h_idx)
                t_lst.append(t_idx)
slots_df.index.rename('slotId', inplace=True)

# Create edge index
edge_index_thief_to_heist = torch.stack([torch.tensor(t_lst), torch.tensor(h_lst)], dim=0)

In [9]:
data = HeteroData()

# Add node indices
data['thief'].node_id = torch.tensor(thieves_df.index)
data['heist'].node_id = torch.tensor(heists_df.index)

# Add node features
data["thief"].x = torch.tensor(thieves_df.values).to(torch.float)
data["heist"].x = torch.tensor(heists_df.values).to(torch.float)

# Add edge indices
data["thief", "slot", "heist"].edge_index = edge_index_thief_to_heist # has shape (2, num_edges)

# Add edge features
data["thief", "slot", "heist"].edge_attr = torch.tensor(slots_df.values).to(torch.float)

# Add reverse edge
data = T.ToUndirected()(data)

In [None]:

src, dst = edge_index
score = (x[src] * x[dst]).sum(dim=-1)

row, col = edge_index
new_edge_attr = self.mlp(torch.cat([x[row], x[col], edge_attr], dim=-1))

In [None]:
"""
Slots and people as nodes
Edge is the assignment decision
Softly represent fullness of a flight through an embedding
- duplicate event doesn't matter
- hard constraints can be solved through environment itself. don't let environment until all seats are filled
- loosely using the term positional embedding. dot product between seats on same flight = 1
- edge between events? 
- temporal embedding: Hope that implicitly it learns it!!!
Recommendation: dot product of describing seat + dot product of positional
- use sinusoidal function for embedding vector???

Nick libertini: two different images, matching them using key points
"""

In [12]:
# Message passing sizes
in_node1_channels = 115
in_node2_channels = 115
in_edge_channels = 11

hidden_channels = 64
out_channels = 1

class GNN(nn.Module):
    def __init__(self, in_node1_channels=in_node1_channels,
                       in_node2_channels=in_node2_channels,
                       in_edge_channels =in_edge_channels,
                       hidden_channels  =hidden_channels,
                       out_channels     =out_channels):
        super(GNN, self).__init__()

        # TODO: add encoding
        # TODO: add edge updates
        # TODO: add edge regression

        # Define message passing layers for each type of node and edge
        self.node1_message_passing = CustomMessagePassing(in_node1_channels, in_edge_channels, hidden_channels)
        self.node2_message_passing = CustomMessagePassing(in_node2_channels, in_edge_channels, hidden_channels)

        # Define final linear transformation layer
        self.linear = nn.Linear(hidden_channels, out_channels) # TODO: hidden_channels * 2?

    def forward(self, data):

        # Extract node features and edge attributes
        node1_x = data['thief'].x
        node2_x = data['heist'].x
        edge_attr = data['thief', 'slot', 'heist'].edge_attr
        edge_index = data['thief', 'slot', 'heist'].edge_index
        edge_rev_index = data['heist', 'rev_slot', 'thief'].edge_index


        # Perform message passing for both node types
        node1_messages = self.node1_message_passing(node1_x, edge_attr, edge_index)
        node2_messages = self.node2_message_passing(node2_x, edge_attr, edge_rev_index)
        # print(node1_messages.shape) # 20, 64
        # print(node2_messages.shape) # 50, 64

        combined_messages = torch.cat([node1_messages, node2_messages], dim=0)
        print(combined_messages.shape)

        # Apply final linear transformation
        out = self.linear(combined_messages)

        return out

class CustomMessagePassing(MessagePassing):
    def __init__(self, node_channels,
                       edge_channels,
                       out_channels):
        super(CustomMessagePassing, self).__init__(aggr='mean')

        # Define linear transformations for message passing
        # self.node_linear = nn.Linear(in_channels, out_channels)
        self.neighbor_linear = nn.Linear(node_channels + edge_channels, out_channels)
        # self.edge_linear = nn.Linear(edge_channels + 2*in_channels, out_channels)
        self.update_linear = nn.Linear(out_channels, out_channels)

    def forward(self, x, edge_attr, edge_index):
        # Perform message passing for both types of nodes simultaneously
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # x_i has shape [E, thief_dim]
        # x_j has shape [E, heist_dim]
        # edge_attr has shape [E, slot_dim]

        # Apply linear transformation for edge messages
        tmp = torch.cat([x_j,edge_attr],dim=1)
        return self.neighbor_linear(tmp)

    
    def aggregate(self, inputs, index, ptr = None, dim_size = None):
        # inputs has shape [E, hidden]
        # index has shape [E]
        # NOTE: output has shape [N1, H] where N1 = num of nodes of opposite type
        return self.aggr_module(inputs, index, ptr=ptr, dim_size=None, # TODO: had to set dim_size to None
                                dim=self.node_dim)

    def update(self, aggr_out):
        # Apply linear transformation for node messages
        return self.update_linear(aggr_out)


# NOTE 1: use the reverse edges
# NOTE 2: need to set dim_size in agg to None

In [13]:
# Instantiate the GNN model
model = GNN(in_node1_channels, in_node2_channels, in_edge_channels, hidden_channels, out_channels)

# Forward pass with the data
output = model(data)

torch.Size([70, 64])


In [22]:
# Encoder sizes
encoder_node1_channels = [115, 128, 64]
encoder_node2_channels = [115, 128, 64]
encoder_edge_channels = [11, 16, 8]

# Message passing sizes
message_hidden_channels = 64
out_channels = 1

class GNN(nn.Module):
    def __init__(self, encoder_node1_channels: list = encoder_node1_channels,
                       encoder_node2_channels: list = encoder_node2_channels,
                       encoder_edge_channels : list = encoder_edge_channels,
                       hidden_channels  =hidden_channels,
                       out_channels     =out_channels):
        super(GNN, self).__init__()

        # TODO: add edge updates
        # TODO: add edge regression

        node1_in, node1_h, node1_out = encoder_node1_channels
        node2_in, node2_h, node2_out = encoder_node2_channels
        edge_in, edge_h, edge_out = encoder_edge_channels # NOTE: the output of the embedding is the input size to the message passing
        self.node1_encoder = Encoder(node1_in, node1_h, node1_out)
        self.node2_encoder = Encoder(node2_in, node2_h, node2_out)
        self.edge_encoder  = Encoder(edge_in, edge_h, edge_out)

        # Define message passing layers for each type of node and edge
        self.node1_message_passing = CustomMessagePassing(node1_out, edge_out, hidden_channels)
        self.node2_message_passing = CustomMessagePassing(node2_out, edge_out, hidden_channels)

        # Define final linear transformation layer
        self.linear = nn.Linear(hidden_channels, out_channels) # TODO: hidden_channels * 2?

    def forward(self, data):

        # Extract node features and edge attributes
        node1_x = data['thief'].x
        node2_x = data['heist'].x
        edge_attr = data['thief', 'slot', 'heist'].edge_attr
        edge_index = data['thief', 'slot', 'heist'].edge_index
        edge_rev_index = data['heist', 'rev_slot', 'thief'].edge_index

        # Embed node and edge features
        node1_x = self.node1_encoder(node1_x)
        node2_x = self.node2_encoder(node2_x)
        edge_attr = self.edge_encoder(edge_attr)
        # print(node1_x.shape) # 50, 64
        # print(node2_x.shape) # 20, 64
        # print(edge_attr.shape) # E, 8

        # Perform message passing for both node types
        node1_messages = self.node1_message_passing(node1_x, edge_attr, edge_index)
        node2_messages = self.node2_message_passing(node2_x, edge_attr, edge_rev_index)
        # print(node1_messages.shape) # 20, 64
        # print(node2_messages.shape) # 50, 64
        
        # Perform message passing for edge

        combined_messages = torch.cat([node1_messages, node2_messages], dim=0)

        # Apply final linear transformation
        out = self.linear(combined_messages)

        return out

class Encoder(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    
    def forward(self, x):
        return self.mlp(x)

class CustomMessagePassing(MessagePassing):
    def __init__(self, node_channels,
                       edge_channels,
                       out_channels):
        super(CustomMessagePassing, self).__init__(aggr='mean')

        # Define linear transformations for message passing
        # self.node_linear = nn.Linear(in_channels, out_channels)
        self.neighbor_linear = nn.Linear(node_channels + edge_channels, out_channels)
        # self.edge_linear = nn.Linear(edge_channels + 2*in_channels, out_channels)
        self.update_linear = nn.Linear(out_channels, out_channels)

    def forward(self, x, edge_attr, edge_index):
        # Perform message passing for both types of nodes simultaneously
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # x_i has shape [E, thief_dim]
        # x_j has shape [E, heist_dim]
        # edge_attr has shape [E, slot_dim]

        # Apply linear transformation for edge messages
        tmp = torch.cat([x_j,edge_attr],dim=1)
        return self.neighbor_linear(tmp)

        # TODO: will this have 0 neighbors, since I only pass in one node type at a time?
    
    def aggregate(self, inputs, index, ptr = None, dim_size = None):
        # inputs has shape [E, hidden]
        # index has shape [E]
        # NOTE: output has shape [N1, H] where N1 = num of nodes of opposite type
        return self.aggr_module(inputs, index, ptr=ptr, dim_size=None, # TODO: had to set dim_size to None
                                dim=self.node_dim)

    def update(self, aggr_out):
        # Apply linear transformation for node messages
        return self.update_linear(aggr_out)


# NOTE 1: use the reverse edges
# NOTE 2: need to set dim_size in agg to None

In [23]:
# Instantiate the GNN model
model = GNN()

# Forward pass with the data
output = model(data)

torch.Size([1934, 64])
torch.Size([1934, 64])


In [49]:
# Encoder sizes
encoder_node1_channels = [115, 128, 64]
encoder_node2_channels = [115, 128, 64]
encoder_edge_channels = [11, 16, 8]

# Message passing sizes
message_hidden_channels = 64
out_channels = 1

class GNN(nn.Module):
    def __init__(self, encoder_node1_channels: list = encoder_node1_channels,
                       encoder_node2_channels: list = encoder_node2_channels,
                       encoder_edge_channels : list = encoder_edge_channels,
                       hidden_channels  =hidden_channels,
                       out_channels     =out_channels):
        super(GNN, self).__init__()

        node1_in, node1_h, node1_out = encoder_node1_channels
        node2_in, node2_h, node2_out = encoder_node2_channels
        edge_in, edge_h, edge_out = encoder_edge_channels # NOTE: the output of the embedding is the input size to the message passing
        self.node1_encoder = Encoder(node1_in, node1_h, node1_out)
        self.node2_encoder = Encoder(node2_in, node2_h, node2_out)
        self.edge_encoder  = Encoder(edge_in, edge_h, edge_out)

        # Define message passing layers for each type of node and edge
        self.node1_message_passing = CustomMessagePassing(node1_out, edge_out, hidden_channels)
        self.node2_message_passing = CustomMessagePassing(node2_out, edge_out, hidden_channels)
        self.edge_message_passing  = EdgeMessagePassing(edge_out, hidden_channels, hidden_channels, 2*hidden_channels, hidden_channels)

        # Define final linear transformation layer
        self.linear = nn.Linear(hidden_channels, out_channels) # TODO: hidden_channels * 2?

    def forward(self, data):

        # Extract node features and edge attributes
        node1_x = data['thief'].x
        node2_x = data['heist'].x
        edge_attr = data['thief', 'slot', 'heist'].edge_attr
        edge_index = data['thief', 'slot', 'heist'].edge_index
        edge_rev_index = data['heist', 'rev_slot', 'thief'].edge_index

        # Embed node and edge features
        node1_x = self.node1_encoder(node1_x)
        node2_x = self.node2_encoder(node2_x)
        edge_x = self.edge_encoder(edge_attr)
        # print(node1_x.shape) # 50, 64
        # print(node2_x.shape) # 20, 64
        # print(edge_attr.shape) # E, 8

        # Perform message passing for both node types
        node1_messages = self.node1_message_passing(node1_x, edge_x, edge_index)
        node2_messages = self.node2_message_passing(node2_x, edge_x, edge_rev_index)
        # print(node1_messages.shape) # 20, 64
        # print(node2_messages.shape) # 50, 64
        
        # Perform message passing for edge
        edge_messages = self.edge_message_passing(edge_x, node2_messages, node1_messages, edge_index)
        # Reverse node2 and node1 because we're using messages, not direct embeddings
        assert edge_messages.shape == (edge_x.shape[0], hidden_channels)
        # TODO: is 64 too big for edge message?

        # combined_messages = torch.cat([node1_messages, node2_messages], dim=0)
        # print(combined_messages.shape)

        # Apply final linear transformation
        out = self.linear(edge_messages)

        return out

class Encoder(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    
    def forward(self, x):
        return self.mlp(x)

class CustomMessagePassing(MessagePassing):
    def __init__(self, node_channels,
                       edge_channels,
                       out_channels):
        super(CustomMessagePassing, self).__init__(aggr='mean')

        # Define linear transformations for message passing
        self.neighbor_linear = nn.Linear(node_channels + edge_channels, out_channels)
        self.update_linear = nn.Linear(out_channels, out_channels)

    def forward(self, x, edge_attr, edge_index):
        # Perform message passing for both types of nodes simultaneously
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # x_i has shape [E, thief_dim]
        # x_j has shape [E, heist_dim]
        # edge_attr has shape [E, slot_dim]

        # Apply linear transformation for edge messages
        tmp = torch.cat([x_j,edge_attr],dim=1)
        return self.neighbor_linear(tmp)

    
    def aggregate(self, inputs, index, ptr = None, dim_size = None):
        # inputs has shape [E, hidden]
        # index has shape [E]
        # NOTE: output has shape [N1, H] where N1 = num of nodes of opposite type
        return self.aggr_module(inputs, index, ptr=ptr, dim_size=None, # TODO: had to set dim_size to None
                                dim=self.node_dim)

    def update(self, aggr_out):
        # Apply linear transformation for node messages
        return self.update_linear(aggr_out)

class EdgeMessagePassing(nn.Module):
    def __init__(self, edge_channels,
                       node1_channels,
                       node2_channels,
                       hidden_channels, # input as 2 * hidden_channels above
                       out_channels):   # input as hidden_channels above
        super().__init__()

        self.mlp = nn.Sequential(
             nn.Linear(edge_channels + node1_channels + node2_channels, hidden_channels),
             nn.ReLU(),
             nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, edge_attr, node1, node2, edge_index):
        # No need for aggr, update since num nodes for each edge is constant at 2
        row, col  = edge_index
        new_edge_attr = self.mlp(torch.cat([node1[row], node2[col], edge_attr], dim=-1))
        return new_edge_attr

# NOTE 1: use the reverse edges
# NOTE 2: need to set dim_size in agg to None

In [50]:
# Instantiate the GNN model
model = GNN()

# Forward pass with the data
output = model(data)

In [54]:
total_params = sum(p.numel() for p in model.parameters())
total_params

90057

### Training