In [1]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt

In [34]:
LABEL_ROOT = "/data/ephemeral/home/data/train/outputs_json"
CLASSES = [
    'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
    'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
    'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
    'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
    'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
    'Triquetrum', 'Pisiform', 'Radius', 'Ulna',
]

CLASSES_grouped = [
    ['finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
    'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
    'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
    'finger-16', 'finger-17', 'finger-18', 'finger-19'],
    ['Trapezium', 'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
    'Triquetrum', 'Pisiform'], 
    ['Radius', 'Ulna']
]

all_points_right = {cls: [] for cls in CLASSES}
all_points_left = {cls: [] for cls in CLASSES}

jsons = {
    os.path.relpath(os.path.join(root, fname), start=LABEL_ROOT)
    for root, _dirs, files in os.walk(LABEL_ROOT)
    for fname in files
    if os.path.splitext(fname)[1].lower() == ".json"
}
jsons = sorted(jsons)
_labelnames = np.array(jsons)

# 손가락, 손등, 팔 히트맵

In [33]:
for idx, label_name in enumerate(_labelnames):
    label_path = os.path.join(LABEL_ROOT, label_name)

    with open(label_path, "r") as f:
        ann = json.load(f)
    ann = ann["annotations"]

    if idx % 2 == 0:
        for cls in CLASSES:
            all_points_right[cls].extend([np.array(ann[i]['points']) for i in range(29) if ann[i]['label'] == cls])
    else:
        for cls in CLASSES:
            all_points_left[cls].extend([np.array(ann[i]['points']) for i in range(29) if ann[i]['label'] == cls])

for hand, all_points in zip(['Right', 'Left'], [all_points_right, all_points_left]):
    for group_idx, group_classes in enumerate(CLASSES_grouped):
        all_points_group = []
        for cls in group_classes:
            if cls in all_points:
                all_points_group.extend(all_points[cls])
        if all_points_group:
            all_points_group = np.concatenate(all_points_group)
            plt.figure(figsize=(2048/100, 2048/100))
            x_cls = all_points_group[:, 0]
            y_cls = all_points_group[:, 1]
            plt.hexbin(x_cls, y_cls, gridsize=100, cmap='inferno', alpha=0.6, mincnt=1)
            plt.colorbar(label='Frequency')
            plt.title(f'{hand} Hand - Group {group_idx+1}')
            plt.xlim(0, 2048)
            plt.ylim(0, 2048)
            plt.tight_layout()
            plt.savefig(f'{hand.lower()}_hand_group_{group_idx+1}_heatmap.png')
            plt.close()