# 视频3d姿势估计

In [None]:
import cv2
import matplotlib.pyplot as plt
import detectron2
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
import os
import time
import torch
import sys
import numpy as np
sys.path.append("../../")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 1. 读取视频和显示

In [None]:
# os.chdir('/home/li/python/pose-estimation/3d/VideoPose3D/data/pose2d-detectron2/')
video_path = '../videos/lizijun.mp4'

def read_video(filepath):
    cap = cv2.VideoCapture(filepath)
    # 帧率
    fps = cap.get(cv2.CAP_PROP_FPS)
    pause = int(1000 * (1/fps))
    # 宽高
    cv2.namedWindow('Video', 0)
    cv2.resizeWindow('Video',  1280, 720)
    while True:
        #  获取帧
        ret_val, frame = cap.read()
        if ret_val != 1:
            break
        # 显示帧
        cv2.imshow('Video', frame)
        if cv2.waitKey(pause) & 0xff == ord('q'):
            break
    cap.release()
    cv2.destroyAllWindows()

# read_video(video_path)

## 2. 2d关键点检测器

### 2.1 加载模型

In [None]:
# 切换工作目录
# os.chdir('/home/li/python/pose-estimation/3d/VideoPose3D/data/pose2d-detectron2/')

def init_kps_predictor(config_path, weights_path, cuda=True):
    cfg = get_cfg()
    cfg.merge_from_file(config_path)
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
    cfg.MODEL.WEIGHTS = weights_path
    if cuda == False:
        cfg.MODEL.DEVICE='cpu'
    predictor = DefaultPredictor(cfg)
    
    return predictor

model_config_path = './keypoint_rcnn_R_50_FPN_3x.yaml'
model_weights_path = './model_R50.pkl'
kps_predictor = init_kps_predictor(model_config_path, model_weights_path, cuda=False)

### 2.2 2d关键点检测和处理

In [None]:
def normalize_screen_coordinates(X, w, h): 
    assert X.shape[-1] == 2
    
    # Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio
    return X/w*2 - [1, h/w]

def predict_kps(kps_predictor, img):
    '''
        kps_predictor: The detectron's 2d keypoints predictor
        img_generator:  Images source
    '''
    # Predict kps:
    pose_output = kps_predictor(img)

    if len(pose_output["instances"].pred_boxes.tensor) > 0:
        kps = pose_output["instances"].pred_keypoints[0].cpu().numpy()
    else:
        kps = np.full((17,3), np.nan, dtype=np.float32)   # nan for images that do not contain human
    # 标准化，去掉概率列，只保留坐标值
    kps = normalize_screen_coordinates(kps[..., :2], w=img.shape[1], h=img.shape[0])
    return kps

## 3. 3d姿势估计器

### 3.1 加载模型

In [None]:
# prev_dir = os.getcwd()
# os.chdir('/home/li/python/pose-estimation/3d/VideoPose3D')
# print("Change work dir from {} to {}".format(prev_dir, os.getcwd()))
from common.model import *


def get_pose3d_predictor(ckpt_dir, ckpt_name, filter_widths, causal=False):
    ckpt_path = os.path.join(ckpt_dir, ckpt_name)
    print('Loading checkpoint', ckpt_path)
    checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
    print('This model was trained for {} epochs'.format(checkpoint['epoch']))
    
    pose3d_predictor = TemporalModel(17, 2 ,17, filter_widths=filter_widths, causal=causal)
    receptive_field = pose3d_predictor.receptive_field()
    print('INFO: Receptive field: {} frames'.format(receptive_field))
    pose3d_predictor.load_state_dict(checkpoint['model_pos'])
    
    if torch.cuda.is_available():
        pose3d_predictor = pose3d_predictor.cuda()
    
    return pose3d_predictor.eval()


ckpt_dir = '../../checkpoint/detectron_pt_coco'
ckpt_name = 'arc_1_epoch_40.bin'
filter_widths = [1,1,1]
pose3d_predictor = get_pose3d_predictor(ckpt_dir, ckpt_name, filter_widths)

### 3.2 2d关键点生成器

In [None]:
from common.camera import *
from common.generators import UnchunkedGenerator

kps_left=[1, 3, 5, 7, 9, 11, 13, 15]
kps_right=[2, 4, 6, 8, 10, 12, 14, 16]
def kps_generator(pose3d_predictor, kps):
#     receptive_field = pose3d_predictor.receptive_field()
#     pad = (receptive_field - 1) // 2  # Padding on each side
    pad = 0
    causal_shift = 0
    
