In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import os.path as osp
import torch
import numpy as np
import mmcv
import cv2
from mmengine.utils import track_iter_progress

In [None]:
# download example videos
from mmengine.utils import mkdir_or_exist
mkdir_or_exist('resources')
! wget -O resources/student_video.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/tom.mp4 
! wget -O resources/teacher_video.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/idol_producer.mp4 
# ! wget -O resources/student_video.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/tsinghua_30fps.mp4 

student_video = 'resources/student_video.mp4'
teacher_video = 'resources/teacher_video.mp4'

In [None]:
# convert the fps of videos to 30
from mmcv import VideoReader

if VideoReader(student_video) != 30:
    # ffmpeg is required to convert the video fps
    # which can be installed via `sudo apt install ffmpeg` on ubuntu
    student_video_30fps = student_video.replace(
        f".{student_video.rsplit('.', 1)[1]}",
        f"_30fps.{student_video.rsplit('.', 1)[1]}"
    )
    !ffmpeg -i {student_video} -vf "minterpolate='fps=30'" {student_video_30fps}
    student_video = student_video_30fps
    
if VideoReader(teacher_video) != 30:
    teacher_video_30fps = teacher_video.replace(
        f".{teacher_video.rsplit('.', 1)[1]}",
        f"_30fps.{teacher_video.rsplit('.', 1)[1]}"
    )
    !ffmpeg -i {teacher_video} -vf "minterpolate='fps=30'" {teacher_video_30fps}
    teacher_video = teacher_video_30fps    

In [None]:
# init pose estimator
from mmpose.apis.inferencers import Pose2DInferencer
pose_estimator = Pose2DInferencer(
    'rtmpose-t_8xb256-420e_aic-coco-256x192',
    det_model='configs/rtmdet-nano_one-person.py',
    det_weights='https://download.openmmlab.com/mmpose/v1/projects/' 
    'rtmpose/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth'
)
pose_estimator.model.test_cfg['flip_test'] = False

In [None]:
@torch.no_grad()
def get_keypoints_from_frame(image, pose_estimator):
    """Extract keypoints from a single video frame."""

    det_results = pose_estimator.detector(
        image, return_datasample=True)['predictions']
    pred_instance = det_results[0].pred_instances.numpy()

    if len(pred_instance) == 0 or pred_instance.scores[0] < 0.2:
        return np.zeros((1, 17, 3), dtype=np.float32)

    data_info = dict(
        img=image,
        bbox=pred_instance.bboxes[:1],
        bbox_score=pred_instance.scores[:1])

    data_info.update(pose_estimator.model.dataset_meta)
    data = pose_estimator.collate_fn(
        [pose_estimator.pipeline(data_info)])

    # custom forward
    data = pose_estimator.model.data_preprocessor(data, False)
    feats = pose_estimator.model.extract_feat(data['inputs'])
    pred_instances = pose_estimator.model.head.predict(
        feats,
        data['data_samples'],
        test_cfg=pose_estimator.model.test_cfg)[0]
    keypoints = np.concatenate(
        (pred_instances.keypoints, pred_instances.keypoint_scores[...,
                                                                  None]),
        axis=-1)

    return keypoints    

In [None]:
# pose estimation in two videos
student_poses, teacher_poses = [], []
for frame in VideoReader(student_video):
    student_poses.append(get_keypoints_from_frame(frame, pose_estimator))
for frame in VideoReader(teacher_video):
    teacher_poses.append(get_keypoints_from_frame(frame, pose_estimator))
    
student_poses = np.concatenate(student_poses)
teacher_poses = np.concatenate(teacher_poses)

In [None]:
valid_indices = np.array([0] + list(range(5, 17)))

