In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.nn import GCNConv, TopKPooling
import matplotlib.pyplot as plt
import networkx as nx

In [4]:
# import the dataset 
dataset = torch_geometric.datasets.Planetoid('Desktop/research/ARETE', "Cora")
data = dataset[0]
print('Kc: ', data)

Kc:  Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TopKPooling

class GraphUNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GraphUNet, self).__init__()

        # define parameters
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.relu = nn.ReLU()

        self.down_steps = nn.ModuleList()
        self.pools = nn.ModuleList()
        self.up_steps = nn.ModuleList()

#         n_hidden = [1600, 400, 100, 16]
        n_hidden = [16, 16, 16, 16]
        self.hidden_channels = n_hidden

        C_down_in = self.in_channels

        # first three GCNConv and pooling in down step
        for h in n_hidden[:-1]: 
            self.down_steps.append(GCNConv(C_down_in, h, improved=True, cached=True))
            self.pools.append(TopKPooling(h, ratio=0.5))
            C_down_in = h

        # bottle neck
        self.bottleneck = GCNConv(n_hidden[-2], n_hidden[-1], improved=True, cached=True)

        # up_steps
        C_up_in = n_hidden[-2]
        for h in reversed(n_hidden[:-1]):
            self.up_steps.append(GCNConv(C_up_in, h, improved=True, cached=True))
            C_up_in = h

        # final step
        self.final_layer = GCNConv(n_hidden[0], self.out_channels, improved=True, cached=True)


    def forward(self, x, edge_index, noise, batch=None):

        # create edge index matrix from adj matrix
#         edge_index = create_edge_index(adj)
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        # edge_weights: lets keep them 1 initially
        edge_weight = x.new_ones(edge_index.shape[1])

        # encoding i.e. down (gPool+GCNConv) steps
        skip_x = []
        skip_edge_index = []
        skip_edge_weight = []
        skip_perms = []

        for down in range(len(self.down_steps)):
            # down GCNs: in_channels ->1600->400->100
            x = self.down_steps[down](x, edge_index, edge_weight)
            skip_x.append(self.relu(x))
            print(f'down GCN {down}:', x.shape)
            print(f'saving x-{down}')
            
            # down pooling: N ->N/2->N/4->N/8
            x, edge_index, edge_weight, _, perm, _ = self.pools[down](
                                                    x, edge_index, edge_weight)
            print(f'down pool {down}:', x.shape)
            
            skip_edge_index.append(edge_index)
            skip_edge_weight.append(edge_weight)
            skip_perms.append(perm)
            print(f'saving pern -{down}')

        # final GCNConv, x: (N/8, 16)
        x = self.bottleneck(x, edge_index, edge_weight)
        print(f'fianl GCN:', x.shape)
            

        # adding the noise of shape (1,16)
        x = x+noise  # x: (N/8, 16)
        print(f'x+noise:', x.shape)

        for up in range(len(self.up_steps)):
            up = -1-up
#             print('up: ', up)
            res_x = skip_x[up]
#             print('res shape:', res_x.shape)
            zero_mat = torch.zeros_like(res_x)
            perm = skip_perms[up]
            edge_index = skip_edge_index[up]
            edge_weight = skip_edge_weight[up]
#             print('zero_mat shape:', zero_mat.shape)
#             print('perm shape:', perm.shape)
            zero_mat[perm] = x
            x = res_x+zero_mat
            x = self.up_steps[up](x, edge_index, edge_weight)
            x = self.relu(x)
            print(f'up GCN {up}:', x.shape)

        x = self.final_layer(x, edge_index, edge_weight)
        print('final shape:', x.shape)
        
        return x

    def create_edge_index(adj):
        # create a zero edge_index matrix
        eimat = torch.zeros(2, adj.shape[0])
        eimat_col = 0
        for row in adj.shape[0]:
            for col in adj.shape[1]:
                if adj[row][col]!=0:
                    eimat[0][eimat_col] = row
                    eimat[1][eimat_col] = col
        return eimat

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(in_channels:{self.in_channels}, '
                f'hidden_channels:{self.hidden_channels}, out_channels:{self.out_channels}, pool_ratios=0.5)')


In [7]:
model = GraphUNet(data.num_features, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001)

x = model(data.x, data.edge_index,torch.zeros(1,16))
x = F.log_softmax(x, dim=1)
loss = F.nll_loss(x[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step


down GCN 0: torch.Size([2708, 16])
saving x-0
down pool 0: torch.Size([1354, 16])
saving pern -0
down GCN 1: torch.Size([1354, 16])
saving x-1
down pool 1: torch.Size([677, 16])
saving pern -1
down GCN 2: torch.Size([677, 16])
saving x-2
down pool 2: torch.Size([339, 16])
saving pern -2
fianl GCN: torch.Size([339, 16])
x+noise: torch.Size([339, 16])
up GCN -1: torch.Size([677, 16])
up GCN -2: torch.Size([1354, 16])
up GCN -3: torch.Size([2708, 16])
final shape: torch.Size([2708, 7])


<bound method Adam.step of Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.01
    maximize: False
    weight_decay: 0.001
)>