# Init

## Load libs

In [1]:
import numpy as np
import plotly.graph_objs as go
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
from itertools import chain
from datetime import datetime
from torch.utils.data import ConcatDataset, DataLoader, random_split

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

from plotly.subplots import make_subplots
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    root_mean_squared_error,
    r2_score
)

import shaft_force_sensing.models
from shaft_force_sensing import ForceSensingDataset
from shaft_force_sensing.models import (
    LitSequenceModel,
    LitTransformer,
    LitLTC,
)
from shaft_force_sensing.evaluation import (
    tb_to_numpy,
    add_norm,
    array_bais,
    array_medfilt,
)

%load_ext autoreload
%autoreload 2

## Set hyperparameters

In [2]:
seed_everything(42)
max_epochs = 30
batch_size = 256
learning_rate = 1e-4
hidden_size = 128
num_layers = 3
num_heads = 8

[rank: 0] Seed set to 42


In [8]:
save_path = Path("/scratch/pioneer/users/sxk2514/shaft_force_sensing/logs/20260216_225606")

model_cls = None
for p in save_path.iterdir():
    if p.name in shaft_force_sensing.models.__all__:
        model_cls = p.name
        print("Found model:", model_cls)
        break
assert model_cls is not None, "Model name not found in checkpoint directory."

Found model: LitLTC


In [12]:
# Determine input columns from the checkpoint's hparams (ablation aware)
from pathlib import Path
import yaml
from shaft_force_sensing.training.utils import get_input_cols_for_config

# Default target columns
t_cols = ['ati_fx', 'ati_fy', 'ati_fz']

# `save_path` is the checkpoint directory selected by the user in the previous cell
model_checkpoint_dir = save_path if 'save_path' in globals() else Path('../logs')
print('Inspecting checkpoint directory for hparams:', model_checkpoint_dir)

# Search for hparams.yaml (Lightning usually stores it under version_*/hparams.yaml)
hparam_file = None
if model_checkpoint_dir.exists():
    # look for version folders first
    for v in sorted(model_checkpoint_dir.glob('version_*')):
        candidate = v / 'hparams.yaml'
        if candidate.exists():
            hparam_file = candidate
            break
    # fallback to recursive search
    if hparam_file is None:
        found = list(model_checkpoint_dir.rglob('hparams.yaml'))
        if found:
            hparam_file = found[0]

if hparam_file is None:
    print('No hparams.yaml found under', model_checkpoint_dir, '\nUsing full input column set by default')
    # Full set (fallback)
    i_cols = [
        'jaw_position', 'wrist_pitch_position', 'wrist_yaw_position', 'roll_position',
        'wrist_pitch_velocity', 'wrist_yaw_velocity', 'jaw_velocity', 'roll_velocity',
        'wrist_pitch_effort', 'wrist_yaw_effort', 'roll_effort',
        'jaw_effort', 'insertion_effort', 'yaw_effort', 'pitch_effort',
        'tx', 'ty', 'tz', 'fx', 'fy', 'fz'
    ]
else:
    print('Found hparams:', hparam_file)
    with open(hparam_file, 'r') as f:
        h = yaml.safe_load(f)
    config_name = h.get('ablation_config') or h.get('ablation') or h.get('input_config') or h.get('config') or 'Full'
    print('Detected ablation config:', config_name)
    i_cols = get_input_cols_for_config(config_name)

print(f"Using {len(i_cols)} input columns (ablation={config_name if hparam_file else 'Full'})")


Inspecting checkpoint directory for hparams: /scratch/pioneer/users/sxk2514/shaft_force_sensing/logs/20260216_225606
Found hparams: /scratch/pioneer/users/sxk2514/shaft_force_sensing/logs/20260216_225606/LitLTC/version_0/hparams.yaml
Detected ablation config: No_Hex10
Using 15 input columns (ablation=No_Hex10)


## Load data and preprocess

In [13]:
data_paths = sorted(Path("../data").rglob("*.csv"))

groups = defaultdict(list)
for p in data_paths:
    groups[p.parent.name].append(p)

test_paths = [lst[-1] for lst in groups.values()]
train_paths = [p for p in data_paths if p not in test_paths]
train_paths.pop(3);
train_paths.pop(2);

# Inference

Load model from checkpoint

In [14]:
model: LitSequenceModel = eval(model_cls).load_from_checkpoint(
    sorted(save_path.glob("best*.ckpt"))[-1],
)

Test set construction

In [15]:
golbal_scaler = StandardScaler()
golbal_scaler.mean_ = model.data_mean.numpy(force=True)
golbal_scaler.scale_ = model.data_std.numpy(force=True)

In [16]:
test_sets = dict()

for p in tqdm(test_paths):
    dataset = ForceSensingDataset(
        p, i_cols, t_cols,
        nomalizer=golbal_scaler)
    test_sets[p.parent.name] = dataset

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 3/3 [00:01<00:00,  2.24it/s]


In [17]:
test_loaders = {group: DataLoader(dset, batch_size=1000, shuffle=False)
                for group, dset in test_sets.items()}

Inference

In [18]:
for group, loader in test_loaders.items():

    logger = TensorBoardLogger(
        save_path,
        name='test',
        version=group
    )

    Trainer(
        logger=logger
    ).test(
        model=model,
        dataloaders=loader
    )

/home/sxk2514/.conda/envs/ltc311/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/sxk2514/.conda/envs/ltc311/lib/python3.11/site ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
ðŸ’¡ Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/sxk2514/.conda/envs/ltc311/lib/python3.11/site-packages/pytorch_lightning/utiliti

Testing DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 60/60 [04:01<00:00,  0.25it/s]


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
ðŸ’¡ Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 76/76 [05:31<00:00,  0.23it/s]


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
ðŸ’¡ Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 127/127 [08:47<00:00,  0.24it/s]
