# Init

## Load libs

In [None]:
import numpy as np
import plotly.graph_objs as go
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
from itertools import chain
from datetime import datetime
from torch.utils.data import ConcatDataset, DataLoader, random_split

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

from plotly.subplots import make_subplots
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    root_mean_squared_error,
    r2_score
)

from shaft_force_sensing import ForceSensingDataset
from shaft_force_sensing.models import LitTransformer
from shaft_force_sensing.evaluation import (
    tb_to_numpy,
    add_norm,
    array_bais,
    array_medfilt,
)

%load_ext autoreload
%autoreload 2

## Set hyperparameters

In [None]:
seed_everything(42)
max_epochs = 30
batch_size = 256
learning_rate = 1e-4
hidden_size = 128
num_layers = 3
num_heads = 8

In [None]:
i_cols = [
    'jaw_position', 'wrist_pitch_position', 'wrist_yaw_position',  'roll_position',
    'wrist_pitch_velocity', 'wrist_yaw_velocity', 'jaw_velocity', 'roll_velocity',
    'wrist_pitch_effort', 'wrist_yaw_effort', 'roll_effort',
    'jaw_effort', 'insertion_effort', 'yaw_effort', 'pitch_effort',
    'tx', 'ty', 'tz', 'fx', 'fy', 'fz'
]
t_cols = ['ati_fx', 'ati_fy', 'ati_fz']

## Load data and preprocess

In [None]:
data_paths = sorted(Path("../data").rglob("*.csv"))

groups = defaultdict(list)
for p in data_paths:
    groups[p.parent.name].append(p)

test_paths = [lst[-1] for lst in groups.values()]
train_paths = [p for p in data_paths if p not in test_paths]
train_paths.pop(3);
train_paths.pop(2);

Nomalize the target forces using a global scaler fitted on all training data

In [None]:
golbal_scaler = StandardScaler()
forces = []
for p in tqdm(train_paths):
    data = np.loadtxt(p, delimiter=",", skiprows=1)
    forces.append(data[:, -3:])
forces = np.concatenate(forces, axis=0)
golbal_scaler.fit(forces);

# Training

Training set construction

In [None]:
train_sets = defaultdict(list)
for p in tqdm(train_paths):
    stride = 5
    if p.parent.name == 'Free':
        stride *= 4
    dataset = ForceSensingDataset(
        p, i_cols, t_cols,
        stride, nomalizer=golbal_scaler)
    train_sets[p.parent.name].append(dataset)

train_set = ConcatDataset(
    list(chain.from_iterable(train_sets.values())))

Ratio check

In [None]:
# for group, dsets in test_sets.items():
#     test_sets[group] = ConcatDataset(dsets)

# total_samples = sum(len(dsets) for dsets in test_sets.values())
# for group, dsets in test_sets.items():
#     print(f"{group}: {len(dsets)} samples, {len(dsets)/total_samples*100:.2f}%")
# print(f"Total: {total_samples} samples")

Validation set split

In [None]:
train_size = int(0.9 * len(train_set))
val_size = len(train_set) - train_size
train_set, val_set = random_split(train_set, [train_size, val_size])

In [None]:
len(train_set), len(val_set)

Set up dataloaders

In [None]:
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)

Set up model

In [None]:
model = LitTransformer(
    input_size=len(i_cols),
    force_output_size=len(t_cols),
    d_model=hidden_size,
    num_layers=num_layers,
    nhead=num_heads,
    lr=learning_rate,
    data_mean=golbal_scaler.mean_,
    data_std=golbal_scaler.scale_
)

In [None]:
save_dir = Path("../logs") / datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir.mkdir(parents=True, exist_ok=True)

In [None]:
# Early stopping callback
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=3,
    verbose=True,
    mode="min"
)

# Checkpoint to save best model
checkpoint_callback = ModelCheckpoint(
    dirpath=save_dir,
    monitor="val_loss",
    mode="min",
    save_top_k=1,
    verbose=True,
    filename="best-epoch-{epoch:02d}-val_loss-{val_loss:.4f}"
)

# TensorBoard logger
logger = TensorBoardLogger(
    save_dir,
    name="transformer_train")

In [None]:
trainer = Trainer(
    max_epochs=max_epochs,
    logger=logger,
    callbacks=[early_stop_callback, checkpoint_callback],
    log_every_n_steps=10,
)

