In [1]:
# To enable labeling of videos in notebook
%matplotlib notebook 

import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
import os
import seaborn as sns

from point_labeler import PointLabeler
from point_tracker import PointTracker
from tapnet.utils import viz_utils
from video_manager import VideoManager

In [2]:
# Specify folder paths
video_folder = "/home/daphne/Documents/GMA/data/Preprocessed_Videos"
labeled_keypoints_folder = "/home/daphne/Documents/GMA/codes/output/labelled_points"
merged_keypoints_folder = "/home/daphne/Documents/GMA/codes/output/merged"
tracked_keypoints_folder = "/home/daphne/Documents/GMA/codes/output/tracked_points"
cropped_videos_folder = "/home/daphne/Documents/GMA/data/Preprocessed_Videos_Cropped"
cropped_resized_videos_folder = "/home/daphne/Documents/GMA/data/Preprocessed_Videos_Cropped_Resized"

## Create and setup VideoManager object

In [3]:
# Create a video manage to hold all videos
video_manager = VideoManager()

In [4]:
# Add all videos as VideoObject to video_manager (optional flag )
video_manager.add_all_videos(video_folder, add_pt_data=True)  # Load class data (pt_data), not the videos themselves

In [5]:
video_manager.get_all_video_ids()

In [6]:
# Some random video as an example
video_id = '35_F-_c'

## Look at patient data

In [7]:
# -- get VideoObject from VideoManager based on video_id (i.e. filename without extension)
video_object = video_manager.get_video_object(video_id)

In [8]:
video_object.video_id

In [9]:
video_object.patient_data.age_group

In [10]:
video_object.patient_data.health_status

## Label video

In [11]:
# -- get VideoObject from VideoManager based on video_id (i.e. filename without extension)
video_object = video_manager.get_video_object(video_id)

In [12]:
# -- Load video
video_object.load_video()

In [13]:
# -- Label keypoints in video (once all points selected, run next cell)
frame_index = 0
video_object.label_and_store_keypoints(frame_index, task='extreme_keypoints')

In [14]:
# -- Labelled keypoints are saved in VideoObject
video_object.keypoint_labels

In [15]:
# -- Save labelled points to file to use again in the future
video_object.save_keypoints_to_csv(os.path.join(labeled_keypoints_folder, f'{video_id}.csv'))
video_object.save_keypoints_to_json(os.path.join(labeled_keypoints_folder, f'{video_id}.json'))

In [None]:
# -- Release video from memory
video_object.release_video()

## Track labelled points

In [28]:
# -- Create PointTracker to track points
tracker = PointTracker('../tapnet/checkpoints/tapir_checkpoint_panning.npy')

In [29]:
# -- Load VideoObject from VideoManager
video_object = video_manager.get_video_object(video_id)

In [30]:
# -- Define starting frame and tracking task (which keypoints to track)
frame_index = 0
task = 'extreme_keypoints'

In [31]:
# -- Load video
video_object.load_video()

In [112]:
video_object.load_keypoint_labels_from_folder(merged_keypoints_folder, 'extreme_keypoints', 'json')

In [113]:
video_object.keypoint_labels

In [32]:
# -- Track points (loading the labelled points from labeled_keypoints_folder, not necessary if already loaded in VideoObject
try:
    video_object.track_points(
        tracker,
        frame_index, # from which index to track
        task, # choose which keypoints to track, all or only extreme ones (for cropping, speeds up the tracking since less points)
        merged_keypoints_folder, # optional: where from to load labelled points
        'json'
    )
except ValueError as e:
    print(e)  # Handle the error appropriately

In [33]:
## -- Tracked points are saved in self.tracking_data
video_object.tracking_data

In [34]:
# -- Save tracked points
video_object.save_tracked_points_to_csv(tracked_keypoints_folder)
video_object.save_tracked_points_to_json(tracked_keypoints_folder)

In [35]:
# -- Release video from memory
video_object.release_video()

## Load tracked points and crop videos accordingly

In [59]:
# -- Load VideoObject from VideoManager
video_object = video_manager.get_video_object(video_id)

In [60]:
# -- Compute extreme coordinates according to tracked points for appropriate cropping
video_object.update_extreme_coordinates(tracked_keypoints_folder)

