In [1]:
import os
import torch
torch.manual_seed(7)

from sklearn.metrics import accuracy_score
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import pickle
import requests
import os
import torch

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

from torch_geometric.nn import GATConv
from torch_geometric.data import Dataset



[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/107.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m102.4/107.6 kB[0m [31m3.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.6/107.6 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.2/209.2 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for torch-sparse (setup.py) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone


In [2]:
# Deffining the dataset

class HW3Dataset(Dataset):
    url = 'https://technionmail-my.sharepoint.com/:u:/g/personal/ploznik_campus_technion_ac_il/EUHUDSoVnitIrEA6ALsAK1QBpphP5jX3OmGyZAgnbUFo0A?download=1'

    def __init__(self, root, transform=None, pre_transform=None):
        super(HW3Dataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return ['data.pt']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        file_url = self.url.replace(' ', '%20')
        response = requests.get(file_url)

        if response.status_code != 200:
            raise Exception(f"Failed to download the file, status code: {response.status_code}")

        with open(os.path.join(self.raw_dir, self.raw_file_names[0]), 'wb') as f:
            f.write(response.content)

    def process(self):
        raw_path = os.path.join(self.raw_dir, self.raw_file_names[0])
        data = torch.load(raw_path)
        torch.save(data, self.processed_paths[0])

    def len(self):
        return 1

    def get(self, idx):
        return torch.load(self.processed_paths[0])


dataset = HW3Dataset(root='data/hw3/')
data = dataset[0]


Processing...
Done!


In [3]:
class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, heads):
        super().__init__()
        self.conv1 = GATConv(dataset.num_features,int(hidden_channels/2))
        self.conv2 = GATConv(int(hidden_channels/2),hidden_channels)
        self.conv3 = GATConv(hidden_channels,dataset.num_classes)
        # adding batch norm layers to generalize the data
        self.bn1 = torch.nn.BatchNorm1d(int(hidden_channels/2))
        self.bn2 = torch.nn.BatchNorm1d(int(hidden_channels))



    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.tanh(x)
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.bn2(x)

        x = F.tanh(x)
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv3(x, edge_index)
        return x

model = GAT(hidden_channels=128, heads=128)
optimizer = torch.optim.Adam(model.parameters(), lr=0.004, weight_decay=5e-4)
# splitting the test into val and test.
val_mask, test_mask = train_test_split(data.val_mask, test_size=0.5, random_state=7)
our_loss = torch.nn.CrossEntropyLoss()

def train(mask):
      model.train()
      optimizer.zero_grad()
      out = model(data.x, data.edge_index)
      # Calculation of the loss
      loss = our_loss(out[mask], data.y[mask].view(-1))
      loss.backward()
      optimizer.step()
      return loss

def test(mask):
      model.eval()
      out = model(data.x, data.edge_index)
      # Prefomring a maximum to find the predicted category
      pred = out.argmax(dim=1)
      return accuracy_score(data.y[mask],pred[mask])

In [4]:
# stop the training if didnt see significant learning for a while
best_acc_val = 0.0
best_model = None
patience = 100
counter = 0

for epoch in range(1, 501):
    loss = train(data.train_mask)
    val_acc = test(val_mask)

    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Validation: {val_acc:.4f}')
    if val_acc > best_acc_val:
        best_acc_val = val_acc
        best_model = model
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print("Stopped training. There was no improvement for {} epochs.".format(patience))
            break

# Loading the best model over the training process
with open('final_model.pkl', 'wb') as f:
    pickle.dump(best_model, f)
print(best_acc_val)

Epoch: 001, Loss: 4.0165, Validation: 0.0680
Epoch: 002, Loss: 3.6903, Validation: 0.1775
Epoch: 003, Loss: 3.3923, Validation: 0.2485
Epoch: 004, Loss: 3.1388, Validation: 0.2868
Epoch: 005, Loss: 2.9157, Validation: 0.3080
Epoch: 006, Loss: 2.7450, Validation: 0.3289
Epoch: 007, Loss: 2.6175, Validation: 0.3500
Epoch: 008, Loss: 2.5136, Validation: 0.3717
Epoch: 009, Loss: 2.4237, Validation: 0.3939
Epoch: 010, Loss: 2.3475, Validation: 0.4177
Epoch: 011, Loss: 2.2787, Validation: 0.4396
Epoch: 012, Loss: 2.2262, Validation: 0.4531
Epoch: 013, Loss: 2.1810, Validation: 0.4622
Epoch: 014, Loss: 2.1441, Validation: 0.4687
Epoch: 015, Loss: 2.1067, Validation: 0.4749
Epoch: 016, Loss: 2.0776, Validation: 0.4821
Epoch: 017, Loss: 2.0530, Validation: 0.4873
Epoch: 018, Loss: 2.0200, Validation: 0.4913
Epoch: 019, Loss: 2.0022, Validation: 0.4978
Epoch: 020, Loss: 1.9881, Validation: 0.5044
Epoch: 021, Loss: 1.9577, Validation: 0.5089
Epoch: 022, Loss: 1.9454, Validation: 0.5095
Epoch: 023

In [5]:
test_acc = test(test_mask)
test_acc

0.5866