In [1]:
import os
import pickle
import numpy as np
from tqdm import tqdm

import torch
from torch_geometric_temporal.signal import StaticGraphTemporalSignal
from graph_models import SpatioTemporalGNN

In [None]:
model_type = 'DCRNN'
num_epochs = 500
learning_rate = 1e-4

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [2]:
if model_type == 'A3TGCN':
    save_dir = '../../dataset/graph/A3TGCN'
    os.makedirs(save_dir, exist_ok=True)
else:
    save_dir = '../../dataset/graph/else'
    os.makedirs(save_dir, exist_ok=True)

# load
file_path = os.path.join(save_dir, f"dataset.pkl")
with open(file_path, 'rb') as f:
    dataset = pickle.load(f)
print(next(iter(dataset)))

Data(x=[42840, 40, 28], edge_index=[2, 299550], edge_attr=[299550], y=[42840, 28])


In [3]:
total_length = len(dataset.features) - 1
val_length = 1
test_length = 1

train_dataset = StaticGraphTemporalSignal(
    edge_index=dataset.edge_index,
    edge_weight=None,
    features=dataset.features[:total_length - (test_length + val_length)],
    targets=dataset.targets[:total_length - (test_length + val_length)],
)
val_dataset = StaticGraphTemporalSignal(
    edge_index=dataset.edge_index,
    edge_weight=None,
    features=dataset.features[total_length - (test_length + val_length):total_length - test_length],
    targets=dataset.targets[total_length - (test_length + val_length):total_length - test_length],
)
test_dataset = StaticGraphTemporalSignal(
    edge_index=dataset.edge_index,
    edge_weight=None,
    features=dataset.features[total_length - test_length:total_length], 
    targets=dataset.targets[total_length - test_length:total_length],
)

print("train_dataset: ", len(list(train_dataset)))
print("val_dataset: ", len(list(val_dataset)))
print("test_dataset: ", len(list(test_dataset)))

train_dataset:  66
val_dataset:  1
test_dataset:  1


In [4]:
model_type = 'A3TGCN'
out_channels = 32 # hidden_dim
num_epochs = 500
learning_rate = 1e-4

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

model = SpatioTemporalGNN(
    model_type=model_type,
    in_channels=dataset.features[0].shape[1],
    out_channels=out_channels
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

cpu


In [6]:
patience = 15 
early_stopping_counter = 0 
best_val_loss = float('inf')  

for epoch in tqdm(range(num_epochs)):
    # train
    model.train()
    train_loss = 0
    for time, snapshot in enumerate(train_dataset):
        snapshot = snapshot.to(device)
        y_hat = model(snapshot.x, snapshot.edge_index)
        loss = torch.mean((y_hat - snapshot.y)**2)
        train_loss += loss
    train_loss = train_loss / (time+1)
    
    train_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    # val
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for time, snapshot in enumerate(val_dataset):
            snapshot = snapshot.to(device)
            y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_weight)
            val_loss += torch.mean((y_hat - snapshot.y)**2)
    val_loss = val_loss / (time+1)

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss.item()}, Val Loss: {val_loss.item()}')
    
    # early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stopping_counter = 0
        torch.save(model.state_dict(), f'{model_type}.pt')
        print(f'Model saved at Epoch {epoch+1} with validation loss: {val_loss.item()}')
    else:
        early_stopping_counter += 1
        print(f'EarlyStopping counter: {early_stopping_counter} out of {patience}')
        if early_stopping_counter >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break

# load best model
model.load_state_dict(torch.load(f'{model_type}.pt'))
print("Loaded the best model")

# test
model.eval()
results = {}
test_loss = 0

with torch.no_grad():
    for time, snapshot in enumerate(test_dataset):
        snapshot = snapshot.to(device)
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_weight)
        test_loss += torch.mean((y_hat - snapshot.y)**2)
        
        y_pred = y_hat.detach().cpu().numpy()
        y_label = snapshot.y.detach().cpu().numpy()
        results.update({time: {"label": y_label.tolist(), "pred": y_pred.tolist()}})
    
    test_loss = test_loss / (time+1)
    test_loss = test_loss.item()
    
    print("MSE: {:.4f}".format(test_loss))
    results.update({"mse": test_loss})

  0%|          | 0/500 [00:13<?, ?it/s]


KeyboardInterrupt: 