<a href="https://colab.research.google.com/github/makhalid1999/LightGAT/blob/main/LightGAT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch_geometric

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch_geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: torch_geometric
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone
  Created wheel for torch_geometric: filename=torch_geometric-2.3.1-py3-none-any.whl size=910459 sha256=f55b5827aed73179855c61785d9d1248c7649aa2516e80f2261224c21b6ba340
  Stored in directory: /root/.cache/pip/wheels/ac/dc/30/e2874821ff308ee67dcd7a66dbde912411e19e35a1addda028
Successfully built torch_geometric
Installing collected packages: torch_geometric
Successfully installed torch_geomet

In [None]:
from torch_geometric.datasets import KarateClub

dataset = KarateClub()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

Dataset: KarateClub():
Number of graphs: 1
Number of features: 34
Number of classes: 4


In [None]:
import torch.nn as nn
import torch
import torch.nn.parameter as Parameter
import torch.nn.functional as F

class LightGAT_Layer(nn.Module):
  def __init__(self, num_nodes, num_features):
    self.num_nodes = num_nodes
    self.num_features = num_features
    super(LightGAT_Layer, self).__init__()
    self.a = Parameter.Parameter(torch.randn(1, num_features*2))
    self.a.requires_grad = True
  def forward(self, x, adj):
    attention = torch.zeros(self.num_nodes, self.num_nodes)
    for i in range(self.num_nodes):
      for j in range(self.num_nodes):
        if adj[i, j] != 0:
          b = torch.cat((x[i], x[j]))
          attention[i, j] = torch.mm(self.a, torch.transpose(b.unsqueeze(0), 1, 0))
    e = nn.functional.softmax(attention, dim = 1)
    e_adj = e * adj
    return torch.sparse.mm(e_adj, x)

class LightGAT_Layer2(nn.Module):
  def __init__(self, num_nodes, num_features):
    super(LightGAT_Layer2, self).__init__()
    self.num_nodes = num_nodes
    self.num_features = num_features
    self.a1 = Parameter.Parameter(torch.randn(num_features, num_features))
    self.a1.requires_grad = True
    self.a2 = Parameter.Parameter(torch.randn(num_features, num_features))
    self.a2.requires_grad = True
  def forward(self, x, adj):
    a1_x = torch.mm(self.a1, x)
    a2_x = torch.mm(self.a2, x)
    attention = torch.mm(a1_x, torch.transpose(a2_x, 1, 0))
    e = nn.functional.softmax(attention, dim = 1)
    e_adj = e * adj
    return torch.sparse.mm(e_adj, x)

class LightGCN_Layer(nn.Module):
  def __init__(self, num_nodes):
    super(LightGCN_Layer, self).__init__()
  def forward(self, x, adj):
    return torch.sparse.mm(adj, x)

class GCN_Layer(nn.Module):
  def __init__(self, num_nodes):
    super(GCN_Layer, self).__init__()
    self.W = Parameter.Parameter(torch.eye(num_nodes, num_nodes))
    self.W.requires_grad = True
  def forward(self, x, adj):
    x1 = torch.mm(x, self.W)
    return torch.sparse.mm(adj, x1)

class GAT_Layer(nn.Module):
  def __init__(self, num_nodes, num_features):
    super(GAT_Layer, self).__init__()
    self.num_nodes = num_nodes
    self.num_features = num_features
    self.a = Parameter.Parameter(torch.randn(1, num_features*2))
    self.a.requires_grad = True
    self.W = Parameter.Parameter(torch.eye(num_nodes, num_nodes))
    self.W.requires_grad = True  
  def forward(self, x, adj):
    x1 = torch.mm(x, self.W)    
    attention = torch.zeros(self.num_nodes, self.num_nodes)
    for i in range(self.num_nodes):
      for j in range(self.num_nodes):
        if adj[i, j] != 0:
          b = torch.cat((x1[i], x1[j]))
          attention[i, j] = torch.mm(self.a, torch.transpose(b.unsqueeze(0), 1, 0))
    e = nn.functional.softmax(attention, dim = 1)
    e_adj = e * adj
    return torch.sparse.mm(e_adj, x1) 

