In [1]:
import os
import cv2
import numpy as np
from PIL import Image

import torch
import torchvision
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms

from post_processing.inference import get_final_preds
from post_processing.transforms import get_affine_transform

from pose_hr_net import get_pose_net

from config.default_configuration import _C as cfg
from config.default_configuration import update_config, COCO_INSTANCE_CATEGORY_NAMES, joints


In [2]:
def get_person_detection_boxes(model, img, threshold=0.5):
    pil_image = Image.fromarray(img)  
    transform = transforms.Compose([transforms.ToTensor()])  
    transformed_img = transform(pil_image)  
    pred = model([transformed_img])  
    pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i]
                    for i in list(pred[0]['labels'].numpy())]  
    pred_boxes = [[(i[0], i[1]), (i[2], i[3])]
                  for i in list(pred[0]['boxes'].detach().numpy())]  
    pred_score = list(pred[0]['scores'].detach().numpy())
    if not pred_score:
        return []
    
    pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
    pred_boxes = pred_boxes[:pred_t+1]
    pred_classes = pred_classes[:pred_t+1]

    person_boxes = []
    for idx, box in enumerate(pred_boxes):
        if pred_classes[idx] == 'person':
            person_boxes.append(box)

    return person_boxes


def get_pose_estimation_prediction(pose_model, image, center, scale):
    rotation = 0

    trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE)
    model_input = cv2.warpAffine(image, trans, (int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])),
        flags=cv2.INTER_LINEAR)
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    model_input = transform(model_input).unsqueeze(0)
    pose_model.eval()
    with torch.no_grad():
        output = pose_model(model_input)
        preds, _ = get_final_preds(
            cfg,
            output.clone().cpu().numpy(),
            np.asarray([center]),
            np.asarray([scale]))

        return preds


def box_to_center_scale(box, model_image_width, model_image_height):

    center = np.zeros((2), dtype=np.float32)

    bottom_left_corner = box[0]
    top_right_corner = box[1]
    box_width = top_right_corner[0]-bottom_left_corner[0]
    box_height = top_right_corner[1]-bottom_left_corner[1]
    bottom_left_x = bottom_left_corner[0]
    bottom_left_y = bottom_left_corner[1]
    center[0] = bottom_left_x + box_width * 0.5
    center[1] = bottom_left_y + box_height * 0.5

    aspect_ratio = model_image_width * 1.0 / model_image_height
    pixel_std = 200

    if box_width > aspect_ratio * box_height:
        box_height = box_width * 1.0 / aspect_ratio
    elif box_width < aspect_ratio * box_height:
        box_width = box_height * aspect_ratio
    scale = np.array([box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std], dtype=np.float32)
    if center[0] != -1:
        scale = scale * 1.25

    return center, scale


In [3]:
pose_dir = 'outputs/poses/'
box_dir = 'outputs/boxes/'
images_dir = 'outputs/poses/'

class Args:
  cfg = 'additional_files/config/inference-config.yaml'
  videoFile = 'data/spinning.mp4'
  outputDir = 'outputs/'
  inferenceFps = 10
  writeBoxFrames = True
  MODEL_FILE = 'additional_files/models/pytorch/pose_coco/pose_hrnet_w32_384x288.pth'
  modelDir = ''
  logDir = ''
  dataDir = ''

In [4]:
# cudnn related setting
cudnn.benchmark = cfg.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

In [5]:
args=Args()
update_config(cfg, args)

box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
box_model.cuda()
box_model.eval()

pose_model = get_pose_net(cfg, is_train=False)
pose_model.load_state_dict(torch.load(args.MODEL_FILE), strict=False)
pose_model.cuda()
pose_model = torch.nn.DataParallel(pose_model, device_ids=cfg.GPUS)


In [9]:
# original 
def run_video(input_video):
    vidcap = cv2.VideoCapture(input_video)
    fps = vidcap.get(cv2.CAP_PROP_FPS)
    if fps < args.inferenceFps:
        print('desired inference fps is '+str(args.inferenceFps)+' but video fps is '+str(fps))
        exit()
    every_nth_frame = round(fps/args.inferenceFps)

    success, image_bgr = vidcap.read()
    count = 0
    while success:
        if count % every_nth_frame != 0:
            success, image_bgr = vidcap.read()
            count += 1
            continue

        image = image_bgr[:, :, [2, 1, 0]]
        count_str = str(count).zfill(32)

        # object detection box
        pred_boxes = get_person_detection_boxes(box_model, image, threshold=0.8)

        if not pred_boxes:
            success, image_bgr = vidcap.read()
            count += 1
            continue

        for box in pred_boxes:
            # pose estimation
            center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
            image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy()
            pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale)    

            for _, mat in enumerate(pose_preds[0]):
                x_coord, y_coord = int(mat[0]), int(mat[1])
                cv2.circle(image_bgr, (x_coord, y_coord), 4, (255, 0, 0), 1)

            for _, joint in enumerate(joints['coco']['skeleton']):
                pt1, pt2 = pose_preds[0][joint]
                cv2.line(image_bgr, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1])), (80, 80, 255), 1)

            x,y,w,h = cv2.boundingRect(pose_preds[0])
            cv2.rectangle(image_bgr, (x-10,y-10), (x+w+10,y+h+10), (80, 80, 255), thickness=1)            
            cv2.rectangle(image_bgr, box[0], box[1], color=(180, 180, 0), thickness=1)

        cv2.imwrite(pose_dir+'pose%s.jpg' % count_str, image_bgr)

        # get next frame
        success, image_bgr = vidcap.read()
        count += 1

    import imageio

    images = []
    for file_name in os.listdir(pose_dir):
        if file_name.endswith('.jpg'):
            file_path = os.path.join(pose_dir, file_name)
            images.append(imageio.imread(file_path))
    imageio.mimsave('outputs/movie.gif', images)    
    

In [10]:
run_video(args.videoFile)


In [9]:
# # single image
# def run(image_bgr):
#     image = image_bgr[:, :, [2, 1, 0]]
#     # object detection box
#     pred_boxes = get_person_detection_boxes(box_model, image, threshold=0.8)
#     image_bgr_box = image_bgr.copy()
#     for box in pred_boxes:
#         cv2.rectangle(image_bgr_box, box[0], box[1], color=(0, 255, 0), thickness=1)
#         if pred_boxes:
#             # pose estimation
#             center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
#             image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy()
#             pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale)    
# 
#             for _, mat in enumerate(pose_preds[0]):
#                 x_coord, y_coord = int(mat[0]), int(mat[1])
#                 cv2.circle(image_bgr, (x_coord, y_coord), 4, (255, 0, 0), 1)
# 
#             for i, joint in enumerate(joints['coco']['skeleton']):
#                 pt1, pt2 = pose_preds[0][joint]
#                 cv2.line(image_bgr, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1])), (80, 80, 255), 1)
# 
#             x,y,w,h = cv2.boundingRect(pose_preds[0])
#             cv2.rectangle(image_bgr, (x-10,y-10), (x+w+10,y+h+10), (80, 80, 255), thickness=1)            
#             cv2.rectangle(image_bgr, box[0], box[1], color=(180, 180, 0), thickness=1)
# 
#     cv2.imwrite('single_image.png', image_bgr)


In [10]:
# import time
# 
# image_bgr = cv2.imread('data/22.png')
# 
# start_time = time.time()
# run(image_bgr)
# stop_time = time.time()
# print('elapsed time: {}s'.format(stop_time-start_time))


elapsed time: 19.432769060134888s
