<a href="https://colab.research.google.com/github/kunwarAbhay/NLP/blob/main/GNN/GNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.2.1-py3-none-any.whl (806 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m806.1/806.1 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.10.0 torchmetrics-1.2.1


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchmetrics.classification import MulticlassAccuracy

import networkx as nx
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [3]:
edges = [[2, 1],
        [3, 1], [3, 2],
        [4, 1], [4, 2], [4, 3],
        [5, 1],
        [6, 1],
        [7, 1], [7, 5], [7, 6],
        [8, 1], [8, 2], [8, 3], [8, 4],
        [9, 1], [9, 3],
        [10, 3],
        [11, 1], [11, 5], [11, 6],
        [12, 1],
        [13, 1], [13, 4],
        [14, 1], [14, 2], [14, 3], [14, 4],
        [17, 6], [17, 7],
        [18, 1], [18, 2],
        [20, 1], [20, 2],
        [22, 1], [22, 2],
        [26, 24], [26, 25],
        [28, 3], [28, 24], [28, 25],
        [29, 3],
        [30, 24], [30, 27],
        [31, 2], [31, 9],
        [32, 1], [32, 25], [32, 26], [32, 29],
        [33, 3], [33, 9], [33, 15], [33, 16], [33, 19], [33, 21], [33, 23], [33, 24], [33, 30], [33, 31], [33, 32],
        [34, 9], [34, 10], [34, 14], [34, 15], [34, 16], [34, 19], [34, 20], [34, 21], [34, 23], [34, 24], [34, 27], [34, 28], [34, 29], [34, 30], [34, 31], [34, 32], [34, 33]]

classes = [1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1, 0, 0, 2, 2, 0, 0, 2, 0, 0, 2, 0, 0]

In [4]:
num_nodes = len(classes)
num_edges = len(edges)
num_classes = 4

num_nodes, num_edges, num_classes

(34, 78, 4)

In [5]:
torch.manual_seed(1234)

<torch._C.Generator at 0x7a4ebbf02710>

In [6]:
class GCN(nn.Module):
  def __init__(self, input_dim, output_dim):
    super().__init__()
    self.W1 = nn.Linear(input_dim, 4, bias=False)
    self.W2 = nn.Linear(4, 4, bias=False)
    self.W3 = nn.Linear(4, 2, bias=False)
    self.tanh = nn.Tanh()
    self.classifier = nn.Linear(2, output_dim)

  def forward(self, adjacency_matrix, node_features):
    output_features = self.W1(torch.matmul(adjacency_matrix, node_features))
    output_features = self.tanh(output_features)

    output_features = self.W2(torch.matmul(adjacency_matrix, output_features))
    output_features = self.tanh(output_features)

    output_features = self.W3(torch.matmul(adjacency_matrix, output_features))
    output_features = self.tanh(output_features)

    output = self.classifier(output_features)

    return output

In [7]:
model = GCN(num_nodes, num_classes)

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.01)
accuracy_fn = MulticlassAccuracy(num_classes=num_classes)

In [9]:
A = torch.eye(num_nodes) # Adjacency Matrix

for [src, tgt] in edges:
  src, tgt = src - 1, tgt - 1 # convert to 0-based indexes
  A[src][tgt] = 1
  A[tgt][src] = 1

X = A # Feature Matrix | Here Feature Matrix is same as Adjacency Matrix

y = torch.tensor(classes) # Actual Node Labels

In [10]:
mask = torch.zeros(num_nodes)

visited = set()

for i, c in enumerate(classes):
  if c in visited:
    mask[i] = 0
  else:
    visited.add(c)
    mask[i] = 1

mask

tensor([1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [11]:
# Training
epochs = 1200

for epoch in tqdm(range(epochs)):
  model.train()

  # Forward Pass
  y_logits = model(A, X)

  y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1)

  # Calculate the loss and acccuracy
  y_logits_masked = [y_logits[i] for i in range(len(mask)) if mask[i] == 1]
  y_masked = [y[i] for i in range(len(mask)) if mask[i] == 1]

  y_logits_masked = torch.stack(y_logits_masked)
  y_masked = torch.tensor(y_masked)

  loss = criterion(y_logits_masked,y_masked)

  accuracy = accuracy_fn(y_pred, y)

  # Update the parameters
  optimizer.zero_grad()

  loss.backward()

  optimizer.step()

  if epoch % 100 == 0:
    print(f"Epoch: {epoch} | Loss: {loss} | Accuracy: {accuracy}")

  0%|          | 0/1200 [00:00<?, ?it/s]

Epoch: 0 | Loss: 1.3931266069412231 | Accuracy: 0.25
Epoch: 100 | Loss: 0.813456118106842 | Accuracy: 0.4314102530479431
Epoch: 200 | Loss: 0.4334444999694824 | Accuracy: 0.5189102292060852
Epoch: 300 | Loss: 0.38085147738456726 | Accuracy: 0.5605769157409668
Epoch: 400 | Loss: 0.3711123466491699 | Accuracy: 0.6137820482254028
Epoch: 500 | Loss: 0.36276090145111084 | Accuracy: 0.6137820482254028
Epoch: 600 | Loss: 0.3584613502025604 | Accuracy: 0.6762820482254028
Epoch: 700 | Loss: 0.3557548522949219 | Accuracy: 0.6762820482254028
Epoch: 800 | Loss: 0.3530580401420593 | Accuracy: 0.5384615659713745
Epoch: 900 | Loss: 0.04132305458188057 | Accuracy: 0.7804487347602844
Epoch: 1000 | Loss: 0.018633658066391945 | Accuracy: 0.7804487347602844
Epoch: 1100 | Loss: 0.011970417574048042 | Accuracy: 0.7804487347602844