class GAT_Layer2(nn.Module):
  def __init__(self, num_nodes, num_features):
    super(GAT_Layer2, self).__init__()
    self.num_nodes = num_nodes
    self.num_features = num_features
    self.a1 = Parameter.Parameter(torch.randn(num_features, num_features))
    self.a1.requires_grad = True
    self.a2 = Parameter.Parameter(torch.randn(num_features, num_features))
    self.a2.requires_grad = True
    self.W = Parameter.Parameter(torch.eye(num_nodes, num_nodes))
    self.W.requires_grad = True  
  def forward(self, x, adj):
    a1_x = torch.mm(self.a1, x)
    a2_x = torch.mm(self.a2, x)
    attention = torch.mm(a1_x, torch.transpose(a2_x, 1, 0))
    e = nn.functional.softmax(attention, dim = 1)
    e_adj = e * adj
    x1 = torch.mm(x, self.W)
    return torch.sparse.mm(e_adj, x1)

In [None]:
class LightGAT(nn.Module):
  def __init__(self):
    super(LightGAT, self).__init__()
    self.L1 = LightGAT_Layer(34, 34)
    self.L2 = LightGAT_Layer(34, 34)
    self.classifier = nn.Linear(34, 4)
    self.x = Parameter.Parameter(torch.eye(34, 34))
    self.x.requires_grad = True
  def forward(self, adj):
    h1 = self.L1(self.x, adj)
    h2 = self.L2(h1, adj)
    out = F.softmax(self.classifier(h2), dim = 1)
    return out

class LightGAT2(nn.Module):
  def __init__(self):
    super(LightGAT2, self).__init__()
    self.L1 = LightGAT_Layer2(34, 34)
    self.L2 = LightGAT_Layer2(34, 34)
    self.classifier = nn.Linear(34, 4)
    self.x = Parameter.Parameter(torch.eye(34, 34))
    self.x.requires_grad = True
  def forward(self, adj):
    h1 = self.L1(self.x, adj)
    h2 = self.L2(h1, adj)
    out = F.softmax(self.classifier(h2), dim = 1)
    return out

class LightGCN(nn.Module):
  def __init__(self):
    super(LightGCN, self).__init__()
    self.L1 = LightGCN_Layer(34)
    self.L2 = LightGCN_Layer(34)
    self.classifier = nn.Linear(34, 4)
    self.x = Parameter.Parameter(torch.eye(34, 34))
    self.x.requires_grad = True
  def forward(self, adj):
    h1 = self.L1(self.x, adj)
    h2 = self.L2(h1, adj)
    out = F.softmax(self.classifier(h2), dim = 1)
    return out

class GCN(nn.Module):
  def __init__(self):
    super(GCN, self).__init__()
    self.L1 = GCN_Layer(34)
    self.L2 = GCN_Layer(34)
    self.classifier = nn.Linear(34, 4)
    self.x = Parameter.Parameter(torch.eye(34, 34))
    self.x.requires_grad = True
  def forward(self, adj):
    h1 = self.L1(self.x, adj)
    h1.tanh()
    h2 = self.L2(h1, adj)
    h2.tanh()
    out = F.softmax(self.classifier(h2), dim = 1)
    return out

class GAT(nn.Module):
  def __init__(self):
    super(GAT, self).__init__()
    self.L1 = GAT_Layer(34, 34)
    self.L2 = GAT_Layer(34, 34)
    self.classifier = nn.Linear(34, 4)
    self.x = Parameter.Parameter(torch.eye(34, 34))
    self.x.requires_grad = True
  def forward(self, adj):
    h1 = self.L1(self.x, adj)
    h1.tanh()
    h2 = self.L2(h1, adj)
    h2.tanh()
    out = F.softmax(self.classifier(h2), dim = 1)
    return out