@torch.no_grad()
def _calculate_similarity(tch_kpts: np.ndarray, stu_kpts: np.ndarray):

    stu_kpts = torch.from_numpy(stu_kpts[:, None, valid_indices])
    tch_kpts = torch.from_numpy(tch_kpts[None, :, valid_indices])
    stu_kpts = stu_kpts.expand(stu_kpts.shape[0], tch_kpts.shape[1],
                               stu_kpts.shape[2], 3)
    tch_kpts = tch_kpts.expand(stu_kpts.shape[0], tch_kpts.shape[1],
                               stu_kpts.shape[2], 3)

    matrix = torch.stack((stu_kpts, tch_kpts), dim=4)
    if torch.cuda.is_available():
        matrix = matrix.cuda()
    # only consider visible keypoints
    mask = torch.logical_and(matrix[:, :, :, 2, 0] > 0.3,
                             matrix[:, :, :, 2, 1] > 0.3)
    matrix[~mask] = 0.0

    matrix_ = matrix.clone()
    matrix_[matrix == 0] = 256
    x_min = matrix_.narrow(3, 0, 1).min(dim=2).values
    y_min = matrix_.narrow(3, 1, 1).min(dim=2).values
    matrix_ = matrix.clone()
    x_max = matrix_.narrow(3, 0, 1).max(dim=2).values
    y_max = matrix_.narrow(3, 1, 1).max(dim=2).values

    matrix_ = matrix.clone()
    matrix_[:, :, :, 0] = (matrix_[:, :, :, 0] - x_min) / (
        x_max - x_min + 1e-4)
    matrix_[:, :, :, 1] = (matrix_[:, :, :, 1] - y_min) / (
        y_max - y_min + 1e-4)
    matrix_[:, :, :, 2] = (matrix_[:, :, :, 2] > 0.3).float()
    xy_dist = matrix_[..., :2, 0] - matrix_[..., :2, 1]
    score = matrix_[..., 2, 0] * matrix_[..., 2, 1]

    similarity = (torch.exp(-50 * xy_dist.pow(2).sum(dim=-1)) *
                  score).sum(dim=-1) / (
                      score.sum(dim=-1) + 1e-6)
    num_visible_kpts = score.sum(dim=-1)
    similarity = similarity * torch.log(
        (1 + (num_visible_kpts - 1) * 10).clamp(min=1)) / np.log(161)

    similarity[similarity.isnan()] = 0

    return similarity

In [None]:
# compute similarity without flip
similarity1 = _calculate_similarity(teacher_poses, student_poses)

# compute similarity with flip
flip_indices = np.array(
    [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15])
student_poses_flip = student_poses[:, flip_indices]
student_poses_flip[..., 0] = 191.5 - student_poses_flip[..., 0]
similarity2 = _calculate_similarity(teacher_poses, student_poses_flip)

# select the larger similarity
similarity = torch.stack((similarity1, similarity2)).max(dim=0).values

In [None]:
# visualize the similarity
plt.imshow(similarity.cpu().numpy())

# there is an apparent diagonal in the figure
# we can select matched video snippets with this diagonal

