In [51]:

import os
import numpy as np
import matplotlib.pyplot as plt
import time
import math
import pandas as pd
import ikpy.chain
import ikpy.utils.plot as plot_utils
import itertools
import csv

In [None]:
# # [shoulder_abd(rx), shoulder_flex(ry), shoulder_rot(rz), elbow_flex(ry)]
# joint_limits_lower_bound = np.deg2rad([-180, -90, -90,   0])
# joint_limits_upper_bound = np.deg2rad([50, 180,  90, 150])

# print(joint_limits_lower_bound)
# print(joint_limits_upper_bound)

[-3.14159265 -1.57079633 -1.57079633  0.        ]
[0.87266463 3.14159265 1.57079633 2.61799388]


In [57]:
healthy_base_case = [
    [-180, 50],
    [-90, 180],
    [-90, 90],
    [0, 150]
]

# joint limits       =  healthy      mild       severe
shoulder_abd_limits  = [[-180, 50], [-110, 30], [-80, 20]]
shoulder_flex_limits = [[-90, 180], [-50, 100], [-20, 60]]
shoulder_rot_limits  = [[-90, 90],  [-60, 60],  [-30, 30]]
elbow_flex_limits    = [[0, 150],   [0, 90],    [0, 40]]

# combination of all these limits for different users
all_joint_options = [
    shoulder_abd_limits,
    shoulder_flex_limits,
    shoulder_rot_limits,
    elbow_flex_limits
]

combinations_iterator = itertools.product(*all_joint_options)

user_joint_limits_deg = np.array(list(combinations_iterator))

print(user_joint_limits_deg.shape)
print(user_joint_limits_deg[0])
print(user_joint_limits_deg[-1])

(81, 4, 2)
[[-180   50]
 [ -90  180]
 [ -90   90]
 [   0  150]]
[[-80  20]
 [-20  60]
 [-30  30]
 [  0  40]]


In [53]:
user_joint_limits_deg = user_joint_limits_deg[:10]

In [50]:
csv_filename = "joint_limits.csv"
with open(csv_filename, 'w') as f:
    csv_writer = csv.writer(f)
    csv_writer.writerow(["user_index", "shoulder_abd(rx)_lower", "shoulder_abd(rx)_upper", "shoulder_flex(ry)_lower","shoulder_flex(ry)_upper", "shoulder_rot(rz)_lower","shoulder_rot(rz)_upper", "elbow_flex(ry)_lower", "elbow_flex(ry)_upper"])
    data = np.reshape(user_joint_limits_deg, (-1, 8))
    for index, joint_limit in enumerate(data):
        csv_writer.writerow([str(index), *joint_limit])


In [54]:
def generate_user_fROM(user_chain, joint_limits_lower_rad, joint_limits_upper_rad, num_poses=1000):
    """
    Generates an fROM point cloud for a *specific user's chain*.
    """
    # (This is your 'generate_arm_poses' function, adapted)
    
    num_active_joints = 4
    active_joint_indices = list(range(1, 1 + num_active_joints)) # Indices 1, 2, 3, 4

    # Define the Target Workspace 
    ws_min = np.array([0.2, -0.4, -0.2]) 
    ws_max = np.array([0.8, 0.4, 0.5])   

    valid_poses = []
    attempts = 0
    max_attempts = num_poses * 50 
    
    # Initial pose for IK solver
    initial_position_full = np.zeros(len(user_chain.links))
    initial_position_full[4] = np.deg2rad(10)

    while len(valid_poses) < num_poses and attempts < max_attempts:
        attempts += 1
        target_point = np.random.uniform(low=ws_min, high=ws_max)

        try:
            ik_solution_full = user_chain.inverse_kinematics(
                target_position=target_point,
                initial_position=initial_position_full,
                orientation_mode=None
            )
            pose_angles = ik_solution_full[active_joint_indices]

            # Check if the solution respects THIS user'S joint limits
            within_limits = np.all(pose_angles >= joint_limits_lower_rad) and \
                            np.all(pose_angles <= joint_limits_upper_rad)

            if within_limits:
                # We only need to check limits, not filter for zero,
                # as the limits themselves will prevent a zero pose
                # unless the limits are [0,0,0,0].
                valid_poses.append(pose_angles)

        except ValueError:
            pass # Target unreachable

    if len(valid_poses) < num_poses:
        print(f"  Warning: Only generated {len(valid_poses)}/{num_poses} poses.")
        
    return np.array(valid_poses)

