In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
from types import SimpleNamespace
import glob
import os
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.loader import DataLoader

In [None]:
args = SimpleNamespace(seed = 2024,
                         learning_rate = 1e-3,
                         weight_decay = 5e-4,
                         batch_size = 16,
                         print_interval = 10,
                         input_dim = 1,
                         output_dim = 1,
                         hidden_size = 8,
                         n_layers = 2,
                         n_heads = 2,
                         out_head = 1,
                         num_epochs = 500,
                         dropout = 0.6,
                         patience = 10,
                         checkpoints_dir = './ckp',
                         device = 'cuda' if torch.cuda.is_available() else 'cpu',
                         conv_type = 'GAT',
                         data_path = './data',
                         data_name = 'baci',
                         save_model = True,
                         )

In [None]:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.device == 'cuda':
    torch.cuda.manual_seed(args.seed)

In [None]:
if not os.path.exists(args.checkpoints_dir):
    os.makedirs(args.checkpoints_dir, exist_ok=True)

In [None]:
colnames= [f'term_{num}' for num in range(3134)]

X_train = pd.read_csv("./data/X_train.csv", names=colnames, header=None)
y_train = pd.read_csv("./data/y_train.csv", names=['GDP'], header=None)

X_test = pd.read_csv("./data/X_test.csv", names=colnames, header=None)
y_test = pd.read_csv("./data/y_test.csv", names=['GDP'], header=None)

X_val = pd.read_csv("./data/X_val.csv", names=colnames, header=None)
y_val = pd.read_csv("./data/y_val.csv", names=['GDP'], header=None)

edge_index = torch.load('./data/edge_index.pt')
edge_attr = torch.load('./data/edge_attr.pt')

In [None]:
X_train = torch.tensor(X_train.values, dtype=torch.float32)
X_test = torch.tensor(X_test.values, dtype=torch.float32)
X_val = torch.tensor(X_val.values, dtype=torch.float32)

y_train = torch.tensor(np.log(y_train.values), dtype=torch.float32)
y_val = torch.tensor(np.log(y_val.values), dtype=torch.float32)
y_test = torch.tensor(np.log(y_test.values), dtype=torch.float32)

In [None]:
X_train.shape, X_test.shape, X_val.shape

In [None]:
num_patients_train = X_train.shape[0]
num_patients_test = X_test.shape[0]
num_patients_val = X_val.shape[0]

# training set
graphs_train = []
for i in range(num_patients_train):
    node_features = X_train[i]
    target = y_train[i]
    graph_train = (node_features, edge_index, edge_attr, target)
    graphs_train.append(graph_train)

# test set
graphs_test = []
for i in range(num_patients_test):
    node_features = X_test[i]
    target = y_test[i]
    graph_test = (node_features, edge_index, edge_attr, target)
    graphs_test.append(graph_test)

# valid set
graphs_val = []
for i in range(num_patients_val):
    node_features = X_val[i]
    target = y_val[i]
    graph_val = (node_features, edge_index, edge_attr, target)
    graphs_val.append(graph_val)

In [None]:
data_train = [Data(x=graph[0].reshape(len(graphs_train[0][0]), 1), edge_index=graph[1], edge_attr=graph[2], y=graph[3]) for graph in graphs_train]

data_test = [Data(x=graph[0].reshape(len(graphs_test[0][0]), 1), edge_index=graph[1], edge_attr=graph[2], y=graph[3]) for graph in graphs_test]

data_val = [Data(x=graph[0].reshape(len(graphs_val[0][0]), 1), edge_index=graph[1], edge_attr=graph[2], y=graph[3]) for graph in graphs_val]

