# Init

## Load libs

In [None]:
import numpy as np
import plotly.graph_objs as go
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
from itertools import chain
from datetime import datetime
from torch.utils.data import ConcatDataset, DataLoader, random_split

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

from plotly.subplots import make_subplots
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    root_mean_squared_error,
    r2_score
)

import shaft_force_sensing.models
from shaft_force_sensing import ForceSensingDataset
from shaft_force_sensing.models import (
    LitSequenceModel,
    LitTransformer,
    LitLTC,
)
from shaft_force_sensing.evaluation import (
    tb_to_numpy,
    add_norm,
    array_bais,
    array_medfilt,
)

%load_ext autoreload
%autoreload 2

## Set hyperparameters

In [None]:
transfer_learning = False

In [None]:
seed_everything(42)
max_epochs = 30
batch_size = 256
learning_rate = 1e-4
hidden_size = 128
num_layers = 3
num_heads = 8

In [None]:
axes = ['F_x', 'F_y', 'F_z', 'Norm']

# Evaluation

In [None]:
save_path = Path("../logs") / input("Enter checkpoint directory: ")
for p in save_path.iterdir():
    if p.name in shaft_force_sensing.models.__all__:
        model_cls = p.name
        break
assert model_cls is not None, "Model name not found in checkpoint directory."

In [None]:
model: LitSequenceModel = eval(model_cls).load_from_checkpoint(
    sorted(save_path.glob("best*.ckpt"))[-1],
    map_location="cpu"
)
golbal_scaler = StandardScaler()
golbal_scaler.mean_ = model.data_mean.numpy(force=True)
golbal_scaler.scale_ = model.data_std.numpy(force=True)

del model

In [None]:
save_path = save_path / "test" if not transfer_learning else save_path / "teleop_test" 

## Single set

In [None]:
idx = 1
path = sorted(save_path.glob("*"))[idx]
gt, pred = tb_to_numpy(path)
path.stem

Denormlization

In [None]:
gt = golbal_scaler.inverse_transform(gt)
pred = golbal_scaler.inverse_transform(pred)

Smooth

In [None]:
pred = array_medfilt(pred, kernel_size=71)

Zero offset

In [None]:
pred = array_bais(pred, 50)

Add norm

In [None]:
gt = add_norm(gt)
pred = add_norm(pred)

Time plot

In [None]:
d = gt.shape[1]
fig = make_subplots(rows=d, cols=1, shared_xaxes=True, subplot_titles=axes)

for i, name in enumerate(axes, start=1):
    fig.add_trace(go.Scatter(y=gt[:, i-1], mode="lines", name=f"{name} (gt)"), row=i, col=1)
    fig.add_trace(go.Scatter(y=pred[:, i-1], mode="lines", name=f"{name} (pred)"), row=i, col=1)

fig.update_layout(height=250 * d, title="Ground Truth vs Prediction", showlegend=True)
fig.show()

## Loop all sets

In [None]:
data = dict()

for path in tqdm([_ for _ in sorted(save_path.iterdir()) if _.is_dir()]):
    group = path.stem

    # Load data
    gt, pred = tb_to_numpy(path)

    # Post-processing
    gt = golbal_scaler.inverse_transform(gt)
    pred = golbal_scaler.inverse_transform(pred)

    pred = array_medfilt(pred, kernel_size=71)
    pred = array_bais(pred, 50)

    gt = add_norm(gt)
    pred = add_norm(pred)

    data[group] = (gt, pred)

In [None]:
gt_all = np.concatenate([data[group][0] for group in data], axis=0)
pred_all = np.concatenate([data[group][1] for group in data], axis=0)
data['All'] = (gt_all, pred_all)

In [None]:
for group, (gt, pred) in data.items():
    # Metrics
    gt_min = np.min(gt, axis=0)
    gt_max = np.max(gt, axis=0)
    gt_range = gt_max - gt_min
    rmse = root_mean_squared_error(gt, pred, multioutput='raw_values')
    nrmse = rmse / gt_range
    r2_scores = r2_score(gt, pred, multioutput='raw_values')

    # Logging
    with open(save_path / f"metrics.txt", "a") as f:
        print(f"Group: {group}", file=f)
        for i, name in enumerate(axes):
            print(
                f"{name}: \
                Range={gt_range[i]:.4f}, \
                RMSE={rmse[i]:.4f}, \
                NRMSE={nrmse[i]*100:.2f}%, \
                R2={r2_scores[i]*100:.2f}",
                file=f)
        print("-" * 10, file=f)