## 1. Load infection benchmarks (Only use Infection_50002d_sp.pt)

In [1]:
import torch

data = torch.load('/workspace/Infection_50002d_sp.pt', map_location="cpu")

## 2. Configurations

### 1. Regularizations

Weight decay: 0, 0.0001, 0.0005, 0.001, 0.005, 0.01

Dropout: 0, 0.2, 0.4, 0.6, 0.8

In [2]:
from models import GAT_L2_intervention

In [4]:
# Split the dataset into training and test sets with a 50:50 ratio
# Use the same random seed for reproducibility
# Use scikit-learn's train_test_split function
# Split by defining the node indices
from sklearn.model_selection import train_test_split

# Define the node indices
node_indices = torch.arange(data.num_nodes)
# Split the node indices into training and test sets
train_indices, test_indices = train_test_split(node_indices, test_size=0.5, random_state=42)
# Define the training and test masks
train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
# Set the training and test masks
train_mask[train_indices] = True
test_mask[test_indices] = True
# Set the training and test masks
data.train_mask = train_mask
data.test_mask = test_mask

In [7]:
"""
Define a function that trains the GAT model for a given number of epochs.
The function takes as input the model, the dataset, the optimizer, the number of epochs.
We perform full-batch training.
The function returns the trained model and the training loss, and the training accuracy.
"""
def train(model, data, optimizer, epochs: int):
    # Set the model to training mode
    model.train()
    # Define the criterion
    criterion = torch.nn.CrossEntropyLoss()
    # Full batch training
    for _ in range(epochs):
        # Zero out the gradients
        optimizer.zero_grad()
        # Perform the forward pass
        out = model(data.x, data.edge_index)
        # Compute the loss
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        # Compute the accuracy
        acc = (out[data.train_mask].argmax(dim=1) == data.y[data.train_mask]).sum().item() / data.train_mask.sum().item()
        # Perform the backward pass
        loss.backward()
        # Perform the optimization step
        optimizer.step()
    return model, loss, acc

"""
Make a function that tests the GAT model.
Use the test mask to only compute the loss and accuracy for the test nodes.
The function takes as input the model and the dataset.
"""
@torch.no_grad()
def test(model, data):
    # Set the model to evaluation mode
    model.eval()
    # Define the test loss
    loss = 0
    # Define the test accuracy
    acc = 0
    # Define the criterion
    criterion = torch.nn.CrossEntropyLoss()
    # Get the test data
    x, edge_index, y = data.x, data.edge_index, data.y
    # Get the output of the model
    out = model(x, edge_index)
    # Compute the loss
    loss = criterion(out[data.test_mask], y[data.test_mask])
    # Compute the test accuracy
    acc = (out[data.test_mask].argmax(dim=1) == y[data.test_mask]).sum().item() / y[data.test_mask].size(0)
    # Return the test loss and the test accuracy
    return loss, acc

Weight decay

In [11]:
def train_model_weight_decay(data, weight_decay):
    out_channels = data.y.max().item() + 1
    model = GAT_L2_intervention(in_channels=2, hidden_channels=8, out_channels=out_channels, heads=1)
    # Define the number of epochs
    epochs = 500
    # Define the learning rate
    lr = 0.005
    # Prepare the optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    # Train the models
    model, loss, acc = train(model=model, data=data, optimizer=optimizer, epochs=epochs)
    # Test the models
    test_loss, test_acc = test(model=model, data=data)
    # Print the results
    print(f"Model: GAT_infection_2L1H_weight_decay_{weight_decay}, Loss: {loss:.4f}, Train Accuracy: {acc:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")
    torch.save(model, f'/workspace/GAT_infection_2L1H_weight_decay_{weight_decay}.pt')

In [12]:
# for weight_decay in [0, 0.0001, 0.0005, 0.001, 0.005, 0.01]:
#     train_model_weight_decay(data=data, weight_decay=weight_decay)

