# Init

## Load libs

In [None]:
import numpy as np
import plotly.graph_objs as go
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
from itertools import chain
from datetime import datetime
from torch.utils.data import ConcatDataset, DataLoader, random_split

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

from plotly.subplots import make_subplots
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    root_mean_squared_error,
    r2_score
)

from shaft_force_sensing import ForceSensingDataset
from shaft_force_sensing.models import LitTransformer
from shaft_force_sensing.evaluation import (
    tb_to_numpy,
    add_norm,
    array_bais,
    array_medfilt,
)

%load_ext autoreload
%autoreload 2

## Set hyperparameters

In [None]:
seed_everything(42)
max_epochs = 30
batch_size = 256
learning_rate = 1e-4
hidden_size = 128
num_layers = 3
num_heads = 8

In [None]:
i_cols = [
    'jaw_position', 'wrist_pitch_position', 'wrist_yaw_position',  'roll_position',
    'wrist_pitch_velocity', 'wrist_yaw_velocity', 'jaw_velocity', 'roll_velocity',
    'wrist_pitch_effort', 'wrist_yaw_effort', 'roll_effort',
    'jaw_effort', 'insertion_effort', 'yaw_effort', 'pitch_effort',
    'tx', 'ty', 'tz', 'fx', 'fy', 'fz'
]
t_cols = ['ati_fx', 'ati_fy', 'ati_fz']

## Load data and preprocess

In [None]:
data_paths = sorted(Path("../data").rglob("*.csv"))

groups = defaultdict(list)
for p in data_paths:
    groups[p.parent.name].append(p)

test_paths = [lst[-1] for lst in groups.values()]
train_paths = [p for p in data_paths if p not in test_paths]
train_paths.pop(3);
train_paths.pop(2);

# Inference

Load model from checkpoint

In [None]:
save_dir = Path("../logs") / datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir.mkdir(parents=True, exist_ok=True)
ckpt_path = sorted(Path("../logs/20260212_150853").glob("best*.ckpt"))[-1]

In [None]:
model = LitTransformer.load_from_checkpoint(
    ckpt_path
)

Test set construction

In [None]:
golbal_scaler = StandardScaler()
golbal_scaler.mean_ = model.model.data_mean.numpy(force=True)
golbal_scaler.scale_ = model.model.data_std.numpy(force=True)

In [None]:
test_sets = dict()

for p in tqdm(test_paths):
    dataset = ForceSensingDataset(
        p, i_cols, t_cols,
        nomalizer=golbal_scaler)
    test_sets[p.parent.name] = dataset

In [None]:
test_loaders = {group: DataLoader(dset, batch_size=1000, shuffle=False)
                for group, dset in test_sets.items()}

Inference

In [None]:
for group, loader in test_loaders.items():

    logger = TensorBoardLogger(
        save_dir,
        name="transformer_test",
        version=group
    )

    Trainer(
        logger=logger
    ).test(
        model=model,
        dataloaders=loader
    )