In [221]:
import sys
print(sys.version)

3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]


In [222]:
import torch
print(torch.__version__)

2.1.1+cu118


In [223]:
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm
import numpy as np

In [224]:
cora_dataset = Planetoid('/tmp/cora', 'cora')

In [225]:
cora_data = cora_dataset[0]

In [226]:
# For debug use only
num_nodes = cora_data.num_nodes
print('cora has {} nodes'.format(num_nodes))

num_edges = cora_data.num_edges
print('cora has {} edges'.format(num_edges))

cora has 2708 nodes
cora has 10556 edges


In [227]:
# For debug use only
print(cora_data)
print(cora_data.x.device)

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
cpu


In [228]:
# For debug use only
cora_x_train = cora_data.x[cora_data.train_mask]
cora_x_val = cora_data.x[cora_data.val_mask]
cora_x_test = cora_data.x[cora_data.test_mask]

print("number of nodes in cora train set,", cora_x_train.shape[0])
print("number of nodes in cora val set,", cora_x_val.shape[0])
print("number of nodes in cora test set,", cora_x_test.shape[0])

number of nodes in cora train set, 140
number of nodes in cora val set, 500
number of nodes in cora test set, 1000


In [229]:
# For debug use only
print(cora_data.y)
print(cora_data.y.shape)
s = set()
histogram = np.zeros(7)
for label in cora_data.y:
    s.add(label.item())
    histogram[label.item()]+=1
print(s)
print(histogram)

tensor([3, 4, 4,  ..., 3, 3, 3])
torch.Size([2708])
{0, 1, 2, 3, 4, 5, 6}
[351. 217. 418. 818. 426. 298. 180.]


In [230]:
# For debug use only
print(cora_data.x.shape)
print(cora_data.x[170:180])
print(cora_data.num_features)
print(cora_data.num_nodes)
print(cora_data.num_node_types)
print(type(cora_data))

torch.Size([2708, 1433])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
1433
2708
1
<class 'torch_geometric.data.data.Data'>


In [231]:
# Define the GAT model
class GAT(torch.nn.Module):
    # hidden channels will be the embedding dimension for each attention head
    # after applying the first GAT layer.
    def __init__(self, in_channels, hidden_channels, 
                 num_heads, dropout_rate, num_classes):
        super().__init__()
        
        self.dropout_rate = dropout_rate
        
        self.hidden_channels = hidden_channels
        self.num_heads = num_heads
        
        self.conv1 = GATConv(in_channels, hidden_channels, heads=num_heads, 
                                dropout=dropout_rate)
        self.conv2 = GATConv(hidden_channels*num_heads, num_classes, 
                                dropout=dropout_rate, concat=False)

    def forward(self, x, edge_index):
        out = F.dropout(x, p=self.dropout_rate, training=self.training)
        
        out = self.conv1(out, edge_index)
        assert out.shape[-1] == self.hidden_channels * self.num_heads
        
        out = F.elu(out)
        out = F.dropout(out, p=self.dropout_rate, training=self.training)
        
        out = self.conv2(out, edge_index)
        return out

In [232]:
def train(model, data, optimizer, loss_fn):
    model.train()
    optimizer.zero_grad()
    
    pred = model(data.x, data.edge_index)
    loss = loss_fn(pred[data.train_mask], data.y[data.train_mask])
    
    loss.backward()
    optimizer.step()
    
    return loss

In [233]:
@torch.no_grad()
def evaluate(model, data, test_mask, loss_fn):
    accuracy_list = [0.0, 0.0]
    loss_list = [0.0, 0.0]
    model.eval()

    logits = model(data.x, data.edge_index)
    pred = logits.argmax(dim=-1)
    
    for i, mask in enumerate([data.train_mask, test_mask]):
        accuracy_list[i] = pred[mask].eq(data.y[mask]).float().mean().item()
        loss_list[i] = loss_fn(logits[mask], data.y[mask]).item()

    return accuracy_list, loss_list

In [234]:
def summarize(model):
    num_params = 0
    print(f"Model Summary: {type(model).__name__}\n")
    for name, param in model.named_parameters():
        print(name, param.size())
        num_params += param.numel()
    print(f"\nTotal number of params: {num_params}")