class GAT2(nn.Module):
  def __init__(self):
    super(GAT2, self).__init__()
    self.L1 = GAT_Layer2(34, 34)
    self.L2 = GAT_Layer2(34, 34)
    self.classifier = nn.Linear(34, 4)
    self.x = Parameter.Parameter(torch.eye(34, 34))
    self.x.requires_grad = True
  def forward(self, adj):
    h1 = self.L1(self.x, adj)
    h1.tanh()
    h2 = self.L2(h1, adj)
    h2.tanh()
    out = F.softmax(self.classifier(h2), dim = 1)
    return out

In [None]:
import time

data = dataset[0]

def adj(data):
  a = torch.zeros(34, 34)
  for i in data.edge_index.T:
    a[i.tolist()[0],i.tolist()[1]] = 1
  return a

model = LightGAT()
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.

def train(data):
    optimizer.zero_grad()  # Clear gradients.
    adj_m = adj(data)/torch.transpose(adj(data).sum(dim = 0).unsqueeze(0), 1, 0)
    out = model(adj_m)  # Perform a single forward pass.
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.

    accuracy = {}
    # Calculate training accuracy on our four examples
    predicted_classes = torch.argmax(out[data.train_mask], axis=1) # [0.6, 0.2, 0.7, 0.1] -> 2
    target_classes = data.y[data.train_mask]
    accuracy['train'] = torch.mean(
        torch.where(predicted_classes == target_classes, 1, 0).float())
    
    # Calculate validation accuracy on the whole graph
    predicted_classes = torch.argmax(out, axis=1)
    target_classes = data.y
    accuracy['val'] = torch.mean(
        torch.where(predicted_classes == target_classes, 1, 0).float())

    return loss, accuracy

for epoch in range(100):
    start = time.time()
    loss, accuracy = train(data)
    print('Loss: ' + str(loss.item()) + ' Training Accuracy: ' + str(accuracy['train'].item()) + ' Validation Accuracy: ' + str(accuracy['val'].item()))

print('Total Time: ' + str(time.time() - start))

Loss: 1.386607050895691 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.3865458965301514 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.3864940404891968 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.3864490985870361 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.386408805847168 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.3863710165023804 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.3863340616226196 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.3862966299057007 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.3862571716308594 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.386213779449463 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.386163592338562 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.386103

In [None]:
import time

data = dataset[0]

def adj(data):
  a = torch.zeros(34, 34)
  for i in data.edge_index.T:
    a[i.tolist()[0],i.tolist()[1]] = 1
  return a

model = LightGAT2()
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.

def train(data):
    optimizer.zero_grad()  # Clear gradients.
    adj_m = adj(data)/torch.transpose(adj(data).sum(dim = 0).unsqueeze(0), 1, 0)
    out = model(adj_m)  # Perform a single forward pass.
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.

    accuracy = {}
    # Calculate training accuracy on our four examples
    predicted_classes = torch.argmax(out[data.train_mask], axis=1) # [0.6, 0.2, 0.7, 0.1] -> 2
    target_classes = data.y[data.train_mask]
    accuracy['train'] = torch.mean(
        torch.where(predicted_classes == target_classes, 1, 0).float())
    
    # Calculate validation accuracy on the whole graph
    predicted_classes = torch.argmax(out, axis=1)
    target_classes = data.y
    accuracy['val'] = torch.mean(
        torch.where(predicted_classes == target_classes, 1, 0).float())

    return loss, accuracy

for epoch in range(100):
    start = time.time()
    loss, accuracy = train(data)
    print('Loss: ' + str(loss.item()) + ' Training Accuracy: ' + str(accuracy['train'].item()) + ' Validation Accuracy: ' + str(accuracy['val'].item()))

print('Total Time: ' + str(time.time() - start))

