In [None]:
import os
import pathlib
import sys
import torch
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import re
import pickle

# Co Contraction Experiments

## Import experimental data

In [None]:
folder_path = "/home/jan/dev/reflex-controller/data/co_contraction_experiment"
step_number = 0.1 # here, we can choose all step sizes form the experiments, like 0.01, 0.1, 0.2, 0.5

file_pattern = re.compile(r"^(.*)_([0-9]*\.[0-9]+)\.pkl$")

loaded_files = {}

for filename in os.listdir(folder_path):
    match = file_pattern.match(filename)
    if match:
        file_step_number = float(match.group(2))
        if step_number == file_step_number:
            filepath = os.path.join(folder_path, filename)
            with open(filepath, "rb") as f:
                loaded_files[match.group(1)] = pickle.load(f)

In [None]:
loaded_files["FR_hip_joint"].keys()

In [None]:
np.array(loaded_files["FR_hip_joint"]["torque"])[:, 0]

In [None]:
x = np.array(loaded_files["FL_hip_joint"]["raw_activations"])[:, 0, 0]
y = np.array(loaded_files["FL_hip_joint"]["raw_activations"])[:, 0, 12]
z = np.array(loaded_files["FR_hip_joint"]["torque"])[:, 0, 0]

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")

ax.scatter(x, y, z, c="r", marker="o")

plt.show()

# Fvmax and Vmax optimization
## Import data

In [None]:
folder_path = "/home/jan/dev/reflex-controller/data/vmax_fvmax_tuning"

def parse_filename(filename):
    """
    Parses filenames like <joint_name>_fvmax_<float>_vmax_<float>.pkl
    Returns (joint_name, fvmax, vmax) or None if not matching.
    """
    pattern = r"^(.*)/(.*?)_fvmax_([-+]?\d*\.\d+|\d+)_vmax_([-+]?\d*\.\d+|\d+)\.pkl$"
    match = re.match(pattern, filename)
    if match:
        joint_name = match.group(2)
        fvmax = float(match.group(3))
        vmax = float(match.group(4))
        return joint_name, fvmax, vmax
    return None

def load_files_to_dict(folder_path):
    result = {}
    for fname in os.listdir(folder_path):
        if fname.endswith(".pkl"):
            full_path = os.path.join(folder_path, fname)
            parsed = parse_filename(full_path)

            if parsed:
                joint_name, fvmax, vmax = parsed

                with open(full_path, 'rb') as f:
                    data = pickle.load(f)
                if joint_name not in result:
                    result[joint_name] = {}

                result[joint_name][(fvmax, vmax)] = data

    return result

dict = load_files_to_dict(folder_path)

In [None]:
joint_mapping = {
    'FL_hip_joint': 0, 
    'FR_hip_joint': 1, 
    'RL_hip_joint': 2, 
    'RR_hip_joint': 3, 
    'FL_thigh_joint': 4, 
    'FR_thigh_joint': 5, 
    'RL_thigh_joint': 6, 
    'RR_thigh_joint': 7, 
    'FL_calf_joint': 8, 
    'FR_calf_joint': 9, 
    'RL_calf_joint': 10, 
    'RR_calf_joint': 11
}

goal_positions = {
    'FL_hip_joint': 0.60, 
    'RL_hip_joint': 0.62, 
    'FL_thigh_joint': 2.36, 
    'RL_thigh_joint': 3.72, 
    'FL_calf_joint': -1.16, 
    'RL_calf_joint': -1.17, 
}

def composite_performance_index(x, t, target, band=0.02, weights=(0.1,10,10,0.5), negative=False):
    a, b, c, d = weights
    # Settling time
    try:
        t_settle = true_settling_time(x, t, target, band)
    except IndexError:
        t_settle = t[-1]
    # Overshoot
    if not negative:
        overshoot = np.abs(np.max(x) - target)
    else:
        overshoot = np.abs(target - np.min(x))
    # Wigglyness (total absolute second derivative)
    d2x = np.diff(x, n=2) / np.diff(t[:-1])**2
    wigglyness = np.sum(np.abs(d2x)) * (t[1]-t[0])
    # Steady-state error
    sse = np.abs(x[-1] - target)
    # Composite index
    CPI = a*t_settle + b*overshoot + c*wigglyness + d*sse
    return CPI

def true_settling_time(x, t, target, band=0.02):
    # Compute error at each time
    err = np.abs(x - target)
    band_val = band * np.abs(target)
    # Start from the last point and move backward
    for i in range(len(x)):
        # Check if from t[i:] onward, all values are within band
        if np.all(err[i:] < band_val):
            return t[i]
    return t[-1]  # Never settled


# get the keys with the best performance measures among thigh joints.
def get_top_n_keys(dict, n):
    performance = {}

    # create performance measures
    for joint_name in ["FL_thigh_joint", "RL_thigh_joint"]:
        performance[joint_name] = []
        for key in dict[joint_name].keys():
            metric = composite_performance_index(np.array(dict[joint_name][key]["joint_position"])[1500:2000, 0, joint_mapping[joint_name]], np.arange(500), goal_positions[joint_name])
            performance[joint_name].append((key, metric))

        performance[joint_name] = sorted(performance[joint_name], key=lambda x: x[1])
        
    # search for intersecting keys, until the top n keys are chosen
    result = []
    count = 1
    while len(result) < n: 
        top_key_sets = []
        for joint_name, top_metrics in performance.items():
            keys, metric = zip(*(top_metrics[:count]))

            top_key_sets.append(set(keys))

        result = list(set.intersection(*top_key_sets))

        count += 1

    return result    

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
import mplcursors

print(len(dict['FL_calf_joint'].keys()))

fig, ax = plt.subplots(6, 1, figsize=(10, 40))

lines = []

for i, joint_name in enumerate(dict.keys()):
    for key in dict[joint_name]:
        line, = ax[i].plot(np.array(dict[joint_name][((1.5, 4.5))]["joint_position"])[1500:2000, 0, joint_mapping[joint_name]], label=f"{key}")
        ax[i].set_title(joint_name)
        ax[i].grid(True)
        lines.append(line)
        ax[i].axhline(y=goal_positions[joint_name], color="r", linestyle='--', label=f'y={goal_positions[joint_name]}')

crs = mplcursors.cursor(lines, hover=True)
@crs.connect("add")
def on_add(sel):
    sel.annotation.set_text(sel.artist.get_label())

plt.show()

In [None]:
%matplotlib widget
fig, ax = plt.subplots(6, 1, figsize=(10, 40))

lines = []

top_keys = get_top_n_keys(dict, 4)

for i, joint_name in enumerate(dict.keys()):
    for key in top_keys:
        line, = ax[i].plot(np.array(dict[joint_name][key]["joint_position"])[1500:2000, 0, joint_mapping[joint_name]], label=f"{key}")
        ax[i].plot(np.array(dict[joint_name][(1.5, 4.5)]["joint_position"])[1500:2000, 0, joint_mapping[joint_name]], label=f"{key}", linewidth=2, color="black")
        ax[i].set_title(joint_name)
        ax[i].grid(True)
        lines.append(line)

crs = mplcursors.cursor(lines, hover=True)
@crs.connect("add")
def on_add(sel):
    sel.annotation.set_text(sel.artist.get_label())

plt.show()