## download mpii dataset

In [None]:
!wget -c https://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/mpii_human_pose_v1.tar.gz && \
mkdir -p /content/data/mpii && \
tar -xzf mpii_human_pose_v1.tar.gz -C /content/data/mpii/ && \
rm /content/mpii_human_pose_v1.tar.gz && \
wget -c https://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/mpii_human_pose_v1_u12_2.zip && \
unzip -qq -j -o mpii_human_pose_v1_u12_2.zip -d /content/data/mpii/annotations && \
rm /content/mpii_human_pose_v1_u12_2.zip

In [None]:
import gdown


def download_from_colab(file_id: str, output_file: str):
    url = f"https://drive.google.com/uc?id={file_id}"
    gdown.download(url, output_file, quiet=True)

In [None]:
download_from_colab(
    file_id="1DZm9_erQd9EbHASYzpaZsjFIKQSkHLi4",
    output_file="/content/data/mpii/annotations/mpii_gt_val.mat",
)

download_from_colab(
    file_id="1goXOVSr-ne8_ZaH80KzEbyAPGn67bs5N",
    output_file="/content/data/mpii/annotations/mpii_train.json",
)

download_from_colab(
    file_id="1B5laHYDFN8oShENI958nTVobU6jeHgi2",
    output_file="/content/data/mpii/annotations/mpii_val.json",
)

## download test video

In [None]:
!mkdir -p /content/data/
!wget -O /content/data/video.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/tom.mp4

## check files

In [None]:
!ls /content/data/mpii/images/033441445.jpg

In [None]:
!head -c 250 /content/data/mpii/annotations/mpii_val.json

## install deps

In [None]:
%pip install -qq -U openmim
!mim install -qq mmengine
!mim install -qq "mmcv>=2.0.0"
!mim install -qq "mmdet>=3.0.0"

In [None]:
%pip install -qq git+https://github.com/jin-s13/xtcocoapi

In [None]:
!git clone https://github.com/open-mmlab/mmpose.git
%cd /content/mmpose
%pip install -qq -r requirements.txt
%pip install -qq -v -e .
%cd /content

In [None]:
%pip install -qq decord

## check installations

In [None]:
# check NVCC version
!nvcc -V

# check GCC version
!gcc --version

# check python in conda environment
!which python

In [None]:
# Check Pytorch installation
import torch
import torchvision


print("torch version:", torch.__version__, torch.cuda.is_available())
print("torchvision version:", torchvision.__version__)

# Check mmcv installation
from mmcv.ops import get_compiler_version, get_compiling_cuda_version


print("cuda version:", get_compiling_cuda_version())
print("compiler information:", get_compiler_version())

# Check MMPose installation
import mmpose


print("mmpose version:", mmpose.__version__)

## run validation using `test.py`

In [None]:
!mkdir -p /content/models && \
mim download mmpose --config td-hm_hrnet-w48_dark-8xb64-210e_mpii-256x256 --dest /content/models/

In [None]:
%%writefile /content/models/td-hm_hrnet-w48_dark-8xb64-210e_mpii-256x256-custom.py
auto_scale_lr = dict(base_batch_size=512)
backend_args = dict(backend='local')
codec = dict(
    heatmap_size=(
        64,
        64,
    ),
    input_size=(
        256,
        256,
    ),
    sigma=2,
    type='MSRAHeatmap',
    unbiased=True)
custom_hooks = [
    dict(type='SyncBuffersHook'),
]
data_mode = 'topdown'
data_root = 'data/mpii/'
dataset_type = 'MpiiDataset'
default_hooks = dict(
    badcase=dict(
        badcase_thr=5,
        enable=False,
        metric_type='loss',
        out_dir='badcase',
        type='BadCaseAnalysisHook'),
    checkpoint=dict(
        interval=10, rule='greater', save_best='PCK', type='CheckpointHook'),
    logger=dict(interval=50, type='LoggerHook'),
    param_scheduler=dict(type='ParamSchedulerHook'),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    timer=dict(type='IterTimerHook'),
    visualization=dict(enable=False, type='PoseVisualizationHook'))
