Let's try some variations of the model from the paper. First we'll experiment with GAT models with a single convolutional layer.

In [7]:
import torch
import torch_geometric

In [9]:
class SingleLayerGAT(torch.nn.Module):
    def __init__(self, n_features, n_classes):
        super().__init__()
        self.conv = torch_geometric.nn.GATConv(heads=8, out_channels=n_classes, in_channels=n_features, dropout=.6)
        self.act = torch.nn.Softmax(dim=1)
        
    def forward(self, x, edge_index):
        return self.act(self.conv(x, edge_index))

In [13]:
for dataset in ['citeseer', 'cora']:
    data = torch_geometric.datasets.Planetoid(root=f'../data/{dataset}', name=dataset)
    model = SingleLayerGAT(data.num_features, data.num_classes)
    optimizer = torch.optim.Adam(model.parameters(), lr=.005, weight_decay=5e-4)

    for epoch in range(200):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = torch.nn.functional.cross_entropy(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            print(out.argmax(dim=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() / int(data.train_mask.sum()))

    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    acc = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item() / int(data.test_mask.sum())
    print('\n\n*****************************************************************************************************\n')
    print(f'                                         {dataset} ')
    print(f'                                         Total Epochs: 200')
    print(f'                                         Test Accuracy: {acc:.4f}')
    print('\n*****************************************************************************************************\n\n')

0.03333333333333333
0.725
0.7666666666666667
0.775
0.7583333333333333
0.7666666666666667
0.8333333333333334
0.7083333333333334
0.7833333333333333
0.7416666666666667
0.75
0.8083333333333333
0.8583333333333333
0.7583333333333333
0.8
0.7833333333333333
0.825
0.7583333333333333
0.8166666666666667
0.7916666666666666


*****************************************************************************************************

                                         citeseer 
                                         Total Epochs: 200
                                         Test Accuracy: 0.6560

*****************************************************************************************************


0.05714285714285714
0.7142857142857143
0.8285714285714286
0.8214285714285714
0.85
0.7928571428571428
0.8142857142857143
0.8071428571428572
0.7714285714285715
0.8285714285714286
0.8285714285714286
0.8285714285714286
0.8
0.7785714285714286
0.8428571428571429
0.8642857142857143
0.85
0.85
0.8428571428571429

As we can see, this model performs significantly than what we see in the paper (83% and 72.5% accuracy for Cora and Citeseer, respecitvely).

We can experiment some more, though. Let's try adding an ELU before the sotmax in our model.

In [14]:
class SingleLayerGAT(torch.nn.Module):
    def __init__(self, n_features, n_classes):
        super().__init__()
        self.conv = torch_geometric.nn.GATConv(heads=8, out_channels=n_classes, in_channels=n_features, dropout=.6)
        self.act1 = torch.nn.ELU()
        self.act2 = torch.nn.Softmax(dim=1)
        
    def forward(self, x, edge_index):
        return self.act2(self.act1(self.conv(x, edge_index)))

In [15]:
for dataset in ['citeseer', 'cora']:
    data = torch_geometric.datasets.Planetoid(root=f'../data/{dataset}', name=dataset)
    model = SingleLayerGAT(data.num_features, data.num_classes)
    optimizer = torch.optim.Adam(model.parameters(), lr=.005, weight_decay=5e-4)

    for epoch in range(200):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = torch.nn.functional.cross_entropy(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            print(out.argmax(dim=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() / int(data.train_mask.sum()))

    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    acc = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item() / int(data.test_mask.sum())
    print('\n\n*****************************************************************************************************\n')
    print(f'                                         {dataset} ')
    print(f'                                         Total Epochs: 200')
    print(f'                                         Test Accuracy: {acc:.4f}')
    print('\n*****************************************************************************************************\n\n')

0.016666666666666666
0.8
0.8333333333333334
0.8083333333333333
0.7666666666666667
0.8083333333333333
0.7583333333333333
0.7416666666666667
0.7666666666666667
0.7583333333333333
0.7833333333333333
0.75
0.825
0.7583333333333333
0.8416666666666667
0.85
0.725
0.8833333333333333
0.8333333333333334
0.8416666666666667


*****************************************************************************************************

                                         citeseer 
                                         Total Epochs: 200
                                         Test Accuracy: 0.6620

*****************************************************************************************************


0.007142857142857143
0.7857142857142857
0.8214285714285714
0.85
0.8214285714285714
0.8785714285714286
0.8785714285714286
0.8571428571428571
0.85
0.8857142857142857
0.8928571428571429
0.8428571428571429
0.8785714285714286
0.8642857142857143
0.8285714285714286
0.8857142857142857
0.8571428571428571
0.87142

The performance looks roughly the same as the previous model. We can see that in both cases the models are achieving fairly high accuracy on the training data, though it is unclear whether this might be indicative of overfitting without more detailed analysis.

The fact that a single-layer model performs worse than a two-layer model suggests that adding more layers may further boost the performance. Let's give it a shot.

In [18]:
class ThreeLayerGAT(torch.nn.Module):
    def __init__(self, n_features, n_classes):
        super().__init__()
        self.conv1 = torch_geometric.nn.GATConv(heads=8, out_channels=8, in_channels=n_features, dropout=.6)
        self.act1 = torch.nn.ELU()
        self.conv2 = torch_geometric.nn.GATConv(heads=8, out_channels=8, in_channels=64, dropout=.6)
        self.act2 = torch.nn.ELU()
        self.conv3 = torch_geometric.nn.GATConv(heads=1, out_channels=n_classes, in_channels=64, dropout=.6)
        self.act3 = torch.nn.Softmax(dim=1)
        
    def forward(self, x, edge_index):
        return self.act3(self.conv3(self.act2(self.conv2(self.act1(self.conv1(x, edge_index)), edge_index)), edge_index))

In [19]:
for dataset in ['citeseer', 'cora']:
    data = torch_geometric.datasets.Planetoid(root=f'../data/{dataset}', name=dataset)
    model = ThreeLayerGAT(data.num_features, data.num_classes)
    optimizer = torch.optim.Adam(model.parameters(), lr=.005, weight_decay=5e-4)

    for epoch in range(200):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = torch.nn.functional.cross_entropy(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            print(out.argmax(dim=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() / int(data.train_mask.sum()))

    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    acc = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item() / int(data.test_mask.sum())
    print('\n\n*****************************************************************************************************\n')
    print(f'                                         {dataset} ')
    print(f'                                         Total Epochs: 200')
    print(f'                                         Test Accuracy: {acc:.4f}')
    print('\n*****************************************************************************************************\n\n')

0.2
0.7416666666666667
0.775
0.8
0.7916666666666666
0.8333333333333334
0.8166666666666667
0.7916666666666666
0.7833333333333333
0.7916666666666666
0.725
0.8
0.7916666666666666
0.8583333333333333
0.8
0.8083333333333333
0.8333333333333334
0.8333333333333334
0.825
0.8833333333333333


*****************************************************************************************************

                                         citeseer 
                                         Total Epochs: 200
                                         Test Accuracy: 0.6280

*****************************************************************************************************


0.07857142857142857
0.7928571428571428
0.8928571428571429
0.8785714285714286
0.8642857142857143
0.8714285714285714
0.8714285714285714
0.85
0.9285714285714286
0.8357142857142857
0.85
0.9214285714285714
0.8571428571428571
0.8428571428571429
0.8785714285714286
0.8714285714285714
0.8857142857142857
0.8857142857142857
0.9
0.907142857142857