In [None]:
train_loader = DataLoader(data_train, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(data_test, batch_size=args.batch_size, shuffle=False)
val_loader = DataLoader(data_val, batch_size=args.batch_size, shuffle=True)

for step, data in enumerate(train_loader):
    data = data.to(args.device)

    print('Training Batches: ')
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    break

In [None]:
class GAT(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels, output_dim, num_heads, out_head, dropout):
        super(GAT, self).__init__()
        self.gat1 = GATConv(num_node_features, hidden_channels, edge_dim=1, heads=num_heads)
        self.gat2 = GATConv(hidden_channels * num_heads, hidden_channels, edge_dim=1, heads=num_heads)
        self.gat3 = GATConv(hidden_channels * num_heads, output_dim, edge_dim=1, concat=False, heads=out_head)

        self.dropout = dropout

    def forward(self, x, edge_index, edge_attr):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.gat1(x, edge_index, edge_attr))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.gat2(x, edge_index, edge_attr))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.gat3(x, edge_index, edge_attr)

        return x.squeeze()

In [None]:
model = GAT(num_node_features=args.input_dim, hidden_channels=args.hidden_size, output_dim=args.output_dim, num_heads=args.n_heads, out_head=args.out_head, dropout=args.dropout)

optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
criterion = torch.nn.MSELoss()

model = model.to(args.device)
criterion = criterion.to(args.device)

num_epochs = args.num_epochs

start_time = time.time()

train_losses = []
val_losses = []
es_counter = 0
best_loss = np.inf
best_epoch = 0

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for step, data in enumerate(train_loader):
        data = data.to(args.device)
        optimizer.zero_grad()

        out = model(data.x, data.edge_index, data.edge_attr)
        loss = criterion(out, data.y.view(-1, 1).to(args.device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    average_train_loss = total_loss / len(train_loader)
    train_losses.append(average_train_loss)

    model.eval()
    with torch.no_grad():
        val_loss = 0
        for step, data in enumerate(val_loader):
            data = data.to(args.device)

            out = model(data.x, data.edge_index, data.edge_attr)
            loss = criterion(out, data.y.view(-1, 1).to(args.device))
            val_loss += loss.item()
            val_losses.append(val_loss)

        average_val_loss = val_loss / len(val_loader)

        if epoch % args.print_interval == 0:
            print(f'Epoch: {epoch:03d}, Train loss: {average_train_loss:.4f}, Validation Loss: {average_val_loss:.4f}')

    torch.save(model.state_dict(), f'./{args.checkpoints_dir}/{epoch}.pth')

    if val_losses[-1] < best_loss:
        best_loss = val_losses[-1]
        best_epoch = epoch
        es_counter = 0
    else:
        es_counter += 1
    
    if es_counter == args.patience:
        print('Early Stopping!!')
        break

    files = glob.glob(f'./{args.checkpoints_dir}/*.pth')
    for file in files:
        epoch_nb = int(file.split('\\')[1].split('.')[0])
        if epoch_nb < best_epoch:
            os.remove(file)

    files = glob.glob(f'./{args.checkpoints_dir}/*.pth')
    for file in files:
        epoch_nb = int(file.split('\\')[1].split('.')[0])
        if epoch_nb > best_epoch:
            os.remove(file)


elapsed_time = time.time() - start_time

print(f"Time used for training: {elapsed_time:.2f} seconds")

In [None]:
def plot_loss(train_value, test_value):
    plt.subplot(121)
    plt.plot(train_value, label='Train Loss')
    plt.title('Train Loss')

    plt.subplot(122)
    plt.plot(test_value, label='Valid Loss')
    plt.title('Valid Loss')

    plt.show()

In [None]:
plot_loss(train_losses, val_losses)

In [None]:
print(f'Loading {best_epoch}th epoch')
test_model = GAT(num_node_features=args.input_dim, hidden_channels=args.hidden_size, output_dim=args.output_dim, num_heads=args.n_heads, out_head=args.out_head, dropout=args.dropout)
test_model.load_state_dict(torch.load(f'./{args.checkpoints_dir}/{best_epoch}.pth'))
test_model.eval()

In [None]:
def inference(model, criterion, args):

    with torch.no_grad():
        test_losses = 0
        for test_data in test_loader:
            test_data.to(args.device)
            test_logit = model(test_data.x, test_data.edge_index, test_data.edge_attr)
            loss = criterion(test_logit, test_data.y.view(-1, 1).to(args.device))
            test_losses += loss.item()

    avg_test_loss = test_losses / len(test_loader)
    print(f"Test loss: {avg_test_loss:.4f}")

In [None]:
inference(test_model, criterion, args)