default_scope = 'mmpose'
env_cfg = dict(
    cudnn_benchmark=False,
    dist_cfg=dict(backend='nccl'),
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
load_from = None
log_level = 'INFO'
log_processor = dict(
    by_epoch=True, num_digits=6, type='LogProcessor', window_size=50)
model = dict(
    backbone=dict(
        extra=dict(
            stage1=dict(
                block='BOTTLENECK',
                num_blocks=(4, ),
                num_branches=1,
                num_channels=(64, ),
                num_modules=1),
            stage2=dict(
                block='BASIC',
                num_blocks=(
                    4,
                    4,
                ),
                num_branches=2,
                num_channels=(
                    48,
                    96,
                ),
                num_modules=1),
            stage3=dict(
                block='BASIC',
                num_blocks=(
                    4,
                    4,
                    4,
                ),
                num_branches=3,
                num_channels=(
                    48,
                    96,
                    192,
                ),
                num_modules=4),
            stage4=dict(
                block='BASIC',
                num_blocks=(
                    4,
                    4,
                    4,
                    4,
                ),
                num_branches=4,
                num_channels=(
                    48,
                    96,
                    192,
                    384,
                ),
                num_modules=3)),
        in_channels=3,
        init_cfg=dict(
            checkpoint=
            'https://download.openmmlab.com/mmpose/pretrain_models/hrnet_w48-8ef0771d.pth',
            type='Pretrained'),
        type='HRNet'),
    data_preprocessor=dict(
        bgr_to_rgb=True,
        mean=[
            123.675,
            116.28,
            103.53,
        ],
        std=[
            58.395,
            57.12,
            57.375,
        ],
        type='PoseDataPreprocessor'),
    head=dict(
        decoder=dict(
            heatmap_size=(
                64,
                64,
            ),
            input_size=(
                256,
                256,
            ),
            sigma=2,
            type='MSRAHeatmap',
            unbiased=True),
        deconv_out_channels=None,
        in_channels=48,
        loss=dict(type='KeypointMSELoss', use_target_weight=True),
        out_channels=16,
        type='HeatmapHead'),
    test_cfg=dict(flip_mode='heatmap', flip_test=True, shift_heatmap=True),
    type='TopdownPoseEstimator')
optim_wrapper = dict(optimizer=dict(lr=0.0005, type='Adam'))
param_scheduler = [
    dict(
        begin=0, by_epoch=False, end=500, start_factor=0.001, type='LinearLR'),
    dict(
        begin=0,
        by_epoch=True,
        end=210,
        gamma=0.1,
        milestones=[
            170,
            200,
        ],
        type='MultiStepLR'),
]
resume = False
test_cfg = dict()
test_dataloader = dict(
    batch_size=32,
    dataset=dict(
        ann_file='annotations/mpii_val.json',
        data_mode='topdown',
        data_prefix=dict(img='images/'),
        data_root='data/mpii/',
        headbox_file='data/mpii/annotations/mpii_gt_val.mat',
        pipeline=[
            dict(type='LoadImage'),
            dict(type='GetBBoxCenterScale'),
            dict(input_size=(
                256,
                256,
            ), type='TopdownAffine'),
            dict(type='PackPoseInputs'),
        ],
        test_mode=True,
        type='MpiiDataset'),
    drop_last=False,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(round_up=False, shuffle=False, type='DefaultSampler'))
test_evaluator = dict(type='MpiiPCKAccuracy')
train_cfg = dict(by_epoch=True, max_epochs=210, val_interval=10)
train_dataloader = dict(
    batch_size=64,
    dataset=dict(
        ann_file='annotations/mpii_train.json',
        data_mode='topdown',
        data_prefix=dict(img='images/'),
        data_root='data/mpii/',
        pipeline=[
            dict(type='LoadImage'),
            dict(type='GetBBoxCenterScale'),
            dict(direction='horizontal', type='RandomFlip'),
            dict(shift_prob=0, type='RandomBBoxTransform'),
            dict(input_size=(
                256,
                256,
            ), type='TopdownAffine'),
            dict(
                encoder=dict(
                    heatmap_size=(
                        64,
                        64,
                    ),
                    input_size=(
                        256,
                        256,
                    ),
                    sigma=2,
                    type='MSRAHeatmap',
                    unbiased=True),
                type='GenerateTarget'),
            dict(type='PackPoseInputs'),
        ],
        type='MpiiDataset'),
    num_workers=2,
    persistent_workers=True,
    sampler=dict(shuffle=True, type='DefaultSampler'))
train_pipeline = [
    dict(type='LoadImage'),
    dict(type='GetBBoxCenterScale'),
    dict(direction='horizontal', type='RandomFlip'),
    dict(shift_prob=0, type='RandomBBoxTransform'),
    dict(input_size=(
        256,
        256,
    ), type='TopdownAffine'),
    dict(
        encoder=dict(
            heatmap_size=(
                64,
                64,
            ),
            input_size=(
                256,
                256,
            ),
            sigma=2,
            type='MSRAHeatmap',
            unbiased=True),
        type='GenerateTarget'),
    dict(type='PackPoseInputs'),
]
val_cfg = dict()
val_dataloader = dict(
    batch_size=32,
    dataset=dict(
        ann_file='annotations/mpii_val.json',
        data_mode='topdown',
        data_prefix=dict(img='images/'),
        data_root='data/mpii/',
        headbox_file='data/mpii/annotations/mpii_gt_val.mat',
        pipeline=[
            dict(type='LoadImage'),
            dict(type='GetBBoxCenterScale'),
            dict(input_size=(
                256,
                256,
            ), type='TopdownAffine'),
            dict(type='PackPoseInputs'),
        ],
        test_mode=True,
        type='MpiiDataset'),
    drop_last=False,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(round_up=False, shuffle=False, type='DefaultSampler'))
val_evaluator = dict(type='MpiiPCKAccuracy')
val_pipeline = [
    dict(type='LoadImage'),
    dict(type='GetBBoxCenterScale'),
    dict(input_size=(
        256,
        256,
    ), type='TopdownAffine'),
    dict(type='PackPoseInputs'),
]
vis_backends = [
    dict(type='LocalVisBackend'),
]
visualizer = dict(
    name='visualizer',
    type='PoseLocalVisualizer',
    vis_backends=[
        dict(type='LocalVisBackend'),
    ])


In [None]:
!python /content/mmpose/tools/test.py \
        /content/models/td-hm_hrnet-w48_dark-8xb64-210e_mpii-256x256-custom.py \
        /content/models/hrnet_w48_mpii_256x256_dark-0decd39f_20200927.pth

## pose detection on some image

In [None]:
import mmcv
import mmengine
import numpy as np
import torch
from mmcv import imread
from mmengine.registry import init_default_scope
from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples


try:
    from mmdet.apis import inference_detector, init_detector

    has_mmdet = True
except (ImportError, ModuleNotFoundError):
    has_mmdet = False

local_runtime = False

try:
    from google.colab.patches import cv2_imshow
except:
    local_runtime = True

In [None]:
joints = {
    (5, 4),
    (4, 3),
    (0, 1),
    (1, 2),
    (3, 2),
    (3, 6),
    (2, 6),
    (6, 7),
    (7, 8),
    (8, 9),
    (13, 7),
    (12, 7),
    (13, 14),
    (12, 11),
    (14, 15),
    (11, 10),
}

In [None]:
img1 = "/content/data/mpii/images/033441445.jpg"
img2 = "/content/data/mpii/images/000061164.jpg"

In [None]:
pose_config = "/content/models/td-hm_hrnet-w48_dark-8xb64-210e_mpii-256x256-custom.py"
pose_checkpoint = "/content/models/hrnet_w48_mpii_256x256_dark-0decd39f_20200927.pth"
det_config = "/content/mmpose/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py"
det_checkpoint = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg_options = dict(model=dict(test_cfg=dict(output_heatmaps=True)))

In [None]:
# build detector
detector = init_detector(
    det_config,
    det_checkpoint,
    device=device,
)

# build pose estimator
pose_estimator = init_pose_estimator(
    pose_config,
    pose_checkpoint,
    device=device,
    cfg_options=cfg_options,
)

In [None]:
# init visualizer
pose_estimator.cfg.visualizer.radius = 3
pose_estimator.cfg.visualizer.line_width = 1
pose_estimator.test_cfg["flip_test"] = False
visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)
# the dataset_meta is loaded from the checkpoint and
# then pass to the model in init_pose_estimator
visualizer.set_dataset_meta(pose_estimator.dataset_meta)

