In [92]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [93]:
from torch_geometric.datasets import Planetoid 
import torch_geometric.transforms as T
import torch
import torch.nn.functional as F
import torch.distributions as D

import matplotlib.pyplot as plt

name_data = 'Cora'
dataset = Planetoid(root='data', name=name_data)
dataset.transform = T.NormalizeFeatures()

print(f"Number of Classes in {name_data}:", dataset.num_classes)
print(f"Number of Node Features in {name_data}:", dataset.num_node_features)

Number of Classes in Cora: 7
Number of Node Features in Cora: 1433


In [94]:
from models.gmodel import VGAT, GATELBO

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# I put the most basic one, make sure to adjust this
prior_distribution = D.Normal(
    loc=torch.zeros(1).to(device),
    scale=torch.tensor(1).to(device),
)

model = VGAT(
    in_features=dataset.num_features,
    hidden_features=64,
    n_heads=8,
    prior_distribution=prior_distribution,
    num_classes=dataset.num_classes,
    leaky_relu=0.2,
).to(device)
print(f"# parameters: {sum(p.nelement() for p in model.parameters())}")

data = dataset[0].to(device)
print(data)
A = 1.0 - torch.sparse.LongTensor(
    data.edge_index, # where to put
    torch.ones(data.edge_index.shape[1]).to(device),
    torch.Size((data.x.shape[0], data.x.shape[0])),
).to_dense()
A[A.bool()] = float("-Inf")
print(torch.all(A == A.T))
A = A.unsqueeze(0)
print((A != 0).float().mean())
print(A.shape)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)

# parameters: 184604
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
tensor(True)
tensor(0.9986)
torch.Size([1, 2708, 2708])


In [95]:
A

tensor([[[-inf, -inf, -inf,  ..., -inf, -inf, -inf],
         [-inf, -inf, 0.,  ..., -inf, -inf, -inf],
         [-inf, 0., -inf,  ..., -inf, -inf, -inf],
         ...,
         [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
         [-inf, -inf, -inf,  ..., -inf, -inf, 0.],
         [-inf, -inf, -inf,  ..., -inf, 0., -inf]]])

In [96]:
criterion = GATELBO()

In [97]:
def train(model, optimizer, data, A, n_samples=1):
    model.train()
    logits, kl_divergence = model(data.x, A, n_samples=n_samples)
    elbo, nll = criterion(
        logits[data.train_mask], data.y[data.train_mask], kl_divergence
    )
    
    optimizer.zero_grad()
    elbo.backward()
    optimizer.step()
    
    return elbo.item(), (elbo - nll).item(), nll.item()


def test(model, data, A, n_samples=1):
    accs = []

    model.eval()
    with torch.no_grad():
        logits, _ = model(data.x, A, n_samples=n_samples)

    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = F.softmax(logits[mask], dim=-1).mean(dim=1).argmax(dim=1)
        acc = pred.eq(data.y[mask]).float().mean()
        accs.append(acc)
    
    return accs

In [98]:
LOGGING_FORMAT = "Epoch: {}, train_elbo: {}, train_kl: {}, train_nll: {}, train_acc: {:.4f}, val_acc: {:.4f}, test_acc: {:.4f}"

def train_loop(model, optimizer, data, A, epochs, n_samples=1):

    best_val_acc = test_acc = 0

    for epoch in range(1, epochs):
        elbo, kl, nll = train(model, optimizer, data, A, n_samples=n_samples)

        train_acc, val_acc, curr_test_acc = test(model, data, A, n_samples=n_samples)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc = curr_test_acc
        
        if epoch % 10 == 0:
            print(LOGGING_FORMAT.format(epoch, elbo, kl, nll, train_acc, best_val_acc, test_acc))


In [99]:
train_loop(model, optimizer, data, A, epochs=1000, n_samples=2)

Epoch: 10, train_elbo: 254978.8125, train_kl: 254708.171875, train_nll: 270.6416015625, train_acc: 0.5500, val_acc: 0.4680, test_acc: 0.4740
Epoch: 20, train_elbo: 250359.78125, train_kl: 250089.578125, train_nll: 270.20660400390625, train_acc: 0.7143, val_acc: 0.4840, test_acc: 0.4780
Epoch: 30, train_elbo: 245888.078125, train_kl: 245618.09375, train_nll: 269.982421875, train_acc: 0.6071, val_acc: 0.5360, test_acc: 0.5280


KeyboardInterrupt: 