Model: GAT_infection_2L1H_weight_decay_0, Loss: 0.1726, Train Accuracy: 0.9348, Test Loss: 0.1776, Test Accuracy: 0.9304
Model: GAT_infection_2L1H_weight_decay_0.0001, Loss: 0.1609, Train Accuracy: 0.9496, Test Loss: 0.1671, Test Accuracy: 0.9432
Model: GAT_infection_2L1H_weight_decay_0.0005, Loss: 0.1142, Train Accuracy: 0.9512, Test Loss: 0.1177, Test Accuracy: 0.9480
Model: GAT_infection_2L1H_weight_decay_0.001, Loss: 0.1091, Train Accuracy: 0.9508, Test Loss: 0.1101, Test Accuracy: 0.9476
Model: GAT_infection_2L1H_weight_decay_0.005, Loss: 0.2144, Train Accuracy: 0.9236, Test Loss: 0.2197, Test Accuracy: 0.9204
Model: GAT_infection_2L1H_weight_decay_0.01, Loss: 0.2491, Train Accuracy: 0.9016, Test Loss: 0.2521, Test Accuracy: 0.8980


Dropout

In [13]:
def train_model_dropout(data, dropout):
    out_channels = data.y.max().item() + 1
    model = GAT_L2_intervention(in_channels=2, 
                                hidden_channels=8, 
                                out_channels=out_channels, 
                                heads=1,
                                dropout=dropout)
    # Define the number of epochs
    epochs = 500
    # Define the learning rate
    lr = 0.005
    # Prepare the optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0)
    # Train the models
    model, loss, acc = train(model=model, data=data, optimizer=optimizer, epochs=epochs)
    # Test the models
    test_loss, test_acc = test(model=model, data=data)
    # Print the results
    print(f"Model: GAT_infection_2L1H_dropout_{dropout}, Loss: {loss:.4f}, Train Accuracy: {acc:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")
    torch.save(model, f'/workspace/GAT_infection_2L1H_dropout_{dropout}.pt')

In [14]:
# for dropout in [0, 0.2, 0.4, 0.6, 0.8]:
#     train_model_dropout(data=data, dropout=dropout)

Model: GAT_infection_2L1H_dropout_0, Loss: 0.1452, Train Accuracy: 0.9504, Test Loss: 0.1492, Test Accuracy: 0.9476
Model: GAT_infection_2L1H_dropout_0.2, Loss: 0.2579, Train Accuracy: 0.9216, Test Loss: 0.1621, Test Accuracy: 0.9476
Model: GAT_infection_2L1H_dropout_0.4, Loss: 0.4582, Train Accuracy: 0.8036, Test Loss: 0.2922, Test Accuracy: 0.8600
Model: GAT_infection_2L1H_dropout_0.6, Loss: 0.6253, Train Accuracy: 0.7652, Test Loss: 0.3000, Test Accuracy: 0.8600
Model: GAT_infection_2L1H_dropout_0.8, Loss: 0.7349, Train Accuracy: 0.7672, Test Loss: 0.4399, Test Accuracy: 0.8356


### 2. General hyperparameters

In [16]:
def train_model_general_hyperparameters(data, hidden_channels):
    out_channels = data.y.max().item() + 1
    model = GAT_L2_intervention(in_channels=2, hidden_channels=hidden_channels, out_channels=out_channels, heads=1)
    # Define the number of epochs
    epochs = 500
    # Define the learning rate
    lr = 0.005
    # Prepare the optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0)
    # Train the models
    model, loss, acc = train(model=model, data=data, optimizer=optimizer, epochs=epochs)
    # Test the models
    test_loss, test_acc = test(model=model, data=data)
    # Print the results
    print(f"Model: GAT_infection_2L1H_hidden_channels_{hidden_channels}, Loss: {loss:.4f}, Train Accuracy: {acc:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")
    torch.save(model, f'/workspace/GAT_infection_2L1H_hidden_channels_{hidden_channels}.pt')

In [18]:
# for hidden_channels in [8, 16, 32, 64]:
#     train_model_general_hyperparameters(data=data, hidden_channels=hidden_channels)

Model: GAT_infection_2L1H_hidden_channels_8, Loss: 0.1289, Train Accuracy: 0.9488, Test Loss: 0.1328, Test Accuracy: 0.9448
Model: GAT_infection_2L1H_hidden_channels_16, Loss: 0.1102, Train Accuracy: 0.9512, Test Loss: 0.1135, Test Accuracy: 0.9480
Model: GAT_infection_2L1H_hidden_channels_32, Loss: 0.1009, Train Accuracy: 0.9512, Test Loss: 0.1036, Test Accuracy: 0.9480
Model: GAT_infection_2L1H_hidden_channels_64, Loss: 0.1262, Train Accuracy: 0.9608, Test Loss: 0.1300, Test Accuracy: 0.9608
