In [8]:
import torch
from iTransformer import iTransformer
from utils.helpers import ModelConfig, calc_percentiles
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter



In [9]:
window_size = 132
pred_length = 12

electricity = "datasets/electricity/electricity_small.csv"

In [12]:
# defining all needed instances
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = ModelConfig(lookback_len=window_size, pred_length=pred_length)
model = iTransformer(config)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
writer = SummaryWriter(log_dir='logs')

In [None]:
# loading saved model
checkpoint = torch.load('model_epoch_1.pt')
model.load_state_dict(checkpoint['model_state_dict'])
writer = SummaryWriter(log_dir='eval_logs')
valid_dataloader = my_loading_function(electricity)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
global_step_eval = 1
total_p10_eval = 0
total_p50_eval = 0
total_p90_eval = 0

# Example inference code
with torch.no_grad():
	total_loss_eval = 0
	for input, target_covariates in tqdm(valid_dataloader, desc=f"Epoch: Validating"):
		input = input.to(device)
		output_covariates = model(input)

		# this can be used for target specific fine-tuning
		#output = output_covariates[12][:,:,0]
		#target = target_covariates[:,:,0]

		loss = torch.nn.MSELoss()
		loss = loss(output_covariates[12], target_covariates)  # compute loss on all variates
		#computed_loss = loss(output, target)  # compute loss on target variate

		total_loss_eval += loss.item()
		total_p10_eval += calc_percentiles(output_covariates[12], target_covariates, 10)
		total_p50_eval += calc_percentiles(output_covariates[12], target_covariates, 50)
		total_p90_eval += calc_percentiles(output_covariates[12], target_covariates, 90)

#additional metrics: MASE, MAPE, sMAPE


print(f'MSE Loss: {total_loss_eval / len(valid_dataloader)}\nP10: { total_p10_eval/ len(valid_dataloader)}\n\
P50: {total_p50_eval/ len(valid_dataloader)}\nP90: { total_p90_eval/ len(valid_dataloader)}')

writer.add_scalar('MSE/valid', total_loss_eval/ len(valid_dataloader), global_step_eval)
writer.add_scalar('P10/valid', total_p10_eval/ len(valid_dataloader), global_step_eval)
writer.add_scalar('P50/valid', total_p50_eval/ len(valid_dataloader), global_step_eval)
writer.add_scalar('P90/valid', total_p90_eval/ len(valid_dataloader), global_step_eval)

writer.close()