In [None]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

import warnings
warnings.filterwarnings("ignore")

In [None]:
# replace lighting.pytorch as pytorch_lightning
from pytorch_forecasting import BaseModel

In [None]:
from scripts.data_builder import HydroForecastData
from scripts.tft_data import file_checker, open_for_tft, train_val_split
from scripts.model_eval import pred_res_builder

import glob
import pandas as pd
import geopandas as gpd
import xarray as xr
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
from copy import deepcopy

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_forecasting import (TimeSeriesDataSet, TemporalFusionTransformer,
                                 Baseline)
from pytorch_lightning.loggers import TensorBoardLogger

from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import QuantileLoss, SMAPE, RMSE, MASE
from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric
from sklearn.preprocessing import RobustScaler, MinMaxScaler

# torch.set_float32_matmul_precision('medium')
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

# Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3, 1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3, 1), 'GB')
meteo_input = ['prcp_e5l',  't_max_e5l', 't_min_e5l']
hydro_target = 'q_mm_day'

ws_file = gpd.read_file('../geo_data/great_db/geometry/russia_ws.gpkg')
ws_file = ws_file.set_index('gauge_id')
# ws_file = ws_file[ws_file['new_area'] <= 50000]

### Model

In [None]:
def interpretation_for_gauge(interp_dict: dict,
                             static_parameters: list,
                             encoder_params: list,
                             decoder_params: list):

    def to_percentage(values: torch.Tensor):
        values = values / values.sum(-1).unsqueeze(-1)
        return values

    def interp_df(interp_tensor: torch.Tensor,
                  df_columns: list):
        interp_tensor = to_percentage(interp_tensor)

        interp = {var: float(val) for var, val in zip(df_columns,
                                                      interp_tensor)}
        interp = pd.DataFrame(interp, index=[0])

        return interp

    # find most informative days
    _, indices = interp_dict['attention'].sort(descending=True)
    indices = indices[0]+1
    # get most valuable static parameters
    static_worth = interp_df(interp_tensor=interp_dict['static_variables'],
                             df_columns=static_parameters)
    stat_cols, _ = (list(static_worth.T.nlargest(n=4, columns=0).T.columns),
                   list(static_worth.T.nlargest(n=4, columns=0).T.to_numpy()))
    
    # get most valuable encoder parameters
    encoder_worth = interp_df(interp_tensor=interp_dict['encoder_variables'],
                              df_columns=encoder_params)
    enc_col, _ = (encoder_worth.idxmax(axis=1)[0],
                  encoder_worth.max(axis=1)[0])
    # get most valuable decoder parameters
    decoder_worth = interp_df(interp_tensor=interp_dict['decoder_variables'],
                              df_columns=decoder_params)
    dec_col, _ = (decoder_worth.idxmax(axis=1)[0],
                  decoder_worth.max(axis=1)[0])

    return int(indices), stat_cols, enc_col, dec_col

In [None]:
open_for_tft(
        nc_files=[nc_file],
        static_path='../geo_data/attributes/geo_vector.csv',
        area_index=ws_file.index,
        meteo_predictors=meteo_input,
        hydro_target=hydro_target, allow_nan=True)

In [None]:
by_gauge_res = list()
for nc_file in glob.glob('../geo_data/great_db/nc_all_q/*.nc'):
    gauge_id = nc_file.split('/')[-1][:-3]

    file = open_for_tft(
        nc_files=[nc_file],
        static_path='../geo_data/attributes/geo_vector.csv',
        area_index=ws_file.index,
        meteo_predictors=meteo_input,
        hydro_target=hydro_target)

    (train_ds, train_loader,
        val_ds, val_loader, val_df,
        scaler) = train_val_split(file)

    # configure network and trainer
    early_stop_callback = EarlyStopping(monitor="val_loss",
                                        min_delta=1e-3, patience=3,
                                        verbose=True, mode="min")
    # log the learning rate
    lr_logger = LearningRateMonitor()
    # logging results to a tensorboard
    logger = TensorBoardLogger(f"./single_gauge_8epoch/{gauge_id}_tft")

    if device == 'cuda':
        accel = 'gpu'
    else:
        accel = 'cpu'

    trainer = pl.Trainer(
        max_epochs=8,
        accelerator='auto',
        enable_model_summary=True,
        check_val_every_n_epoch=3,
        gradient_clip_val=0.5,
        log_every_n_steps=3,
        callbacks=[lr_logger, early_stop_callback],
        logger=logger)

    tft = TemporalFusionTransformer.from_dataset(
        train_ds,
        learning_rate=1e-3,
        hidden_size=64,
        dropout=0.4,
        loss=nnse(),
        optimizer='adam')

    # print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

    # fit network
    trainer.fit(tft,
                train_dataloaders=train_loader,
                val_dataloaders=val_loader)