#     print('kps.shape:' , kps.shape)
    # 创建生成器作为3d预测器的输入
    generator = UnchunkedGenerator(None, None, [kps], pad=pad, causal_shift=causal_shift, augment=True, kps_left=kps_left, kps_right=kps_right)
    return generator


### 3.3 图像渲染函数

In [None]:
from matplotlib.animation import FuncAnimation, writers
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm

# 画图
def render_image(keypoints, pos_3d, skeleton, azim, input_video_frame=None):
    fig = plt.figure(figsize=(12, 6), dpi=120)
    canvas = FigureCanvas(fig)
    # plot input frame
    ax_in = fig.add_subplot(1, 2, 1)
    ax_in.get_xaxis().set_visible(False)
    ax_in.get_yaxis().set_visible(False)
    ax_in.set_axis_off()
    ax_in.set_title('Input')
    ax_in.imshow(input_video_frame, aspect='equal')
    
    # 3D
    ax = fig.add_subplot(1, 2, 2, projection='3d')
    ax.view_init(elev=15., azim=azim)
    # set 长度范围
    radius = 2.0
    ax.set_xlim3d([-radius / 2, radius / 2])
    ax.set_zlim3d([0, radius])
    ax.set_ylim3d([-radius / 2, radius / 2])
    ax.set_aspect('equal')
    # 坐标轴刻度
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])
    ax.dist = 7.5
    ax.set_title('3D Pose Reconstruction')

    # lxy add
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')

    # array([-1,  0,  1,  2,  0,  4,  5,  0,  7,  8,  9,  8, 11, 12,  8, 14, 15])
    parents = skeleton.parents()

    pos = pos_3d
    for j, j_parent in enumerate(parents):
        if j_parent == -1:
            continue

        if len(parents) == keypoints.shape[1]:
            color_pink = 'pink'
            if j == 1 or j == 2:
                color_pink = 'black'

        col = 'red' if j in skeleton.joints_right() else 'black'
        # 画图3D
        ax.plot([pos[j, 0], pos[j_parent, 0]],
                [pos[j, 1], pos[j_parent, 1]],
                [pos[j, 2], pos[j_parent, 2]], zdir='z', c=col)

    width, height = fig.get_size_inches() * fig.get_dpi()
    canvas.draw()  # draw the canvas, cache the renderer
    image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3)
    plt.close()
    return image

### 3.4 3d关键点预测

In [None]:
joints_left = [4, 5, 6, 11, 12, 13] 
joints_right = [1, 2, 3, 14, 15, 16]

class Skeleton:
    def parents(self):
        return np.array([-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15])

    def joints_right(self):
        return [1, 2, 3, 9, 10]

# 预测3d坐标
def predict_3d_pos(test_generator, predictor):
    with torch.no_grad():
        predictor.eval()
        for _, batch, batch_2d in test_generator.next_epoch():
            inputs_2d = torch.from_numpy(batch_2d.astype('float32'))
            if torch.cuda.is_available():
                inputs_2d = inputs_2d.cuda()
            # Positional model
            predicted_3d_pos = predictor(inputs_2d)
            
            # Test-time augmentation (if enabled)
            if test_generator.augment_enabled():
                # Undo flipping and take average with non-flipped version
                predicted_3d_pos[1, :, :, 0] *= -1
                predicted_3d_pos[1, :, joints_left + joints_right] = predicted_3d_pos[1, :, joints_right + joints_left]
                predicted_3d_pos = torch.mean(predicted_3d_pos, dim=0, keepdim=True)
            
            return predicted_3d_pos.squeeze(0).cpu().numpy()

### 3.5 视频3d姿势估计

In [None]:
def video_pose(filepath, 
               ckpt_dir = '../../checkpoint/detectron_pt_coco', 
               ckpt_name = 'arc_1_epoch_40.bin',
               filter_widths = [1,1,1],
               show=False):
    
    # 加载3d姿势估计器
    pose3d_predictor = get_pose3d_predictor(ckpt_dir, ckpt_name, filter_widths)
    
    # 初始化2d检测器
    model_config_path = './keypoint_rcnn_R_50_FPN_3x.yaml'
    model_weights_path = './model_R50.pkl'
    kps_predictor = init_kps_predictor(model_config_path, model_weights_path, cuda=False)
    
    receive_field = 1
    for i in filter_widths:
        receive_field *= i
