# Filter Perfomance and Stability
> Measure performance between Standard Filter/ Square Root Filter, CPU/GPU, batched/not batched 

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from fastcore.test import *
from fastcore.basics import *
from meteo_imp.utils import *
from meteo_imp.gaussian import *
from meteo_imp.data_preparation import MeteoDataTest
from meteo_imp.kalman.filter import *
from meteo_imp.kalman.filter import get_test_data

import pykalman
from typing import *

import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torch.distributions import MultivariateNormal

In [None]:
reset_seed()

In [None]:
kSR = KalmanFilterSR.init_random(3,4,3)
data, mask, control = get_test_data(15, 3,3)

In [None]:
data.shape

torch.Size([2, 15, 3])

In [None]:
kSR.filter(data, mask, control).cov.shape

torch.Size([2, 3, 4, 4])

In [None]:
filt.mean.shape

torch.Size([2, 3, 4, 1])

In [None]:
k = KalmanFilter.init_from(kSR)

In [None]:
k.filter(data, mask, control).cov.shape

torch.Size([2, 15, 4, 4])

In [None]:
class KalmanFilterPerformance():
    def __init__(n_obs, n_dim_obs, n_dim_state, n_dim_contr, bs):
        store_attr()
        
    

## Performance

In [None]:
def compare_performance(n_obs, n_dim_obs, n_dim_state, n_dim_contr, bs, dtype=torch.float64):
    kf_cuda = KalmanFilter.init_random(n_dim_obs,n_dim_state, dtype=dtype).cuda()
    data_cuda, mask_cuda = get_test_data(n_dim_obs,n_dim_state, bs=bs, device="cuda", dtype=dtype)
    
    print("GPU")
    %timeit -n 1 -r 1 kf_cuda.predict(data_cuda, mask_cuda);

    kf_cuda = KalmanFilter.init_random(n_dim_obs,n_dim_state, dtype=dtype)
    data_cuda, mask_cuda = get_test_data(n_dim_obs,n_dim_state, bs=bs, dtype=dtype)
    print("CPU")
    %timeit -n 1 -r 1 kf.predict(data, mask)
    print("No batches CPU")
    %timeit -n 1 -r 1 [kf.predict(d.unsqueeze(0), m.unsqueeze(0)) for d,m in zip(data, mask)] 
    print("No batches GPU")
    %timeit -n 1 -r 1 [kf_cuda.predict(d.unsqueeze(0), m.unsqueeze(0)) for d,m in zip(data_cuda, mask_cuda)] 

In [None]:
compare_performance(100, 2,2,100)

GPU
87.9 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
CPU
7.83 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
No batches CPU
12.9 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
No batches GPU
154 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [None]:
compare_performance(200, 10,10,200)

GPU
2.04 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
CPU
7.9 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
No batches CPU
13.5 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
No batches GPU
2.07 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


### Float64

In [None]:
compare_performance(100, 2,2,100, dtype=torch.float64)

GPU
100 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
CPU
8.29 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
No batches CPU
13.9 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
No batches GPU
159 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [None]:
compare_performance(200, 10,10,200, dtype=torch.float64)

GPU
2.22 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
CPU
8.35 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
No batches CPU
13.7 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
No batches GPU
2.01 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


## Stability

In [None]:
import polars as pl
import altair as alt
from altair import datum

In [None]:
def fuzz_filter_SR(n_iter=10, n_obs=50):
    reset_seed(27)
    out = []
    for n in range(n_iter):
        k = KalmanFilter.init_random(10,5,8)
        kSR = KalmanFilterSR.init_from(k)
        data, mask, control = get_test_data(n_obs,10,8)
        filt = k.filter(data, mask, control)
        filtSR = kSR.filter(data, mask, control)
        for t in range(n_obs):
            P = filt.cov[:,t]
            P_C = filtSR.cov[:,t]
            out.append({'t': t, 'n': n, 'MAE': (P - P_C @ P_C.mT).abs().mean().item()})
    return pl.DataFrame(out)

In [None]:
err_raw = fuzz_filter_SR(70, 62)

In [None]:
err = err_raw.groupby('t').agg([
    pl.col('MAE').median().alias("median"),
    pl.col('MAE').quantile(.75).alias("Q3"),
    pl.col('MAE').quantile(.25).alias("Q1"),
    pl.col('MAE').max().alias("max")
])

In [None]:
median = alt.Chart(err.to_pandas()).mark_line(color="black"
           ).encode(
    x = alt.X('t', title="Number of Iterations"),
    y = alt.Y('median', axis=alt.Axis(format=".1e"), scale=alt.Scale(type="log"), title="MAE"),
    # color=datum("median"),
    strokeDash = datum("median")
    #, scale=alt.Scale(range=['black']))
)

Q1 = alt.Chart(err.to_pandas()).mark_line(color='dimgray', strokeDash=[4,6]).encode(x = 't', y = 'Q1', strokeDash=datum("quantile"))
Q3 = alt.Chart(err.to_pandas()).mark_line(color='dimgray', strokeDash=[4,6]).encode(x = 't', y = 'Q3', strokeDash=datum("quantile"))
max = alt.Chart(err.to_pandas()).mark_line(color='black', strokeDash=[2,2]).encode(x = 't', y = 'max', strokeDash=datum("max"))
(Q1 + Q3 + max + median).interactive().properties(title="Standard Filter vs Square Root Filter (Mean Absolute Error of state cavariances)")