In [None]:
from ultralytics import YOLO
import os
import torch
from tqdm import tqdm

In [None]:
model = YOLO('yolov8m-pose.pt')

In [None]:
destination_path = '/Users/aravdhoot/Parkinson-Project/hpe_images_new'
frame_path = '/Users/aravdhoot/PD/frames_updated'

In [None]:
import cv2
import matplotlib.pyplot as plt

In [None]:
def scale(x_list, y_list):
    y_min = min(y_list)
    y_max = max(y_list)

    multiplier = 320 / (y_max - y_min)

    x_list = [value + (abs(value - 320) * multiplier) for value in x_list]
    y_list = [value + (abs(value - 180) * multiplier) for value in y_list]

    return x_list, y_list

In [None]:
def preprocess_keypoints(results):
    conf_list = results[0].keypoints.conf[0]
    conf_list = conf_list.tolist()
    
    x_list = [value[0] for value in results[0].keypoints.xy[0]]
    y_list = [value[1] for value in results[0].keypoints.xy[0]]

    try:
        x_list = [value[0] for value in results[0].keypoints.xy[0]]
        y_list = [value[1] for value in results[0].keypoints.xy[0]]

        zero_x_indices = [i for i, x in enumerate(x_list) if x == 0.0]
        zero_y_indices = [i for i, y in enumerate(y_list) if y == 0.0]
                
        temp_x_list = [x for x in x_list if x != 0.0]
        temp_y_list = [y for y in y_list if y != 0.0]

        x_min = min(temp_x_list)
        y_min = min(temp_y_list)
        x_max = max(temp_x_list)
        y_max = max(temp_y_list)

        norm_x = (x_min + x_max)/2
        norm_y = (y_min + y_max)/2

        x_list  = [item + (320 - norm_x) for item in temp_x_list]
        y_list = [item + (180 - norm_y) for item in temp_y_list]
        
        min_y = min(y_list).item()
        max_y = max(y_list).item()

        scaled_y_list = [(y - min_y) / (max_y - min_y) * 360 for y in y_list]
        scaled_x_list = [(x - min_y) / (max_y - min_y) * 360 for x in x_list]

        x_min = min(scaled_x_list)
        x_max = max(scaled_x_list)

        norm_x = (x_min + x_max)/2
        scaled_x_list  = [item + (320 - norm_x) for item in scaled_x_list]

        for i in zero_x_indices:
            scaled_x_list.insert(i, 0.0)
            conf_list[i] = 0.0
        
        for i in zero_y_indices:
            scaled_y_list.insert(i, 0.0)
            conf_list[i] = 0.0

        scaled_x_list = [float(element.item()) if torch.is_tensor(element) else float(element) for element in scaled_x_list]
        scaled_y_list = [float(element.item()) if torch.is_tensor(element) else float(element) for element in scaled_y_list]

        final_list = list(zip(scaled_x_list, scaled_y_list, conf_list))
        return final_list, scaled_x_list, scaled_y_list
    
    except:
        return list(zip(x_list, y_list, conf_list)), x_list, y_list

In [None]:
def display_keypoints(final_list, destination_path, keypoints=False, display=False):    
    import numpy as np
    skeletons = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]

    filtered_final_list = [value if value[2] > 0.5 else [None, None, None] for value in final_list]
    x = [value[0] for value in filtered_final_list]
    y = [value[1] for value in filtered_final_list]

    count = sum(1 for element in filtered_final_list if isinstance(element, tuple) and element[2] >= 0.5)
    coordinate_connections = [[[x[skeleton[0] - 1], x[skeleton[1] - 1]], [y[skeleton[0] - 1], y[skeleton[1] - 1]]] for skeleton in skeletons]
    filtered_coordinate_connections = [coordinate_connections[i] for i, value in enumerate(coordinate_connections) if not None in value[0] or not None in value[1]]

    if count == 12:
        height, width = 360, 640
        white_bg = np.ones((height, width, 3), np.uint8) * 255 
        plt.imshow(cv2.cvtColor(white_bg, cv2.COLOR_BGR2RGB))
        plt.axis('off')
        if keypoints: plt.scatter(x, y)     
        for value in filtered_coordinate_connections:
            plt.plot(value[0], value[1], color='black', solid_capstyle='round', linewidth=5)
        plt.savefig(destination_path)
        if display:
            plt.show()
        plt.close()

In [None]:
os.makedirs(destination_path, exist_ok=True)
for severity in os.listdir(frame_path):
    os.makedirs(os.path.join(destination_path, severity), exist_ok=True)
    for video in os.listdir(os.path.join(frame_path, severity)):
        os.makedirs(os.path.join(destination_path, severity, video), exist_ok=True)
        for image in tqdm(os.listdir(os.path.join(frame_path, severity, video))):
            results = model(os.path.join(frame_path, severity, video, image), verbose=False)
            if len(results[0].keypoints.xy[0]) != 0:
                final_list, x_list, y_list = preprocess_keypoints(results)
                display_keypoints(final_list, os.path.join(destination_path, severity, video, image)) 