In [10]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
device = torch.device('cuda:0')

from sklearn.metrics import r2_score, mean_absolute_error
import pickle

from models.SCINet import SCINet
from tools.tools import GridSearch

In [9]:
masked = True  # True or False
masked = '_masked' if masked else ''

with open(f'data/LG/train_dataset{masked}.pickle', 'rb') as f:
    train_dataset = pickle.load(f)

with open(f'data/LG/val_dataset{masked}.pickle', 'rb') as f:
    val_dataset = pickle.load(f)

with open(f'data/LG/test_dataset{masked}.pickle', 'rb') as f:
    test_dataset = pickle.load(f)

train_loader = DataLoader(train_dataset, batch_size = 16, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = 64, shuffle = False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
criterion = nn.MSELoss()
eval_metrics = [r2_score, mean_absolute_error]
lr = 1e-3

In [None]:
data_name = 'LG'  # 'LG' or 'AirPollution'
model_name = input('Enter the name of the model: ')

basic_params = {
    'T': 60,
    'n': 7
}

param_grid = {
    'skip_hidden_size': [12, 16, 24, 32],
    'T_modified': [20, 30, 40],
    'skip': [8, 10, 12]
}

model_trainer = GridSearch(criterion, eval_metrics, device,
                           temp_save_path=f'checkpoints/{data_name}/{model_name}_temp.pt')
model2_best = model_trainer.train_by_grid(SCINet, basic_params, param_grid, Adam, train_loader,
                                          test_loader, lr, patience=5, epochs=50, save_filename=f'checkpoints/{model_name}_best.pt')


In [None]:
model_trainer.test(test_loader)
model_trainer.plot_losses(plot_title=f'{model_name}', save_filename=f'{model_name}_losses.png')