# Load libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objs as go
from pathlib import Path
from collections import defaultdict
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from itertools import chain
from torch.utils.data import ConcatDataset, DataLoader, random_split
from datetime import datetime

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

from shaft_force_sensing import ForceSensingDataset
from shaft_force_sensing.models import LitTransformer

%load_ext autoreload
%autoreload 2

## Set hyperparameters

In [None]:
seed_everything(42)
max_epochs = 30
batch_size = 256
learning_rate = 1e-4
hidden_size = 128
num_layers = 3
num_heads = 8

In [None]:
i_cols = [
    'jaw_position', 'wrist_pitch_position', 'wrist_yaw_position',  'roll_position',
    'wrist_pitch_velocity', 'wrist_yaw_velocity', 'jaw_velocity', 'roll_velocity',
    'wrist_pitch_effort', 'wrist_yaw_effort', 'roll_effort',
    'jaw_effort', 'insertion_effort', 'yaw_effort', 'pitch_effort',
    'tx', 'ty', 'tz', 'fx', 'fy', 'fz'
]
t_cols = ['ati_fx', 'ati_fy', 'ati_fz']

# Load data and preprocess

In [None]:
data_paths = sorted(Path("../data").rglob("*.csv"))

groups = defaultdict(list)
for p in data_paths:
    groups[p.parent.name].append(p)

test_paths = [lst[-1] for lst in groups.values()]
train_paths = [p for p in data_paths if p not in test_paths]
train_paths.pop(3);
train_paths.pop(2);

Nomalize the target forces using a global scaler fitted on all training data

In [None]:
golbal_scaler = StandardScaler()
forces = []
for p in tqdm(train_paths):
    data = np.loadtxt(p, delimiter=",", skiprows=1)
    forces.append(data[:, -3:])
forces = np.concatenate(forces, axis=0)
golbal_scaler.fit(forces);

Training set construction

In [None]:
train_sets = defaultdict(list)
for p in tqdm(train_paths):
    stride = 5
    if p.parent.name == 'Free':
        stride *= 4
    dataset = ForceSensingDataset(
        p, i_cols, t_cols,
        stride, nomalizer=golbal_scaler)
    train_sets[p.parent.name].append(dataset)

train_set = ConcatDataset(
    list(chain.from_iterable(train_sets.values())))

Test set construction

In [None]:
test_sets = defaultdict(list)
for p in tqdm(test_paths):
    stride = 1
    dataset = ForceSensingDataset(
        p, i_cols, t_cols,
        stride, nomalizer=golbal_scaler)
    test_sets[p.parent.name].append(dataset)

test_set = ConcatDataset(
    list(chain.from_iterable(test_sets.values())))

Ratio check

In [None]:
# for group, dsets in test_sets.items():
#     test_sets[group] = ConcatDataset(dsets)

# total_samples = sum(len(dsets) for dsets in test_sets.values())
# for group, dsets in test_sets.items():
#     print(f"{group}: {len(dsets)} samples, {len(dsets)/total_samples*100:.2f}%")
# print(f"Total: {total_samples} samples")

Validation set split

In [None]:
train_size = int(0.9 * len(train_set))
val_size = len(train_set) - train_size
train_set, val_set = random_split(train_set, [train_size, val_size])

In [None]:
len(train_set), len(val_set), len(test_set)

Set up dataloaders

In [None]:
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)

# Training

In [None]:
model = LitTransformer(
    input_size=len(i_cols),
    force_output_size=len(t_cols),
    d_model=hidden_size,
    num_layers=num_layers,
    nhead=num_heads,
    lr=learning_rate,
    data_mean=golbal_scaler.mean_,
    data_std=golbal_scaler.scale_
)

In [None]:
save_dir = Path("../logs") / datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir.mkdir(parents=True, exist_ok=True)

In [None]:
# Early stopping callback
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=3,
    verbose=True,
    mode="min"
)

# Checkpoint to save best model
checkpoint_callback = ModelCheckpoint(
    dirpath=save_dir,
    monitor="val_loss",
    mode="min",
    save_top_k=1,
    verbose=True,
    filename="best-epoch-{epoch:02d}-val_loss-{val_loss:.4f}"
)

# TensorBoard logger

logger = TensorBoardLogger(
    save_dir,
    name="transformer_logs")

In [None]:
trainer = Trainer(
    max_epochs=max_epochs,
    accelerator="auto",
    logger=logger,
    callbacks=[early_stop_callback, checkpoint_callback],
    log_every_n_steps=10,
)

In [None]:
trainer.fit(model, train_loader, val_loader)

# Inference