In [1]:
# Implementing same but with list

In [2]:
import torch
import torch.nn as nn
from torchinfo import summary

import torch_geometric
from torch_geometric.data import Data, Dataset, DataLoader

from scipy.spatial.distance import cdist
import networkx as nx

import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl

from dataset import *

In [3]:
# Dataset
train_set = CustomizableMNIST(root='./data', train=True, download=True)

val_set_ratio = 0.2
shuffle = True
batch_size = 32

train_loader, valid_loader = split_and_shuffle_data(train_set, val_set_ratio, batch_size)

Initializing CustomizableMNIST...
Training set
Init done.



In [4]:
random_image, target = train_set.get_item_numpy(0)
adj_mat = compute_adj_mat(random_image)
norm_adj_mat = norm_adjacency(adj_mat)

#print(random_image.squeeze(2).shape)
#plt.imshow(random_image.squeeze(2))
#plt.show()

In [5]:
# converting to graph
graph = nx.from_numpy_array(norm_adj_mat, create_using=nx.DiGraph)
print(graph.edges)

[(0, 0), (0, 1), (0, 28), (0, 29), (1, 0), (1, 1), (1, 2), (1, 28), (1, 29), (1, 30), (2, 1), (2, 2), (2, 3), (2, 29), (2, 30), (2, 31), (3, 2), (3, 3), (3, 4), (3, 30), (3, 31), (3, 32), (4, 3), (4, 4), (4, 5), (4, 31), (4, 32), (4, 33), (5, 4), (5, 5), (5, 6), (5, 32), (5, 33), (5, 34), (6, 5), (6, 6), (6, 7), (6, 33), (6, 34), (6, 35), (7, 6), (7, 7), (7, 8), (7, 34), (7, 35), (7, 36), (8, 7), (8, 8), (8, 9), (8, 35), (8, 36), (8, 37), (9, 8), (9, 9), (9, 10), (9, 36), (9, 37), (9, 38), (10, 9), (10, 10), (10, 11), (10, 37), (10, 38), (10, 39), (11, 10), (11, 11), (11, 12), (11, 38), (11, 39), (11, 40), (12, 11), (12, 12), (12, 13), (12, 39), (12, 40), (12, 41), (13, 12), (13, 13), (13, 14), (13, 40), (13, 41), (13, 42), (14, 13), (14, 14), (14, 15), (14, 41), (14, 42), (14, 43), (15, 14), (15, 15), (15, 16), (15, 42), (15, 43), (15, 44), (16, 15), (16, 16), (16, 17), (16, 43), (16, 44), (16, 45), (17, 16), (17, 17), (17, 18), (17, 44), (17, 45), (17, 46), (18, 17), (18, 18), (18, 1

In [6]:
def create_nodes_feat_from_image(image):
    """
        image: numpy array
    """
    flattened_image = image.reshape(image.shape[0] * image.shape[1])
    list_flattened_image = list(flattened_image)
    return list_flattened_image

def fill_graph_nodes_feat_list(graph, nodes_feat_list):
    for i, node_feat in enumerate(nodes_feat_list):
        graph.nodes[i]['pix_int'] = node_feat

In [7]:
nodes_feat_list = create_nodes_feat_from_image(random_image)
print(len(nodes_feat_list))
print(torch.tensor(nodes_feat_list).unsqueeze(1).shape)
fill_graph_nodes_feat_list(graph, nodes_feat_list)
print(graph.nodes[0]['pix_int'])
print(torch.tensor(list(graph.edges)).permute(1,0))
print(len(list(graph.edges)))

784
torch.Size([784, 1])
-0.42421296
tensor([[  0,   0,   0,  ..., 783, 783, 783],
        [  0,   1,  28,  ..., 755, 782, 783]])
6724


In [8]:
class GraphData_fromMNIST(Data):

    def __init__(self, image, label):
        super().__init__()

        adj_mat = self.compute_adjacency_mat(image)
        norm_adj_mat = self.norm_adjacency(adj_mat)
        graph = nx.from_numpy_array(norm_adj_mat, create_using=nx.DiGraph)

        # feature of node shape (N, 1)
        nodes_feature = create_nodes_feat_from_image(image)
        self.nodes_feature = torch.tensor(nodes_feature).unsqueeze(1)
        # tensor of size (2, num_edges)
        self.edges_index = torch.tensor(list(graph.edges)).permute(1,0)
        # label int
        self.graph_label = label

    def compute_adjacency_mat(self, image):
        """
            image: numpy array
        """
        col, row = np.meshgrid(np.arange(image.shape[0]), np.arange(image.shape[1]))
        coord = np.stack((col, row), axis=2).reshape(-1, 2)
        dist = cdist(coord, coord)
        adj = ( dist <= np.sqrt(2) ).astype(float)
        return adj

    def norm_adjacency(self, adj):
        """
            adj: numpy array
        """
        deg = np.diag(np.sum(adj, axis=0))
        deg_inv_1_2 = np.linalg.inv(deg) ** (1/2)
        return deg_inv_1_2 @ adj @ deg_inv_1_2
    
    def create_nodes_feat_list_from_image(image):
        """
            image: numpy array
        """
        flattened_image = image.reshape(image.shape[0] * image.shape[1])
        list_flattened_image = list(flattened_image)
        return list_flattened_image

    def fill_graph_nodes_feat_list(graph, nodes_feat_list):
        """ 
            graph: networkx graph type
            nodes_feat_list: list
        """
        for i, node_feat in enumerate(nodes_feat_list):
            graph.nodes[i]['pix_int'] = node_feat


class GraphDataset(Dataset):

    def __init__(self, graph_collection):
        super().__init__()

        nodes_feat_list = []
        edges_index_list = []
        graph_label_list = []
        
        for graph in graph_collection:
            nodes_feat_list.append(graph.nodes_feature)
            edges_index_list.append(graph.edges_index)
            graph_label_list.append(graph.label)

        self.nodes_feature_list = nodes_feat_list
        self.edges_index_list = edges_index_list

    def __len__(self):
        return len(nodes_feat_list)
    
    def __getitem__(self, index):
        return self.nodes_feature_list[index], self.edges_index_list[index]

class GraphDataLoader(DataLoader):
    
    def __init__(self, graph_collection, batch_size=1, shuffle=False):
        graph_dataset = GraphDataset(graph_collection)

        super().__init__(graph_dataset, batch_size, shuffle)
        

In [9]:
test, label = train_set.get_item_numpy(0)
test_data = GraphData_fromMNIST(test.squeeze(), label)

In [10]:
print(test_data.edges_index.shape)

torch.Size([2, 6724])


In [11]:
graph_collection = [GraphData_fromMNIST(image.squeeze(), label) for image, label in train_set]


In [None]:
graph_collection = []
for image, label in train_set:
    graph_collection.append(GraphData_fromMNIST(image.squeeze(), label))

with open('graph_collection_dataset', 'wb') as f:
    pkl.dump(graph_collection, f)

In [None]:
graph_train_dataset = GraphDataset(graph_collection)

In [None]:
class GCN_model(nn.Module):

    def __init__(self):
        super(GCN_model, self).__init__()
        return NotImplementedError
    
    def forward(self, data):
        return NotImplementedError