In [5]:
%reload_ext autoreload
%autoreload 2

import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch_geometric.datasets import LRGBDataset
from torch_geometric.loader import DataLoader

from models import GCN
from utils import compute_dirichlet_energy, train, test

In [2]:
device = torch.device("cpu" if torch.backends.mps.is_available() else "cpu")

# Load dataset
dataset = LRGBDataset(root="./data/LRGB", name="Peptides-func")

train_dataset = LRGBDataset(root="./data/LRGB", name="Peptides-func", split="train")
val_dataset = LRGBDataset(root="./data/LRGB", name="Peptides-func", split="test")

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
train_loader_for_energy = DataLoader(train_dataset, batch_size=10873, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)


In [9]:
# model = GCN(dataset.num_node_features, hidden_channels=64, out_channels=dataset.num_classes).to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# dirichlet_energies = []

In [10]:
for epoch in range(1, 10):
    train()
    train_acc = test(train_loader)
    val_acc = test(val_loader)
    print(f'Epoch: {epoch}, Train: {train_acc:.4f}, Val: {val_acc:.4f}')
    
    # compute the energy for the full graph
    data = next(iter(train_loader_for_energy))
    data = data.to(device)
    out = model(data.x.float(), data.edge_index, data.batch, use_pooling=False)
    energy = compute_dirichlet_energy(out.detach(), data.edge_index, device)
    print(f'Energy: {energy:.4f}')
    dirichlet_energies.append(energy)

Epoch: 1, Train: 0.5148, Val: 0.5165
Energy: 5539598.0000
Epoch: 2, Train: 0.5150, Val: 0.5169
Energy: 7219669.0000
Epoch: 3, Train: 0.5308, Val: 0.5371
Energy: 14369955.0000
Epoch: 4, Train: 0.5265, Val: 0.5285
Energy: 20832504.0000
Epoch: 5, Train: 0.5554, Val: 0.5676
Energy: 57291572.0000
Epoch: 6, Train: 0.5701, Val: 0.5822
Energy: 55465064.0000
Epoch: 7, Train: 0.5634, Val: 0.5719
Energy: 66085992.0000
Epoch: 8, Train: 0.5779, Val: 0.5813
Energy: 71935080.0000
Epoch: 9, Train: 0.5772, Val: 0.5847
Energy: 117058696.0000


In [None]:
# Function to run one experiment
def run_experiment(train_loader, val_loader, train_loader_for_energy, model, device, epochs=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    dirichlet_energies = []
    train_accuracies = []
    val_accuracies = []
    
    for epoch in range(1, epochs + 1):
        train(model, train_loader, optimizer, device)  # Assuming you have this defined elsewhere
        train_acc = test(model, train_loader, device)
        val_acc = test(model, val_loader, device)
        
        # Store the accuracies
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)
        
        print(f'Epoch: {epoch}, Train: {train_acc:.4f}, Val: {val_acc:.4f}')
        
        # Compute the energy for the full graph
        data = next(iter(train_loader_for_energy))
        data = data.to(device)
        out = model(data.x.float(), data.edge_index, data.batch, use_pooling=False)
        energy = compute_dirichlet_energy(out.detach(), data.edge_index, device)
        
        # Store the Dirichlet energy
        dirichlet_energies.append(energy)
        print(f'Energy: {energy:.4f}')
    
    return train_accuracies, val_accuracies, dirichlet_energies

# Run the experiment multiple times and compute confidence intervals
def compute_confidence_intervals(runs, alpha=0.95):
    mean = np.mean(runs, axis=0)
    std_err = np.std(runs, axis=0) / np.sqrt(len(runs))
    ci = std_err * 1.96  # For 95% confidence interval
    return mean, ci

# Number of times to run the experiment
n_runs = 5

# To store the results of multiple runs
train_accuracies_runs = []
val_accuracies_runs = []
dirichlet_energies_runs = []

for i in range(n_runs):
    print(f'Run {i+1}/{n_runs}')
    model = GCN(dataset.num_node_features, hidden_channels=64, out_channels=dataset.num_classes).to(device)
    train_acc, val_acc, dirichlet_energy = run_experiment(train_loader, val_loader, train_loader_for_energy, model, device)
    
    # Store results of this run
    train_accuracies_runs.append(train_acc)
    val_accuracies_runs.append(val_acc)
    dirichlet_energies_runs.append(dirichlet_energy)

