In [19]:
import matplotlib.pyplot as plt
import torch 
import torch.nn as nn
import os
import networkx as nx
import torch_geometric as tg

In [101]:
# import the dataset 
# import the dataset 
dataset = tg.datasets.KarateClub()
data = dataset[0]
print('Kc: ', data)


Kc:  Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])


In [121]:
# construct the model
import torch
import torch.nn.functional as F
from torch_sparse import spspmm

from torch_geometric.nn import GCNConv, TopKPooling
from torch_geometric.utils import (
    add_self_loops,
    remove_self_loops,
    sort_edge_index,
)
from torch_geometric.utils.repeat import repeat

class GraphUNet(torch.nn.Module):
    r"""The Graph U-Net model from the `"Graph U-Nets"
    <https://arxiv.org/abs/1905.05178>`_ paper which implements a U-Net like
    architecture with graph pooling and unpooling operations.

    Args:
        in_channels (int): Size of each input sample.
        hidden_channels (int): Size of each hidden sample.
        out_channels (int): Size of each output sample.
        depth (int): The depth of the U-Net architecture.
        pool_ratios (float or [float], optional): Graph pooling ratio for each
            depth. (default: :obj:`0.5`)
        sum_res (bool, optional): If set to :obj:`False`, will use
            concatenation for integration of skip connections instead
            summation. (default: :obj:`True`)
        act (torch.nn.functional, optional): The nonlinearity to use.
            (default: :obj:`torch.nn.functional.relu`)
    """
    def __init__(self, in_channels, hidden_channels, out_channels, depth,
                 pool_ratios=0.5, sum_res=True, act=F.relu):
        super().__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = repeat(pool_ratios, depth)
        self.act = act
        self.sum_res = sum_res

        channels = hidden_channels

        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(GCNConv(in_channels, channels, improved=True))
        for i in range(depth):
            self.pools.append(TopKPooling(channels, self.pool_ratios[i]))
            self.down_convs.append(GCNConv(channels, channels, improved=True))

        in_channels = channels if sum_res else 2 * channels

        self.up_convs = torch.nn.ModuleList()
        for i in range(depth - 1):
            self.up_convs.append(GCNConv(in_channels, channels, improved=True))
        self.up_convs.append(GCNConv(in_channels, out_channels, improved=True))

        self.reset_parameters()

    def reset_parameters(self):
        for conv in self.down_convs:
            conv.reset_parameters()
        for pool in self.pools:
            pool.reset_parameters()
        for conv in self.up_convs:
            conv.reset_parameters()

    def forward(self, x, edge_index, batch=None):
                
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        edge_weight = x.new_ones(edge_index.size(1))

        x = self.down_convs[0](x, edge_index, edge_weight)
        print('x after down-0: ', x.shape)
        x = self.act(x)

        xs = [x]
        edge_indices = [edge_index]
        edge_weights = [edge_weight]
        perms = []

        for i in range(1, self.depth + 1):
            edge_index, edge_weight = self.augment_adj(edge_index, edge_weight,
                                                       x.size(0))
            x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1](
                x, edge_index, edge_weight, batch)
            print(f'x after pool-{i-1}: ', x.shape)

            x = self.down_convs[i](x, edge_index, edge_weight)
            print(f'x after down-{i}: ', x.shape)
            x = self.act(x)
            print('final x...')
            print(x)
            if i < self.depth:
                print(f'saving x at {i}')
                xs += [x]
                edge_indices += [edge_index]
                edge_weights += [edge_weight]
            print(f'saving perm at {i-1}')
            perms += [perm]
    
        print('perms len and x len: ', len(perms), len(xs))
        print(' ========== '*3)
        for i in range(self.depth):
            j = self.depth - 1 - i

            res = xs[j]
            
            print(f'using x and perm of {j}')
            print('res...')
            print(res)
            edge_index = edge_indices[j]
            edge_weight = edge_weights[j]
            perm = perms[j]
            up = torch.zeros_like(res)
            print('up shape: ', up.shape)

            up[perm] = x
            print('up perm...')
            print(up[perm])
            print('perm shape:', perm.shape)
            print('permed up: ', up[perm].shape)
            print('up shape now:  ', up.shape)
            print(up)
            x = res + up if self.sum_res else torch.cat((res, up), dim=-1)
            print('res+up shape: ', x.shape)
            print(x)
            x = self.up_convs[i](x, edge_index, edge_weight)
            print(f'x after up-{i}: ', x.shape)
            x = self.act(x) if i < self.depth - 1 else x

        return x


    def augment_adj(self, edge_index, edge_weight, num_nodes):
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                 num_nodes=num_nodes)
        edge_index, edge_weight = sort_edge_index(edge_index, edge_weight,
                                                  num_nodes)
        edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index,
                                         edge_weight, num_nodes, num_nodes,
                                         num_nodes)
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        return edge_index, edge_weight


    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.hidden_channels}, {self.out_channels}, '
                f'depth={self.depth}, pool_ratios={self.pool_ratios})')


In [122]:
model = GraphUNet(data.num_features,4 , dataset.num_classes, 4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001)

In [123]:

def init_weights(m):
    '''
    initializing the model wghts with values
    drawn from normal distribution.
    else initialize them with 0.
    '''
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)
print(model.apply(init_weights))

GraphUNet(34, 4, 4, depth=4, pool_ratios=[0.5, 0.5, 0.5, 0.5])


In [128]:
x = model(data.x, data.edge_index)
x = F.log_softmax(x, dim=1)
loss = F.nll_loss(x[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step


x after down-0:  torch.Size([34, 4])
x after pool-0:  torch.Size([17, 4])
x after down-1:  torch.Size([17, 4])
final x...
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 1.3707e-06],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 6.4371e-07],
        [0.0000e+00, 0.0000e+00, 1.5181e-07, 3.6597e-07],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 5.1888e-07],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 4.9667e-07],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 9.3511e-07],
        [0.0000e+00, 0.0000e+00, 4.5011e-08, 7.5241e-07],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 8.3768e-07],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 8.3460e-07],
        [9.1069e-08, 0.0000e+00, 1.2364e-07, 3.5987e-07],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 7.7538e-07],
        [5.6722e-08, 0.0000e+00, 2.4201e-07, 3.6246e-07],
        [0.0000e+00, 0.0000e+00, 2.1261e-08, 6.9524e-07],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 8.3472e-07],
        [0.0000e+00, 0.0000e+00, 6.1661e-08, 6.8947e-07],
        

<bound method Adam.step of Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.01
    maximize: False
    weight_decay: 0.001
)>

In [32]:
for epoch in range(1, 20):
    model.eval()
    accs= []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = x[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    print(accs)


[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
[0.10714285714285714, 0.116, 0.102]