In [None]:
trainer.fit(model, train_loader, val_loader)

# Inference

Load model from checkpoint

In [None]:
save_dir = Path("../logs") / datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir.mkdir(parents=True, exist_ok=True)
ckpt_path = sorted(Path("../logs/20260212_150853").glob("best*.ckpt"))[-1]

In [None]:
model = LitTransformer.load_from_checkpoint(
    ckpt_path
)

Test set construction

In [None]:
golbal_scaler = StandardScaler()
golbal_scaler.mean_ = model.model.data_mean.numpy(force=True)
golbal_scaler.scale_ = model.model.data_std.numpy(force=True)

In [None]:
test_sets = dict()

for p in tqdm(test_paths):
    dataset = ForceSensingDataset(
        p, i_cols, t_cols,
        nomalizer=golbal_scaler)
    test_sets[p.parent.name] = dataset

In [None]:
test_loaders = {group: DataLoader(dset, batch_size=1000, shuffle=False)
                for group, dset in test_sets.items()}

Inference

In [None]:
for group, loader in test_loaders.items():

    logger = TensorBoardLogger(
        save_dir,
        name="transformer_test",
        version=group
    )

    Trainer(
        logger=logger
    ).test(
        model=model,
        dataloaders=loader
    )

# Evaluation

In [None]:
save_dir = Path("../logs") / "20260212_155802" / "transformer_test"

In [None]:
axes = ['F_x', 'F_y', 'F_z', 'Norm']

## Single set

In [None]:
idx = 1
path = list(save_dir.glob("*"))[idx]
gt, pred = tb_to_numpy(path)
path.stem

Denormlization

In [None]:
gt = golbal_scaler.inverse_transform(gt)
pred = golbal_scaler.inverse_transform(pred)

Smooth

In [None]:
pred = array_medfilt(pred, kernel_size=71)

Zero offset

In [None]:
pred = array_bais(pred, 50)

Add norm

In [None]:
gt = add_norm(gt)
pred = add_norm(pred)

Time plot

In [None]:
d = gt.shape[1]
fig = make_subplots(rows=d, cols=1, shared_xaxes=True, subplot_titles=axes)

for i, name in enumerate(axes, start=1):
    fig.add_trace(go.Scatter(y=gt[:, i-1], mode="lines", name=f"{name} (gt)"), row=i, col=1)
    fig.add_trace(go.Scatter(y=pred[:, i-1], mode="lines", name=f"{name} (pred)"), row=i, col=1)

fig.update_layout(height=250 * d, title="Ground Truth vs Prediction", showlegend=True)
fig.show()

## Loop all sets

In [None]:
data = dict()

for path in tqdm([_ for _ in save_dir.iterdir() if _.is_dir()]):
    group = path.stem

    # Load data
    gt, pred = tb_to_numpy(path)

    # Post-processing
    gt = golbal_scaler.inverse_transform(gt)
    pred = golbal_scaler.inverse_transform(pred)

    pred = array_medfilt(pred, kernel_size=71)
    pred = array_bais(pred, 50)

    gt = add_norm(gt)
    pred = add_norm(pred)

    data[group] = (gt, pred)

In [None]:
gt_all = np.concatenate([data[group][0] for group in data], axis=0)
pred_all = np.concatenate([data[group][1] for group in data], axis=0)
data['All'] = (gt_all, pred_all)

In [None]:
for group, (gt, pred) in data.items():
    # Metrics
    gt_min = np.min(gt, axis=0)
    gt_max = np.max(gt, axis=0)
    gt_range = gt_max - gt_min
    rmse = root_mean_squared_error(gt, pred, multioutput='raw_values')
    nrmse = rmse / gt_range
    r2_scores = r2_score(gt, pred, multioutput='raw_values')

    # Logging
    with open(save_dir / f"metrics.txt", "a") as f:
        print(f"Group: {group}", file=f)
        for i, name in enumerate(axes):
            print(
                f"{name}: \
                Range={gt_range[i]:.4f}, \
                RMSE={rmse[i]:.4f}, \
                NRMSE={nrmse[i]*100:.2f}%, \
                R2={r2_scores[i]*100:.2f}",
                file=f)
        print("-" * 10, file=f)