### Result evaluation

In [None]:
from pytorch_forecasting.metrics import RMSE

class nnse(MultiHorizonMetric):

    def loss(self, pred, target):

        pred = self.to_prediction(pred)
        denom = torch.sum((target-pred)**2)
        divsr = torch.sum((target - torch.mean(target)**2))
        nse = 1 - torch.div(denom, divsr)
        nnse = 1 / (2 - nse)

        return nnse


def nse(pred, target):
    denom = np.sum((target-pred)**2)
    divsr = np.sum((target-np.mean(target))**2)
    return 1-(denom/divsr)

In [None]:
q_indexes = [i.split('/')[-1][:-3] for i in
             glob.glob('../geo_data/great_db/nc_all_q/*.nc')]

In [None]:
nc_file = '../geo_data/great_db/nc_all_q/72519.nc'
test_ws = ws_file.loc[ws_file['name_ru'] == 'р.Ситня - д.Пески']
meteo_input = ['prcp_e5l',  't_max_e5l', 't_min_e5l']
hydro_target = 'q_mm_day'
static_parameters = ['for_pc_sse', 'crp_pc_sse',
                     'inu_pc_ult', 'ire_pc_sse',
                     'lka_pc_use', 'prm_pc_sse',
                     'pst_pc_sse', 'cly_pc_sav',
                     'slt_pc_sav', 'snd_pc_sav',
                     'kar_pc_sse', 'urb_pc_sse',
                     'gwt_cm_sav', 'lkv_mc_usu',
                     'rev_mc_usu', 'sgr_dk_sav',
                     'slp_dg_sav', 'ws_area',
                     'ele_mt_sav']

In [None]:
gauge_id = nc_file.split('/')[-1][:-3]
file = open_for_tft(
    nc_files=[nc_file],
    static_path='../geo_data/attributes/geo_vector.csv',
    area_index=test_ws.index,
    meteo_predictors=meteo_input,
    hydro_target=hydro_target, allow_nan=True)

(train_ds, train_loader,
    val_ds, val_loader, val_df,
    scaler) = train_val_split(file)

# configure network and trainer
early_stop_callback = EarlyStopping(monitor="val_loss",
                                    min_delta=1e-3, patience=3,
                                    verbose=True, mode="min")
# log the learning rate
lr_logger = LearningRateMonitor()
# logging results to a tensorboard
logger = TensorBoardLogger(f"./single_gauge/{gauge_id}_tft")

if device == 'cuda':
    accel = 'gpu'
else:
    accel = 'cpu'

trainer = pl.Trainer(
    max_epochs=1,
    accelerator='auto',
    enable_model_summary=True,
    check_val_every_n_epoch=3,
    gradient_clip_val=0.5,
    log_every_n_steps=3,
    callbacks=[lr_logger, early_stop_callback],
    logger=logger)

tft = TemporalFusionTransformer.from_dataset(
    train_ds,
    learning_rate=1e-3,
    hidden_size=64,
    dropout=0.4,
    loss=nnse(),
    optimizer='adam')

# print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# fit network
# trainer.fit(tft,
#             train_dataloaders=train_loader,
#             val_dataloaders=val_loader)
chkpt = glob.glob(
    f'./single_gauge/{gauge_id}_tft/*/*/checkpoints/*.ckpt')[0]
resdf, interpretation = pred_res_builder(gauge_id=gauge_id,
                                         hydro_target=hydro_target,
                                         meteo_input=meteo_input,
                                         static_parameters=static_parameters,
                                         model_checkpoint=chkpt,
                                         res_storage='./result/tft_single',
                                         val_df=val_df,
                                         scaler=scaler,
                                         val_ts_ds=val_ds, with_plot=False)

In [None]:
tft = TemporalFusionTransformer.load_from_checkpoint('./TFT_914/lightning_logs/version_0/checkpoints/epoch=2-step=78123.ckpt')

