In [1]:
import os
import sys
import time
import cv2
import h5py
import torch as t
import ipywidgets as widgets
from warnings import warn
from IPython.display import display, clear_output

sys.path.append(os.path.abspath(os.path.curdir))
from network.rtpose_vgg import get_model
from network.post import decode_pose
from evaluate.coco_eval import get_multiplier, get_outputs, handle_paf_and_heat

  from ._conv import register_converters as _register_converters


## Paramenters & Constants

In [2]:
VIDEO_EXT = ['.mp4', '.avi', '.mpg', '.mpeg', '.mov']
IMAGE_EXT = ['.jpg', '.png', '.bmp', '.jpeg', '.jpe', '.tif', '.tiff']

In [3]:
input_data = r"/home/liuqixuan/datasets/UCF-101/"
input_type = "1to1" # choose from ["1to1", "nto1"]
output_dir = r"/home/liuqixuan/datasets/UCF-101_processed"
weight_file = r'./network/weight/pose_model.pth'
input_ext = "video" # either choose from ["image", "video"] or define selfish extension-names
output_ext = ".mp4"
frame_rate_ratio = 1 # analyze every [n] frames
process_speed = 2 # int, 1 (fastest, lowest quality) to 4 (slowest, highest quality)
resize_fac = 1.0 # minification factor
output_length = None # int, frame count for output, None for input length
show_visualize_process = True # show canvas through matplotlib
rebuild_exist_file = False

## Select GPU Devices

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

## Main Process Function

In [5]:
def process(model, oriImg, process_speed):
    # Get results of original image
    multiplier = get_multiplier(oriImg, process_speed)
    with t.no_grad():
        orig_paf, orig_heat = get_outputs(multiplier, oriImg, model, 'rtpose')

        # Get results of flipped image
        swapped_img = oriImg[:, ::-1, :]
        flipped_paf, flipped_heat = get_outputs(multiplier, swapped_img, model, 'rtpose')

        # compute averaged heatmap and paf
        paf, heatmap = handle_paf_and_heat(orig_heat, flipped_heat, orig_paf, flipped_paf)
    param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
    to_plot, canvas, joint_list, person_to_joint_assoc = decode_pose(oriImg, param, heatmap, paf)
    return to_plot, canvas, joint_list, person_to_joint_assoc

## Organize I/O Paths

In [6]:
def organize_1to1_io_paths(input_data, input_ext, output_dir, output_ext):
    if not os.path.exists(input_data):
        raise FileNotFoundError("File not exist in {}".format(input_data))
    io_paths = {"input": [], "output": []}
    if os.path.isdir(input_data):
        for root, dirs, files in os.walk(input_data):
            rel_path = os.path.relpath(root, input_data)
            for file in files:
                name, ext = os.path.splitext(file)
                if ext.lower() in input_ext:
                    input_path = os.path.join(root, file)
                    output_path = os.path.join(output_dir, rel_path, name + output_ext)
                    io_paths["input"].append(input_path)
                    io_paths["output"].append(output_path)
                else:
                    warn("Unsupported format: %s" % file)
    else:
        name, ext = os.path.splitext(input_data)
        assert ext.lower() in input_ext, "Unsupported format: %s" % input_data
        output_path = os.path.join(output_dir, os.path.basename(name) + output_ext)
        io_paths["input"].append(input_data)
        io_paths["output"].append(output_path)
    return io_paths

In [7]:
def organize_Nto1_io_paths(input_data, input_ext, output_dir, output_ext):
    if not os.path.exists(input_data):
        raise FileNotFoundError("File not exist in {}".format(input_data))
    io_paths = {"input": [], "output": []}
    if os.path.isdir(input_data):
        for root, dirs, files in os.walk(input_data):
            rel_path = os.path.relpath(root, input_data)
            image_list = []
            for file in files:
                name, ext = os.path.splitext(file)
                if ext.lower() in input_ext:
                    image_path = os.path.join(root, file)
                    image_list.append(image_path)
                else:
                    warn("Unsupported format: %s" % file)
            if len(image_list) > 0:
                output_path = os.path.join(output_dir, rel_path + output_ext)
                image_list = sorted(image_list)
                io_paths["input"].append(image_list)
                io_paths["output"].append(output_path)
    else:
        name, ext = os.path.splitext(input_data)
        assert ext.lower() in input_ext, "Unsupported format: %s" % input_data
        output_path = os.path.join(output_dir, os.path.basename(name) + output_ext)
        io_paths["input"].append([input_data])
        io_paths["output"].append(output_path)
    return io_paths

## Data Loader

In [8]:
def load_video_frames(video_path, output_length=None, frame_rate_ratio=1):
    cam = cv2.VideoCapture(video_path)
    assert cam.isOpened(), "Open Video %s Failed!" % video_path
    video_length = int(cam.get(cv2.CAP_PROP_FRAME_COUNT))
    if output_length is None:
        output_length = video_length
    i = 0  # default is 0
    while (cam.isOpened()) and i < output_length:
        ret_val, image = cam.read()
        if not ret_val:
            break
        if i % frame_rate_ratio == 0:
            yield image
        i += 1
    cam.release()


def get_video_size(video_path, output_length=None):
    cam = cv2.VideoCapture(video_path)
    assert cam.isOpened(), "Open Video %s Failed!" % video_path
    l = int(cam.get(cv2.CAP_PROP_FRAME_COUNT))
    if output_length is not None:
        l = min(l, output_length)
    h = int(cam.get(cv2.CAP_PROP_FRAME_HEIGHT))
    w = int(cam.get(cv2.CAP_PROP_FRAME_WIDTH))
    cam.release()
    return l, h, w

