# Results
> functions for plotting results and comparing Kalman Filter with MDS and ERA

In [None]:
#| hide
#| default_exp kalman.results

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import altair as alt

In [None]:
import altair

In [None]:
#| export
from fastcore.test import *
from fastcore.basics import *
from meteo_imp.utils import *
from meteo_imp.gaussian import *
from meteo_imp.kalman.filter import *
from meteo_imp.kalman.filter import get_test_data
from meteo_imp.data import *
from meteo_imp.kalman.training import *
from meteo_imp.kalman.training import _n_tuple
from fastcore.transform import *
from fastai.learner import *
from pyprojroot import here

import pykalman
from typing import *

import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype
import torch
from torch import Tensor
from torch.distributions import MultivariateNormal

from timeit import timeit
import polars as pl
import altair as alt

from tqdm.auto import tqdm

import io
from contextlib import redirect_stderr
import random
from math import floor

from dataclasses import dataclass
from functools import partial
from itertools import zip_longest

load data to test functions

In [None]:
reset_seed()

In [None]:
hai = pd.read_parquet(hai_big_path)
hai_era = pd.read_parquet(hai_era_big_path)

In [None]:
dls = imp_dataloader(hai, hai_era, var_sel = 'TA', gap_len=5, block_len=100, control_lags = [1], n_rep=1, bs=1).cpu()

In [None]:
item = MeteoImpItem(1,2, 'TA', gap_len=10)
items = [item]

In [None]:
targ = orig_target(dls, items)[0]

In [None]:
input, _ = one_batch_with_items(dls, items)

## Imputation Methods

### Kalman Filter

take a `MeteoImp Df` and gap fill it using the given model

In [None]:
model = torch.load(here("analysis/results/trained_models/1_gap_varying_6-336_v1.pickle"))

In [None]:
preds, targs = predict_items(model, dls=dls, items = items)
preds[0].mean[targs[0].mask] = targs[0].data[targs[0].mask]

#### KalmanImputation

In [None]:
input[1].shape

torch.Size([1, 100, 9])

In [None]:
preds_raw = model(input)

In [None]:
len(preds_raw.mean[0])

10

In [None]:
#| export
def _extract_var(preds, var_idx, max_len):
    "extract prediction only from one var"
    preds_new = []
    for b_pred in preds:
        b_pred_new = []
        for pred in b_pred:
            if len(pred) == 1: pred_new = pred
            elif len(pred) == max_len: pred_new = pred[var_idx:var_idx+1] if pred.dim() == 1 else pred[var_idx:var_idx+1, var_idx:var_idx+1]
            else: raise ValueError("supports only gaps for 1 or all variables")
            b_pred_new.append(pred_new)
        preds_new.append(b_pred_new)
    return preds_new
         

In [None]:
_extract_var(preds_raw.mean, 1, 7)