Loss: 1.3864359855651855 Training Accuracy: 0.25 Validation Accuracy: 0.38235294818878174
Loss: 1.386380672454834 Training Accuracy: 0.25 Validation Accuracy: 0.38235294818878174
Loss: 1.3863245248794556 Training Accuracy: 0.25 Validation Accuracy: 0.38235294818878174
Loss: 1.386281967163086 Training Accuracy: 0.25 Validation Accuracy: 0.38235294818878174
Loss: 1.386240005493164 Training Accuracy: 0.25 Validation Accuracy: 0.38235294818878174
Loss: 1.3861724138259888 Training Accuracy: 0.25 Validation Accuracy: 0.38235294818878174
Loss: 1.3860726356506348 Training Accuracy: 0.25 Validation Accuracy: 0.38235294818878174
Loss: 1.3858708143234253 Training Accuracy: 0.25 Validation Accuracy: 0.38235294818878174
Loss: 1.3856756687164307 Training Accuracy: 0.25 Validation Accuracy: 0.4117647111415863
Loss: 1.3854405879974365 Training Accuracy: 0.5 Validation Accuracy: 0.4117647111415863
Loss: 1.385114073753357 Training Accuracy: 0.5 Validation Accuracy: 0.4117647111415863
Loss: 1.38467502593

In [None]:
import time

data = dataset[0]

def adj(data):
  a = torch.zeros(34, 34)
  for i in data.edge_index.T:
    a[i.tolist()[0],i.tolist()[1]] = 1
  return a

model = LightGCN()
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.

def train(data):
    optimizer.zero_grad()  # Clear gradients.    
    adj_m = adj(data)/torch.transpose(adj(data).sum(dim = 0).unsqueeze(0), 1, 0)
    out = model(adj_m)  # Perform a single forward pass.
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.

    accuracy = {}
    # Calculate training accuracy on our four examples
    predicted_classes = torch.argmax(out[data.train_mask], axis=1) # [0.6, 0.2, 0.7, 0.1] -> 2
    target_classes = data.y[data.train_mask]
    accuracy['train'] = torch.mean(
        torch.where(predicted_classes == target_classes, 1, 0).float())
    
    # Calculate validation accuracy on the whole graph
    predicted_classes = torch.argmax(out, axis=1)
    target_classes = data.y
    accuracy['val'] = torch.mean(
        torch.where(predicted_classes == target_classes, 1, 0).float())

    return loss, accuracy

for epoch in range(100):
    start = time.time()
    loss, accuracy = train(data)
    print('Loss: ' + str(loss.item()) + ' Training Accuracy: ' + str(accuracy['train'].item()) + ' Validation Accuracy: ' + str(accuracy['val'].item()))

print('Total Time: ' + str(time.time() - start))

Loss: 1.384352445602417 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.3784607648849487 Training Accuracy: 0.25 Validation Accuracy: 0.14705882966518402
Loss: 1.3724428415298462 Training Accuracy: 0.75 Validation Accuracy: 0.4117647111415863
Loss: 1.3660521507263184 Training Accuracy: 0.75 Validation Accuracy: 0.4117647111415863
Loss: 1.3590610027313232 Training Accuracy: 0.75 Validation Accuracy: 0.3529411852359772
Loss: 1.3513187170028687 Training Accuracy: 0.75 Validation Accuracy: 0.529411792755127
Loss: 1.3427104949951172 Training Accuracy: 1.0 Validation Accuracy: 0.5588235259056091
Loss: 1.3331363201141357 Training Accuracy: 1.0 Validation Accuracy: 0.5882353186607361
Loss: 1.3225035667419434 Training Accuracy: 1.0 Validation Accuracy: 0.5882353186607361
Loss: 1.3107249736785889 Training Accuracy: 1.0 Validation Accuracy: 0.6176470518112183
Loss: 1.2977209091186523 Training Accuracy: 1.0 Validation Accuracy: 0.6176470518112183
Loss: 1.283425211906433 Tr