In [58]:
base_urdf_file = "simple_arm.urdf"
active_joint_indices = [1, 2, 3, 4] # The indices of 4 active joints
all_user_fROMs = {} # A dictionary to store the results

print("Dataset generation starts")

for index, limits_deg in enumerate(user_joint_limits_deg):
    print(f"user {index}")
    start_time = time.time()
    
    chain = ikpy.chain.Chain.from_urdf_file(base_urdf_file)

    limits_rad_lower = np.deg2rad([limit[0] for limit in limits_deg])
    limits_rad_upper = np.deg2rad([limit[1] for limit in limits_deg])

    for i, joint_idx in enumerate(active_joint_indices):
        new_bounds = (limits_rad_lower[i], limits_rad_upper[i])
        chain.links[joint_idx].bounds = new_bounds
        # print(f"  Set joint {joint_idx} bounds to {new_bounds}")

    user_fROM_data = generate_user_fROM(
        chain, 
        limits_rad_lower, 
        limits_rad_upper, 
        num_poses=256
    )
    
    user_name = f"{index}"
    all_user_fROMs[user_name] = user_fROM_data
    
    end_time = time.time()
    print(f"  Finished {user_name}. Generated {user_fROM_data.shape[0]} poses in {end_time - start_time:.2f}s.")

Dataset generation starts
user 0




  Finished 0. Generated 256 poses in 3.85s.
user 1
  Finished 1. Generated 256 poses in 5.37s.
user 2
  Finished 2. Generated 256 poses in 4.32s.
user 3
  Finished 3. Generated 256 poses in 3.55s.
user 4
  Finished 4. Generated 256 poses in 5.29s.
user 5
  Finished 5. Generated 256 poses in 3.45s.
user 6
  Finished 6. Generated 256 poses in 3.90s.
user 7
  Finished 7. Generated 256 poses in 5.96s.
user 8
  Finished 8. Generated 256 poses in 4.75s.
user 9
  Finished 9. Generated 256 poses in 4.56s.
user 10
  Finished 10. Generated 256 poses in 5.36s.
user 11
  Finished 11. Generated 256 poses in 3.12s.
user 12
  Finished 12. Generated 256 poses in 3.61s.
user 13
  Finished 13. Generated 256 poses in 5.29s.
user 14
  Finished 14. Generated 256 poses in 3.13s.
user 15
  Finished 15. Generated 256 poses in 3.74s.
user 16
  Finished 16. Generated 256 poses in 5.23s.
user 17
  Finished 17. Generated 256 poses in 3.55s.
user 18
  Finished 18. Generated 256 poses in 4.81s.
user 19
  Finished 1

In [39]:
print(all_user_fROMs["user_0"].shape)
print(all_user_fROMs["user_0"][:5])

(256, 4)
[[ 8.15316735e-01 -8.07705799e-01 -4.87686427e-02  8.36759440e-01]
 [ 8.72664626e-01 -3.11111735e-01  3.65572104e-05  7.40272810e-01]
 [ 1.91220636e-01 -6.07560576e-01  2.23428610e-01  1.98703074e+00]
 [-1.33207543e-01  8.64182567e-01 -1.94676147e-01  1.87819908e+00]
 [ 2.75530307e-01 -6.69864450e-01  3.03527072e-01  1.33127686e+00]]


In [56]:


csv_filename = "all_users_fROM_data.csv"
with open(csv_filename, 'w') as f:
    csv_writer = csv.writer(f)
    csv_writer.writerow(["user_name", "shoulder_abd(rx)", "shoulder_flex(ry)", "shoulder_rot(rz)", "elbow_flex(ry)"])
    for user_name, fROM_data in all_user_fROMs.items():
        for pose in fROM_data:
            pose_str = ','.join(map(str, pose))
            csv_writer.writerow([user_name, *pose])