In [None]:
# variable selection
def make_selection_plot(title, values, labels, ax):
    order = np.argsort(values)
    values = values / values.sum(-1).unsqueeze(-1)
    ax.barh(np.arange(len(values)), values[order] * 100, tick_label=np.asarray(labels)[order])
    ax.set_title(title)
    ax.set_xlabel("Значимость в %")
    plt.tight_layout()
    return fig

In [None]:
fig, axs = plt.subplots(figsize=(15, 8),
                        ncols=2,
                        nrows=2)
# attention
attention = interpretation["attention"].detach().cpu()
attention = attention / attention.sum(-1).unsqueeze(-1)
axs[0, 0].plot(
    np.arange(0, 365), attention
)
axs[0, 0].set_xlabel("Дней назад")
axs[0, 0].set_ylabel("Значимость")
axs[0, 0].set_title("Значимость")


make_selection_plot(
    "Значимость физико-географических характеристик", interpretation["static_variables"].detach().cpu(), static_parameters,
    ax=axs[0, 1]);
make_selection_plot(
    "Значимость переменных кодировщика",
    interpretation["encoder_variables"].detach().cpu(),
    [*meteo_input, hydro_target],
    ax=axs[1, 0])
make_selection_plot(
    "Значимость переменных декодировщика",
    interpretation["decoder_variables"].detach().cpu(),
    meteo_input,
    ax=axs[1, 1]);
fig.savefig('../conclusions/images/interp_model.png',
            dpi=650, bbox_inches='tight')


In [None]:
figs = {}
# attention
fig, ax = plt.subplots()
attention = interpretation["attention"].detach().cpu()
attention = attention / attention.sum(-1).unsqueeze(-1)
ax.plot(
    np.arange(0, 365), attention
)
ax.set_xlabel("Дней назад")
ax.set_ylabel("Значимость")
ax.set_title("Значимость")
figs["attention"] = fig

# variable selection
def make_selection_plot(title, values, labels):
    fig, ax = plt.subplots(figsize=(7, len(values) * 0.25 + 2))
    order = np.argsort(values)
    values = values / values.sum(-1).unsqueeze(-1)
    ax.barh(np.arange(len(values)), values[order] * 100, tick_label=np.asarray(labels)[order])
    ax.set_title(title)
    ax.set_xlabel("Значимость в %")
    plt.tight_layout()
    return fig

figs["static_variables"] = make_selection_plot(
    "Значимость физико-географических характеристик", interpretation["static_variables"].detach().cpu(), static_parameters
)
figs["encoder_variables"] = make_selection_plot(
    "Значимость переменных кодировщика", interpretation["encoder_variables"].detach().cpu(), [*meteo_input, hydro_target],
)
figs["decoder_variables"] = make_selection_plot(
    "Значимость переменных декодировщика", interpretation["decoder_variables"].detach().cpu(), meteo_input
)
fig, axs = plt.subplots(figsize=(15, 8),
                        ncols=2,
                        nrows=2)


In [None]:
partial_gauges = gpd.read_file(
    '../geo_data/great_db/geometry/gauges_partial_q.gpkg')
partial_gauges.index = partial_gauges['gauge_id'].astype(str)
partial_ws = ws_file[ws_file.index.isin(partial_gauges.index)]
partial_ws

In [None]:
lost_gauges = gpd.read_file('../geo_data/great_db/geometry/lost_gauges.gpkg')
lost_gauges.index = lost_gauges['gauge_id'].astype(str)
lost_ws = ws_file[ws_file.index.isin(lost_gauges.index)]

In [None]:
meteo_input = ['prcp_e5l',  't_max_e5l', 't_min_e5l']
hydro_target = 'q_mm_day'
static_parameters = ['for_pc_sse', 'crp_pc_sse',
                     'inu_pc_ult', 'ire_pc_sse',
                     'lka_pc_use', 'prm_pc_sse',
                     'pst_pc_sse', 'cly_pc_sav',
                     'slt_pc_sav', 'snd_pc_sav',
                     'kar_pc_sse', 'urb_pc_sse',
                     'gwt_cm_sav', 'lkv_mc_usu',
                     'rev_mc_usu', 'sgr_dk_sav',
                     'slp_dg_sav', 'ws_area',
                     'ele_mt_sav']