[[tensor([-0.7555], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([-0.7765], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([-0.8153], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([-0.8653], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([-0.8660], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([-0.8712], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([-0.8551], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([-0.8420], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([-0.8442], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([-0.8536], dtype=torch.float64, grad_fn=<IndexBackward0>)]]

In [None]:
_extract_var(preds_raw.cov, 1, 7)

[[tensor([[0.0227]], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([[0.0307]], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([[0.0359]], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([[0.0392]], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([[0.0409]], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([[0.0409]], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([[0.0393]], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([[0.0361]], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([[0.0310]], dtype=torch.float64, grad_fn=<IndexBackward0>),
  tensor([[0.0230]], dtype=torch.float64, grad_fn=<IndexBackward0>)]]

In [None]:
items_all = [MeteoImpItem(1,2, list(hai.columns), gap_len=3)] * 3

In [None]:
input_all, _ = one_batch_with_items(dls, items_all)

In [None]:
preds_raw_all = model(input_all)

In [None]:
_extract_var(preds_raw_all.mean, 3, 9)

[[tensor([-0.5650], dtype=torch.float64, grad_fn=<SliceBackward0>),
  tensor([-0.5887], dtype=torch.float64, grad_fn=<SliceBackward0>),
  tensor([-0.6259], dtype=torch.float64, grad_fn=<SliceBackward0>)],
 [tensor([-0.5650], dtype=torch.float64, grad_fn=<SliceBackward0>),
  tensor([-0.5887], dtype=torch.float64, grad_fn=<SliceBackward0>),
  tensor([-0.6259], dtype=torch.float64, grad_fn=<SliceBackward0>)],
 [tensor([-0.5650], dtype=torch.float64, grad_fn=<SliceBackward0>),
  tensor([-0.5887], dtype=torch.float64, grad_fn=<SliceBackward0>),
  tensor([-0.6259], dtype=torch.float64, grad_fn=<SliceBackward0>)]]

In [None]:
buffer_pred(preds_raw_all.mean, input_all[1]).shape

torch.Size([3, 100, 9])

In [None]:
_extract_var(preds_raw_all.cov, 3, 9)

[[tensor([[0.1189]], dtype=torch.float64, grad_fn=<SliceBackward0>),
  tensor([[0.1717]], dtype=torch.float64, grad_fn=<SliceBackward0>),
  tensor([[0.1251]], dtype=torch.float64, grad_fn=<SliceBackward0>)],
 [tensor([[0.1189]], dtype=torch.float64, grad_fn=<SliceBackward0>),
  tensor([[0.1717]], dtype=torch.float64, grad_fn=<SliceBackward0>),
  tensor([[0.1251]], dtype=torch.float64, grad_fn=<SliceBackward0>)],
 [tensor([[0.1189]], dtype=torch.float64, grad_fn=<SliceBackward0>),
  tensor([[0.1717]], dtype=torch.float64, grad_fn=<SliceBackward0>),
  tensor([[0.1251]], dtype=torch.float64, grad_fn=<SliceBackward0>)]]

In [None]:
#| export
class PredictLossVar:
    """loss (negative log likelihood) for only for one variable for each batch"""
    def __init__(self, only_gap:bool, var: int):
        self.loss_func = KalmanLoss(only_gap)
        self.var = var
    def __call__(self, preds, targs):
        sel_idx = [idx for idx in range(targs[1].shape[-1]) if idx != self.var]
        mask_new = targs[1].clone()
        mask_new[:, :, sel_idx] = True # make all other variables present
        targs_new = (targs[0], mask_new, targs[2])
        preds_new = (_extract_var(preds.mean, self.var, targs[1].shape[-1]), _extract_var(preds.cov, self.var, targs[1].shape[-1])) 
        # return preds_new
        # return self.loss_func(preds_new, targs_new)
        losses = []
        for i in range(len(preds_new[0])):
            losses.append(self.loss_func(_n_tuple(preds_new, i), _n_tuple(targs_new, i)))
        return losses

In [None]:
PredictLossVar(True, 3)(preds_raw_all, input_all)

[tensor(-0.2252, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(-0.2252, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(-0.2252, dtype=torch.float64, grad_fn=<MeanBackward0>)]

In [None]:
input_all[1].shape

torch.Size([3, 100, 9])

In [None]:
#| export
class PredictLikelihoodVar:
    """mean between timesteps of Likelihood for only for one variable for each batch"""
    def __init__(self, only_gap:bool, var: int):
        self.loss_func = KalmanLoss(only_gap, reduction_inbatch='none', reduction = 'none')
        self.var = var
    def __call__(self, preds, targs):
        sel_idx = [idx for idx in range(targs[1].shape[-1]) if idx != self.var]
        mask_new = targs[1].clone()
        mask_new[:, :, sel_idx] = True # make all other variables present
        targs_new = (targs[0], mask_new, targs[2])
        preds_new = (_extract_var(preds.mean, self.var, targs[1].shape[-1]), _extract_var(preds.cov, self.var, targs[1].shape[-1])) 
        # return preds_new
        # return self.loss_func(preds_new, targs_new)
        likelihoods = []
        for i in range(len(preds_new[0])):
            loss = self.loss_func(_n_tuple(preds_new, i), _n_tuple(targs_new, i))[0]
            lh = (torch.exp(-loss[0])).mean()
            likelihoods.append(lh)
        return likelihoods

In [None]:
PredictLikelihoodVar(True, 3)(preds_raw_all, input_all)

[tensor(1.1569, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(1.1569, dtype=torch.float64, grad_fn=<MeanBackward0>),
 tensor(1.1569, dtype=torch.float64, grad_fn=<MeanBackward0>)]

In [None]:
#| export
class MultiMetrics():
    def __init__(self, **metrics): self.metrics = metrics
    def __call__(self, preds, targs): return {name: metric(preds, targs) for name, metric in self.metrics.items()}

In [None]:
#| export
class KalmanImputation:
    name = "Kalman Filter"
    def __init__(self, model): store_attr()
    def __call__(self, item, dls):
        pred, targ = predict_items(self.model, dls=dls, items = [item])
        pred[0].mean[targ[0].mask] = targ[0].data[targ[0].mask]
        return pred[0].mean
    def preds_all(self, items, dls):
        return predict_items(self.model, dls=dls, items = items)
    def preds_all_loss(self, items, dls, var):
        return predict_items(self.model, dls=dls, items = items, metric_fn = PredictLossVar(only_gap=self.model.pred_only_gap, var = var))
    def preds_all_metrics(self, items, dls, metrics):
        return predict_items(self.model, dls=dls, items = items, metric_fn = metrics)

In [None]:
k_imp = KalmanImputation(model)

In [None]:
pred = k_imp(item, dls=dls)

In [None]:
display_as_row({'pred': pred[45:55],'data': targ.data[45:55], 'mask': targ.mask[45:55]}, hide_idx=False)

Unnamed: 0_level_0,TA,SW_IN,LW_IN,VPD,WS,PA,P,SWC,TS
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2000-01-04 02:30:00,2.3467,0.0,304.731,0.952,5.39,96.3,0.0,43.08,1.31
2000-01-04 03:00:00,2.1803,0.0,304.731,0.971,5.59,96.3,0.0,43.08,1.31
2000-01-04 03:30:00,1.8728,0.0,304.731,0.958,5.67,96.29,0.0,43.08,1.31
2000-01-04 04:00:00,1.4769,0.0,304.731,0.938,6.61,96.26,0.0,43.08,1.31
2000-01-04 04:30:00,1.471,0.0,315.005,0.785,6.65,96.26,0.0,43.08,1.31
2000-01-04 05:00:00,1.4303,0.0,315.005,0.529,6.12,96.27,0.0,43.08,1.3
2000-01-04 05:30:00,1.5573,0.0,315.005,0.45,5.58,96.26,0.29,43.13,1.3
2000-01-04 06:00:00,1.6616,0.0,315.005,0.499,4.7,96.25,0.0,43.13,1.29
2000-01-04 06:30:00,1.6438,0.0,315.005,0.382,5.07,96.23,0.0,43.13,1.29
2000-01-04 07:00:00,1.5697,0.0,315.005,0.313,5.23,96.24,0.0,43.13,1.29

Unnamed: 0_level_0,TA,SW_IN,LW_IN,VPD,WS,PA,P,SWC,TS
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2000-01-04 02:30:00,2.19,0.0,304.731,0.952,5.39,96.3,0.0,43.08,1.31
2000-01-04 03:00:00,2.27,0.0,304.731,0.971,5.59,96.3,0.0,43.08,1.31
2000-01-04 03:30:00,2.32,0.0,304.731,0.958,5.67,96.29,0.0,43.08,1.31
2000-01-04 04:00:00,2.34,0.0,304.731,0.938,6.61,96.26,0.0,43.08,1.31
2000-01-04 04:30:00,2.24,0.0,315.005,0.785,6.65,96.26,0.0,43.08,1.31
2000-01-04 05:00:00,2.0,0.0,315.005,0.529,6.12,96.27,0.0,43.08,1.3
2000-01-04 05:30:00,1.94,0.0,315.005,0.45,5.58,96.26,0.29,43.13,1.3
2000-01-04 06:00:00,2.07,0.0,315.005,0.499,4.7,96.25,0.0,43.13,1.29
2000-01-04 06:30:00,2.04,0.0,315.005,0.382,5.07,96.23,0.0,43.13,1.29
2000-01-04 07:00:00,2.03,0.0,315.005,0.313,5.23,96.24,0.0,43.13,1.29

Unnamed: 0_level_0,TA,SW_IN,LW_IN,VPD,WS,PA,P,SWC,TS
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2000-01-04 02:30:00,False,True,True,True,True,True,True,True,True
2000-01-04 03:00:00,False,True,True,True,True,True,True,True,True
2000-01-04 03:30:00,False,True,True,True,True,True,True,True,True
2000-01-04 04:00:00,False,True,True,True,True,True,True,True,True
2000-01-04 04:30:00,False,True,True,True,True,True,True,True,True
2000-01-04 05:00:00,False,True,True,True,True,True,True,True,True
2000-01-04 05:30:00,False,True,True,True,True,True,True,True,True
2000-01-04 06:00:00,False,True,True,True,True,True,True,True,True
2000-01-04 06:30:00,False,True,True,True,True,True,True,True,True
2000-01-04 07:00:00,False,True,True,True,True,True,True,True,True


#### Kalman Imputation - Specialized Variables

In [None]:
#| export
class KalmanImputationVar:
    name = "Kalman Filter"
    def __init__(self, models # dataframe with 2 columns `model` and `var`
                ): store_attr()
    def __call__(self, var, item, dls):
        model = self._select_model(var)
        pred, targ = predict_items(model, dls=dls, items = [item])
        pred[0].mean[targ[0].mask] = targ[0].data[targ[0].mask]
        return pred[0].mean
    def _select_model(self, var):
        return self.models[self.models['var'] == var].model.iloc[0]
    def preds_all(self, var:str, items:list, dls):
        model = self._select_model(var)
        return predict_items(model, dls=dls, items = items)

In [None]:
k_impVar = KalmanImputationVar(pd.DataFrame({'model': [model], 'var': 'TA'}))

In [None]:
pred = k_impVar(var= 'TA', item=item, dls=dls)

In [None]:
display_as_row({'pred': pred[45:55],'data': targ.data[45:55], 'mask': targ.mask[45:55]}, hide_idx=False)

Unnamed: 0_level_0,TA,SW_IN,LW_IN,VPD,WS,PA,P,SWC,TS
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2000-01-04 02:30:00,2.3467,0.0,304.731,0.952,5.39,96.3,0.0,43.08,1.31
2000-01-04 03:00:00,2.1803,0.0,304.731,0.971,5.59,96.3,0.0,43.08,1.31
2000-01-04 03:30:00,1.8728,0.0,304.731,0.958,5.67,96.29,0.0,43.08,1.31
2000-01-04 04:00:00,1.4769,0.0,304.731,0.938,6.61,96.26,0.0,43.08,1.31
2000-01-04 04:30:00,1.471,0.0,315.005,0.785,6.65,96.26,0.0,43.08,1.31
2000-01-04 05:00:00,1.4303,0.0,315.005,0.529,6.12,96.27,0.0,43.08,1.3
2000-01-04 05:30:00,1.5573,0.0,315.005,0.45,5.58,96.26,0.29,43.13,1.3
2000-01-04 06:00:00,1.6616,0.0,315.005,0.499,4.7,96.25,0.0,43.13,1.29
2000-01-04 06:30:00,1.6438,0.0,315.005,0.382,5.07,96.23,0.0,43.13,1.29
2000-01-04 07:00:00,1.5697,0.0,315.005,0.313,5.23,96.24,0.0,43.13,1.29

Unnamed: 0_level_0,TA,SW_IN,LW_IN,VPD,WS,PA,P,SWC,TS
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2000-01-04 02:30:00,2.19,0.0,304.731,0.952,5.39,96.3,0.0,43.08,1.31
2000-01-04 03:00:00,2.27,0.0,304.731,0.971,5.59,96.3,0.0,43.08,1.31
2000-01-04 03:30:00,2.32,0.0,304.731,0.958,5.67,96.29,0.0,43.08,1.31
2000-01-04 04:00:00,2.34,0.0,304.731,0.938,6.61,96.26,0.0,43.08,1.31
2000-01-04 04:30:00,2.24,0.0,315.005,0.785,6.65,96.26,0.0,43.08,1.31
2000-01-04 05:00:00,2.0,0.0,315.005,0.529,6.12,96.27,0.0,43.08,1.3
2000-01-04 05:30:00,1.94,0.0,315.005,0.45,5.58,96.26,0.29,43.13,1.3
2000-01-04 06:00:00,2.07,0.0,315.005,0.499,4.7,96.25,0.0,43.13,1.29
2000-01-04 06:30:00,2.04,0.0,315.005,0.382,5.07,96.23,0.0,43.13,1.29
2000-01-04 07:00:00,2.03,0.0,315.005,0.313,5.23,96.24,0.0,43.13,1.29

Unnamed: 0_level_0,TA,SW_IN,LW_IN,VPD,WS,PA,P,SWC,TS
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2000-01-04 02:30:00,False,True,True,True,True,True,True,True,True
2000-01-04 03:00:00,False,True,True,True,True,True,True,True,True
2000-01-04 03:30:00,False,True,True,True,True,True,True,True,True
2000-01-04 04:00:00,False,True,True,True,True,True,True,True,True
2000-01-04 04:30:00,False,True,True,True,True,True,True,True,True
2000-01-04 05:00:00,False,True,True,True,True,True,True,True,True
2000-01-04 05:30:00,False,True,True,True,True,True,True,True,True
2000-01-04 06:00:00,False,True,True,True,True,True,True,True,True
2000-01-04 06:30:00,False,True,True,True,True,True,True,True,True
2000-01-04 07:00:00,False,True,True,True,True,True,True,True,True


### MDS

Need to call R from python

In [None]:
#| export
import rpy2.robjects
from rpy2.robjects import r
import rpy2.robjects as ro
from rpy2.robjects.packages import importr
from rpy2.robjects import pandas2ri

In [None]:
r.print('here::here("R/REddyProc_tools.R")')

[1] "here::here(\"R/REddyProc_tools.R\")"


0
"'here::here(""R/REddyProc_tools.R"")'"


In [None]:
#| export
path_r = str(here("R/REddyProc_tools.R"))

r.source(path_r);

In [None]:
#| export
def importr_install(pkg):
    try:
        importr(pkg)
    except:
        utils = importr('utils')
        utils.chooseCRANmirror(ind=1)
        utils.install_packages(pkg)
        importr(pkg) 

In [None]:
#| export
def setupR():
    importr_install('tidyverse')
    importr_install('REddyProc')
    importr('lubridate')
    path_r = str(here("R/REddyProc_tools.R"))
    r.source(path_r) # R functions
    r("""toR_timestamp <- function(x){
   x$TIMESTAMP_END = as_datetime(x$TIMESTAMP_END) 
    x
     }""")

In [None]:
#| export
def pd2R(x):
    with ro.default_converter + pandas2ri.converter:
        return ro.conversion.py2rpy(x) 

def R2pd(x):
    with ro.default_converter + pandas2ri.converter:
        return ro.conversion.rpy2py(x) 

In [None]:
#| export
setupR() #start the R process and load depenencies

#### Experiment

The time series needs to be at least 90 days long, we we add 45 days before and after the gap

There is a problem with conversions of timestamps between R and Python, so convert to string in Python and then back to datetime in R

In [None]:
#| export
def add_buffer(index, inner_index, n):
    """Adds  a buffer of after and after index so length is at least """
    start = int(np.argwhere(index == inner_index[0]).squeeze())
    end = int(np.argwhere(index == inner_index[-1]).squeeze())
    start = start - n //2
    end = end + n//2
    if start < 0:
        end += -start
        start = 0
    if end > len(index):
        start -= end - len(index)
        end = len(index)
    
    index = index[start:end]
    
    assert len(index) > n
    
    return index

In [None]:
add_buffer(hai.index, hai.index[:100], 50)
add_buffer(hai.index, hai.index[-50:], 50);

In [None]:
#| export
def item2REddy(item, var, df):
    " Add context around item for supporting REddyProc"
    index = add_buffer(df.index, item.data.index, 90 * 24 * 2)
    REddy_df = df.loc[index].assign(gap = (~item.mask[var]).astype(int)).fillna({'gap': 0})
    return REddy_df.reset_index().astype({'time': str}).rename(columns={'time': 'TIMESTAMP_END'})

In [None]:
REddy_df = item2REddy(targ, 'TA', hai)

REddy_df_r = r.toR_timestamp(pd2R(REddy_df))

filled = R2pd(r.fill_gaps_EProc(REddy_df_r, "TA"))

R[write to console]: New sEddyProc class for site 'ID'

R[write to console]: Initialized variable 'TA' with 10 real gaps for gap filling.

R[write to console]: Limited MDS algorithm for gap filling of 'TA.gap_0' with LUT(SW_IN only) and MDC.

R[write to console]: Look up table with window size of 7 days with SW_IN

R[write to console]: 10

R[write to console]: Finished gap filling of 'TA' in 0 seconds. Artificial gaps filled: 4419, real gaps filled: 10, unfilled (long) gaps: 0.



In [None]:
filled

Unnamed: 0,TA_orig,TA_f,TA_fqc,TA_fall,TA_fall_qc,TA_fnum,TA_fsd,TA_fmeth,TA_fwin
1,-0.60,-0.60,0.0,-0.60,,,,,
2,-0.65,-0.65,0.0,-0.65,,,,,
3,-0.58,-0.58,0.0,-0.58,,,,,
4,-0.51,-0.51,0.0,-0.51,,,,,
5,-0.49,-0.49,0.0,-0.49,,,,,
...,...,...,...,...,...,...,...,...,...
4415,3.30,3.30,0.0,3.30,,,,,
4416,3.10,3.10,0.0,3.10,,,,,
4417,3.05,3.05,0.0,3.05,,,,,
4418,3.05,3.05,0.0,3.05,,,,,


In [None]:
REddy_df.set_index("TIMESTAMP_END") 

Unnamed: 0_level_0,TA,SW_IN,LW_IN,VPD,WS,PA,P,SWC,TS,gap
TIMESTAMP_END,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
2000-01-01 00:30:00,-0.60,0.0,302.475,0.222,2.05,96.63,0.0,43.00,1.44,0.0
2000-01-01 01:00:00,-0.65,0.0,302.475,0.122,2.53,96.58,0.0,43.00,1.43,0.0
2000-01-01 01:30:00,-0.58,0.0,301.677,0.090,3.15,96.56,0.0,43.00,1.43,0.0
2000-01-01 02:00:00,-0.51,0.0,301.677,0.110,3.12,96.56,0.0,43.00,1.45,0.0
2000-01-01 02:30:00,-0.49,0.0,301.677,0.102,3.04,96.57,0.0,43.00,1.44,0.0
...,...,...,...,...,...,...,...,...,...,...
2000-04-01 23:30:00,3.30,0.0,283.856,1.092,2.50,95.06,0.0,44.96,4.32,0.0
2000-04-02 00:00:00,3.10,0.0,283.856,1.013,2.09,95.04,0.0,44.96,4.22,0.0
2000-04-02 00:30:00,3.05,0.0,283.856,0.964,2.42,95.04,0.0,44.96,4.13,0.0
2000-04-02 01:00:00,3.05,0.0,283.856,0.933,3.04,95.04,0.0,44.96,4.04,0.0


In [None]:
filled = filled.set_index(pd.to_datetime(REddy_df.TIMESTAMP_END))

filled_item = filled.loc[targ.data.index]

pred = targ.data.copy()
pred.loc[~targ.mask["TA"], "TA"] = filled_item["TA_f"][~targ.mask["TA"]]

In [None]:
pred[45:55], targ.data[45:55]

(                           TA  SW_IN    LW_IN    VPD    WS     PA     P  \
 time                                                                      
 2000-01-04 02:30:00  1.853612    0.0  304.731  0.952  5.39  96.30  0.00   
 2000-01-04 03:00:00  1.837083    0.0  304.731  0.971  5.59  96.30  0.00   
 2000-01-04 03:30:00  1.823985    0.0  304.731  0.958  5.67  96.29  0.00   
 2000-01-04 04:00:00  1.809805    0.0  304.731  0.938  6.61  96.26  0.00   
 2000-01-04 04:30:00  1.794282    0.0  315.005  0.785  6.65  96.26  0.00   
 2000-01-04 05:00:00  1.783447    0.0  315.005  0.529  6.12  96.27  0.00   
 2000-01-04 05:30:00  1.771525    0.0  315.005  0.450  5.58  96.26  0.29   
 2000-01-04 06:00:00  1.760362    0.0  315.005  0.499  4.70  96.25  0.00   
 2000-01-04 06:30:00  1.748410    0.0  315.005  0.382  5.07  96.23  0.00   
 2000-01-04 07:00:00  1.737933    0.0  315.005  0.313  5.23  96.24  0.00   
 
                        SWC    TS  
 time                              
 2000-01-04 02

In [None]:
#| export
def gap_fill_item(item, REddy_df, var, filled):
    
    filled = filled.set_index(pd.to_datetime(REddy_df.TIMESTAMP_END))
    filled_item = filled.loc[item.data.index]

    pred = item.data.copy()
    pred.loc[~item.mask[var], var] = filled_item[f"{var}_f"][~item.mask[var]]
    return pred

In [None]:
gap_fill_item(targ, REddy_df, "TA", filled)


Unnamed: 0_level_0,TA,SW_IN,LW_IN,VPD,WS,PA,P,SWC,TS
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2000-01-03 04:00:00,0.81,0.0,304.148,0.000,3.53,97.00,0.0,43.09,1.37
2000-01-03 04:30:00,0.95,0.0,304.382,0.000,3.99,96.98,0.0,43.09,1.37
2000-01-03 05:00:00,1.09,0.0,304.382,0.000,3.61,96.96,0.0,43.09,1.37
2000-01-03 05:30:00,1.18,0.0,304.382,0.009,3.90,96.93,0.0,43.09,1.37
2000-01-03 06:00:00,1.35,0.0,304.382,0.061,4.17,96.91,0.0,43.09,1.38
...,...,...,...,...,...,...,...,...,...
2000-01-05 03:30:00,4.62,0.0,330.202,1.162,6.53,95.91,0.0,44.12,1.82
2000-01-05 04:00:00,4.51,0.0,330.202,1.636,5.76,95.94,0.0,44.12,1.85
2000-01-05 04:30:00,4.11,0.0,299.320,1.746,5.79,95.98,0.0,44.12,1.85
2000-01-05 05:00:00,3.77,0.0,299.320,2.065,6.00,96.03,0.0,44.12,1.83


#### MDSImputation

In [None]:
#| export
class MDSImputation:
    name = "MDS"
    def __init__(self, var, df):
        store_attr()
        self.out = io.StringIO()
    def __call__(self, item):
        REddy_df = item2REddy(item, self.var, self.df)
        REddy_df_r = r.toR_timestamp(pd2R(REddy_df))
        with redirect_stderr(self.out):
            filled = R2pd(r.fill_gaps_EProc(REddy_df_r, self.var))
        return gap_fill_item(item, REddy_df, self.var, filled)

In [None]:
mds_imp = MDSImputation('TA', hai)

In [None]:
mds_imp(targ)

Unnamed: 0_level_0,TA,SW_IN,LW_IN,VPD,WS,PA,P,SWC,TS
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2000-01-03 04:00:00,0.81,0.0,304.148,0.000,3.53,97.00,0.0,43.09,1.37
2000-01-03 04:30:00,0.95,0.0,304.382,0.000,3.99,96.98,0.0,43.09,1.37
2000-01-03 05:00:00,1.09,0.0,304.382,0.000,3.61,96.96,0.0,43.09,1.37
2000-01-03 05:30:00,1.18,0.0,304.382,0.009,3.90,96.93,0.0,43.09,1.37
2000-01-03 06:00:00,1.35,0.0,304.382,0.061,4.17,96.91,0.0,43.09,1.38
...,...,...,...,...,...,...,...,...,...
2000-01-05 03:30:00,4.62,0.0,330.202,1.162,6.53,95.91,0.0,44.12,1.82
2000-01-05 04:00:00,4.51,0.0,330.202,1.636,5.76,95.94,0.0,44.12,1.85
2000-01-05 04:30:00,4.11,0.0,299.320,1.746,5.79,95.98,0.0,44.12,1.85
2000-01-05 05:00:00,3.77,0.0,299.320,2.065,6.00,96.03,0.0,44.12,1.83


### ERA Imputation

In [None]:
#| export
class ERAImputation:
    name = "ERA-I"
    def __call__(self, item):
        pred = item.control.copy()
        names = [col for col in pred.columns if not col.endswith("_lag_1")]
        pred = pred.filter(names)
        pred = pred.rename(columns=lambda x: x.replace("_ERA", ""))
        # columns that cannot be predicted get a NA        
        for col in item.data.columns:
            if col not in pred.columns:
                pred[col] = np.nan 
        return pred

In [None]:
targ.data.columns

Index(['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], dtype='object')

In [None]:
era_imp = ERAImputation()

In [None]:
era_imp(targ)

Unnamed: 0_level_0,TA,SW_IN,VPD,PA,P,WS,LW_IN,SWC,TS
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
2000-01-03 04:00:00,1.702,0.0,0.693,97.016,0.000,2.948,304.148,,
2000-01-03 04:30:00,1.762,0.0,0.691,97.002,0.026,2.989,304.382,,
2000-01-03 05:00:00,1.736,0.0,0.697,96.988,0.000,3.015,304.382,,
2000-01-03 05:30:00,1.711,0.0,0.703,96.974,0.000,3.041,304.382,,
2000-01-03 06:00:00,1.686,0.0,0.709,96.960,0.000,3.066,304.382,,
...,...,...,...,...,...,...,...,...,...
2000-01-05 03:30:00,4.945,0.0,0.774,95.910,0.252,6.107,330.202,,
2000-01-05 04:00:00,4.789,0.0,0.758,95.926,0.252,6.189,330.202,,
2000-01-05 04:30:00,4.632,0.0,0.741,95.942,0.044,6.270,299.320,,
2000-01-05 05:00:00,4.251,0.0,0.766,95.988,0.000,5.923,299.320,,


## Metrics

In [None]:
#| export
class MaskedMetric:
    def __init__(self, metric): store_attr()
    def __call__(self, targ, pred):
        if isinstance(pred, NormalsDf): pred = pred.mean
        row_sel, col_sel = ~targ.mask.all(1), ~targ.mask.all(0)
        assert not targ.mask.loc[row_sel,col_sel].any().all() # gap is a rectangle
        data, pred = targ.data.loc[row_sel,col_sel], pred.loc[row_sel,col_sel]
        return self.metric(data, pred) if not np.isnan(pred).all().all() else np.array([np.nan])

In [None]:
#| export
from sklearn.metrics import mean_squared_error

In [None]:
#| export
def rmse(y_true, y_pred):
    return np.sqrt(mean_squared_error(y_true, y_pred, multioutput='raw_values'))

In [None]:
#| export
def normalize(x, mean, std): return (x - mean) / std 
class NormalizedMetric:
    def __init__(self, metric: MaskedMetric, mean, std):
        self.mean = np.array(mean)
        self.std = np.array(std)
        self.metric = metric
    def __call__(self, targ, pred, var:int|None = None):
        targ = targ.copy()
        norm_conf = (self.mean if var is None else self.mean[var], self.std if var is None else self.std[var])
        targ.data = normalize(targ.data, *norm_conf)
        if isinstance(pred, NormalsDf): pred = pred.mean
        pred = normalize(pred, *norm_conf)
        return self.metric(targ, pred)

In [None]:
#| export
rmse_mask = MaskedMetric(rmse)

In [None]:
rmse_mask(targ, pred)

array([0.3698457])

### Validation Imputations

checks that the values of the imputation makes sense

this is the average error for ERA

In [None]:
np.sqrt((hai['TA'] - hai_era.loc[hai.index]['TA_ERA']).pow(2).mean())

1.8708585984196093

In [None]:
rmse_mask(targ, ERAImputation()(targ))

array([0.29919709])

In [None]:
err_era = []
for _ in range(100):
    items = random.choices(dls.items, k = 1)
    targ = orig_target(dls, items)[0]
    err_era.append(rmse_mask(targ, ERAImputation()(targ)))

In [None]:
np.array(err_era).mean()

1.3839633283333717

## Comparison

Take a variable to be filled, makes an artificial gap of given len, tries to fill with 3 methods and return metrics for each of them, repeat `n_rep` times

In [None]:
#| export
import random
import polars as pl
from tqdm.auto import tqdm

### prep visualization

In [None]:
#| export
def format_gap_len(
    gap_len: int # gap length in num observations (30 mins)
):
    """Nice formatting for gap lengths"""
    gap_h = round(gap_len / 2) # to hours
    if gap_h < 24:
        return f"{gap_h} h"
    elif gap_h < 24 * 7: # days
        gap_d = round(gap_h / 24)
        label = "days" if gap_d > 1 else "day"
        return f"{gap_d} {label} ({gap_h} h)"
    else: # weeks
        gap_w = round(gap_h / (24 * 7))
        label = "weeks" if gap_w > 1 else "week"
        return f"{gap_w} {label} ({gap_h} h)"

In [None]:
test_eq(format_gap_len(10), "5 h")
test_eq(format_gap_len(48), "1 day (24 h)")
test_eq(format_gap_len(90), "2 days (45 h)")
test_eq(format_gap_len(336), "1 week (168 h)")

In [None]:
#| export
var_type = CategoricalDtype(categories=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], ordered=True)

method_type = CategoricalDtype(categories=["Kalman Filter", "ERA-I", "MDS"], ordered=True)

In [None]:
#| export
def _as_category(df: pd.DataFrame):
    # print([format_gap_len(g) for g in np.sort(df.gap_len.unique())])
    df = df.assign(
        var = df['var'].astype(var_type),
        gap_len_f = df['gap_len_f'].astype(CategoricalDtype(categories = [format_gap_len(g) for g in np.sort(df.gap_len.unique())], ordered=True)))
    if 'method' in df.columns: df = df.assign(method = df['method'].astype(method_type))
    return df.sort_values(['var', 'gap_len', 'method'] if 'method' in df.columns else ['var', 'gap_len'] )

In [None]:
#| export
def prep_df(df): 
    df = df.assign(gap_len_f = df.gap_len.apply(format_gap_len))
    df =  _as_category(df)
    return df.assign(gap_len = df.gap_len // 2) # need to do after category conversion

### Aggregation

In [None]:
hai.std(axis=0)['TA']

7.924610764367695

In [None]:
get_stats(hai)

(tensor([8.3339e+00, 1.2096e+02, 3.1150e+02, 3.3807e+00, 3.1800e+00, 9.5962e+01,
         4.3427e-02, 3.4843e+01, 7.9349e+00], dtype=torch.float64),
 tensor([  7.9246, 204.0026,  41.9557,   4.3684,   1.6254,   0.8552,   0.2803,
           8.9131,   5.6586], dtype=torch.float64))

$$ \text{RMSE}_\text{stand} = \frac{1}{\sigma}\text{RMSE}$$

In [None]:
#| export
def timeseriesAgg(targ, pred, *args): return {'pred': [pred], 'targ': [targ]}

class RMSEAgg:
    """Aggregate to rmse and normalized rmse"""
    def __init__(self, df):
        self.std = df.std(axis=0)
    def __call__(self, targ, pred, var):
        rmse = rmse_mask(targ, pred).item()
        rmse_stand = rmse / self.std[var]
        return {'rmse': rmse,'rmse_stand': rmse_stand,}

### Imp Methods

In [None]:
#| export
class ImpComparison():
    def __init__(self, models: pd.DataFrame, df, control, block_len, rmse=True, time_series = False):
        store_attr()
        self.k_imp = KalmanImputationVar(models)
        self.era_imp = ERAImputation()
        self.mds_imp = MDSImputation("", df)
        self.methods = [self.k_imp, self.era_imp, self.mds_imp]
        aggrs = []
        if rmse: aggrs.append(RMSEAgg(df))
        if time_series: aggrs.append(timeseriesAgg) 
        self.aggrs = aggrs
        
    def _compare_single(self, gap_len, var, n_rep):
        """Compares `n_rep` times the imputation methods, for gap in `var` with len `gap_len`"""
        dls = imp_dataloader(self.df, self.control, var_sel = var, gap_len=gap_len, block_len=self.block_len, control_lags = [1], n_rep=1, bs=1).cpu()
        self.mds_imp.var = var
        
        outs = []
        for i in tqdm(range(n_rep), leave=False):
            item = random.choice(dls.items)
            pred_k, targ = self.k_imp.preds_all(var = var, items = [item], dls=dls)
            pred_k, targ = pred_k[0], targ[0]
            for imp in self.methods:
                pred = imp(targ) if imp is not self.k_imp else pred_k
                out = {
                    'method': imp.name,
                    'var': var,
                    'gap_len': gap_len,
                    'idx_rep': i,
                }
                for aggr in self.aggrs:
                    out = out | aggr(targ, pred, var)
                outs.append(out)
        return pd.DataFrame(outs)
    
    def compare(self, gap_len, var, n_rep, raw=False):
        """Compare imputation performance for all combination of parameters"""
        arg_sets = list(product_dict(gap_len=tuplify(gap_len), var=tuplify(var)))
        out = []
        for arg_set in tqdm(arg_sets):
            out.append(self._compare_single(**arg_set, n_rep=n_rep))
        return prep_df(pd.concat(out)) if not raw else pd.concat(out)
        

In [None]:
#| export
base_path = here("analysis/results/trained_models")
def l_model(x, base_path=base_path): return torch.load(base_path / x)

In [None]:
models_var = pd.DataFrame.from_records([
    {'var': 'TA',    'model': l_model("TA_specialized_gap_6-336_v3_0.pickle",base_path)},
    {'var': 'SW_IN', 'model': l_model("SW_IN_specialized_gap_6-336_v2_0.pickle",base_path)},
    {'var': 'LW_IN', 'model': l_model("LW_IN_specialized_gap_6-336_v1.pickle",base_path)},
    {'var': 'VPD',   'model': l_model("VPD_specialized_gap_6-336_v2_0.pickle",base_path)},
    {'var': 'WS',    'model': l_model("WS_specialized_gap_6-336_v1.pickle",base_path)},
    {'var': 'PA',    'model': l_model("PA_specialized_gap_6-336_v3_0.pickle",base_path)},
    {'var': 'P',     'model': l_model("1_gap_varying_6-336_v3.pickle",base_path)},
    {'var': 'TS',    'model': l_model("TS_specialized_gap_6-336_v2_0.pickle",base_path)},
    {'var': 'SWC',   'model': l_model("SWC_specialized_gap_6-336_v2_1.pickle",base_path)},
])

In [None]:
comp = ImpComparison(models = models_var, df = hai, control = hai_era, block_len = 446)

In [None]:
data_results = comp.compare(gap_len = [12, 24, 48, 336], var=["TA", "SW_IN"], n_rep=3) 

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
data_results.head()

Unnamed: 0,method,var,gap_len,idx_rep,rmse,rmse_stand,gap_len_f
0,Kalman Filter,TA,6,0,0.310825,0.039223,6 h
3,Kalman Filter,TA,6,1,0.52928,0.066789,6 h
6,Kalman Filter,TA,6,2,0.330725,0.041734,6 h
1,ERA-I,TA,6,0,1.329099,0.167718,6 h
4,ERA-I,TA,6,1,1.468124,0.185261,6 h


In [None]:
data_results.method.unique()

['Kalman Filter', 'ERA-I', 'MDS']
Categories (3, object): ['Kalman Filter' < 'ERA-I' < 'MDS']

In [None]:
data_results.columns

Index(['method', 'var', 'gap_len', 'idx_rep', 'rmse', 'rmse_stand',
       'gap_len_f'],
      dtype='object')

### Kalman Filters

In [None]:
#| export
class KalmanImpComparison():
    """Compare different Kalman filters"""
    def __init__(self,
                 models: pd.DataFrame, # model, gap_single_var, 
                 df: pd.DataFrame,
                 control: pd.DataFrame,
                 block_len:int,
                 rmse:bool=True,
                 time_series:bool = False):
        store_attr()
        self.imps = models.assign(imp = models.model.map(lambda model: KalmanImputation(model)))
        self.columns = list(df.columns)
        aggrs = []
        if rmse: aggrs.append(RMSEAgg(df))
        if time_series: aggrs.append(timeseriesAgg) 
        self.aggrs = aggrs
        
    def _compare_single(self, n_rep:int, gap_len:int, var:list[str]):
        """Compares `n_rep` times the imputation methods, for gap in `var` with len `gap_len`"""
        outs = []
        imps = imps = self.imps[self.imps['var'] == var] if 'var' in self.imps.columns else self.imps
        
       
        dls = imp_dataloader(self.df, self.control,
                                 var_sel = var,
                                 gap_len=gap_len, block_len=self.block_len, control_lags = [1], n_rep=1, bs=1,
                                 shifts = gen_shifts(50),
                                ).cpu()

        items_orig = random.choices(dls.items, k=n_rep)
             
        for (_,imp) in imps.iterrows():
            
            if not ( imp.gap_single_var if hasattr(imp, 'gap_single_var') else True):
                items = [MeteoImpItem(i = i.i, shift = i.shift, var_sel = self.columns, gap_len = i.gap_len) for i in items_orig]
            else:
                items = items_orig.copy()
            

            var_idx = _index_var(self.df, [var])[0]
            metrics_fn = MultiMetrics(loss = PredictLossVar(only_gap=True, var = var_idx), likelihood = PredictLikelihoodVar(only_gap=True, var = var_idx) )

            for i in range(n_rep):
                pred, targ, metric = imp.imp.preds_all_metrics(items = [items[i]], dls=dls, metrics=metrics_fn)
                pred, targ = pred[0], targ[0]
                pred = pred.mean.iloc[:, [var_idx]]
                targ = MeteoImpDf(targ.data.iloc[:, [var_idx]], targ.mask.iloc[:, [var_idx]], targ.control.iloc[:, [var_idx]])
                out = {
                    'var': var,
                    'loss': metric['loss'][0].item(),
                    'likelihood': metric['likelihood'][0].item(),
                    'gap_len': gap_len,
                    'idx_rep': i,
                } | imp.drop(index=["model", "imp"]).to_dict()
                for aggr in self.aggrs:
                    out = out | aggr(targ, pred, var)
                outs.append(out)
        return pd.DataFrame(outs)
    
    def compare(self, n_rep:int, gap_len:list[int], var:list[list[str]]):
        """Compare imputation performance for all combination of parameters"""
        arg_sets = list(product_dict(gap_len=tuplify(gap_len), var=tuplify(var)))
        out = []
        for arg_set in tqdm(arg_sets):
            out.append(self._compare_single(**arg_set, n_rep=n_rep))
        return prep_df(pd.concat(out))
        

In [None]:
models_nc = pd.DataFrame({'model': [ l_model("1_gap_varying_336_no_control_v1.pickle"), l_model("1_gap_varying_6-336_v3.pickle")],
                          'type':   [ 'No Control',                                       'Use Control'                         ]})                                        

In [None]:
models_nc

Unnamed: 0,model,type
0,"Kalman Filter\n N dim obs: 9,\n N dim state: 18,\n N dim contr: 14",No Control
1,"Kalman Filter\n N dim obs: 9,\n N dim state: 18,\n N dim contr: 14",Use Control


In [None]:
kcomp = KalmanImpComparison(models_nc, hai, hai_era, 100)

In [None]:
#| export
from meteo_imp.kalman.training import _index_var

In [None]:
k_results = kcomp.compare(n_rep =3, gap_len = [6, 12, 24], var = list(hai.columns))

  0%|          | 0/27 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
k_results

## Plotting

### Utils

In [None]:
#| export
import altair as alt

#### Faceting

In [None]:
#| export
def facet_wrap(data: pd.DataFrame, # full dataset
               plot_fn,# function that makes the plot, takes 2 arguments: data and y_label
               col: str, # column to facet
               y_labels: list[str]|None =None, # custom labels y axis 
               n_cols=3,
               y_resolve='independent'
              ):
    col_vals = data[col].unique()
    plot_list = [alt.hconcat() for _ in range(0, len(col_vals), n_cols)]
    for i, col_v in enumerate(col_vals):
        plot = plot_fn(data[data[col]==col_v].copy(),
                       y_labels[i] if y_labels is not None else col_v
                      ).properties(title=str(col_v))
        plot_list[i // n_cols] |= plot
    return alt.vconcat(*plot_list).resolve_scale(xOffset='independent')#.resolve_scale(
        #y=y_resolve
    #)
    

In [None]:
#| export
def facet_grid(data: pd.DataFrame, # full dataset
               plot_fn,# function that makes the plot, takes 2 arguments: data and y_label
               col: str, # column to facet,
               row: str,
               y_labels: list[str]|None = None, # custom labels y axis
              ):
    row_vals = data[row].unique()
    n_cols = len(data[col].unique())
    plots = []
    for row_val, y_label in zip_longest(row_vals, listify(y_labels)):
        plot = facet_wrap(data[data[row]==row_val].copy(), plot_fn, col, y_label, n_cols=n_cols).properties(title=row_val)
        plots.append(plot)
    return alt.vconcat(*plots)

In [None]:
from itertools import product

In [None]:
test_data = pd.DataFrame(list(product(['0','1'], ['a', 'b'])), columns = ['row', 'col'])
test_data['text'] = test_data.row + test_data.col

In [None]:
def test_plot(data, *args): return alt.Chart(data).mark_text().encode(text='text')

In [None]:
facet_wrap(test_data, test_plot, col = 'row')

In [None]:
facet_grid(test_data, test_plot, col = 'col', row='row')

#### Format

In [None]:
#| export
method_scale = alt.Scale(domain=["Kalman Filter", "ERA-I", "MDS"], scheme='dark2')
meteo_scale = alt.Scale(domain = ['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], scheme='dark2')

In [None]:
#| export
def _get_labels(data, y, y_labels):
    """Get optimal labels depending on the `y`"""
    if y_labels is None:
        if y == "rmse_stand": return [f"Standardized RMSE" for var in data['var'].unique()]
        elif y == "rmse": return [f"RMSE {var} [{units_big[var]}]" for var in data['var'].unique()]
        elif y == "mean": return [f"{var} [{units_big[var]}]" for var in data['var'].unique()]
        else: return [y for var in data['var'].unique()]
    else:
        return y_labels

In [None]:
#| export
@patch
def pipe(self: alt.Chart|alt.VConcatChart|alt.HConcatChart, f: Callable):
    """applies `f` to `self`"""
    return f(self)

@dataclass
class PlotFormatter():
    """Format altair plot by setting font sizes/legend position """
    font_size: int = 18
    legend_font_size: int = 18 
    title_font_size: int = 20
    legend_label_limit = 300
    legend_symbol_size = 150
    def __call__(self, plot: alt.Chart):
        return (plot
            .configure_legend(orient="bottom", labelFontSize=self.font_size, titleFontSize=self.legend_font_size)
            .configure_axis(labelFontSize=self.font_size, titleFontSize=self.font_size )
            .configure_title(fontSize=self.title_font_size))
    
    @property
    def color_legend(self):
        """Settings for color legend"""
        return alt.Legend(labelLimit=self.legend_label_limit, symbolSize=self.legend_symbol_size)
    
plot_formatter = PlotFormatter()

### The Plot 

In [None]:
def remove_outliers(data, var, groupby):
    """Simple but maybe computationally inefficent to remove outliers from boxplot """

In [None]:
def custom_box_plot(data, x, y, color, xoffset):
    """Boxplot for altair without outliers"""
    bar = alt.Chart(data).mark_bar().encode(
            x=x,
            y='q1:Q',
            y2='q2:Q',
            color=color,
            xoffset=xoffset
        ).transform_aggregate(
            q1=f'q1({y})',
            q3=f'q1({y})',
            groupby=[color, xoffset]
    )
    
    return bar

In [None]:
# custom_box_plot(data_results, 'gap_len', 'rmse', 'method', 'method')

In [None]:
#| export
def _the_plot(data, y_label='rmse', y = 'rmse'):
    xoffset_domain = ["Kalman Filter", "ERA-I", "MDS"] 
    return alt.Chart(data).mark_boxplot(extent="min-max").encode(
        x = alt.X('gap_len:N', title='Gap length [h]', axis=alt.Axis(labelAngle=0)),
        y = alt.Y(y, title=y_label, axis=alt.Axis(grid=True)),
        color=alt.Color('method:N', title = "Method", scale=method_scale, legend=plot_formatter.color_legend),
        xOffset=alt.XOffset('method', scale=alt.Scale(domain=xoffset_domain)),
        # column=alt.Column('var')
    ).properties(width=250, height=200)

In [None]:
#| export
def the_plot(data):
     return (facet_wrap(data, _the_plot, "var",
                       y_labels = [f"RMSE {var} [{units_big[var]}]" for var in data['var'].unique()])
            .pipe(plot_formatter))

In [None]:
the_plot(data_results)

In [None]:
#| export
def the_plot_stand(data):
    # data = data.query("method == 'KalmanFilter'")
    return alt.Chart(data).mark_boxplot(extent="min-max").encode(
        x = alt.X('var', title='gap_len [h]', axis=alt.Axis(labelAngle=0)),
        y = alt.Y('rmse_stand', title="Standardized RMSE", axis=alt.Axis(grid=True)),
        color=alt.Color('method', title="Method", scale=method_scale, legend=plot_formatter.color_legend),
        xOffset=alt.XOffset('method:N', scale= alt.Scale(domain=method_scale.domain)),
    ).properties(width=600, height=300).pipe(plot_formatter)

In [None]:
the_plot_stand(data_results)

In [None]:
#| export
def the_plot_stand2(data):
     return _the_plot(data, y='rmse_stand', y_label="Standardized RMSE").facet(alt.Row('var',sort=list(var_type.categories)) , columns=3)

In [None]:
dict(zip(scale_meteo.domain, scale_meteo.range))

In [None]:
the_plot_stand2(data_results)

In [None]:
alt.Color('color')

In [None]:
x = alt.X('gap_len'); color = alt.Color('method'); xOffset = alt.XOffset('method'); y = alt.Y('rmse_stand')

In [None]:
data_results.groupby([x.shorthand, color.shorthand, xOffset.shorthand]).agg({y.shorthand: ['median', lambda x: x.quantile(.25), lambda x: x.quantile(.75)]}).columns

In [None]:
data_results

In [None]:
data_results.groupby(['gap_len', 'method', 'var']).agg(
        median = pd.NamedAgg(y.shorthand, 'median'),
        q1 = pd.NamedAgg(y.shorthand,lambda x: x.quantile(.25)),
        q3 = pd.NamedAgg(y.shorthand,lambda x: x.quantile(.75))).reset_index()

In [None]:
#| export
def custom_boxplot_nooutlier(data: pd.DataFrame,x:alt.X,y:alt.Y, color:alt.Color, xOffset:alt.XOffset):
    data = data.groupby(['var', 'method']).agg(
        median = pd.NamedAgg(y.shorthand, 'median'),
        q1 = pd.NamedAgg(y.shorthand,lambda x: x.quantile(.25)),
        q3 = pd.NamedAgg(y.shorthand,lambda x: x.quantile(.75))).reset_index().astype({'var': str})
    y_label= 'Standaridized RMSE'
    bar = alt.Chart(data).mark_bar(size=14).encode(alt.Y('q1', title=y_label), alt.Y2('q3', title=y_label), x, color, xOffset)
    tick = alt.Chart(data).mark_tick(color='white', size=14).encode(alt.Y('median'), x, xOffset)
    return bar + tick

In [None]:
#| export
def the_plot_stand3(data):
    return custom_boxplot_nooutlier(data,
                                    alt.X('var:N', title='variable',axis=alt.Axis(labelAngle=0), scale=alt.Scale(domain=meteo_scale.domain)),
                                    alt.Y('rmse_stand'),
                                    alt.Color('method:N', scale=method_scale),
                                    alt.XOffset('method', )).properties(width=600)

In [None]:
the_plot_stand3(data_results)

violin plots don't work

In [None]:
def _the_plot_violin(data, y_label):
        return alt.Chart(data).mark_area(orient='horizontal').encode(
        x = alt.X('density:Q', title='gap_len [h]', axis=alt.Axis(labelAngle=0)),
        y = alt.Y('rmse', title=y_label, scale=alt.Scale(domainMin=-.2 *data['rmse'].mean())),
        color=alt.Color('method:N', scale=alt.Scale(domain=["Kalman Filter", "ERA-I", "MDS"], scheme='dark2')),
        xOffset=alt.XOffset('method', scale=alt.Scale(domain=["Kalman Filter", "ERA-I", "MDS"])),
        column=alt.Column('gap_len')
    ).transform_density(
    'rmse',
    as_=['rmse', 'density'],
    groupby=['method', 'gap_len']
).properties(width=250, height=200)

In [None]:
facet_wrap(data_results, _the_plot_violin, "var", y_labels = [f"RMSE {var} [{units_big[var]}]" for var in data_results['var'].unique()])

### Gap Length

In [None]:
gap_len = KalmanImpComparison(models_var, hai, hai_era, block_len=100).compare(gap_len = [2,10,30], var=['TA', 'SW_IN', 'VPD'], n_rep=3)

In [None]:
gap_len.gap_len_f

In [None]:
#| export
def agg_gap_len(data):
    return data.groupby(["var", "gap_len"]).agg({'rmse': [
        'median',
        ("Q1", lambda x: x.quantile(.25)),
        ("Q3", lambda x: x.quantile(.75))]}).droplevel(0, axis=1).reset_index()

In [None]:
gap_len_agg = agg_gap_len(gap_len)

In [None]:
gap_len_agg

In [None]:
#| export
def _get_era_rmse(df, control):
    """Average RMSE for ERA"""
    names = [col for col in control.columns if not col.endswith("_lag_1")]
    control = control.copy().filter(names)
    control = control.rename(columns=lambda x: x.replace("_ERA", "")).loc[df.index]
    rmse_df = np.sqrt((df[control.columns] - control).pow(2).mean(axis=0))
    rmse_df = rmse_df.to_frame().reset_index()
    rmse_df.columns = ["var", "era_rmse"] 
    return rmse_df.astype({'var': var_type})

In [None]:
rmse_df = _get_era_rmse(hai, hai_era)

In [None]:
rmse_df

In [None]:
gap_len_agg

In [None]:
gap_len_agg = pd.merge(gap_len_agg, rmse_df, on='var')

In [None]:
#| export
def _plot_gap_len(data, y_label):
    median = alt.Chart(data).mark_line().encode(
        x = alt.X('gap_len', title="Gap length [h]", axis=alt.Axis(labelAngle=0)),
        y = alt.Y('median', title=y_label),
        color=alt.Color('var', scale=meteo_scale)
    ).properties(height=150, width=200)
    Qs = alt.Chart(data).mark_errorband().encode(x = 'gap_len', y = alt.Y('Q1', title=y_label), y2= 'Q3', color='var')
    # min = alt.Chart(gap_len_agg).mark_point().encode(x = 'gap_len', y = 'min', color='var')
    # max = alt.Chart(gap_len_agg).mark_point().encode(x = 'gap_len', y = 'max', color='var')
    plot = (median + Qs)
    if not np.isnan(data.era_rmse.iloc[0]): plot += alt.Chart().mark_rule(strokeDash=[2,4]).encode(y=alt.datum(data.era_rmse.iloc[0]))

    return plot

In [None]:
_plot_gap_len(gap_len_agg[gap_len_agg['var'] == 'TA'], "label")

In [None]:
facet_wrap(gap_len_agg, _plot_gap_len, 'var')

In [None]:
#| export
def plot_gap_len(data, df, control):
    data = agg_gap_len(data)
    data = pd.merge(data, _get_era_rmse(df, control), on='var', how='left')
    return (facet_wrap(data, _plot_gap_len, 'var',
                       y_labels = [f"RMSE {var} [{units_big[var]}]" for var in data['var'].unique()])
            .pipe(plot_formatter))

In [None]:
plot_gap_len(gap_len, hai, hai_era)

### Compare

plotting for comparing performance of 2 or more different conditions

In [None]:
#| export
def _plot_compare(data, y_label='rmse', compare:str="", y = "rmse_stand", scale_domain=None):
    domain = ifnone(scale_domain, data[compare].unique())
    return alt.Chart(data).mark_boxplot(extent="min-max").encode(
        x = alt.X('gap_len:N', title='Gap length [h]', axis=alt.Axis(labelAngle=0)),
        y = alt.Y(y, title=y_label, scale= alt.Scale(zero=False)),
        color=alt.Color(compare, scale=alt.Scale(domain = domain, scheme='accent'),legend=plot_formatter.color_legend),
        xOffset=alt.XOffset(compare, scale=alt.Scale(domain = domain)),
        # column=alt.Column('var')
    ).properties(width=250, height=200)

In [None]:
#| export
def plot_compare(data: pd.DataFrame,
                 compare: str,
                 y:str = "rmse_stand",
                 scale_domain: Sequence|None = None,
                 y_labels:Sequence|None = None
                ) -> alt.Chart:
    y_labels = _get_labels(data, y, y_labels) 
    return facet_wrap(data, partial(_plot_compare, compare=compare, scale_domain=scale_domain),
                       "var", y_labels = y_labels
                      ).pipe(plot_formatter)

In [None]:
plot_compare(k_results, compare="type")

can also show different variables on the y, like the loss of the likelihood

In [None]:
plot_compare(k_results, compare="type", y="loss")

### Timeseries

In [None]:
comp_ts = ImpComparison(models = models_var, df = hai, control = hai_era, block_len = 446, rmse=False, time_series = True)

In [None]:
res_ts = comp_ts.compare(gap_len = [12,24,336], var=["TA", 'SW_IN'], n_rep=1) 

In [None]:
res_ts.columns

In [None]:
test_ts = res_ts.query('method == "Kalman Filter" and var == "TA"')

In [None]:
row_res = test_ts.iloc[0]

In [None]:
pred_all = row_res.pred[0]
targ =  row_res.targ[0]
var = row_res['var']

In [None]:
type(pred_all)

In [None]:
targ.mask[var]

In [None]:
list(row_res.index)

In [None]:
np.argmin(row_res.targ[0].mask['TA'])

In [None]:
row_res.targ[0].mask['TA'].index[47]

In [None]:
#| export
def unnest_predictions(row_res: pd.Series, ctx_len:int=50):
    """Unnest predictions/target for each gap to plot timeseries """
    pred = row_res.pred[0]
    targ =  row_res.targ[0]
    var = row_res['var']
    if isinstance(pred, NormalsDf):
        mean = pred.mean[var]
        std = pred.std[var]
    else:
        mean = pred[var]
        std = np.nan
    
    
    measurement = targ.data[var]
    is_present = targ.mask[var]
    
    gap_len = (~is_present).sum()
    gap_start = np.argmin(is_present)
    ctx_start = is_present.index[gap_start - (ctx_len - gap_len) //2]
    ctx_end = is_present.index[gap_start + gap_len + (ctx_len - gap_len) //2]
    
    is_present = is_present[ctx_start:ctx_end]
    measurement = measurement[ctx_start:ctx_end]
    mean = mean[ctx_start:ctx_end]
    if type(std) != float: std = std[ctx_start:ctx_end] 
    
    mean[is_present] = np.nan # plot only predictions in gap
    
    other_cols = {name: row_res[name] for name in list(row_res.index) if name not in ['pred', 'targ']}
    
    out =  pd.DataFrame({'mean': mean, 'std': std, 'measurement': measurement, 'is_present': is_present} | other_cols)
    return out.reset_index()
    

In [None]:
unnest_predictions(row_res)

In [None]:
res_ts_plot = pd.concat([unnest_predictions(row) for _,row in res_ts.iterrows()])

In [None]:
#| export
from meteo_imp.kalman.training import def_selection, plot_points, plot_line, plot_error

In [None]:
plot_points(res_ts_plot.query('var == "TA" and method == "Kalman Filter" and idx_rep == 0 and gap_len==6.'), y= "measurement")

In [None]:
plot_missing_area(res_ts_plot.query('var == "TA" and method == "Kalman Filter" and idx_rep == 0 and gap_len==6.'))

In [None]:
plot_line(res_ts_plot.query('var == "TA" and idx_rep == 0 and gap_len==6.'), y= "mean", color='method', scale=alt.Scale())

In [None]:
plot_error(res_ts_plot.query('var == "TA" and idx_rep == 0 and gap_len==6.').copy(), y= "mean", color='method', scale=alt.Scale())

In [None]:
#| export
def _plot_timeseries(data_plot, y_label="", scale_color=method_scale, compare = 'method', err_band=True):
    """data for one variable and one gap"""
    data_measure = data_plot[data_plot[compare] == data_plot[compare].unique()[0]]
    
    p = [plot_points(data_measure, y= "measurement", y_label=y_label)]
    p.append(plot_missing_area(data_measure))
    if err_band: p.append(plot_error(data_plot, y= "mean", color=compare, scale=scale_color, y_label=y_label))
    p.append(plot_line(data_plot, y= "mean", color=compare,
                     color_title = "Method",   
                     scale=scale_color, y_label=y_label, props={'height': 200, 'width': 300}))
    return alt.layer(*p)

In [None]:
_plot_timeseries(res_ts_plot.query('var == "TA" and idx_rep == 0 and gap_len==6.').copy(), "TA")

In [None]:
_plot_timeseries(res_ts_plot.query('var == "SW_IN" and idx_rep == 0 and gap_len==12.').copy(), "SW_IN")

In [None]:
data_ts = pd.concat([unnest_predictions(row) for _,row in res_ts.iterrows()])

In [None]:
#| export
def plot_timeseries(data, idx_rep:int|None=None, gap_len:int|None = None, max_idx:int = 3,
                    ctx_len={6.: 50, 12.: 50, 168.: 336+48},
                    scale_color=method_scale, compare='method'):
    
    if idx_rep == 'random': idx_rep = int(data['idx_rep'].sample())
    if gap_len is None and idx_rep is not None:
        data_plot = data.query(f'idx_rep == {idx_rep}').copy()
        facet_var = 'gap_len_f'
    elif gap_len is not None and idx_rep is None:
        data_plot = data.query(f'gap_len == {gap_len} and idx_rep < {max_idx}').copy()
        facet_var = 'idx_rep'
    else:
        raise ValueError(f"One and only one of idx_rep, gap_len should be None. got {idx_rep}, {gap_len}")
    data_plot = pd.concat([unnest_predictions(row, ctx_len[row.gap_len]) for _, row in data_plot.iterrows()])
    data_plot = data_plot.astype({'idx_rep': str, 'gap_len_f': str})
    data_plot['gap_len_f'] = data_plot['gap_len_f'].apply(lambda x: "Gap " + x)
    y_labels = _get_labels(data, 'rmse', None)
    return (facet_grid(data_plot, partial(_plot_timeseries, scale_color=scale_color, compare=compare),
                      row="var", col=facet_var, y_labels=y_labels)
            .pipe(plot_formatter))
        
    

In [None]:
plot_timeseries(res_ts, idx_rep='random')

### Scatter plot

In [None]:
#| export
def _plot_scatter(df, only_present=True, x = "value", y="mean", x_label="", y_label = "", color = 'method', scale=method_scale, props = {}):
    # df = df[df.is_present] if only_present else df
    # TODO remove onle_present
    return alt.Chart(df).mark_point().encode(
        x = alt.X(x, title=x_label),    
        y = alt.Y(y, title = y_label),
        color=alt.Color(color, scale= scale),
        shape = color
    ).properties(
        **props
    )

    

In [None]:
_plot_scatter(data_ts.query('var == "TA" and gap_len==12. and idx_rep ==0').copy(), x="measurement", y= "mean", color='method', scale=alt.Scale())

## Table

### The Table

A table where in the rows there is 

In [None]:
t = data_results.groupby(['method', 'var', 'gap_len']).agg({'rmse': ['mean', 'std']}).unstack(level=0)

t_idx = t.columns.droplevel()
t_idx.names = ['RMSE', None]

t2 = t.copy()
t2.columns = t_idx

t2 = t2.sort_index(axis=1, level=1).swaplevel(axis=1)
t2

In [None]:
row = t2.iloc[0]

In [None]:
row

In [None]:
np.argmin(row.iloc[[0,2,4]]) * 2

In [None]:
#| export
def highlight_min_method(row, props, cols): 
    # select even columns that are the mean
    return np.where(row == np.min(row.iloc[cols]), props, '')

def style_the_table(style, cols=[0,2,4]):
    return (style.apply(highlight_min_method, props="font-weight: bold", cols=cols, axis=1)
                .format_index(precision=0).format(precision=3, na_rep='-'))

In [None]:
# #| export
# renames_table_latex = {name: f"\parbox{{2.1cm}}{{{val}}}" for name, val in 
#                  {'SW_IN': "Shortwave radiation incoming \\textbf{SW IN} [\si{W/m^2}]",
#                'LW_IN': 'Longwave radiation incoming \\textbf{LW IN} [$W/m^2$]',
#                'TA': "Air Temperature \\textbf{TA} [$°C$]",
#                'VPD': "Vapuour Pressure Deficit \\textbf{VPD} [$hPa$]",
#                'PA': "Air Pressure \\textbf{PA} [$hPa$]",
#                'P': "Precipitation \\textbf{P} [$mm$]",
#                'WS': "Wind Speed \\textbf{WS} [$m/s$]",
#           }.items()}

In [None]:
#| export
renames_table_latex = {name: f"\\parbox{{2.1cm}}{{{val}}}" for name, val in 
                 {'SW_IN': "\\textbf{SW\\_IN} [\\si{W/m^2}]",
               'LW_IN': '\\textbf{LW\\_IN} [\\si{W/m^2}]',
               'TA': "\\textbf{TA} [\\si{°C}]",
               'VPD': "\\textbf{VPD} [\\si{hPa}]",
               'PA': "\\textbf{PA} [\\si{hPa}]",
               'P': "\\textbf{P} [\\si{mm}]",
               'WS': "\\textbf{WS} [\\si{m/s}]",
               'TS': "\\textbf{TS} [\\si{°C}]",
               'SWC': "\\textbf{SWC} [\\si{\%}]",
          }.items()}

In [None]:
#| export
renames_table_latex = {name: f"{{{val}}}" for name, val in 
                 {'SW_IN': "\\textbf{SW\\_IN} [\\si{W/m^2}]",
               'LW_IN': '\\textbf{LW\\_IN} [\\si{W/m^2}]',
               'TA': "\\textbf{TA} [\\si{°C}]",
               'VPD': "\\textbf{VPD} [\\si{hPa}]",
               'PA': "\\textbf{PA} [\\si{hPa}]",
               'P': "\\textbf{P} [\\si{mm}]",
               'WS': "\\textbf{WS} [\\si{m/s}]",
               'TS': "\\textbf{TS} [\\si{°C}]",
               'SWC': "\\textbf{SWC} [\\si{\%}]",
          }.items()}

In [None]:
#| export
renames_table_latex_stand = {name: f"\\parbox{{2.1cm}}{{{val}}}" for name, val in 
                 {'SW_IN': "\\textbf{SW\\_IN}",
               'LW_IN': '\\textbf{LW\\_IN}',
               'TA': "\\textbf{TA}",
               'VPD': "\\textbf{VPD}",
               'PA': "\\textbf{PA}",
               'P': "\\textbf{P}",
               'WS': "\\textbf{WS}",
               'TS': "\\textbf{TS}",
               'SWC': "\\textbf{SWC}",
          }.items()}

In [None]:
t2.style.pipe(style_the_table)

In [None]:
t3 = t2.rename(index = renames_table_latex).style.apply(highlight_min_method, props="font-weight: bold", axis=1, cols=[0,2,4]).format_index(precision=0).format(precision=3)

In [None]:
t3

In [None]:
print(t3.to_latex(convert_css=True, hrules=True, clines="skip-last;data", column_format="p{2.1cm}c|rr|rr|rr", caption="caption", label="table"))

In [None]:
#| export
def the_table(data, y='rmse', y_name="RMSE"):
    data = data.groupby(['method', 'var', 'gap_len']).agg({y: ['mean', 'std']}).unstack(level=0)

    data_cols = data.columns.droplevel()
    data_cols.names = [y_name, None]
    data.columns = data_cols
    data.index.names = ["Variable", "Gap [$h$]"]
    return data.sort_index(axis=1, level=1).swaplevel(axis=1)   

In [None]:
the_table(data_results)

In [None]:
#| export
def the_table_latex(table, file, caption="", label="", stand=False):
    renames = renames_table_latex if not stand else renames_table_latex_stand
    styled = table.rename(index = renames).style.pipe(style_the_table).format(na_rep="-", precision=3)
    latex = styled.to_latex(convert_css=True, hrules=True, clines="skip-last;data",
                            column_format="p{2.1cm}c|rr|rr|rr", caption=caption, label=label, position_float="centering")
    with open(file, 'w') as f:
        f.write(latex)
    return file

In [None]:
the_table_latex(the_table(data_results), "test_table.tex")

### Table compare

In [None]:
#| export
err_type = CategoricalDtype(categories=["se", "std", "mean", "diff."], ordered=True)
def table_compare(data, compare:str, y = 'rmse_stand', compare_ascending=True):
    data = data.groupby([compare, 'var', 'gap_len']).agg({y: ['mean', 'std', ('se', 'sem')]}).unstack(level=0).droplevel(level=0, axis=1)
    
    data["diff."] = (data.iloc[:, 0] - data.iloc[:, 1])
    
    data_cols = data.columns
    data_cols.names = ['RMSE Standardized', compare]
    # support custom sorting order
    data.columns = pd.MultiIndex.from_frame(data_cols.to_frame().astype({'RMSE Standardized': err_type}))
    data.index.names = ["Variable", "Gap [$h$]"]
    return data.sort_index(axis=1, level=1, ascending=False).swaplevel(axis=1) 

In [None]:
table_compare(k_results, 'type')

In [None]:
#| export
def table_compare_latex(table, file, caption="", label=""):
    styled = table.rename(index = renames_table_latex_stand).style.pipe(partial(style_the_table, cols=[0,3]))
    latex = styled.to_latex(convert_css=True, hrules=True, clines="skip-last;data",
                            column_format="p{2.1cm}c|rrr|rrr|r", caption=caption, label=label, position_float="centering")
    with open(file, 'w') as f:
        f.write(latex)
    return file

### Table Compare 3

same of above but for comparing 3 options instead of 2

In [None]:
#| export
def table_compare3(data, compare:str, y = 'rmse_stand', compare_ascending=True):
    data = data.groupby([compare, 'var', 'gap_len']).agg({y: ['mean', 'std']}).unstack(level=0).droplevel(level=0, axis=1)
    
    data_cols = data.columns
    data_cols.names = ['RMSE Standardized', compare]
    data.index.names = ["Variable", "Gap [$h$]"]
    return data.sort_index(axis=1, level=1, ascending=True).swaplevel(axis=1) 

In [None]:
#| export
def table_compare3_latex(table, file, caption="", label=""):
    styled = table.rename(index = renames_table_latex_stand).style.pipe(partial(style_the_table, cols=[0,2,4]))
    latex = styled.to_latex(convert_css=True, hrules=True, clines="skip-last;data",
                            column_format="p{2.1cm}c|rr|rr|rr", caption=caption, label=label, position_float="centering")
    with open(file, 'w') as f:
        f.write(latex)
    return file

### Gap len table

In [None]:
g = gap_len.groupby(['var', 'gap_len']).agg({'rmse': ['mean']})

In [None]:
g

In [None]:
(g.droplevel(level=0, axis=1)
 .reset_index()
 .melt(id_vars=['var', 'gap_len'], var_name='rmse')
 .pivot(index = ['var', 'rmse'], columns=['gap_len'])
 .droplevel(level=0, axis=1)
)

In [None]:
#| export
def table_gap_len(data, y = 'rmse'):
    t = (data
         .groupby(['var', 'gap_len']).agg({y: ['mean', 'std']})
        .droplevel(level=0, axis=1)
         .reset_index()
         .melt(id_vars=['var', 'gap_len'], var_name='rmse')
         .pivot(index = ['var', y], columns=['gap_len'])
         .droplevel(level=0, axis=1)
        ) 
    
    t.columns.names = ["Gap [$h$]"]
    t.index.names = ("Variable", "RMSE")
    
    return t.sort_index(axis=1, level=1)

In [None]:
table_gap_len(gap_len)

In [None]:
#| export
def table_gap_len_latex(table, file, caption="", label=""):
    table.columns = [f"{col:.0f}" for col in list(table.columns)]
    styled = table.rename(index = renames_table_latex_stand).style.format(precision=3, na_rep='-')
    table_cols = 'c' * len(table.columns)
    latex = styled.to_latex(convert_css=True, hrules=True, clines="skip-last;data",
                            column_format="p{2.1cm}l|" + table_cols, caption=caption, label=label, position_float="centering")
    with open(file, 'w') as f:
        f.write(latex)
    return file

In [None]:
[f"{col:.0f}" for col in list(table_gap_len(gap_len).columns)]

In [None]:
from tempfile import tempdir
from pathlib import Path

In [None]:
table_gap_len_latex(table_gap_len(gap_len), Path(tempdir) / "test_table.tex")