In [None]:
import sys,os
import torch
import yaml
import logging
from pydantic import ValidationError
from typing import List, Tuple

sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(os.path.abspath(os.path.dirname(os.getcwd()))))))
from datasets.weather_bench import WeatherDataset
from models.VariableEncoder.datasets.dataset import CustomDataset
from models.VariableEncoder.training.configs import TrainingConfig
from models.VariableEncoder.training.configs import TrainingRunConfig


def get_normal_dataset(config: TrainingConfig) -> Tuple[CustomDataset, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    device = ("cuda" if torch.cuda.is_available() else "cpu" )
    device = torch.device(device)

    vars = config.air_variable + config.surface_variable + config.only_input_variable + config.constant_variable

    weather = WeatherDataset(config.train_start, config.train_end, device=device, download_variables=vars, download_levels=config.levels)
    # dataset.shape:  torch.Size([7309, 100, 1450])
    
    source, mean_std, var_vocab = weather.load_one(config.air_variable, config.surface_variable, config.only_input_variable, 
                                        config.constant_variable, level=config.levels)
    src_var_list = var_vocab.get_code(vars)
    tgt_var_list = var_vocab.get_code(config.air_variable + config.surface_variable)

    dataset = CustomDataset(source, config.src_time_len, config.tgt_time_len, n_only_input=len(config.only_input_variable)+len(config.constant_variable))
    return dataset, mean_std, (src_var_list, tgt_var_list)


config_path = os.path.join(os.path.dirname(os.getcwd()), 'configs/train_config.yaml')

try:
    with open(config_path) as f:
        config_dict = yaml.safe_load(f)
    config: TrainingRunConfig = TrainingRunConfig.parse_obj(config_dict)
except FileNotFoundError:
    logging.error(f"Config file {config_path} does not exist. Exiting.")
except yaml.YAMLError:
    logging.error(f"Config file {config_path} is not valid YAML. Exiting.")
except ValidationError as e:
    logging.error(f"Config file {config_path} is not valid. Exiting.\n{e}")


dataset, mean_std, var_list = get_normal_dataset(config.training)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def visualize(loss: torch.Tensor, title, isAIR = False):
    if isAIR:
        loss = loss.swapaxes(0, 1)
    print(loss.shape)
    plt.plot(loss, label="rmse loss")
    plt.title(title)
    plt.show()

In [None]:
def predict(model: TrainModule, src: torch.Tensor, label: torch.Tensor, device):
    model.model.eval()
    model.setting()

    src = src.to(device).unsqueeze(0)
    label = label.to(device).unsqueeze(0)
    src = src + positional_encoding(src.size(0), config.training.src_time_len, var_list[0].size(0), src.size(-1), device)
    tgt = torch.zeros(label.size(0), 1, label.size(2), device=device)
    tgt = torch.cat([tgt, label[:, :var_list[1].size(0)]], dim=1)

    for i in range(1, config.training.tgt_time_len+1):
        pos = positional_encoding(tgt.size(0), i, var_list[1].size(0), tgt.size(-1), device, has_special_token=True)
        src_seq, tgt_seq = get_var_seq(var_list[0], var_list[1], config.training.src_time_len, i, src.size(0))
        tgt_mask = get_tgt_mask(var_list[1].size(0), i).to(device)

        tgt = tgt + pos
        src_seq = src_seq.to(device)
        tgt_seq = tgt_seq.to(device)
        predict = model.model(src, tgt, src_seq, tgt_seq, tgt_mask)
        # predict.shape = (batch, i * var + 1, hidden)
        print(predict.shape)
        tgt = predict
    
    tgt = tgt[:, :-1]
    
    tgt = tgt.view(tgt.size(0), config.training.tgt_time_len, var_list[1].size(0), tgt.size(-1))
    label = label.view(tgt.size(0), config.training.tgt_time_len, var_list[1].size(0), tgt.size(-1))
    
    loss = model.calculate_sqare_loss(tgt, label)
    # loss.shape = (batch, var_len, time_len, 1450)
    loss = loss.swapaxes(1, 2)
    # loss.shape = (batch, var_len, time_len)
    loss = torch.mean(loss, dim=-1)
     # loss.shape = (var_len, batch, time_len)
    loss = loss.swapaxes(0, 1)


    src_seq.cpu().detach()
    tgt_seq.cpu().detach()
    src.cpu().detach()
    tgt.cpu().detach()

    label.cpu().detach()
    predict.cpu().detach()
    loss = loss.cpu().detach()
    
    return loss