res_list = list()
for nc_file in [file for file
                in glob.glob('../geo_data/great_db/nc_concat/*.nc')
                if file.split('/')[-1][:-3] in lost_ws.index]:
    gauge_id = nc_file.split('/')[-1][:-3]
    file = open_for_tft(
        nc_files=[nc_file],
        static_path='../geo_data/attributes/geo_vector.csv',
        area_index=lost_ws.index,
        meteo_predictors=meteo_input,
        hydro_target=hydro_target, allow_nan=True)
    try:
        (train_ds, train_loader,
         val_ds, val_loader, val_df,
         scaler) = train_val_split(file)

        res_df, interpretation = pred_res_builder(gauge_id=gauge_id,
                                                  hydro_target=hydro_target,
                                                  meteo_input=meteo_input,
                                                  static_parameters=static_parameters,
                                                  model_checkpoint='/workspaces/my_dissertation/forecast/TFT_914/lightning_logs/version_0/checkpoints/epoch=2-step=78123.ckpt',
                                                  res_storage='./result/lost_gauge',
                                                  val_df=val_df,
                                                  scaler=scaler,
                                                  val_ts_ds=val_ds, with_plot=False)
        res_list.append(res_df)
    except AssertionError:
        print(f'No available data for {gauge_id}')
pd.concat(res_list).to_csv('./result/tft_lost_gauge.csv', index=False)

### Blind forecast

In [None]:
meteo_input = ['prcp_e5l',  't_max_e5l', 't_min_e5l']
hydro_target = 'q_mm_day'
static_parameters = ['for_pc_sse', 'crp_pc_sse',
                     'inu_pc_ult', 'ire_pc_sse',
                     'lka_pc_use', 'prm_pc_sse',
                     'pst_pc_sse', 'cly_pc_sav',
                     'slt_pc_sav', 'snd_pc_sav',
                     'kar_pc_sse', 'urb_pc_sse',
                     'gwt_cm_sav', 'lkv_mc_usu',
                     'rev_mc_usu', 'sgr_dk_sav',
                     'slp_dg_sav', 'ws_area',
                     'ele_mt_sav']

partial_gauges = gpd.read_file(
    '../geo_data/great_db/geometry/gauges_partial_q.gpkg')
partial_gauges.index = partial_gauges['gauge_id'].astype(str)
partial_ws = ws_file[ws_file.index.isin(partial_gauges.index)]

res_list = list()
for nc_file in [file for file
                in glob.glob('../geo_data/great_db/nc_concat/*.nc')
                if file.split('/')[-1][:-3] in partial_ws.index]:
    gauge_id = nc_file.split('/')[-1][:-3]
    file = open_for_tft(
        nc_files=[nc_file],
        static_path='../geo_data/attributes/geo_vector.csv',
        area_index=partial_ws.index,
        meteo_predictors=meteo_input,
        hydro_target=hydro_target, allow_nan=True)

In [None]:
open_for_tft(
        nc_files=[nc_file],
        static_path='../geo_data/attributes/geo_vector.csv',
        area_index=partial_ws.index,
        meteo_predictors=meteo_input,
        hydro_target=hydro_target, allow_nan=True)

In [None]:
static_attributes = pd.read_csv('../geo_data/attributes/geo_vector.csv',
                                index_col='gauge_id')
static_attributes.index = static_attributes.index.astype(str)
static_attributes

In [None]:
static_attributes.empty

In [None]:
static_attributes = pd.read_csv('../geo_data/attributes/geo_vector.csv',
                                index_col='gauge_id')
static_attributes.index = static_attributes.index.astype(str)
gauge_id = nc_file.split('/')[-1][:-3]
res_file = list()
try:
    gauge_static = static_attributes.loc[[gauge_id], :]
except KeyError:
    print(f'No data for {gauge_id} !')

else:
    file = xr.open_dataset(nc_file)
    file = file.to_dataframe()
    # file['date'] = file.index
    file = file.reset_index()
    file['time_idx'] = file.index

    for col in gauge_static.columns:
        file[col] = gauge_static.loc[gauge_id, col]

    res_file.append(file)

In [None]:
f = open_for_tft(
        nc_files=[nc_file],
        static_path='../geo_data/attributes/geo_vector.csv',
        area_index=lost_ws.index,
        meteo_predictors=meteo_input,
        hydro_target=hydro_target, allow_nan=True)

In [None]:
if 'index' in f.columns:
    f = f.rename(columns={'index': 'date'})
f = f[['date', 'time_idx', 'gauge_id',
                    hydro_target, *meteo_input, *static_parameters]]
f = f.dropna().reset_index(drop=True)

scaler = MinMaxScaler(feature_range=(1, 10))
f[[hydro_target, *meteo_input,
        *static_parameters]] = scaler.fit_transform(
    f[[hydro_target, *meteo_input, *static_parameters]])