#     print(receive_field)
    half = receive_field // 2
    # 读取视频
    cap = cv2.VideoCapture(filepath)
    cap.set(3,1080) #设置分辨率
    cap.set(4,720)
    # 帧率
    fps = cap.get(cv2.CAP_PROP_FPS)
    pause = int(1000 * (1/fps))
    
    if show:
        # 宽高
        cv2.namedWindow('Video', 0)
        cv2.resizeWindow('Video',  960, 480)
    
    # 帧率
    fps = cap.get(cv2.CAP_PROP_FPS)
    # 保存视频文件
    wh = (1080, 720)
    fourcc = cv2.VideoWriter_fourcc(*'MP4V')
    output_avi = cv2.VideoWriter('output.mp4', fourcc, fps, wh)

    kps_list = []
    frame_list = []
    i = 0
    # 因为设置了数据生成器的pad=0，因此需要获取前receive_field//2帧做准备
    print("Preparing...")
    while i < half:
        ret_val, frame = cap.read()
        if ret_val != 1:
                print("Video is too short!")
                output_avi.release()
                cap.release()
                cv2.destroyAllWindows()
                return
        try:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame_list.append(frame)
        except:
            continue
        # 生成2d关键点
        kps = predict_kps(kps_predictor, frame)
        kps_list.append(kps)
        i += 1
    
    print("Starting to predict 3d pose...")
    fps_time = time.time()
    while True:
        #  获取帧
        i += 1
        if len(frame_list) < 1:
            break
        ret_val, frame = cap.read()
        if ret_val == 1:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame_list.append(frame)
            # 生成2d关键点
            kps = predict_kps(kps_predictor, frame)
            kps_list.append(kps)
        
        frame = frame_list[0]
        if i > half + 1:
            # 去除最左端的无用帧
            kps_list = kps_list[1:]
            frame_list = frame_list[1:]
        
        if len(kps_list) < receive_field:
            if i < receive_field:
                # 视频开头小于receive_field帧时，在左边进行pad操作
#                 print("kps_list length is {}, padding {} frames to left end.".format(len(kps_list), half))
                while len(kps_list) < receive_field:
                    kps_list.insert(0, kps_list[0])
            else:
                # 视频末尾不足receive_field帧时，在右边进行pad操作
#                 print("kps_list length is {}, padding 1 frames to right end.".format(len(kps_list)))
                kps_list.append(kps_list[-1])
        
        # 构造2d关键点生成器
        kps_2d = np.stack(kps_list)
        generator = kps_generator(pose3d_predictor, kps_2d)
#         print(generator.num_frames())
        
        # 3d关键点预测
        predictions = predict_3d_pos(generator, pose3d_predictor)
#         print('predictions.shape: ', predictions.shape)

        rot = np.array([0.14070565, -0.15007018, -0.7552408, 0.62232804], dtype=np.float32)
        predictions = camera_to_world(predictions, R=rot, t=0)
        # We don't have the trajectory, but at least we can rebase the height
        predictions[:, :, 2] -= np.min(predictions[:, :, 2])
        
        pos_3d = predictions[0]
        kps_2d = image_coordinates(kps_2d[..., :2], w=1080, h=720)
        pos_2d = kps_2d[0]
        
#         print('predicted {} frame, elapsed time: {:.3f} seconds.'.format(predictions.shape[0], time.time() - fps_time))
        fps = 1.0 / (time.time() - fps_time)
        
        # 渲染图像
        result_image = render_image(pos_2d, pos_3d=pos_3d, skeleton=Skeleton(), azim=np.array(70., dtype=np.float32),  input_video_frame=frame)
        cv2.putText(result_image, "FPS: %f" % (fps), (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
        result_image = cv2.cvtColor(result_image, cv2.COLOR_RGB2BGR)

        # resize and write
        to_write = cv2.resize(result_image, wh)
        output_avi.write(to_write)

        if show:
            # 显示帧
            cv2.imshow('Video', result_image)
            if cv2.waitKey(pause) & 0xff == ord('q'):
                break
        fps_time = time.time()
    output_avi.release()
    cap.release()
    cv2.destroyAllWindows()
# os.chdir('/home/li/python/pose-estimation/3d/VideoPose3D/data/pose2d-detectron2/')
video_path = '../videos/huaban_01-08.mp4'
video_pose(video_path, ckpt_name = 'arc_27_epoch_40.bin', filter_widths=[3, 3 ,3])
print("Finish prediction...") 

1. 读取27帧图像并检测出2d关键点坐标
2. 构造2d关键点生成器
3. 预测3d关键点坐标
4. 显示