In [2]:
import pickle
import math

data_files = [
    "index_pinch",
    "middle_pinch",
    "ring_pinch",
    "none",
    "pinky_pinch",
]
pairwise_points = [
    ("THUMB_TIP", "INDEX_FINGER_TIP"),
    ("THUMB_TIP", "MIDDLE_FINGER_TIP"),
    ("THUMB_TIP", "RING_FINGER_TIP"),
    ("THUMB_TIP", "PINKY_TIP"),
    ("THUMB_TIP", "WRIST"),
    ("INDEX_FINGER_TIP", "WRIST"),
    ("MIDDLE_FINGER_TIP", "WRIST"),
    ("RING_FINGER_TIP", "WRIST"),
    ("PINKY_TIP", "WRIST"),
]

data_dict = {}
for file in data_files:
    with open(f"{file}.pkl", "rb") as f:
        data_dict[file] = pickle.load(f)

data_dict.keys()

dict_keys(['index_pinch', 'middle_pinch', 'ring_pinch', 'none', 'pinky_pinch'])

In [5]:
keys_to_labels = {key: label for label, key in enumerate(data_dict.keys())}
euclidean_dist = lambda a, b: math.sqrt((a["x"] - b["x"]) ** 2 + (a["y"] - b["y"]) ** 2)

def preprocess(data):
    right_hand_data, left_hand_data = [], []
    right_hand_labels, left_hand_labels = [], []
    for pose, pose_data in data.items():
        label = keys_to_labels[pose]
        for data_point in pose_data:
            landmarks = data_point["landmarks"]
            feat_vec = [euclidean_dist(landmarks[a], landmarks[b]) for a, b in pairwise_points]
            label = keys_to_labels[pose]
            if data_point["handedness"] == "Right":
                right_hand_data.append(feat_vec)
                right_hand_labels.append(pose)
            else:
                left_hand_data.append(feat_vec)
                left_hand_labels.append(pose)

    return (right_hand_data, right_hand_labels), (left_hand_data, left_hand_labels)

right, left = preprocess(data_dict)

In [6]:
from sklearn.cluster import KMeans

n_clusters = len(keys_to_labels)
hand_kmeans = KMeans(n_clusters=n_clusters)
hand_kmeans.fit(right[0] + left[0], right[1] + left[1])
# right_hand_kmeans.fit(right[0], right[1])
# left_hand_kmeans.fit(left[0], left[1])

In [7]:
hand_kmeans.cluster_centers_

array([[0.05308381, 0.3700525 , 0.41723969, 0.39024684, 0.37723414,
        0.41324361, 0.69598563, 0.70869242, 0.61507006],
       [0.33720596, 0.41143391, 0.37406228, 0.04879374, 0.35137312,
        0.67957793, 0.75993423, 0.71995237, 0.38494225],
       [0.27317758, 0.04913904, 0.31499614, 0.33584301, 0.38314826,
        0.65005824, 0.41986455, 0.66316381, 0.61389367],
       [0.33288283, 0.39859269, 0.05819739, 0.32172614, 0.34419366,
        0.67576879, 0.73676922, 0.38224947, 0.6028739 ],
       [0.26643064, 0.36104879, 0.38666625, 0.37098103, 0.42066665,
        0.65149052, 0.71653372, 0.69513723, 0.58063184]])