In [None]:
@torch.no_grad()
def select_piece_from_similarity(similarity):
    m, n = similarity.size()
    row_indices = torch.arange(m).view(-1, 1).expand(m, n).to(similarity)
    col_indices = torch.arange(n).view(1, -1).expand(m, n).to(similarity)
    diagonal_indices = similarity.size(0) - 1 - row_indices + col_indices
    unique_diagonal_indices, inverse_indices = torch.unique(
        diagonal_indices, return_inverse=True)

    diagonal_sums_list = torch.zeros(
        unique_diagonal_indices.size(0),
        dtype=similarity.dtype,
        device=similarity.device)
    diagonal_sums_list.scatter_add_(0, inverse_indices.view(-1),
                                    similarity.view(-1))
    diagonal_sums_list[:min(m, n) // 4] = 0
    diagonal_sums_list[-min(m, n) // 4:] = 0
    index = diagonal_sums_list.argmax().item()

    similarity_smooth = torch.nn.functional.max_pool2d(
        similarity[None], (1, 11), stride=(1, 1), padding=(0, 5))[0]
    similarity_vec = similarity_smooth.diagonal(offset=index - m +
                                                1).cpu().numpy()

    stu_start = max(0, m - 1 - index)
    tch_start = max(0, index - m + 1)

    return dict(
        stu_start=stu_start,
        tch_start=tch_start,
        length=len(similarity_vec),
        similarity=similarity_vec)

In [None]:
matched_piece_info = select_piece_from_similarity(similarity)

In [None]:
plt.imshow(similarity.cpu().numpy())
plt.plot((matched_piece_info['tch_start'], 
          matched_piece_info['tch_start']+matched_piece_info['length']-1),
         (matched_piece_info['stu_start'],
          matched_piece_info['stu_start']+matched_piece_info['length']-1), 'r')

# Generate Output Video

In [None]:
from typing import Tuple

def resize_image_to_fixed_height(image: np.ndarray,
                                 fixed_height: int) -> np.ndarray:
    """Resizes an input image to a specified fixed height while maintaining its
    aspect ratio.

    Args:
        image (np.ndarray): Input image as a numpy array [H, W, C]
        fixed_height (int): Desired fixed height of the output image.

    Returns:
        Resized image as a numpy array (fixed_height, new_width, channels).
    """
    original_height, original_width = image.shape[:2]

    scale_ratio = fixed_height / original_height
    new_width = int(original_width * scale_ratio)
    resized_image = cv2.resize(image, (new_width, fixed_height))

    return resized_image

def blend_images(img1: np.ndarray,
                 img2: np.ndarray,
                 blend_ratios: Tuple[float, float] = (1, 1)) -> np.ndarray:
    """Blends two input images with specified blend ratios.

    Args:
        img1 (np.ndarray): First input image as a numpy array [H, W, C].
        img2 (np.ndarray): Second input image as a numpy array [H, W, C]
        blend_ratios (tuple): A tuple of two floats representing the blend
            ratios for the two input images.

    Returns:
        Blended image as a numpy array [H, W, C]
    """

    def normalize_image(image: np.ndarray) -> np.ndarray:
        if image.dtype == np.uint8:
            return image.astype(np.float32) / 255.0
        return image

    img1 = normalize_image(img1)
    img2 = normalize_image(img2)

    blended_image = img1 * blend_ratios[0] + img2 * blend_ratios[1]
    blended_image = blended_image.clip(min=0, max=1)
    blended_image = (blended_image * 255).astype(np.uint8)

    return blended_image

def get_smoothed_kpt(kpts, index, sigma=5):
    """Smooths keypoints using a Gaussian filter."""
    assert kpts.shape[1] == 17
    assert kpts.shape[2] == 3
    assert sigma % 2 == 1

    num_kpts = len(kpts)

    start_idx = max(0, index - sigma // 2)
    end_idx = min(num_kpts, index + sigma // 2 + 1)

    # Extract a piece of the keypoints array to apply the filter
    piece = kpts[start_idx:end_idx].copy()
    original_kpt = kpts[index]

    # Split the piece into coordinates and scores
    coords, scores = piece[..., :2], piece[..., 2]

    # Calculate the Gaussian ratio for each keypoint
    gaussian_ratio = np.arange(len(scores)) + start_idx - index
    gaussian_ratio = np.exp(-gaussian_ratio**2 / 2)

    # Update scores using the Gaussian ratio
    scores *= gaussian_ratio[:, None]

    # Compute the smoothed coordinates
    smoothed_coords = (coords * scores[..., None]).sum(axis=0) / (
        scores[..., None].sum(axis=0) + 1e-4)

    original_kpt[..., :2] = smoothed_coords

    return original_kpt

In [None]:
score, last_vis_score = 0, 0
video_writer = None
output_file = 'output.mp4'
stu_kpts = student_poses
tch_kpts = teacher_poses

In [None]:
from mmengine.structures import InstanceData

tch_video_reader = VideoReader(teacher_video)
stu_video_reader = VideoReader(student_video)
for _ in range(matched_piece_info['tch_start']):
    _ = next(tch_video_reader)
for _ in range(matched_piece_info['stu_start']):
    _ = next(stu_video_reader)
    
for i in track_iter_progress(range(matched_piece_info['length'])):
    tch_frame = mmcv.bgr2rgb(next(tch_video_reader))
    stu_frame = mmcv.bgr2rgb(next(stu_video_reader))
    tch_frame = resize_image_to_fixed_height(tch_frame, 300)
    stu_frame = resize_image_to_fixed_height(stu_frame, 300)

    stu_kpt = get_smoothed_kpt(stu_kpts, matched_piece_info['stu_start'] + i,
                               5)
    tch_kpt = get_smoothed_kpt(tch_kpts, matched_piece_info['tch_start'] + i,
                               5)

    # draw pose
    stu_kpt[..., 1] += (300 - 256)
    tch_kpt[..., 0] += (256 - 192)
    tch_kpt[..., 1] += (300 - 256)
    stu_inst = InstanceData(
        keypoints=stu_kpt[None, :, :2],
        keypoint_scores=stu_kpt[None, :, 2])
    tch_inst = InstanceData(
        keypoints=tch_kpt[None, :, :2],
        keypoint_scores=tch_kpt[None, :, 2])
    
    stu_out_img = pose_estimator.visualizer._draw_instances_kpts(
        np.zeros((300, 256, 3)), stu_inst)
    tch_out_img = pose_estimator.visualizer._draw_instances_kpts(
        np.zeros((300, 256, 3)), tch_inst)
    out_img = blend_images(
        stu_out_img, tch_out_img, blend_ratios=(1, 0.3))

    # draw score
    score_frame = matched_piece_info['similarity'][i]
    score += score_frame * 1000
    if score - last_vis_score > 1500:
        last_vis_score = score
    pose_estimator.visualizer.set_image(out_img)
    pose_estimator.visualizer.draw_texts(
        'score: ', (60, 30),
        font_sizes=15,
        colors=(255, 255, 255),
        vertical_alignments='bottom')
    pose_estimator.visualizer.draw_texts(
        f'{int(last_vis_score)}', (115, 30),
        font_sizes=30 * max(0.4, score_frame),
        colors=(255, 255, 255),
        vertical_alignments='bottom')
    out_img = pose_estimator.visualizer.get_image()   
    
    # concatenate
    concatenated_image = np.hstack((stu_frame, out_img, tch_frame))
    if video_writer is None:
        video_writer = cv2.VideoWriter(output_file,
                                       cv2.VideoWriter_fourcc(*'mp4v'),
                                       30,
                                       (concatenated_image.shape[1],
                                        concatenated_image.shape[0]))
    video_writer.write(mmcv.rgb2bgr(concatenated_image))

  

In [None]:
if video_writer is not None:
    video_writer.release()  