### Run this to get results

In [None]:
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_forecasting import TemporalFusionTransformer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
import pytorch_lightning as pl
from scripts.tft_data import open_for_tft, train_val_split
from scripts.model_eval import nnse, pred_res_builder

import glob
import geopandas as gpd
import pandas as pd
from tqdm import tqdm

import torch
torch.set_float32_matmul_precision('medium')


# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

# Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3, 1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3, 1), 'GB')


meteo_input = ['prcp_e5l',  't_max_e5l', 't_min_e5l']
hydro_target = 'lvl_mbs'
static_parameters = ['for_pc_sse', 'crp_pc_sse',
                     'inu_pc_ult', 'ire_pc_sse',
                     'lka_pc_use', 'prm_pc_sse',
                     'pst_pc_sse', 'cly_pc_sav',
                     'slt_pc_sav', 'snd_pc_sav',
                     'kar_pc_sse', 'urb_pc_sse',
                     'gwt_cm_sav', 'lkv_mc_usu',
                     'rev_mc_usu', 'sgr_dk_sav',
                     'slp_dg_sav', 'ws_area',
                     'ele_mt_sav']

ws_file = gpd.read_file('../geo_data/great_db/geometry/russia_ws.gpkg')
ws_file = ws_file.set_index('gauge_id')
# ws_file = ws_file[ws_file['new_area'] <= 50000]



In [None]:

###############################################################################
file = open_for_tft(
    nc_files=glob.glob('../geo_data/great_db/nc_all_h/*.nc'),
    static_path='../geo_data/attributes/geo_vector.csv',
    area_index=ws_file.index,
    meteo_input=meteo_input,
    hydro_target=hydro_target,
    shuffle_static=False,
    with_static=True)

(train_ds, train_loader,
    val_ds, val_loader, val_df,
    scaler) = train_val_split(file, with_static=True)
res = list()


In [None]:
for gauge in val_df.gauge_id.unique():
    try:
        res.append(pred_res_builder(gauge_id=gauge,
                                    res_storage='./result/tft_level_multi_256/',
                                    model_checkpoint='/workspaces/my_dissertation/forecast/lvl_prediction_multi_gauge_NEXT/lightning_logs/version_0/checkpoints/epoch=11-step=200928.ckpt',
                                    hydro_target=hydro_target,
                                    meteo_input=meteo_input,
                                    static_parameters=static_parameters,
                                    val_ts_ds=val_ds, val_df=val_df,
                                    scaler=scaler,
                                    with_plot=False)[0])
    except:
        pass
res = pd.concat(res)
res.to_csv('./result/tft_level_multi_256.csv',
           index=False)

In [None]:
pd.read_csv('./result/tft_shuffled_static_pred_17epoch_RMSE_256.csv')['NSE'].median()


In [None]:
best_tft = TemporalFusionTransformer.load_from_checkpoint(
    '/workspaces/my_dissertation/forecast/TFT_914/lightning_logs/version_6/checkpoints/epoch=20-step=246057.ckpt')
res = list()
for gauge in tqdm(val_df.gauge_id.unique()):
    res.append(pred_res_builder(gauge_id=gauge,
                                val_ts_ds=validation,
                                with_plot=False))
res = pd.concat(res)
res.to_csv('./result/tft_predictions_6.csv',
           index=False)

In [None]:
best_tft = TemporalFusionTransformer.load_from_checkpoint(
    '/workspaces/my_dissertation/forecast/multi_gauge_256/lightning_logs/version_0/checkpoints/epoch=20-step=274596.ckpt')


pred_id = '5461'
file = open_for_tft(
    nc_files=glob.glob(f'../geo_data/great_db/nc_all_q/{pred_id}.nc'),
    static_path='../geo_data/attributes/geo_vector.csv',
    area_index=ws_file.index,
    meteo_predictors=meteo_input,
    hydro_target=hydro_target)

(train_ds, train_loader,
    val_ds, val_loader, val_df,
    scaler) = train_val_split(file)

raw_predictions, _, x, _, _ = best_tft.predict(val_ds.filter(
    lambda x: x.gauge_id == f'{pred_id}').to_dataloader(train=False,
                                                   batch_size=128,
                                                   num_workers=8), mode="raw",
    return_x=True)
interpretation = best_tft.interpret_output(raw_predictions, reduction="sum")
best_tft.plot_interpretation(interpretation)