<a href="https://colab.research.google.com/github/duynht/Greedy_InfoMax/blob/master/Gradient_isolated_GNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# README
This notebook houses preliminary experiments for applying gradient-isolated training to Graph Neural Networks

# Graph package installation 

In [0]:
!pip install torch-scatter
!pip install torch-sparse
!pip install torch-cluster
!pip install torch-geometric





In [0]:
!pip install tensorboardX



# Import

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

import networkx as nx
import numpy as np
import torch
import torch.optim as optim

from torch_geometric.datasets import TUDataset
from torch_geometric.datasets import Planetoid
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import DataLoader
from torch_geometric.data import Data

import torch_geometric.transforms as T

from tensorboardX import SummaryWriter
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

import os.path as osp
from PIL import Image

# Create a Custom Graph Dataset from STL10
Each image is divided into a $2\times2$ grid. Each patch of the grid is a node of the graph.


In [0]:
class GraphSTL10(InMemoryDataset):
  def __init__(self, root, split):
    self.split = split
    if (split == 'train'):
      self.dataset = datasets.STL10(root='/tmp/stl10_train', split='train', download=True)
    if (split == 'test'):
      self.dataset = datasets.STL10(root='/tmp/stl10_test', split='test', download=True)
    super(GraphSTL10, self).__init__(root)
    # self.num_classes = 10
    self.data, self.slices = torch.load(self.processed_paths[0])

  @property
  def raw_file_names(self):
    return []

  @property
  def processed_file_names(self):
    if (self.split == 'train'):
      return ['graphstl10_train.pt']
    if (self.split == 'test'):
      return ['graphstl10_test.pt']
    
    return []

  def download(self):
    pass

  def process(self):
    def crop(image,pc_height,pc_width):
      im_width, im_height = image.size
      for i in range(im_height//pc_height):
        for j in range(im_width//pc_width):
          box = (j*pc_width, i*pc_height, (j+1)*pc_width, (i+1)*pc_height)
          yield image.crop(box)

    data_list = []

    preprocess = transforms.Compose([transforms.Resize(256),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    resnet50 = models.resnet50(pretrained=True)
    resnet50.fc = nn.Identity()

    for param in resnet50.parameters():
      param.requires_grad = False

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    resnet50.to(device)

    source_nodes = [i for i in range(0,4) for j in range(0,4)]
    target_nodes = [j for i in range(0,4) for j in range(0,4)]
    edge_index = torch.tensor([source_nodes, target_nodes],dtype = torch.long)

    for imid, (image, label) in enumerate(self.dataset):
      # neighbors = np.arange(imid, imid+4)
      # mask = [1]*len(neighbors)
      im_height, im_width = image.size
      pc_height, pc_width = im_height//2, im_width//2
      node_features = []
      for pid,piece in enumerate(crop(image, pc_height, pc_width)):
        patch = Image.new('RGB', (pc_width, pc_height), 255)
        patch.paste(piece)
        patch = preprocess(patch)
        patch = patch.view(1,*patch.shape)
        patch = patch.to(device)
        patch = resnet50.forward(patch).to(torch.device('cpu'))
        node_features.append(torch.tensor(patch))

      x = torch.cat(node_features)

      y = torch.tensor(label).unsqueeze(0)

      data = Data(x=x, edge_index=edge_index.clone(), y=y)

      data_list.append(data)

    data, slices = self.collate(data_list)
    torch.save((data, slices), self.processed_paths[0])
  

# Define custom MessagePassing [WIP]


In [0]:
class CustomConv(pyg_nn.MessagePassing):
  def __init__(self, in_channels, out_channels):
    super(CustomConv, self).__init__(aggr='add') #'add' aggregation
    self.lin = nn.Linear(in_channels, out_channels)
    self.lin_self = nn.Linear(in_channels, out_channels)

  def forward(self, x, edge_index):
    # x has shape [N, in_channels]
    # edge_index has shape [2, E]

    # Transform node feature matrix
    self_x = self.lin_self(x)
    # x = self.lin(x)

    return self_x + self.propagate(edge_index, size=(x.size(0), x.size(0)), x=self.lin(x))
  
  def message(self, x_i, x_j, edge_index, size):
    # Compute messages
    # x_j has shape [E, out_channels]
    # TODO:
    row, col = edge_index
    deg = pyg_utils.degree(row, size[0], dtype=x_j.dtype)
    deg_inv_sqrt = deg.pow(-0.5)
    norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

    return x_j

  def update(self, aggr_out):
    # aggr_out has shape [N, out_channels]
    return aggr_out

# Define the Graph Neural Network

In [0]:
class VisionGNN(nn.Module):
  def __init__(self, input_dim, hidden_dim, output_dim):
    super(VisionGNN, self).__init__()
        
    self.dropout = 0.25
    self.num_layers = 2
    self.hidden = [input_dim, 512, hidden_dim]
    # self.resnet = models.resnet50(pretrained=True)
    # self.resnet.fc = nn.Identity()

    self.convs = nn.ModuleList()
    self.lns = nn.ModuleList()

    for l in range(self.num_layers):
      self.convs.append(self.build_conv_model(self.hidden[l], self.hidden[l+1]))
      if (l + 1 < self.num_layers):
        self.lns.append(nn.LayerNorm(self.hidden[l+1]))

    # post-message-passing
    self.post_mp = nn.Sequential(
        nn.Linear(hidden_dim, hidden_dim), nn.Dropout(0.25),
        nn.Linear(hidden_dim, output_dim))

  def build_conv_model(self, input_dim, hidden_dim):
      # return CustomConv(input_dim, hidden_dim)
      return pyg_nn.GCNConv(input_dim, hidden_dim)
      
  def forward(self, data):
    x, edge_index = data.x, data.edge_index
    if data.num_node_features == 0:
      x = torch.ones(data.num_nodes, 1)

    for i in range(self.num_layers):
      x = self.convs[i](x, edge_index)
      emb = x
      x = F.relu(x)
      x = F.dropout(x, p=self.dropout, training = self.training)
      if not i == self.num_layers - 1:
        x = self.lns[i](x)

    # x = pyg_nn.global_mean_pool(x, data.batch)
    x = pyg_nn.global_add_pool(x, data.batch)

    x = self.post_mp(x)

    return emb, F.log_softmax(x, dim=1)

  def loss(self, pred, label):
    # Negative log-likelihood
    return F.nll_loss(pred, label)

In [0]:
stl10_train = GraphSTL10('/graphstl10/',split='train')
stl10_test = GraphSTL10('/graphstl10/',split='test')

Files already downloaded and verified
Files already downloaded and verified


In [0]:
# train
trainloader = DataLoader(stl10_train, batch_size=64, shuffle=True)
testloader = DataLoader(stl10_test, batch_size=64, shuffle=True)

model = VisionGNN(input_dim=stl10_train.num_node_features, hidden_dim=64, output_dim = 10)
optimizer = optim.Adam(params=model.parameters(), lr=0.003)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

epochs = 200
for e in range(epochs):
  running_loss = 0
  model.train()
  for batch in trainloader:
    optimizer.zero_grad()
    batch = batch.to(device)
    emb, logits = model(batch)
    labels = batch.y
    loss = model.loss(logits, labels)
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()

  running_loss /= len(trainloader.dataset)

# test
  if (e+1) % 10 == 0:
    accuracy = 0
    model.eval()
    for batch in testloader:
      with torch.no_grad():
        batch = batch.to(device)
        emb, logits = model(batch)
        pred = logits.argmax(dim=1)
        labels = batch.y
      
      accuracy += pred.eq(labels).sum().item()

    accuracy /= len(testloader.dataset)

    print("Epoch {}/{}. Loss: {:.4f}. Test accuracy: {:.4f}".format(e+1, epochs, running_loss, accuracy))
      


Epoch 10/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 20/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 30/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 40/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 50/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 60/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 70/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 80/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 90/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 100/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 110/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 120/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 130/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 140/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 150/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 160/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 170/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 180/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 190/200. Loss: 0.0364. Test accuracy: 0.1000
Epoch 200/200. Loss: 0.0364. Test accura