In [97]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [136]:
from torch_geometric.data import Data
from torch_geometric.nn import GATConv # official GAT implementation in PyG
from torch_geometric.datasets import Planetoid 
import torch_geometric.transforms as T
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt 

name_data = 'Cora'
dataset = Planetoid(root='data', name=name_data)
dataset.transform = T.NormalizeFeatures()

print(f"Number of Classes in {name_data}:", dataset.num_classes)
print(f"Number of Node Features in {name_data}:", dataset.num_node_features)

Number of Classes in Cora: 7
Number of Node Features in Cora: 1433


In [137]:
from models.gmodel import GAT, VGAT


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GAT(
    in_features=dataset.num_features,
    hidden_features=64,
    n_heads=8,
    num_classes=dataset.num_classes,
    dropout=0.6,
    leaky_relu=0.2,
).to(device)

data = dataset[0].to(device)
print(data)
A = 1.0 - torch.sparse.LongTensor(
    data.edge_index, # where to put
    torch.ones(data.edge_index.shape[1]).to(device),
    torch.Size((data.x.shape[0], data.x.shape[0])),
).to_dense()
A[A.bool()] = float("-Inf")
print(torch.all(A == A.T))
A = A.unsqueeze(0)
print((A != 0).float().mean())
print(A.shape)

optimizer = torch.optim.AdamQ(model.parameters(), lr=5e-3, weight_decay=5e-4)

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
tensor(True)
tensor(0.9986)
torch.Size([1, 2708, 2708])


In [138]:
A

tensor([[[-inf, -inf, -inf,  ..., -inf, -inf, -inf],
         [-inf, -inf, 0.,  ..., -inf, -inf, -inf],
         [-inf, 0., -inf,  ..., -inf, -inf, -inf],
         ...,
         [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
         [-inf, -inf, -inf,  ..., -inf, -inf, 0.],
         [-inf, -inf, -inf,  ..., -inf, 0., -inf]]])

In [142]:
def train(model, optimizer, data, A):
    model.train()
    logits = model(data.x, A)
    loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()


def test(model, data, A):
    accs = []

    model.eval()
    with torch.no_grad():
        logits = model(data.x, A)

    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].argmax(1)
        acc = pred.eq(data.y[mask]).float().mean()
        accs.append(acc)
    
    return accs

In [145]:
LOGGING_FORMAT = 'Epoch: {}, train_loss: {}, val_acc: {:.4f}, test_acc: {:.4f}'

def train_loop(model, optimizer, data, A, epochs):

    best_val_acc = test_acc = 0

    for epoch in range(1, epochs):
        loss = train(model, optimizer, data, A)

        _, val_acc, curr_test_acc = test(model, data, A)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc = curr_test_acc
        
        if epoch % 10 == 0:
            print(LOGGING_FORMAT.format(epoch, loss, best_val_acc, test_acc))


In [146]:
train_loop(model, optimizer, data, A, epochs=1000)

Epoch: 10, train_loss: 1.9150930643081665, val_acc: 0.5080, test_acc: 0.5270
Epoch: 20, train_loss: 1.878787636756897, val_acc: 0.5620, test_acc: 0.5760
Epoch: 30, train_loss: 1.8617841005325317, val_acc: 0.5780, test_acc: 0.5760
Epoch: 40, train_loss: 1.801270842552185, val_acc: 0.5880, test_acc: 0.5850
Epoch: 50, train_loss: 1.7311216592788696, val_acc: 0.5880, test_acc: 0.5850
Epoch: 60, train_loss: 1.6968421936035156, val_acc: 0.5880, test_acc: 0.5850
Epoch: 70, train_loss: 1.5875798463821411, val_acc: 0.5940, test_acc: 0.5880
Epoch: 80, train_loss: 1.513552188873291, val_acc: 0.5960, test_acc: 0.5990
Epoch: 90, train_loss: 1.4785003662109375, val_acc: 0.6060, test_acc: 0.6080
Epoch: 100, train_loss: 1.419276475906372, val_acc: 0.6080, test_acc: 0.6100
Epoch: 110, train_loss: 1.3837158679962158, val_acc: 0.6080, test_acc: 0.6100
Epoch: 120, train_loss: 1.2157044410705566, val_acc: 0.6100, test_acc: 0.6080
Epoch: 130, train_loss: 1.1617473363876343, val_acc: 0.6100, test_acc: 0.6080

KeyboardInterrupt: 