In [61]:
# -- Extreme coordinates are stored in self.extreme_coordinates
video_object.extreme_coordinates

In [62]:
video_object.load_keypoint_labels_from_folder(merged_keypoints_folder, 'extreme_keypoints', 'json')

In [63]:
video_object.keypoint_labels

In [64]:
cropped_keypoint_labels = {
    frame: {
        keypoint: {
            'x': coords['x'] - 20,
            'y': coords['y'] - 30
        }
        for keypoint, coords in frame_data.items()
    }
    for frame, frame_data in video_object.keypoint_labels.items()
}

In [65]:
cropped_keypoint_labels

In [66]:
# -- Crop videos according to extreme coordinates
video_object.crop_and_resize_video(cropped_videos_folder, resize=True, resize_folder=cropped_resized_videos_folder, load_and_release_video=True)

# Visualize labelled points (to update!)

In [14]:
# -- Load VideoObject from VideoManager
video_object = video_manager.get_video_object(video_id)

In [15]:
extreme_keypoints = ['head top', 'left elbow', 'right elbow',
                  'left wrist', 'right wrist',
                  'left knee', 'right knee',
                  'left ankle', 'right ankle']

all_body_keypoints = ['nose',
                           'head bottom', 'head top',
                           'left ear', 'right ear',
                           'left shoulder', 'right shoulder',
                           'left elbow', 'right elbow',
                           'left wrist', 'right wrist',
                           'left hip', 'right hip',
                           'left knee', 'right knee',
                           'left ankle', 'right ankle']

In [16]:
def create_bodypart_colormap(body_keypoints):
    # Use a matplotlib colormap
    colorpalette = sns.color_palette("hls", len(body_keypoints))  # 'tab20' is a good palette for distinct colors

    bodypart_colors = {
        body_keypoints[i]: colorpalette[i] for i in range(len(body_keypoints))
    }

    return bodypart_colors

colormap = create_bodypart_colormap(all_body_keypoints)

In [80]:
colormap

In [77]:
def draw_points(points_dict, frame):
    fig = plt.figure(figsize=(10, 5))
    ax_image = fig.add_subplot(121)
    ax_image.imshow(frame)
                
    for keypoint in extreme_keypoints:
        if keypoint in points_dict:
            point = (points_dict[keypoint][0], points_dict[keypoint][1])
            if point is not None:
                color = colormap[keypoint]
                ax_image.plot(point[0], point[1], 'o', color=color)
    ax_image.axis('off')
    plt.draw()

In [61]:
video_object.load_tracked_points_from_folder(tracked_keypoints_folder, 'json')

In [62]:
video_object.tracking_data[0]

In [63]:
video_object.load_video()

In [75]:
# Create dictionary of points to draw and frame to draw on
index = 0
points_to_draw = {key: [value[index]['x'], value[index]['y']] for key, value in video_object.tracking_data[0].items() if len(value) > index and value[index]['visible'] == True}
frame = video_object.video[index]

In [76]:
for key, value in video_object.tracking_data[0].items():
    print('--')
    print(key, value[index]['x'], value[index]['y'])

In [65]:
points_to_draw

In [78]:
# Visualize tracked points
draw_points(points_to_draw, frame)

In [81]:
# draw query_points on original image
from point_tracker import convert_select_point_dict_to_query_points
select_frame = index
fig, ax = plt.subplots(figsize=(10, 5))
ax.imshow(np.array(video_object.video[select_frame]))
# to convert must be in form {keypoint: [x, y], ...}

query_points = convert_select_point_dict_to_query_points(select_frame, points_to_draw) #-DC: still in global reference: from xy to tyx
color_list = [colormap[keypoint] for keypoint in points_to_draw.keys()]
ax.scatter(query_points[:, 2], query_points[:, 1], marker="o", color=color_list, s=25)
plt.show(block=False)

## Transform to cropped coords

In [85]:
# Create a video manage to hold all videos
cropped_video_manager = VideoManager()
# Add all videos as VideoObject to video_manager (optional flag )
cropped_video_manager.add_all_videos(cropped_videos_folder, add_pt_data=True)  # Load class data (pt_data), not the videos themselves

In [101]:
cropped_video_manager.get_all_video_ids()

In [102]:
cropped_video_id = 'cropped_vid_35_F-_c'

In [94]:
video_object = video_manager.get_video_object(video_id)