In [None]:
def visualize_img(img_path, detector, pose_estimator, visualizer, show_interval, out_file):
    """Visualize predicted keypoints (and heatmaps) of one image."""
    # predict bbox
    scope = detector.cfg.get("default_scope", "mmdet")
    if scope is not None:
        init_default_scope(scope)
    detect_result = inference_detector(detector, img_path)
    pred_instance = detect_result.pred_instances.cpu().numpy()
    bboxes = np.concatenate((pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
    bboxes = bboxes[np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.3)]
    bboxes = bboxes[nms(bboxes, 0.3)][:, :4]

    # predict keypoints
    pose_results = inference_topdown(pose_estimator, img_path, bboxes)
    data_samples = merge_data_samples(pose_results)

    # show the results
    img = mmcv.imread(img_path, channel_order="rgb")

    visualizer.add_datasample(
        "result",
        img,
        data_sample=data_samples,
        draw_gt=False,
        draw_heatmap=False,
        draw_bbox=False,
        show=False,
        wait_time=show_interval,
        out_file=out_file,
        kpt_thr=0.3,
    )

In [None]:
visualize_img(img1, detector, pose_estimator, visualizer, show_interval=0, out_file=None)

vis_result = visualizer.get_image()

In [None]:
if local_runtime:
    import os.path as osp
    import tempfile

    import cv2
    from IPython.display import Image, display

    with tempfile.TemporaryDirectory() as tmpdir:
        file_name = osp.join(tmpdir, "pose_results.png")
        cv2.imwrite(file_name, vis_result[:, :, ::-1])
        display(Image(file_name))
else:
    cv2_imshow(vis_result[:, :, ::-1])  # RGB2BGR to fit cv2

In [None]:
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
from PIL import Image


def visualize_keypoints(img, image_path, keypoints, joints):
    if image_path is not None:
        # Load the image
        img = Image.open(image_path)
        h, w = img.height, img.width
    else:
        h, w, _ = img.shape

    # Unpack the keypoints from the shape (1, 17, 2)
    keypoints = keypoints[0]

    # Create figure and axes
    fig, ax = plt.subplots()

    # Display the image
    ax.imshow(img)

    # Plot keypoints
    for i, (x, y) in enumerate(keypoints):
        ax.scatter(x, y, c="r", marker="o")
        ax.text(x, y, f"{i}", color="r", fontsize=8)

    # Plot joints
    for i, j in joints:
        x_i, y_i = keypoints[i]
        x_j, y_j = keypoints[j]
        line = mlines.Line2D([x_i, x_j], [y_i, y_j], color="g")
        ax.add_line(line)

    # Set axis limits
    ax.set_xlim(0, w)
    ax.set_ylim(h, 0)

    # Show the plot
    plt.show()

## measure metrics between 2 images

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize

In [None]:
@torch.no_grad()
def get_keypoints(img_path, detector, pose_estimator):
    scope = detector.cfg.get("default_scope", "mmdet")
    if scope is not None:
        init_default_scope(scope)
    detect_result = inference_detector(detector, img_path)
    pred_instance = detect_result.pred_instances.cpu().numpy()
    bboxes = np.concatenate((pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
    bboxes = bboxes[np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.3)]
    bboxes = bboxes[nms(bboxes, 0.3)][:, :4]

    # predict keypoints
    pose_results = inference_topdown(pose_estimator, img_path, bboxes)[0].pred_instances
    keypoints = np.concatenate(
        (pose_results.keypoints, pose_results.keypoint_scores[..., None]),
        axis=-1,
    )
    return keypoints

In [None]:
def normalize_vector(vector):
    magnitude = np.linalg.norm(vector)
    if magnitude == 0:
        return vector  # Avoid division by zero
    return vector / magnitude


def compute_cosine_similarity(vector1, vector2):
    return np.dot(vector1, vector2)


def compute_cosine_similarity_for_joints(keypoints1, keypoints2, joints):
    similarities = []

    for i, j in joints:
        vector1 = keypoints1[j] - keypoints1[i]
        vector2 = keypoints2[j] - keypoints2[i]

        normalized_vector1 = normalize_vector(vector1)
        normalized_vector2 = normalize_vector(vector2)

        cosine_similarity = compute_cosine_similarity(normalized_vector1, normalized_vector2)
        similarities.append(cosine_similarity)

    average_cos_sililarity = sum(similarities) / len(similarities)

    return similarities, average_cos_sililarity, np.min(np.abs(similarities))


def compute_cosine_similarity_for_joints_with_weights(keypoints1, keypoints2, weights, joints):
    similarities = []

    for index, (i, j) in enumerate(joints):
        vector1 = keypoints1[j] - keypoints1[i]
        vector2 = keypoints2[j] - keypoints2[i]

        normalized_vector1 = normalize_vector(vector1)
        normalized_vector2 = normalize_vector(vector2)
        # print(compute_cosine_similarity(normalized_vector1, normalized_vector2))
        cosine_similarity = weights[index] * compute_cosine_similarity(normalized_vector1, normalized_vector2)

        similarities.append(cosine_similarity)

    # average_cos_sililarity = sum(similarities)/ len(similarities)
    average_cos_sililarity = sum(similarities) / sum(weights)

    return similarities, average_cos_sililarity, np.min(np.abs(similarities))

In [None]:
import numpy as np
from scipy.spatial.distance import cdist


def compute_oks(prediction_poses, ground_truth_poses, sigma=10):
    """Compute average Object Keypoint Similarity (OKS) between predicted and ground truth poses.

    Parameters
    ----------
    - prediction_poses: List of predicted poses (each pose is an Nx3 array where N is the number of keypoints)
    - ground_truth_poses: List of ground truth poses (each pose is an Nx3 array)
    - sigma: Standard deviation for the Gaussian smoothing (default is 0.25)

    Returns
    -------
    - Average OKS score across all frames
    """
    total_oks = 0
    num_frames = min(len(prediction_poses), len(ground_truth_poses))

    for i in range(num_frames):
        predicted_pose = normalize_vector(prediction_poses[i])
        ground_truth_pose = normalize_vector(ground_truth_poses[i])

        # Check if the poses are valid
        if len(predicted_pose) == 0 or len(ground_truth_pose) == 0:
            continue

        d = cdist(predicted_pose[:, :2], ground_truth_pose[:, :2], "euclidean")
        sigmas = sigma * np.maximum(ground_truth_pose[:, 2], 1e-14)
        e = np.exp(-0.5 * (d / sigmas[:, None]) ** 2)
        oks = np.sum(e) / len(e)

        total_oks += oks

    # Calculate the average OKS across all frames
    if num_frames > 0:
        average_oks = total_oks / num_frames
    else:
        average_oks = 0.0

    return average_oks

In [None]:
get_keypoints(img1, detector, pose_estimator).shape

In [None]:
keypoints_with_scores_1 = get_keypoints(img1, detector, pose_estimator)
keypoints_with_scores_2 = get_keypoints(img2, detector, pose_estimator)

In [None]:
keypoints_with_scores_1.squeeze().shape

In [None]:
compute_oks(
    keypoints_with_scores_1,
    keypoints_with_scores_2,
)

In [None]:
compute_cosine_similarity_for_joints(
    keypoints_with_scores_1[:, :, 0:2].squeeze(), keypoints_with_scores_2[:, :, 0:2].squeeze(), joints
)

In [None]:
compute_cosine_similarity_for_joints_with_weights(
    keypoints_with_scores_1[:, :, 0:2].squeeze(),
    keypoints_with_scores_2[:, :, 0:2].squeeze(),
    keypoints_with_scores_1[:, :, 2].squeeze(),
    joints,
)

## measure metrics between video and image

In [None]:
import cv2
from decord import VideoReader, cpu, gpu
from tqdm import tqdm

In [None]:
vr = VideoReader("/content/data/video.mp4", ctx=cpu(0))
print("video frames:", len(vr))

In [None]:
target_frame_index = 470
target_frame = vr[target_frame_index].asnumpy()
keypoints_target = get_keypoints(target_frame, detector, pose_estimator)

In [None]:
visualize_keypoints(vr[target_frame_index].asnumpy(), None, keypoints_target[:, :, 0:2], joints)

In [None]:
oks_distance = []
cosine_joints_similarity_avg = []
cosine_joints_similarity_min = []
cosine_weight_joints_similarity_avg = []
cosine_weight_joints_similarity_min = []

w, h = vr[target_frame_index].shape[1], vr[target_frame_index].shape[0]
output = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc(*"MP4V"), vr.get_avg_fps(), (w * 2, h))  # *'DIVX'

for i in tqdm(range(len(vr))):
    # the video reader will handle seeking and skipping in the most efficient manner
    frame = vr[i].asnumpy()
    keypoints_frame = get_keypoints(frame, detector, pose_estimator)

    oks_distance.append(
        compute_oks(
            keypoints_frame,
            keypoints_target,
        )
    )

    _, dist_avg, dist_min = compute_cosine_similarity_for_joints(
        keypoints_frame[:, :, 0:2].squeeze(), keypoints_target[:, :, 0:2].squeeze(), joints
    )

    _, dist_w_avg, dist_w_min = compute_cosine_similarity_for_joints_with_weights(
        keypoints_frame[:, :, 0:2].squeeze(),
        keypoints_target[:, :, 0:2].squeeze(),
        keypoints_frame[:, :, 2].squeeze(),
        joints,
    )

    cosine_joints_similarity_avg.append(dist_avg)
    cosine_joints_similarity_min.append(dist_min)
    cosine_weight_joints_similarity_avg.append(dist_w_avg)
    cosine_weight_joints_similarity_min.append(dist_w_min)

    out_frame = cv2.hconcat([frame, target_frame])

    # get boundary of this text
    textsize = cv2.getTextSize("OKS: 0.0000", cv2.FONT_HERSHEY_DUPLEX, 1, 2)[0]

    # get coords based on boundary
    textX = (out_frame.shape[1] - textsize[0]) // 2
    textY = textsize[1]

    cv2.putText(
        out_frame,
        text=f"OKS: {round(oks_distance[-1], 4)}",
        org=(textX, textY),
        fontFace=cv2.FONT_HERSHEY_DUPLEX,
        fontScale=1.0,
        color=(0, 0, 0),
        thickness=2,
    )
    output.write(out_frame)

output.release()

In [None]:
plt.plot(oks_distance, label="oks_distance")
plt.plot(cosine_joints_similarity_avg, label="cosine_joints_similarity_avg")
# plt.plot(cosine_joints_similarity_min, label = "cosine_joints_similarity_min")
plt.plot(cosine_weight_joints_similarity_avg, label="cosine_weight_joints_similarity_avg")
# plt.plot(cosine_weight_joints_similarity_min, label = "cosine_weight_joints_similarity_min")

plt.axvline(x=target_frame_index, color="red", label="target frame")

plt.legend()
plt.show()

## add metrics to video

In [None]:
# from enum import Enum
# from typing import Callable

# try:
#     import cv2
# except ImportError:
#     raise Exception("Looks like you don't have cv2 installed. "
#                     "Please check https://pypi.org/project/opencv-python/ on how to install it."
#                     "Or: pip install opencv-python")

# from time import time, sleep


# class FrameChange(Enum):
#     NoChange = 0
#     Next = 1
#     Seek = 2


# class CV2VideoPlayer:
#     KEY_CODE_STOP = 27
#     KEY_CODE_TOGGLE_PLAY = 32
#     KEY_CODE_PREV_FRAME = ord('a')
#     KEY_CODE_NEXT_FRAME = ord('d')

#     CMD_NEXT_FRAME = "next"
#     CMD_NOOP = "noop"

#     # __cap:
#     __frame_count: int
#     __fps: int
#     __on_frame_callback: Callable
#     __on_stop_callback: Callable
#     __previous_frame_display_timestamp: float
#     __window_name: str = 'VideoPlayer'

#     def __init__(self, filename: str, on_frame: Callable, on_stop: Callable):
#         self.__cap = cv2.VideoCapture(filename)
#         if not self.__cap.isOpened():
#             raise Exception(f"Could not read {filename}")

#         self.__frame_count = int(self.__cap.get(cv2.CAP_PROP_FRAME_COUNT))
#         self.__fps = self.__cap.get(cv2.CAP_PROP_FPS)

#         self.__playback_rate = 1.0

#         self.__current_frame = 0
#         self.__on_frame_callback = on_frame
#         self.__on_stop_callback = on_stop
#         self.__previous_frame_display_timestamp = 0

#         self.__status = 'play'

#         self.__setup_ui()

#     def __setup_ui(self):
#         def on_change_frame(x):
#             if x == self.__current_frame:
#                 return

#             self.__current_frame = x
#             if self.__status == 'paused':
#                 self.__status = 'seek_frame'

#         def on_change_playback_rate(x):
#             if x == 0:
#                 x = 1
#             self.__playback_rate = x / 100.

#         cv2.namedWindow(self.__window_name)
#         cv2.createTrackbar('Frame', self.__window_name, 0, self.__frame_count - 1, on_change_frame)
#         cv2.setTrackbarPos('Frame', self.__window_name, 0)

#         cv2.createTrackbar('Playback Speed', self.__window_name, 1, 400, on_change_playback_rate)
#         cv2.setTrackbarPos('Playback Speed', self.__window_name, int(self.__playback_rate * 100))

#     def __handle_keyboard_input(self):
#         key = cv2.waitKey(1)
#         if key == self.KEY_CODE_TOGGLE_PLAY:
#             if self.__status == 'paused':
#                 self.__status = 'play'
#             else:
#                 self.__status = 'paused'
#         elif key == self.KEY_CODE_PREV_FRAME:
#             self.__status = 'prev_frame'
#         elif key == self.KEY_CODE_NEXT_FRAME:
#             self.__status = 'next_frame'
#         elif key == self.KEY_CODE_STOP:
#             self.stop()

#     def __calculate_current_frame(self) -> FrameChange:
#         if self.__status == 'play':
#             if self.__current_frame == self.__frame_count:
#                 self.__status = 'paused'
#                 return FrameChange.NoChange

#             now = time()
#             if self.__previous_frame_display_timestamp == 0:
#                 self.__previous_frame_display_timestamp = now
#                 return FrameChange.NoChange

#             frame_display_time = 1 / self.__fps / self.__playback_rate
#             time_delta = now - self.__previous_frame_display_timestamp

#             frame_delta = int(time_delta / frame_display_time)

#             if frame_delta > 0:
#                 self.__current_frame += frame_delta
#                 self.__previous_frame_display_timestamp = now
#                 if frame_delta > 1:
#                     return FrameChange.Seek
#                 else:
#                     return FrameChange.Next
#         elif self.__status == 'next_frame':
#             if self.__current_frame < self.__frame_count:
#                 self.__current_frame += 1
#                 self.__status = 'paused'
#                 return FrameChange.Next
#         elif self.__status == 'prev_frame':
#             if self.__current_frame > 0:
#                 self.__current_frame -= 1
#                 self.__status = 'paused'
#                 return FrameChange.Seek
#         elif self.__status == 'seek_frame':
#             self.__status = 'paused'
#             return FrameChange.Seek
#         elif self.__status == 'paused':
#             self.__previous_frame_display_timestamp = 0

#         return FrameChange.NoChange

#     def __show_current_frame(self):
#         ret, im = self.__cap.read()
#         if im is None:
#             return
#         r = 720.0 / im.shape[1]
#         dim = (720, int(im.shape[0] * r))
#         im = cv2.resize(im, dim, interpolation=cv2.INTER_AREA)
#         cv2.imshow(self.__window_name, im)

#     def on_timer(self):
#         if self.__status == "stopped":
#             return

#         self.__handle_keyboard_input()
#         frame_change = self.__calculate_current_frame()

#         if frame_change == FrameChange.NoChange:
#             return
#         elif frame_change == FrameChange.Next:
#             pass  # by default capture device reads next frame
#         elif frame_change == FrameChange.Seek:
#             self.__cap.set(cv2.CAP_PROP_POS_FRAMES, self.__current_frame)

#         self.__show_current_frame()
#         cv2.setTrackbarPos('Frame', self.__window_name, self.__current_frame)
#         self.__on_frame_callback(self.__current_frame / self.__fps)

#     def stop(self):
#         self.__cap.release()
#         del self.__cap
#         cv2.destroyWindow(self.__window_name)
#         self.__status = "stopped"
#         self.__on_stop_callback()


# def attach_video_player_to_figure(figure, filename: str, on_frame: Callable, **callback_args):
#     try:
#         from matplotlib.figure import Figure
#     except ImportError:
#         raise Exception("Looks like you don't have matplotlib installed.")

#     if not isinstance(figure, Figure):
#         raise ValueError("`figure` should be a matplotlib Figure instance")

#     timer = None

#     def on_stop():
#         timer.stop()

#     video_player = CV2VideoPlayer(filename, lambda timestamp: on_frame(timestamp, **callback_args), on_stop)

#     timer = figure.canvas.new_timer(interval=10,
#                                     callbacks=[(video_player.on_timer, [], {})])
#     timer.start()


# def start_video_player(filename: str, on_frame: Callable):
#     _should_stop = False

#     def on_stop():
#         nonlocal _should_stop
#         _should_stop = True

#     video_player = CV2VideoPlayer(filename, on_frame, on_stop)

#     while not _should_stop:
#         video_player.on_timer()
#         sleep(0.01)

In [None]:
# import matplotlib.pyplot as plt

# x = []


# def on_frame(video_timestamp, line):
#     x.append(video_timestamp)

#     line.set_data(x, oks_distance[:len(x)])
#     line.axes.relim()
#     line.axes.autoscale_view()
#     line.axes.figure.canvas.draw()


# def main():
#     fig, ax = plt.subplots()
#     plt.xlim(-15, 15)
#     plt.axvline(x=0, color='k', linestyle='--')

#     line, = ax.plot([], [], color='blue')

#     attach_video_player_to_figure(fig, "/content/data/video.mp4", on_frame, line=line)

#     plt.show()


# main()

## visualize keypoints

In [None]:
visualize_keypoints(None, img1, keypoints_with_scores_1[:, :, 0:2], joints)

In [None]:
visualize_keypoints(None, img2, keypoints_with_scores_2[:, :, 0:2], joints)