In [10]:
import os
import torch
torch.manual_seed(7)

from sklearn.metrics import accuracy_score
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import pickle
import requests
import os
import torch

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

from torch_geometric.nn import GATConv
from torch_geometric.data import Dataset



  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for torch-sparse (setup.py) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [2]:
# Deffining the dataset

class HW3Dataset(Dataset):
    url = 'https://technionmail-my.sharepoint.com/:u:/g/personal/ploznik_campus_technion_ac_il/EUHUDSoVnitIrEA6ALsAK1QBpphP5jX3OmGyZAgnbUFo0A?download=1'

    def __init__(self, root, transform=None, pre_transform=None):
        super(HW3Dataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return ['data.pt']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        file_url = self.url.replace(' ', '%20')
        response = requests.get(file_url)

        if response.status_code != 200:
            raise Exception(f"Failed to download the file, status code: {response.status_code}")

        with open(os.path.join(self.raw_dir, self.raw_file_names[0]), 'wb') as f:
            f.write(response.content)

    def process(self):
        raw_path = os.path.join(self.raw_dir, self.raw_file_names[0])
        data = torch.load(raw_path)
        torch.save(data, self.processed_paths[0])

    def len(self):
        return 1

    def get(self, idx):
        return torch.load(self.processed_paths[0])


dataset = HW3Dataset(root='data/hw3/')
data = dataset[0]


Processing...
Done!


In [6]:
class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, heads):
        super().__init__()
        self.conv1 = GATConv(dataset.num_features,int(hidden_channels/2))
        self.conv2 = GATConv(int(hidden_channels/2),hidden_channels)
        self.conv3 = GATConv(hidden_channels,dataset.num_classes)
        # adding batch norm layers to generalize the data
        self.bn1 = torch.nn.BatchNorm1d(int(hidden_channels/2))
        self.bn2 = torch.nn.BatchNorm1d(int(hidden_channels))



    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.tanh(x)
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.bn2(x)

        x = F.tanh(x)
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv3(x, edge_index)
        return x

model = GAT(hidden_channels=128, heads=128)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
# splitting the test into val and test.
val_mask, test_mask = train_test_split(data.val_mask, test_size=0.5, random_state=7)
our_loss = torch.nn.CrossEntropyLoss()

def train(mask):
      model.train()
      optimizer.zero_grad()
      out = model(data.x, data.edge_index)
      # Calculation of the loss
      loss = our_loss(out[mask], data.y[mask].view(-1))
      loss.backward()
      optimizer.step()
      return loss

def test(mask):
      model.eval()
      out = model(data.x, data.edge_index)
      # Prefomring a maximum to find the predicted category
      pred = out.argmax(dim=1)
      return accuracy_score(data.y[mask],pred[mask])

In [7]:
# stop the training if didnt see significant learning for a while
best_acc_val = 0.0
best_model = None
patience = 100
counter = 0

for epoch in range(1, 501):
    loss = train(data.train_mask)
    val_acc = test(val_mask)

    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Validation: {val_acc:.4f}')
    if val_acc > best_acc_val:
        best_acc_val = val_acc
        best_model = model
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print("Stopped training. There was no improvement for {} epochs.".format(patience))
            break

# Loading the best model over the training process
with open('final_model.pkl', 'wb') as f:
    pickle.dump(best_model, f)
print(best_acc_val)

Epoch: 001, Loss: 3.9230, Validation: 0.0743
Epoch: 002, Loss: 3.5000, Validation: 0.2822
Epoch: 003, Loss: 3.1421, Validation: 0.3390
Epoch: 004, Loss: 2.8539, Validation: 0.3522
Epoch: 005, Loss: 2.6542, Validation: 0.3606
Epoch: 006, Loss: 2.5241, Validation: 0.3691
Epoch: 007, Loss: 2.4331, Validation: 0.3796
Epoch: 008, Loss: 2.3366, Validation: 0.3905
Epoch: 009, Loss: 2.2575, Validation: 0.3983
Epoch: 010, Loss: 2.2053, Validation: 0.4105
Epoch: 011, Loss: 2.1580, Validation: 0.4216
Epoch: 012, Loss: 2.1280, Validation: 0.4357
Epoch: 013, Loss: 2.0967, Validation: 0.4482
Epoch: 014, Loss: 2.0708, Validation: 0.4616
Epoch: 015, Loss: 2.0359, Validation: 0.4703
Epoch: 016, Loss: 2.0103, Validation: 0.4776
Epoch: 017, Loss: 1.9806, Validation: 0.4841
Epoch: 018, Loss: 1.9682, Validation: 0.4905
Epoch: 019, Loss: 1.9452, Validation: 0.4988
Epoch: 020, Loss: 1.9368, Validation: 0.5059
Epoch: 021, Loss: 1.9259, Validation: 0.5108
Epoch: 022, Loss: 1.9040, Validation: 0.5143
Epoch: 023

In [8]:
test_acc = test(test_mask)
test_acc

0.5864