In [1]:
# Implementing same but with list

In [2]:
import torch
import torch.nn as nn
from torchinfo import summary

import torch_geometric
from torch_geometric.data import Data, Dataset, DataLoader

from scipy.spatial.distance import cdist
import networkx as nx

import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl

from dataset_list import *
from dataset import *

In [3]:
# Dataset
train_set = CustomizableMNIST(root='./data', train=True, download=True)

val_set_ratio = 0.2
shuffle = True
batch_size = 32

train_loader, valid_loader = split_and_shuffle_data(train_set, val_set_ratio, batch_size)

Initializing CustomizableMNIST...
Training set
Init done.



In [4]:
filename = 'graph_collection_dataset.pkl'
#graph_collection = compute_graph_collection_and_save_pickle_file(train_set, filename, ratio=0.01)
graph_collection_dataset = load_graph_collection_from_pickle_file(filename) 
graph_train_dataset = GraphDataset(graph_collection_dataset)
dataloader = GraphDataLoader(graph_collection_dataset, batch_size=32)



In [40]:
class GCN_model(nn.Module):

    def __init__(self, num_layer, num_in_features=1, num_out_features=1):
        super(GCN_model, self).__init__()
        self.fc = nn.Linear(in_features=num_in_features, out_features=num_out_features)
    
    def forward(self, list_of_graph):
        # unpack data
        nodes_feat_list, edges_index_list, graph_label_list = list_of_graph

        # D ** (-1 / 2) of shape (num_of_nodes) only for one matrix as all graph got the same structure
        graph_degrees = self.compute_degree(edges_index_list, num_of_nodes=nodes_feat_list.shape[1])

        # expand degree matrix to perform Hadamard product
        graph_degrees = graph_degrees.unsqueeze(-1)
        # Hadamard product h_l * (D ** (-1 / 2))
        nodes_feat_list = nodes_feat_list * graph_degrees

        # aggregation 
        nodes_features_aggregated = self.aggregate_neighbors(nodes_feat_list, edges_index_list[0,:,:])

        # Second Hadamard product
        nodes_features_aggregated = nodes_features_aggregated * graph_degrees

        # message passing
        nodes_features_output = self.fc(nodes_features_aggregated)

        return nodes_features_output
        
    
    def aggregate_neighbors(self, nodes_feature_tens, edges_index):
        """ 
            nodes_feature_tens: tensor of shape (batch_size, num_nodes, num_features=1)
            edges_index: tensor of shape (2, num_edges)
        """
        num_of_nodes = nodes_feature_tens.shape[1]
        squeeze_flag = False
        
        if nodes_feature_tens.shape[-1] == 1:
            nodes_feature_tens = nodes_feature_tens.squeeze(-1)
            squeeze_flag = True
        
        # expand nodes_feature_tens to be able to use scatter_add_ function
        # (batch_size, num_of_nodes) --> (batch_size, num_of_edges)
        expanded_nodes_feature_tens = nodes_feature_tens[:,edges_index[1,:]]
        
        # nodes on which to aggregate
        target_index = edges_index[0,:].unsqueeze(0)
        
        # shape (batch_size, num_of_edges)
        nodes_feature_tens_output = torch.zeros_like(expanded_nodes_feature_tens, dtype=nodes_feature_tens.dtype)
        
        # for all node i, sum all j from i's neighborhood (in place)
        nodes_feature_tens_output.scatter_add_(dim=1, index=target_index, src=expanded_nodes_feature_tens)
        # crop from shape (batch_size, num_of_edges) to (batch_size, num_of_nodes)
        nodes_feature_tens_output = nodes_feature_tens_output[:, 0:num_of_nodes]
        
        if squeeze_flag:
            nodes_feature_tens_output = nodes_feature_tens_output.unsqueeze(-1)

        return nodes_feature_tens_output
    

    def compute_degree(self, edges_index, num_of_nodes):
        """
            Compute the degree tensor 
            /!\ doesn't fit unconnected graph
            edges_index: tensor of shape (batch_size, 2, number of edges) 
        """
        _, degree = torch.unique(edges_index[0,0,:], return_counts=True)
        assert degree.shape[0] == num_of_nodes, f'Expected degree matrix with shape=({num_of_nodes}) got {degree.shape}.'
        return degree ** (-1 / 2)



In [41]:
model_test = GCN_model(1,1,1)

In [42]:
data = next(iter(dataloader))
test = model_test(data)
print(test.shape)

torch.Size([32, 784, 1])


In [None]:
# torch.random.manual_seed(123)

# x = torch.randint(10, (1,10))

# y, count = torch.unique(x, return_counts=True)

# print(x)
# print(count ** (-1 / 2))
# print(count)

In [None]:
# torch.random.manual_seed(40)

# N = 3
# B = 1
# E = 7

# h_l1 = torch.zeros((B, E, 1))

# ed_idx = torch.tensor([[0, 0, 0, 1, 1, 2, 2],
#                        [0, 1, 2, 0, 1, 0, 2]])

# h_l = torch.tensor([[[10], 
#                      [20], 
#                      [30]]], dtype=h_l1.dtype)

# print("h_l", h_l.shape)
# print("h_l1", h_l1.shape)

# expanded_h_l = h_l[:,ed_idx[1,:],:]
# print("expanded_h_l", expanded_h_l.shape)

# index = ed_idx[0,:].unsqueeze(0).unsqueeze(-1)
# print("index", index.shape)
# print(index)

# print(h_l1)
# print(expanded_h_l)

# h_l1.scatter_add_(1, index=index, src=expanded_h_l)
# h_l1 = h_l1[:,0:N,:]
# print(h_l1)

In [None]:
torch.random.manual_seed(42)

test = torch.randint(10, (2,3,1))

truc = torch.tensor([[3],
                     [1],
                     [2]])
truc = truc
print(truc.shape)

torch.Size([3, 1])


In [None]:
print(test)
print(truc.shape)
print(truc)
print(test * truc)

tensor([[[2],
         [7],
         [6]],

        [[4],
         [6],
         [5]]])
torch.Size([3, 1])
tensor([[3],
        [1],
        [2]])
tensor([[[ 6],
         [ 7],
         [12]],

        [[12],
         [ 6],
         [10]]])