# Convert lists to numpy arrays for easier handling
train_accuracies_runs = np.array(train_accuracies_runs)
val_accuracies_runs = np.array(val_accuracies_runs)
dirichlet_energies_runs = np.array(dirichlet_energies_runs)

# Compute confidence intervals
train_acc_mean, train_acc_ci = compute_confidence_intervals(train_accuracies_runs)
val_acc_mean, val_acc_ci = compute_confidence_intervals(val_accuracies_runs)
energy_mean, energy_ci = compute_confidence_intervals(dirichlet_energies_runs)

# Print results with confidence intervals
print(f"Train Accuracy Mean: {train_acc_mean[-1]:.4f} ± {train_acc_ci[-1]:.4f}")
print(f"Val Accuracy Mean: {val_acc_mean[-1]:.4f} ± {val_acc_ci[-1]:.4f}")
print(f"Dirichlet Energy Mean: {energy_mean[-1]:.4f} ± {energy_ci[-1]:.4f}")


Run 1/5
Epoch: 1, Train: 0.5148, Val: 0.5165
Energy: 5663834.5000
Epoch: 2, Train: 0.5149, Val: 0.5165
Energy: 7262938.0000
Epoch: 3, Train: 0.5400, Val: 0.5474
Energy: 12170196.0000
Epoch: 4, Train: 0.5515, Val: 0.5603
Energy: 18332170.0000
Epoch: 5, Train: 0.5552, Val: 0.5581
Energy: 33488366.0000
Epoch: 6, Train: 0.5664, Val: 0.5719
Energy: 38871068.0000
Epoch: 7, Train: 0.5774, Val: 0.5903
Energy: 56150020.0000
Epoch: 8, Train: 0.5630, Val: 0.5723
Energy: 59156160.0000
Epoch: 9, Train: 0.5781, Val: 0.5783
Energy: 69394736.0000
Epoch: 10, Train: 0.5566, Val: 0.5598
Energy: 71511432.0000
Run 2/5
Epoch: 1, Train: 0.5148, Val: 0.5165
Energy: 6496665.0000
Epoch: 2, Train: 0.4990, Val: 0.5036
Energy: 10153779.0000
Epoch: 3, Train: 0.5268, Val: 0.5311
Energy: 13188036.0000
Epoch: 4, Train: 0.5508, Val: 0.5628
Energy: 17920750.0000
Epoch: 5, Train: 0.5300, Val: 0.5393
Energy: 23944084.0000
Epoch: 6, Train: 0.5515, Val: 0.5594
Energy: 30876156.0000
Epoch: 7, Train: 0.5646, Val: 0.5783


In [None]:
# plots
import matplotlib.pyplot as plt
import numpy as np

# Assuming the confidence interval code and data from previous steps

# Number of epochs
epochs = np.arange(1, len(train_acc_mean) + 1)

# Create subplots to visualize training accuracy, validation accuracy, and Dirichlet energy
fig, axs = plt.subplots(3, 1, figsize=(8, 12))

# Plot for training accuracy
axs[0].errorbar(epochs, train_acc_mean, yerr=train_acc_ci, fmt='-o', label='Train Accuracy', capsize=5)
axs[0].set_title('Training Accuracy with Confidence Interval')
axs[0].set_xlabel('Epochs')
axs[0].set_ylabel('Accuracy')
axs[0].legend()
axs[0].grid(True)

# Plot for validation accuracy
axs[1].errorbar(epochs, val_acc_mean, yerr=val_acc_ci, fmt='-o', label='Validation Accuracy', capsize=5, color='orange')
axs[1].set_title('Validation Accuracy with Confidence Interval')
axs[1].set_xlabel('Epochs')
axs[1].set_ylabel('Accuracy')
axs[1].legend()
axs[1].grid(True)

# Plot for Dirichlet energy
axs[2].errorbar(epochs, energy_mean, yerr=energy_ci, fmt='-o', label='Dirichlet Energy', capsize=5, color='green')
axs[2].set_title('Dirichlet Energy with Confidence Interval')
axs[2].set_xlabel('Epochs')
axs[2].set_ylabel('Energy')
axs[2].legend()
axs[2].grid(True)

# Adjust layout for better readability
plt.tight_layout()

# Show the plots
plt.show()