In [None]:
import time

data = dataset[0]

def adj(data):
  a = torch.zeros(34, 34)
  for i in data.edge_index.T:
    a[i.tolist()[0],i.tolist()[1]] = 1
  return a

model = GCN()
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.03)  # Define optimizer.

def train(data):
    optimizer.zero_grad()  # Clear gradients.
    adj_m = adj(data)/torch.transpose(adj(data).sum(dim = 0).unsqueeze(0), 1, 0)
    out = model(adj_m)  # Perform a single forward pass.
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.

    accuracy = {}
    # Calculate training accuracy on our four examples
    predicted_classes = torch.argmax(out[data.train_mask], axis=1) # [0.6, 0.2, 0.7, 0.1] -> 2
    target_classes = data.y[data.train_mask]
    accuracy['train'] = torch.mean(
        torch.where(predicted_classes == target_classes, 1, 0).float())
    
    # Calculate validation accuracy on the whole graph
    predicted_classes = torch.argmax(out, axis=1)
    target_classes = data.y
    accuracy['val'] = torch.mean(
        torch.where(predicted_classes == target_classes, 1, 0).float())

    return loss, accuracy

for epoch in range(100):
    start = time.time()
    loss, accuracy = train(data)
    print('Loss: ' + str(loss.item()) + ' Training Accuracy: ' + str(accuracy['train'].item()) + ' Validation Accuracy: ' + str(accuracy['val'].item()))

print('Total Time: ' + str(time.time() - start))

Loss: 1.3955817222595215 Training Accuracy: 0.25 Validation Accuracy: 0.29411765933036804
Loss: 1.364160418510437 Training Accuracy: 1.0 Validation Accuracy: 0.7058823704719543
Loss: 1.3065834045410156 Training Accuracy: 1.0 Validation Accuracy: 0.6470588445663452
Loss: 1.1752618551254272 Training Accuracy: 1.0 Validation Accuracy: 0.6470588445663452
Loss: 0.9786765575408936 Training Accuracy: 1.0 Validation Accuracy: 0.6470588445663452
Loss: 0.8519307971000671 Training Accuracy: 1.0 Validation Accuracy: 0.5882353186607361
Loss: 0.7763817310333252 Training Accuracy: 1.0 Validation Accuracy: 0.5882353186607361
Loss: 0.7468534708023071 Training Accuracy: 1.0 Validation Accuracy: 0.529411792755127
Loss: 0.7444478273391724 Training Accuracy: 1.0 Validation Accuracy: 0.5882353186607361
Loss: 0.7440365552902222 Training Accuracy: 1.0 Validation Accuracy: 0.5882353186607361
Loss: 0.7438129782676697 Training Accuracy: 1.0 Validation Accuracy: 0.5882353186607361
Loss: 0.7437161803245544 Trainin

In [None]:
import time

data = dataset[0]

def adj(data):
  a = torch.zeros(34, 34)
  for i in data.edge_index.T:
    a[i.tolist()[0],i.tolist()[1]] = 1
  return a

model = GAT()
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.

def train(data):
    optimizer.zero_grad()  # Clear gradients.
    adj_m = adj(data)/torch.transpose(adj(data).sum(dim = 0).unsqueeze(0), 1, 0)
    out = model(adj_m)  # Perform a single forward pass.  # Perform a single forward pass.
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.

    accuracy = {}
    # Calculate training accuracy on our four examples
    predicted_classes = torch.argmax(out[data.train_mask], axis=1) # [0.6, 0.2, 0.7, 0.1] -> 2
    target_classes = data.y[data.train_mask]
    accuracy['train'] = torch.mean(
        torch.where(predicted_classes == target_classes, 1, 0).float())
    
    # Calculate validation accuracy on the whole graph
    predicted_classes = torch.argmax(out, axis=1)
    target_classes = data.y
    accuracy['val'] = torch.mean(
        torch.where(predicted_classes == target_classes, 1, 0).float())

    return loss, accuracy

