In [2]:
import torch
from utilities.helpers import *
from algorithms.gat import GAT
from algorithms.graphsage import GraphSAGE
from algorithms.gcn import GCN

In [3]:
device = get_device()
data = load_cora(device)

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [4]:
%%capture captured_output


input_dim = data.x.shape[1]
num_classes = len(torch.unique(data.y))

accuracies = []
for num_epochs in [100*(i+1) for i in range(10)]:
    model = GCN(input_dim, 64, num_classes, 2).to(device)
    model = train(data, model, num_epochs)
    accuracy = evaluate(model, data)
    accuracies.append(accuracy)

In [5]:
accuracies

[0.711, 0.78, 0.816, 0.821, 0.817, 0.814, 0.815, 0.823, 0.812, 0.816]

In [6]:
%%capture captured_output


input_dim = data.x.shape[1]
num_classes = len(torch.unique(data.y))

accuracies = []
for num_epochs in [100*(i+1) for i in range(10)]:
    model = GraphSAGE(input_dim, 64, num_classes, 2, 'mean').to(device)
    model = train(data, model, num_epochs)
    accuracy = evaluate(model, data)
    accuracies.append(accuracy)

In [7]:
accuracies

[0.565, 0.779, 0.809, 0.797, 0.814, 0.804, 0.802, 0.797, 0.811, 0.798]

In [8]:
%%capture captured_output


input_dim = data.x.shape[1]
num_classes = len(torch.unique(data.y))

accuracies = []
for num_epochs in [100*(i+1) for i in range(10)]:
    model = GAT(input_dim, 8, num_classes, 8, alpha=0.2).to(device)
    model = train(data, model, num_epochs)
    accuracy = evaluate(model, data)
    accuracies.append(accuracy)

In [9]:
accuracies

[0.789, 0.794, 0.808, 0.809, 0.813, 0.818, 0.814, 0.824, 0.816, 0.821]