In [2]:
import os
import pickle
import networkx as nx
import numpy as np
from typing import Tuple,List

import torch
from torch_geometric.utils import dense_to_sparse,unbatch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch.distributions import Bernoulli

In [3]:
from conditional_rate_matching.data.graph_dataloaders_config import CommunitySmallGConfig
from conditional_rate_matching.data.graph_dataloaders_config import GraphDataloaderGeometricConfig

In [4]:
def read_graph_lists(graph_data_config:GraphDataloaderGeometricConfig)->Tuple[List[nx.Graph]]:
    """
    parameters
    ----------

    return
    ------
        Tuple[List[nx.Graph]]: train_graph_list, test_graph_list
    """
    data_dir = graph_data_config.data_dir
    file_name = graph_data_config.dataset_name
    file_path = os.path.join(data_dir, file_name)
    with open(file_path + '.pkl', 'rb') as f:
        graph_list = pickle.load(f)
    test_size = int(graph_data_config.test_split * len(graph_list))
    train_graph_list, test_graph_list = graph_list[test_size:], graph_list[:test_size]
    return train_graph_list, test_graph_list

def create_geometric_dataset(train_graph_list:List[nx.Graph],test_graph_list:List[nx.Graph])->Tuple[List[Data]]:
    """
    we take the list of networkx graph and create a torch_geometric dataset

    parameters
    ----------
        train_graph_list,test_graph_list
    return
    ------
    train_dataset,test_dataset
    """
    train_data = []
    test_data = []

    max_train = max([graph.number_of_nodes() for graph in train_graph_list])
    max_test = max([graph.number_of_nodes() for graph in test_graph_list])

    max_num_nodes = max(max_train,max_test)
    num_node_features = max_num_nodes

    for graph in train_graph_list:
        adj = nx.to_numpy_array(graph)
        number_of_nodes = adj.shape[0]

        padded_adj = np.zeros((max_num_nodes,max_num_nodes))
        padded_adj[:number_of_nodes,:number_of_nodes] = adj

        edge_index = dense_to_sparse(torch.Tensor(adj))[0]
        nodes_attributes = torch.eye(max_num_nodes)
        train_data.append(Data(x=nodes_attributes,edge_index=edge_index))

    for graph in test_graph_list:
        adj = nx.to_numpy_array(graph)
        number_of_nodes = adj.shape[0]

        padded_adj = np.zeros((max_num_nodes,max_num_nodes))
        padded_adj[:number_of_nodes,:number_of_nodes] = adj

        edge_index = dense_to_sparse(torch.Tensor(adj))[0]

        nodes_attributes = torch.eye(max_num_nodes)
        test_data.append(Data(x=nodes_attributes,edge_index=edge_index))
    
    return train_data,test_data,num_node_features

class GraphGeometricDataloader:

    def __init__(self,config:GraphDataloaderGeometricConfig):
        self.config = config
        train_graph_list, test_graph_list = read_graph_lists(config)
        train_dataset,test_dataset,num_node_features = create_geometric_dataset(train_graph_list,test_graph_list)

        self.num_node_features = num_node_features
        self.config.dimensions = num_node_features
        self.train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
        self.test_dataloader = DataLoader(test_dataset,batch_size=config.batch_size)
    
    def train(self):
        return self.train_dataloader
    
    def test(self):
        return self.test_dataloader

In [44]:
from torch import nn
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


    
class GCN(torch.nn.Module):
    def __init__(self, num_node_features,hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        return x


from torch_geometric.utils import dense_to_sparse,unbatch
from torch_geometric.data import Data

def sample_to_geometric(X,number_of_nodes=20):
    """
    obtains a representation which is suuitable for GNNs defined wiith the torch_geometric library
    """
    batch_size = X.shape[0]
    adj = X[:,:,None].reshape(batch_size,number_of_nodes,number_of_nodes)
    edge_index = dense_to_sparse(torch.Tensor(adj))[0]
    nodes_attributes = torch.eye(number_of_nodes)
    nodes_attributes = nodes_attributes.repeat((batch_size,1))
    batch = torch.arange(batch_size).repeat_interleave(number_of_nodes)
    return nodes_attributes,edge_index,batch

class EmbedFC(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super(EmbedFC, self).__init__()
        '''
        generic one layer FC NN for embedding things  
        '''
        self.input_dim = input_dim
        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.model(x)
    
    
class SimpleTemporalGCN(torch.nn.Module):
    def __init__(self,num_nodes,num_node_features,hidden_channels,time_dimension=19):
        super(GCN, self).__init__()
        torch.manual_seed(12345)

        self.number_of_nodes = num_nodes
        self.num_node_features = num_node_features
        
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)

        self.timeembed1 = EmbedFC(1, time_dimension)

        self.edge_encoding_0 = Linear(2*hidden_channels,hidden_channels)
        self.edge_encoding = Linear(hidden_channels+time_dimension,1)

    def forward(self, X, time):
        x,edge_index,batch = sample_to_geometric(X,number_of_nodes=self.number_of_nodes)
        time_emb = self.timeembed1(time)

        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        x = F.dropout(x, p=0.5, training=self.training)

        # 2. Create outer concatenation for the edge encoding
        x = torch.stack(unbatch(x,batch=batch),dim=0)
        N = self.number_of_nodes

        x_i = x.unsqueeze(2)  # Shape becomes (batch_size, N, 1, D)
        x_j = x.unsqueeze(1)  # Shape becomes (batch_size, 1, N, D)
        
        # Concatenate the expanded tensors along the last dimension
        x = torch.cat((x_i.expand(-1, -1, N, -1), x_j.expand(-1, N, -1, -1)), dim=-1)  # Shape becomes (batch_size, N, N, 2*D)
        x = self.edge_encoding_0(x) # Shape becomes (batch_size,N,N,D)

        # Expand time_emb to match the dimensions of B
        time_emb = time_emb.unsqueeze(1).unsqueeze(2).expand(-1, N, N, -1)  # Shape becomes (batch_size, N, N, time_emd_dim)

        # Concatenate time_emb_expanded to B along the last dimension
        x = torch.cat((x, time_emb), dim=-1)  # Shape becomes (batch_size, N, N, D + time_emd_dim)
        x = self.edge_encoding(x) # Shape becomes (batch_size, N, N, 1)

        return x

In [45]:
batch_size = 3
hidden_channels = 64
time_encoding = 9
number_of_nodes = 20

graph_config = CommunitySmallGConfig(batch_size=batch_size)
graph_dataloader = GraphGeometricDataloader(graph_config)
model = GCN(20,graph_dataloader.num_node_features,hidden_channels=hidden_channels)

batch = next(graph_dataloader.train().__iter__())
batch

DataBatch(x=[60, 20], edge_index=[2, 146], batch=[60], ptr=[4])

In [51]:
batch_size = 4
X = Bernoulli(torch.rand((number_of_nodes*number_of_nodes))).sample((batch_size,))
time = torch.rand((batch_size))

In [52]:
model(X,time).shape

torch.Size([4, 20, 20, 1])

torch.Size([40, 20])

In [6]:
nodes_attributes,edge_index,batch = sample_to_geometric(X)

NameError: name 'X' is not defined

In [88]:
model.edge_encoding(output).shape

torch.Size([2, 20, 20, 1])

In [74]:
unbatch(output,)

torch.Size([2, 20, 1, 64])

In [58]:
outer_product = torch.einsum("bik,bjk->bij",result,result)

torch.Size([2, 20, 20])