for epoch in range(100):
    start = time.time()
    loss, accuracy = train(data)
    print('Loss: ' + str(loss.item()) + ' Training Accuracy: ' + str(accuracy['train'].item()) + ' Validation Accuracy: ' + str(accuracy['val'].item()))

print('Total Time: ' + str(time.time() - start))

Loss: 1.3866838216781616 Training Accuracy: 0.25 Validation Accuracy: 0.14705882966518402
Loss: 1.38661527633667 Training Accuracy: 0.25 Validation Accuracy: 0.14705882966518402
Loss: 1.3865504264831543 Training Accuracy: 0.25 Validation Accuracy: 0.14705882966518402
Loss: 1.3864741325378418 Training Accuracy: 0.25 Validation Accuracy: 0.14705882966518402
Loss: 1.386365294456482 Training Accuracy: 0.25 Validation Accuracy: 0.14705882966518402
Loss: 1.3861886262893677 Training Accuracy: 0.25 Validation Accuracy: 0.14705882966518402
Loss: 1.385871171951294 Training Accuracy: 0.25 Validation Accuracy: 0.14705882966518402
Loss: 1.3852218389511108 Training Accuracy: 0.25 Validation Accuracy: 0.14705882966518402
Loss: 1.3837295770645142 Training Accuracy: 0.5 Validation Accuracy: 0.1764705926179886
Loss: 1.3803622722625732 Training Accuracy: 0.5 Validation Accuracy: 0.23529411852359772
Loss: 1.3739759922027588 Training Accuracy: 0.5 Validation Accuracy: 0.2647058963775635
Loss: 1.36456036567

In [None]:
import time

data = dataset[0]

def adj(data):
  a = torch.zeros(34, 34)
  for i in data.edge_index.T:
    a[i.tolist()[0],i.tolist()[1]] = 1
  return a

model = GAT2()
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.

def train(data):
    optimizer.zero_grad()  # Clear gradients.
    adj_m = adj(data)/torch.transpose(adj(data).sum(dim = 0).unsqueeze(0), 1, 0)
    out = model(adj_m)  # Perform a single forward pass.  # Perform a single forward pass.
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.

    accuracy = {}
    # Calculate training accuracy on our four examples
    predicted_classes = torch.argmax(out[data.train_mask], axis=1) # [0.6, 0.2, 0.7, 0.1] -> 2
    target_classes = data.y[data.train_mask]
    accuracy['train'] = torch.mean(
        torch.where(predicted_classes == target_classes, 1, 0).float())
    
    # Calculate validation accuracy on the whole graph
    predicted_classes = torch.argmax(out, axis=1)
    target_classes = data.y
    accuracy['val'] = torch.mean(
        torch.where(predicted_classes == target_classes, 1, 0).float())

    return loss, accuracy

for epoch in range(100):
    start = time.time()
    loss, accuracy = train(data)
    print('Loss: ' + str(loss.item()) + ' Training Accuracy: ' + str(accuracy['train'].item()) + ' Validation Accuracy: ' + str(accuracy['val'].item()))

print('Total Time: ' + str(time.time() - start))

Loss: 1.3863977193832397 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.3862807750701904 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.3861726522445679 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.3860305547714233 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.3858325481414795 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.3855502605438232 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.3851741552352905 Training Accuracy: 0.25 Validation Accuracy: 0.11764705926179886
Loss: 1.3845819234848022 Training Accuracy: 0.25 Validation Accuracy: 0.1764705926179886
Loss: 1.3827365636825562 Training Accuracy: 0.75 Validation Accuracy: 0.5
Loss: 1.3800532817840576 Training Accuracy: 0.5 Validation Accuracy: 0.23529411852359772
Loss: 1.376124382019043 Training Accuracy: 0.5 Validation Accuracy: 0.20588235557079315
Loss: 1.3694018125534058 Train