In [1]:
import os
import pickle
import numpy as np
from tqdm import tqdm

import torch
from graph_models.stgcn import STGCN
from torch_geometric_temporal.signal import StaticGraphTemporalSignal

In [2]:
model_type = 'STGCN'

num_epochs = 500
learning_rate = 1e-4

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
save_dir = '../../dataset/graph'
os.makedirs(save_dir, exist_ok=True)

# load
file_path = os.path.join(save_dir, f"dataset.pkl")
with open(file_path, 'rb') as f:
    dataset = pickle.load(f)

In [4]:
total_length = len(dataset.features)-1
val_length = 1
test_length = 1
    
train_dataset = StaticGraphTemporalSignal(
    edge_index=dataset.edge_index,
    edge_weight=None,
    features=dataset.features[:total_length - (test_length + val_length)],
    targets=dataset.targets[:total_length - (test_length + val_length)],
)
val_dataset = StaticGraphTemporalSignal(
    edge_index=dataset.edge_index,
    edge_weight=None,
    features=dataset.features[total_length - (test_length + val_length):total_length - test_length],
    targets=dataset.targets[total_length - (test_length + val_length):total_length - test_length],
)
test_dataset = StaticGraphTemporalSignal(
    edge_index=dataset.edge_index,
    edge_weight=None,
    features=dataset.features[total_length - test_length:total_length], 
    targets=dataset.targets[total_length - test_length:total_length],
)
print(next(iter(dataset)))
print("train_dataset: ", len(list(train_dataset)))
print("val_dataset: ", len(list(val_dataset)))
print("test_dataset: ", len(list(test_dataset)))

Data(x=[42840, 40, 28], edge_index=[2, 299550], edge_attr=[299550], y=[42840, 28])
train_dataset:  66
val_dataset:  1
test_dataset:  1


In [5]:
if model_type == 'STGCN':
    model = STGCN(
        num_nodes=dataset[0].x.shape[0],
        feature_dim=dataset[0].x.shape[1],
    )
model = model.to(device)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

STGCN(
  (stconv_blocks): ModuleList(
    (0): STConv(
      (_temporal_conv1): TemporalConv(
        (conv_1): Conv2d(40, 64, kernel_size=(1, 3), stride=(1, 1))
        (conv_2): Conv2d(40, 64, kernel_size=(1, 3), stride=(1, 1))
      )
      (_graph_conv): ChebConv(64, 16, K=3, normalization=sym)
      (_temporal_conv2): TemporalConv(
        (conv_1): Conv2d(16, 64, kernel_size=(1, 3), stride=(1, 1))
        (conv_2): Conv2d(16, 64, kernel_size=(1, 3), stride=(1, 1))
      )
      (_batch_norm): BatchNorm2d(42840, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): STConv(
      (_temporal_conv1): TemporalConv(
        (conv_1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))
        (conv_2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))
      )
      (_graph_conv): ChebConv(64, 16, K=3, normalization=sym)
      (_temporal_conv2): TemporalConv(
        (conv_1): Conv2d(16, 64, kernel_size=(1, 3), stride=(1, 1))
        (conv_2): Conv2d(16, 64, kerne

In [6]:
dataset[0].x.shape

torch.Size([42840, 40, 28])

In [None]:
save_dir = f'../../result/graph/{model_type}'
os.makedirs(save_dir, exist_ok=True)

patience = 15 
early_stopping_counter = 0 
best_val_loss = float('inf')  

torch.cuda.empty_cache()
for epoch in tqdm(range(num_epochs)):
    # train
    model.train()
    train_loss = 0
    for snapshot in train_dataset:
        snapshot = snapshot.to(device)
        y_hat = model(snapshot.x.permute(1, 0, 2).unsqueeze(0), snapshot.edge_index)
        loss = torch.mean((y_hat.squeeze() - snapshot.y)**2)
        train_loss += loss.item()
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        del snapshot, y_hat, loss
        
    train_loss = train_loss / train_dataset.snapshot_count
    
    # val
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for snapshot in val_dataset:
            snapshot = snapshot.to(device)
            y_hat = model(snapshot.x.permute(1, 0, 2).unsqueeze(0), snapshot.edge_index)
            loss = torch.mean((y_hat.squeeze() - snapshot.y)**2)
            val_loss += loss.item()
            
            del snapshot, y_hat, loss
            
    val_loss = val_loss / val_dataset.snapshot_count
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss}, Val Loss: {val_loss}')
    
    # early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stopping_counter = 0
        torch.save(model.state_dict(), os.path.join(save_dir, f'model.pt'))
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= patience:
            print("Early stopped")
            break

# load best model
torch.cuda.empty_cache()
model.load_state_dict(torch.load(os.path.join(save_dir, 'model.pt')))
print("Model loaded")

# test
model.eval()
result = {}
test_loss = 0
with torch.no_grad():
    for snapshot in test_dataset:
        snapshot = snapshot.to(device)
        y_hat = model(snapshot.x.permute(1, 0, 2).unsqueeze(0), snapshot.edge_index)
        loss = torch.mean((y_hat.squeeze() - snapshot.y)**2)
        test_loss += loss.item()
        
        y_pred = y_hat.detach().cpu().numpy()
        y_label = snapshot.y.detach().cpu().numpy()
        
        result.update({"label": y_label, "pred": y_pred})
        
        del snapshot, y_hat, loss, y_pred, y_label
    
    test_loss = test_loss / test_dataset.snapshot_count
    print(f"Test Loss: {test_loss}")
    result.update({"mse": test_loss})

with open(os.path.join(save_dir, f'result.pkl'), "wb") as f:
    pickle.dump(result, f)
os.rename(save_dir, f'../../result/graph/{model_type}_{result["mse"]:.2f}')

  0%|          | 0/500 [00:00<?, ?it/s]