# Demo on the Cora dataset

This notebook demonstrates how to run GAtt on the Cora dataset.

In [1]:
import os

# Set proper working directory
os.chdir("/workspace/")

## Data preparation

In [2]:
import torch
from torch_geometric.datasets import Planetoid


dataset_name = "Cora"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = Planetoid("/workspace/", dataset_name)[0].to(device)

## Model preparation

Model definition

In [3]:
import torch.nn.functional as F
from torch_geometric.nn import GATConv


class GAT_L2(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6)
        self.conv2 = GATConv(
            hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6
        )

    def forward(self, x, edge_index, return_att=False):
        if return_att:
            x, att1 = self.conv1(x, edge_index, return_attention_weights=return_att)
            x = F.elu(x)
            x, att2 = self.conv2(x, edge_index, return_attention_weights=return_att)
            self.att = [att1, att2]
        else:
            x = self.conv1(x, edge_index)
            x = F.elu(x)
            x = self.conv2(x, edge_index)
        return x


class GAT_L3(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6)
        self.conv2 = GATConv(
            hidden_channels * heads, hidden_channels, heads=1, concat=False, dropout=0.6
        )
        self.conv3 = GATConv(
            hidden_channels, out_channels, heads=1, concat=False, dropout=0.6
        )

    def forward(self, x, edge_index, return_att=False):
        if return_att:
            x, att1 = self.conv1(x, edge_index, return_attention_weights=return_att)
            x = F.elu(x)
            x, att2 = self.conv2(x, edge_index, return_attention_weights=return_att)
            x = F.elu(x)
            x, att3 = self.conv3(x, edge_index, return_attention_weights=return_att)
            self.att = [att1, att2, att3]
        else:
            x = self.conv1(x, edge_index)
            x = F.elu(x)
            x = self.conv2(x, edge_index)
            x = F.elu(x)
            x = self.conv3(x, edge_index)
        return x

Basic training code

In [4]:
from torch_geometric.logging import log

hidden_channels = 32
num_classes = data.y.max().item() + 1
heads = 4

model_l2 = GAT_L2(
    in_channels=data.num_features,
    hidden_channels=hidden_channels,
    out_channels=num_classes,
    heads=heads,
).to(device)
optimizer_l2 = torch.optim.Adam(model_l2.parameters(), lr=0.005)

model_l3 = GAT_L3(
    in_channels=data.num_features,
    hidden_channels=hidden_channels,
    out_channels=num_classes,
    heads=heads,
).to(device)
optimizer_l3 = torch.optim.Adam(model_l3.parameters(), lr=0.001)


def train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test(model, data):
    model.eval()
    pred = model(data.x, data.edge_index).argmax(dim=-1)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs


# Training for GAT_L2
print(f"Training for {dataset_name} with GAT_L2...")
best_val_acc = final_test_acc = 0
for epoch in range(1, 501):
    loss = train(model=model_l2, optimizer=optimizer_l2, data=data)
    train_acc, val_acc, tmp_test_acc = test(model=model_l2, data=data)
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
        model_state = model_l2.state_dict()
    if epoch % 100 == 0:
        log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)

# Get the best model state and test acc
model_l2.load_state_dict(model_state)
test_acc = test(model=model_l2, data=data)[-1]
print(f"Test Accuracy of GAT_L2: {test_acc:.4f}\n")

# Training for GAT_L3
print(f"Training for {dataset_name} with GAT_L3...")
best_val_acc = final_test_acc = 0
for epoch in range(1, 501):
    loss = train(model=model_l3, optimizer=optimizer_l3, data=data)
    train_acc, val_acc, tmp_test_acc = test(model=model_l3, data=data)
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
        model_state = model_l3.state_dict()
    if epoch % 100 == 0:
        log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)

# Get the best model state and test acc
model_l3.load_state_dict(model_state)
test_acc = test(model=model_l3, data=data)[-1]
print(f"Test Accuracy of GAT_L3: {test_acc:.4f}")

Training for Cora with GAT_L2...
Epoch: 100, Loss: 0.2789, Train: 1.0000, Val: 0.7580, Test: 0.7820
Epoch: 200, Loss: 0.2786, Train: 1.0000, Val: 0.7680, Test: 0.7820
Epoch: 300, Loss: 0.3018, Train: 1.0000, Val: 0.7720, Test: 0.7820
Epoch: 400, Loss: 0.2309, Train: 1.0000, Val: 0.7640, Test: 0.7820
Epoch: 500, Loss: 0.2242, Train: 1.0000, Val: 0.7520, Test: 0.7820
Test Accuracy of GAT_L2: 0.7850

Training for Cora with GAT_L3...
Epoch: 100, Loss: 0.4205, Train: 1.0000, Val: 0.7580, Test: 0.7880
Epoch: 200, Loss: 0.5022, Train: 1.0000, Val: 0.7500, Test: 0.7880
Epoch: 300, Loss: 0.3063, Train: 1.0000, Val: 0.7440, Test: 0.7880
Epoch: 400, Loss: 0.3592, Train: 1.0000, Val: 0.7200, Test: 0.7880
Epoch: 500, Loss: 0.3644, Train: 1.0000, Val: 0.7320, Test: 0.7880
Test Accuracy of GAT_L3: 0.7650


### Acquiring GAtt scores

Here, we will get the edge attribution scores using GAtt. Importing the necessary libraries:

In [5]:
from gatt import get_gatt

### GAtt calculation

We will calculate the edge attribution scores using GAtt. `get_gatt` returns the GAtt scores (e.g., $\phi_{i,j}^v$ in the paper) values for all edges within $L$-hops of the target node $v$, where $L$ is the number of layers in the GAT model.

In [6]:
target_node = 1201

gatt_val_l2, edge_index_l2 = get_gatt(
    target_node=target_node, model=model_l2, data=data, sparse=True
)
print(f"GAtt values for GAT_L2 (showing only the first 5): {gatt_val_l2[:5]}")

GAtt values for GAT_L2 (showing only the first 5): [0.04767515882849693, 0.057931121438741684, 0.0, 0.0, 0.295834481716156]


In [7]:
target_node = 1201

gatt_val_l3, edge_index_l3 = get_gatt(
    target_node=target_node, model=model_l3, data=data, sparse=True
)
print(f"GAtt values for GAT_L3 (showing only the first 5): {gatt_val_l3[:5]}")

GAtt values for GAT_L3 (showing only the first 5): [0.07037428021430969, 0.0, 0.0, 0.0, 0.0968344509601593]