In [95]:
# shift selected points according to cropped image
video_object.update_extreme_coordinates(tracked_keypoints_folder)
x_min = video_object.extreme_coordinates['leftmost']
y_max = video_object.extreme_coordinates['topmost']
x_max = video_object.extreme_coordinates['rightmost']
y_min = video_object.extreme_coordinates['bottommost']

margin = 0.15  # 15% margin
x_margin = round(margin * (x_max - x_min))
y_margin = round(margin * (y_max - y_min))
largest_margin = max(x_margin, y_margin)

height, width = video_object.video.metadata.shape
min_x_crop, max_x_crop = (max(0, x_min - largest_margin - 1),
                          min(width, x_max + largest_margin + 1))
min_y_crop, max_y_crop = (max(0, y_min - largest_margin - 1),
                          min(height, y_max + largest_margin + 1))

# first dimension is 't'
query_points_crop = np.array([[0.0, cc[1]-min_y_crop, cc[2]-min_x_crop] for cc in query_points])

In [114]:
cropped_keypoint_labels = {
    frame: {
        keypoint: {
            'x': coords['x'] - min_x_crop,
            'y': coords['y'] - min_y_crop
        }
        for keypoint, coords in frame_data.items()
    }
    for frame, frame_data in video_object.keypoint_labels.items()
}

In [115]:
cropped_keypoint_labels

In [96]:
print(query_points_crop)

In [103]:
cropped_video_object = cropped_video_manager.get_video_object(cropped_video_id)

In [109]:
cropped_video_object.load_video()

In [110]:
# draw query_points on cropped video
fig, ax = plt.subplots(figsize=(10, 5))
ax.imshow(np.array(cropped_video_object.video[select_frame]))
ax.scatter(query_points_crop[:, 2], query_points_crop[:, 1], marker="o", color=color_list, s=25)
plt.show(block=False)

## visualize cropped coordinates

In [11]:
all_body_keypoints = ['nose',
                           'head bottom', 'head top',
                           'left ear', 'right ear',
                           'left shoulder', 'right shoulder',
                           'left elbow', 'right elbow',
                           'left wrist', 'right wrist',
                           'left hip', 'right hip',
                           'left knee', 'right knee',
                           'left ankle', 'right ankle']

def create_bodypart_colormap(body_keypoints):
    # Use a matplotlib colormap
    colorpalette = sns.color_palette("hls", len(body_keypoints))  # 'tab20' is a good palette for distinct colors

    bodypart_colors = {
        body_keypoints[i]: colorpalette[i] for i in range(len(body_keypoints))
    }

    return bodypart_colors

colormap = create_bodypart_colormap(all_body_keypoints)

In [12]:
# Create a video manage to hold all videos
cropped_video_manager = VideoManager()
# Add all videos as VideoObject to video_manager (optional flag )
cropped_video_manager.add_all_videos(cropped_videos_folder, add_pt_data=True)  # Load class data (pt_data), not the videos themselves

In [13]:
cropped_video_id = 'cropped_vid_35_F-_c'

In [14]:
import json
dir = '/home/daphne/Documents/GMA/codes/output/labeled'
file_path = os.path.join(dir, "35_F-_c.extreme_keypoints.cropped.json")
with open(file_path, 'r') as f:
    cropped_keypoint_labels = json.load(f)

In [15]:
cropped_keypoint_labels

In [16]:
select_frame = '0'
keypoints = cropped_keypoint_labels[select_frame].values()

# Extract x and y coordinates and convert to a NumPy array
coords = np.array([[point['x'], point['y']] for point in keypoints])

In [17]:
print(coords)

In [18]:
coords[:,0]

In [19]:
cropped_video_object = cropped_video_manager.get_video_object(cropped_video_id)

In [20]:
cropped_video_object.load_video()

In [22]:
cropped_video_object.patient_data

In [23]:
color_list = [colormap[keypoint] for keypoint in cropped_keypoint_labels[select_frame].keys()]

In [25]:
# draw query_points on cropped video
fig, ax = plt.subplots(figsize=(10, 5))
ax.imshow(np.array(cropped_video_object.video[int(select_frame)]))
ax.scatter(coords[:, 0], coords[:, 1], marker="o", color=color_list, s=25)
plt.show(block=False)