In [1]:
import os
import pickle
import numpy as np
from tqdm import tqdm

import torch
from graph_models import SpatioTemporalGNN
from torch_geometric_temporal.signal import StaticGraphTemporalSignal

In [2]:
model_type = 'STGCN'

num_epochs = 500
learning_rate = 1e-4

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
save_dir = '../../dataset/graph'
os.makedirs(save_dir, exist_ok=True)

# load
file_path = os.path.join(save_dir, f"dataset.pkl")
with open(file_path, 'rb') as f:
    dataset = pickle.load(f)

In [4]:
total_length = len(dataset.features)-1
val_length = 1
test_length = 1
    
train_dataset = StaticGraphTemporalSignal(
    edge_index=dataset.edge_index,
    edge_weight=None,
    features=dataset.features[:total_length - (test_length + val_length)],
    targets=dataset.targets[:total_length - (test_length + val_length)],
)
val_dataset = StaticGraphTemporalSignal(
    edge_index=dataset.edge_index,
    edge_weight=None,
    features=dataset.features[total_length - (test_length + val_length):total_length - test_length],
    targets=dataset.targets[total_length - (test_length + val_length):total_length - test_length],
)
test_dataset = StaticGraphTemporalSignal(
    edge_index=dataset.edge_index,
    edge_weight=None,
    features=dataset.features[total_length - test_length:total_length], 
    targets=dataset.targets[total_length - test_length:total_length],
)
print(next(iter(dataset)))
print("train_dataset: ", len(list(train_dataset)))
print("val_dataset: ", len(list(val_dataset)))
print("test_dataset: ", len(list(test_dataset)))

Data(x=[42840, 40, 28], edge_index=[2, 299550], edge_attr=[299550], y=[42840, 28])
train_dataset:  66
val_dataset:  1
test_dataset:  1


In [5]:
num_hidden_features=64
num_output_features=28
kernel_size=3
K=3

model = SpatioTemporalGNN(
    model_type=model_type,
    num_nodes=dataset[0].x.shape[0],
    num_node_features=dataset[0].x.shape[1],
    num_hidden_features=num_hidden_features,
    num_output_features=num_output_features,
    kernel_size=kernel_size,
    K=K,
).to(device)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

SpatioTemporalGNN(
  (stgnn): STConv(
    (_temporal_conv1): TemporalConv(
      (conv_1): Conv2d(40, 64, kernel_size=(1, 3), stride=(1, 1))
      (conv_2): Conv2d(40, 64, kernel_size=(1, 3), stride=(1, 1))
      (conv_3): Conv2d(40, 64, kernel_size=(1, 3), stride=(1, 1))
    )
    (_graph_conv): ChebConv(64, 64, K=3, normalization=sym)
    (_temporal_conv2): TemporalConv(
      (conv_1): Conv2d(64, 28, kernel_size=(1, 3), stride=(1, 1))
      (conv_2): Conv2d(64, 28, kernel_size=(1, 3), stride=(1, 1))
      (conv_3): Conv2d(64, 28, kernel_size=(1, 3), stride=(1, 1))
    )
    (_batch_norm): BatchNorm2d(42840, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)


In [None]:
save_dir = f'../../result/graph/{model_type}'
os.makedirs(save_dir, exist_ok=True)

patience = 15 
early_stopping_counter = 0 
best_val_loss = float('inf')  

torch.cuda.empty_cache()
for epoch in tqdm(range(num_epochs)):
    # train
    model.train()
    train_loss = 0
    for snapshot in train_dataset:
        snapshot = snapshot.to(device)
        y_hat = model(snapshot.x.permute(2, 0, 1).unsqueeze(0), snapshot.edge_index)
        loss = torch.mean((y_hat.squeeze() - snapshot.y)**2)
        train_loss += loss.item()
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        del snapshot, y_hat, loss
        
    train_loss = train_loss / train_dataset.snapshot_count
    
    # val
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for snapshot in val_dataset:
            snapshot = snapshot.to(device)
            y_hat = model(snapshot.x.permute(2, 0, 1).unsqueeze(0), snapshot.edge_index)
            loss = torch.mean((y_hat.squeeze() - snapshot.y)**2)
            val_loss += loss.item()
            
            del snapshot, y_hat, loss
            
    val_loss = val_loss / val_dataset.snapshot_count
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss}, Val Loss: {val_loss}')
    
    # early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stopping_counter = 0
        torch.save(model.state_dict(), os.path.join(save_dir, f'model.pt'))
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= patience:
            print("Early stopped")
            break

# load best model
torch.cuda.empty_cache()
model.load_state_dict(torch.load(os.path.join(save_dir, 'model.pt')))
print("Model loaded")

# test
model.eval()
result = {}
test_loss = 0
with torch.no_grad():
    for snapshot in test_dataset:
        snapshot = snapshot.to(device)
        y_hat = model(snapshot.x.permute(2, 0, 1).unsqueeze(0), snapshot.edge_index)
        loss = torch.mean((y_hat.squeeze() - snapshot.y)**2)
        test_loss += loss.item()
        
        y_pred = y_hat.detach().cpu().numpy()
        y_label = snapshot.y.detach().cpu().numpy()
        
        result.update({"label": y_label, "pred": y_pred})
        
        del snapshot, y_hat, loss, y_pred, y_label
    
    test_loss = test_loss / test_dataset.snapshot_count
    print(f"Test Loss: {test_loss}")
    result.update({"mse": test_loss})

with open(os.path.join(save_dir, f'result.pkl'), "wb") as f:
    pickle.dump(result, f)
os.rename(save_dir, f'../../result/graph/{model_type}_{result["mse"]:.2f}')

  0%|          | 1/500 [00:24<3:22:16, 24.32s/it]

Epoch 1/500, Train Loss: 77010.73301373106, Val Loss: 111791.0625


  0%|          | 2/500 [00:49<3:24:18, 24.62s/it]

Epoch 2/500, Train Loss: 77009.72141335228, Val Loss: 111790.1171875


  1%|          | 3/500 [01:13<3:24:11, 24.65s/it]

Epoch 3/500, Train Loss: 77009.38648200757, Val Loss: 111789.5234375


  1%|          | 4/500 [01:38<3:25:17, 24.83s/it]

Epoch 4/500, Train Loss: 77009.2099905303, Val Loss: 111789.421875


  1%|          | 5/500 [02:03<3:24:46, 24.82s/it]

Epoch 5/500, Train Loss: 77009.05303030302, Val Loss: 111789.2421875


  1%|          | 6/500 [02:28<3:24:00, 24.78s/it]

Epoch 6/500, Train Loss: 77008.89879261363, Val Loss: 111788.9140625


  1%|▏         | 7/500 [02:53<3:24:20, 24.87s/it]

Epoch 7/500, Train Loss: 77008.744140625, Val Loss: 111788.515625


  2%|▏         | 8/500 [03:18<3:23:16, 24.79s/it]

Epoch 8/500, Train Loss: 77008.59339488637, Val Loss: 111788.40625


  2%|▏         | 9/500 [03:43<3:23:24, 24.86s/it]

Epoch 9/500, Train Loss: 77008.43158143939, Val Loss: 111787.859375


  2%|▏         | 10/500 [04:07<3:22:57, 24.85s/it]

Epoch 10/500, Train Loss: 77008.27852746213, Val Loss: 111787.6875