In [245]:
def train_model(model, data, optimizer, loss_fn, save_name, patience=100):
    print(f"Using {device}\n")
    print(f"model devide: {next(model.parameters()).device}")
    
    # Early stopping initialization
    best_val_loss = float('inf')
    best_val_acc = 0.0
    patience_counter = 0
    best_epoch = 0
    best_model_state = None
    
    # Evaluate before training
    acc_list, loss_list = evaluate(model, data, data.val_mask, loss_fn)
    print("Before training: ")
    print(f"Train Acc: {acc_list[0]:.4f}, Train Loss: {loss_list[0]:.4f}, Val Acc: {acc_list[1]:.4f}, Val Loss: {loss_list[1]:.4f}\n")
    
    # Start training
    for epoch in tqdm(range(num_epochs), desc="Training Epochs"):
        loss = train(model, data, optimizer, loss_fn)
        acc_list, loss_list = evaluate(model, data, 
                                       data.val_mask, loss_fn)
    
        # Update early stopping criteria
        val_loss = loss_list[1]
        val_acc = acc_list[1]
        if val_loss < best_val_loss or val_acc > best_val_acc:
            best_val_loss = min(best_val_loss, val_loss)
            best_val_acc = max(best_val_acc, val_acc)
            patience_counter = 0
            best_epoch = epoch
            torch.save(model.state_dict(), save_name)
        else:
            patience_counter += 1
    
        # Check if patience limit is reached
        if patience_counter >= patience:
            print(f"Early stopping triggered at epoch {epoch + 1}")
            break
            
        # Logging
        if (epoch % log_freq == 0) or (epoch + 1 == num_epochs):
            print(f"Epoch: {epoch+1}, Loss: {loss:.4f}")
            print(f"    Eval: Train Acc: {acc_list[0]:.4f}, Train Loss: {loss_list[0]:.4f}, Val Acc: {acc_list[1]:.4f}, Val Loss: {loss_list[1]:.4f}")

    if best_model_state is not None:
        print(f"model devide: {next(model.parameters()).device}")
        model.load_state_dict(best_model_state["state_dict"])
        print(f"Model restored to the best state from epoch {best_epoch + 1}")
        print(f"model devide: {next(model.parameters()).device}")
        
    print(f"\nTraining completed.\nBest Validation at Epoch: {best_epoch + 1}\nBest Val Acc: {best_val_acc:.4f}, Best Val Loss: {best_val_loss:.4f}")

In [240]:
# device = torch.device("cpu")
# device = torch.device("cuda")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_heads = 8
dropout_rate = 0.4
emb_dim1 = 8
lr = 0.005

cora_num_classes = len(cora_data.y.unique())
assert cora_num_classes == 7

In [253]:
num_epochs = 1000
log_freq = 50

cora_model = GAT(cora_data.num_features, emb_dim1, num_heads, dropout_rate, 
            cora_num_classes).to(device)
cora_data.to(device)

lambda_l2 = 0.001
optimizer = torch.optim.Adam(cora_model.parameters(), lr=lr, weight_decay=lambda_l2)
loss_fn = nn.CrossEntropyLoss()

In [254]:
train_model(cora_model, cora_data, optimizer, loss_fn, "cora_model_01.pth", patience=100)
cora_model.load_state_dict(torch.load("cora_model_01.pth", map_location=device))
cora_model = cora_model.to(device)

Using cuda

model devide: cuda:0
Before training: 
Train Acc: 0.1214, Train Loss: 1.9582, Val Acc: 0.1840, Val Loss: 1.9537



Training Epochs:   1%|          | 11/1000 [00:00<00:09, 102.32it/s]

Epoch: 1, Loss: 1.9646
    Eval: Train Acc: 0.7000, Train Loss: 1.7718, Val Acc: 0.4500, Val Loss: 1.8411


Training Epochs:   7%|▋         | 72/1000 [00:00<00:05, 155.29it/s]

Epoch: 51, Loss: 0.1733
    Eval: Train Acc: 1.0000, Train Loss: 0.0207, Val Acc: 0.7740, Val Loss: 0.7283


Training Epochs:  13%|█▎        | 126/1000 [00:00<00:05, 168.73it/s]

Epoch: 101, Loss: 0.1912
    Eval: Train Acc: 1.0000, Train Loss: 0.0096, Val Acc: 0.7800, Val Loss: 0.7419


Training Epochs:  19%|█▊        | 187/1000 [00:01<00:04, 189.31it/s]

Epoch: 151, Loss: 0.1388
    Eval: Train Acc: 1.0000, Train Loss: 0.0073, Val Acc: 0.7660, Val Loss: 0.8033


Training Epochs:  20%|██        | 205/1000 [00:01<00:04, 164.86it/s]

Epoch: 201, Loss: 0.1056
    Eval: Train Acc: 1.0000, Train Loss: 0.0068, Val Acc: 0.7560, Val Loss: 0.8198
Early stopping triggered at epoch 206

Training completed.
Best Validation at Epoch: 106
Best Val Acc: 0.7880, Best Val Loss: 0.7116





In [255]:
# Evaluate after training
acc_list, loss_list = evaluate(cora_model, cora_data, cora_data.val_mask, loss_fn)
print("After training: ")
print(f"Train Acc: {acc_list[0]:.4f}, Train Loss: {loss_list[0]:.4f}, Val Acc: {acc_list[1]:.4f}, Val Loss: {loss_list[1]:.4f}\n")

After training: 
Train Acc: 1.0000, Train Loss: 0.0087, Val Acc: 0.7880, Val Loss: 0.7232