In [9]:
def load_image_frames(image_list, output_length=None, frame_rate_ratio=1):
    image_count = len(image_list)
    assert image_count > 0
    if output_length is None:
        output_length = image_count
    for i, path in enumerate(image_list):
        if i >= output_length:
            break
        if i % frame_rate_ratio == 0:
            image = cv2.imread(path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            yield image


def get_image_size(image_list, output_length=None):
    image_count = len(image_list)
    assert image_count > 0
    image = cv2.imread(image_list[0])
    l = len(image_list)
    if output_length is not None:
        l = min(l, output_length)
    h, w = image.shape[:2]
    return l, h, w

## Load Model

In [10]:
model = get_model('vgg19')
model.load_state_dict(t.load(weight_file))
model = t.nn.DataParallel(model)
model.cuda()
model.float()
model.eval()
print("Model Ready!")

Bulding VGG19
Model Ready!


## Init I/O Paths 
Select from `organize_image_io_paths` or `organize_video_io_paths`

In [11]:
_input_ext_ = IMAGE_EXT if input_ext == "image" \
    else VIDEO_EXT if input_ext == "video" \
    else input_ext if isinstance(input_ext, list) \
    else [input_ext]
print(input_type== "1to1")
print(_input_ext_)
if input_type == "1to1":
    io_paths = organize_1to1_io_paths(input_data, _input_ext_, output_dir, output_ext)
else:
    io_paths = organize_Nto1_io_paths(input_data, _input_ext_, output_dir, output_ext)
total_item = len(io_paths["input"])
print("Items count: ", total_item)

Items count:  502


## Calling Process

In [12]:
caption = widgets.Label("Ready to work!")
msg = widgets.Label('0/0, process time: 0.0s, total time: 0.0s')
bar = widgets.FloatProgress(
    value=0,
    min=0,
    max=1.0,
    description='[0/0]',
    bar_style='',  # 'success', 'info', 'warning', 'danger' or ''
    orientation='horizontal'
)
if show_visualize_process:
    imgbox = widgets.Image(format='jpg')
    proc_info = widgets.VBox([caption, widgets.HBox([msg, bar]), imgbox])
else:
    proc_info = widgets.VBox([caption, widgets.HBox([msg, bar])])
display(proc_info)
ignore_item = 0
for i, (input_data, output_path) in enumerate(zip(io_paths["input"], io_paths["output"])):
    if os.path.isfile(output_path):
        if rebuild_exist_file:
            title = '[{}/{}]Rebuild {} from {}'
        else:
            print('[{}/{}]{} already exist, pass'.format(i, total_item, output_path))
            ignore_item += 1
            continue
    else:
        title = '[{}/{}]Build {} from {}'
    if isinstance(input_data, str):  # process video
        source_position = input_data
        loader = load_video_frames(input_data, output_length, frame_rate_ratio)
        length, h, w = get_video_size(input_data, output_length)
    elif isinstance(input_data, list):  # process images
        source_position = os.path.dirname(input_data[0])
        loader = load_image_frames(input_data, output_length, frame_rate_ratio)
        length, h, w = get_image_size(input_data, output_length)
    else:
        raise TypeError("Expected string or list(string), but got %s" % type(input_data))
    caption.value = title.format(i, total_item, output_path, source_position)
    # Video writer
    try:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        output_fps = 15
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        height = int(resize_fac * h)
        width = int(resize_fac * w)
        caption.value += "\nsource:{}x{}  target:{}x{}".format(h, w, height, width)
        if imgbox:
            imgbox.width = width
            imgbox.height = height
        out = cv2.VideoWriter(output_path, fourcc, output_fps, (width, height))
        out_h5 = h5py.File(output_path + ".h5", mode="w")
        out_h5["height"] = height
        out_h5["width"] = width
        t0 = time.time()
        for i, image in enumerate(loader):
            t1 = time.time()
            # generate image with body parts
            resized_image = cv2.resize(image, (0, 0), fx=1 * resize_fac, fy=1 * resize_fac,
                                       interpolation=cv2.INTER_CUBIC)
            to_plot, canvas, joint_list, person_to_joint_assoc = process(model, resized_image, process_speed)
            # save outputs
            out.write(canvas)
            frame_h5 = out_h5.create_group("frame%d" % i)
            frame_h5.create_dataset("joint_list", data=joint_list)
            frame_h5.create_dataset("person_to_joint_assoc", data=person_to_joint_assoc)
            t2 = time.time()
            # print messages
            msg.value = '{}  process time:{:.3f}s  total time:{:.3f}s'.format(
                time.strftime('%H:%M:%S'), (t2 - t1), (t2 - t0))
            bar.description = '[{}/{}]'.format(i, length)
            bar.value = i / length
            if show_visualize_process:
                canvas = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
                imgbox.value = np.array(cv2.imencode('.jpg', canvas)[1]).tostring()
    finally:
        out.release()
        out_h5.close()
clear_output()
print("Prosessed {} items, ignore {} existing items. Saved into {}".format(
    total_item - ignore_item, ignore_item, output_dir))
print("All work are Finished！")
exit()  # clean GPU memories

Prosessed 0 items, ignore 502 existing items. Saved into /home/liuqixuan/datasets/actions/val
All work are Finished！
