In [1]:
import torch

In [2]:
!apt-get install -y ffmpeg 

Reading package lists... Done
Building dependency tree       
Reading state information... Done
ffmpeg is already the newest version (7:3.4.8-0ubuntu0.2).
The following packages were automatically installed and are no longer required:
  gyp libc-ares2 libhttp-parser2.7.1 libjs-async libjs-inherits
  libjs-node-uuid libjs-underscore libuv1-dev node-abbrev node-ansi
  node-ansi-color-table node-archy node-async node-balanced-match
  node-block-stream node-brace-expansion node-builtin-modules
  node-combined-stream node-concat-map node-cookie-jar node-delayed-stream
  node-forever-agent node-form-data node-fs.realpath node-fstream
  node-fstream-ignore node-github-url-from-git node-glob node-graceful-fs
  node-hosted-git-info node-inflight node-inherits node-ini
  node-is-builtin-module node-isexe node-json-stringify-safe node-lockfile
  node-lru-cache node-mime node-minimatch node-mkdirp node-mute-stream
  node-node-uuid node-nopt node-normalize-package-data node-npmlog node-once
  node-

In [3]:
import os
import numpy as np
import cv2
import shutil

from detectron2.config import get_cfg
from predictor import VisualizationDemo
from multi_person_tracker import MPT
from multi_person_tracker.data import video_to_images

In [4]:
import xml.etree.ElementTree as elemTree

In [5]:
%matplotlib inline
from matplotlib import pyplot as plt

In [6]:
PATH = "../data/103-6"

In [7]:
def init_cfg():
    # load config from file and command-line arguments
    cfg = get_cfg()
    cfg.merge_from_file("../detectron2/configs/COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x.yaml")
    cfg.MODEL.WEIGHTS = "detectron2://COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x/138363331/model_final_997cc7.pkl"
    # Set score_threshold for builtin models
    cfg.MODEL.RETINANET.SCORE_THRESH_TEST = 0.5
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
    cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.5
    cfg.freeze()
    return cfg

In [8]:
def get_xmls(PATH):
    filenames = []
    for filename in os.listdir(PATH):
        if filename[-4:] == ".xml":
            filenames.append(filename)
    return filenames

In [9]:
def get_frame(cap, frame_num):
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
    return cap.read()


def get_frames_with_5fps(cap, frame_start, frame_end, fps, resize=True):
    if fps % 5 != 0:
        print("wrong frame num")
        return None
    
    frame_step = int(fps / 5)
    
    frames = []
    i = 0
    while(True):
        sample_frame_num = frame_start + i * frame_step
        if sample_frame_num > frame_end:
            break
        _, frame = get_frame(cap, sample_frame_num)
        if resize:
            frame = cv2.resize(frame, dsize=(960, 540))
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame)
        i += 1
    return frames

def plot_frame(frame, keypoint_coord=None):
    plt.imshow(frame)
    if keypoint_coord is not None:
        plt.scatter(keypoint_coord[0], keypoint_coord[1], marker="o")
    plt.show()

In [10]:
filenames = get_xmls(PATH)

In [None]:
xml_filename = filenames[1]
tree = elemTree.parse(os.path.join(PATH, xml_filename))
video_filename = tree.find("filename").text
cap = cv2.VideoCapture(os.path.join(PATH, video_filename))
print(video_filename)

# Frame per second
fps = int(tree.find("header").find("fps").text)
frames = int(tree.find("header").find("frames").text)
print(f"fps {fps} frames {frames}")

# Frame width, height
width = tree.find("size").find("width").text
height = tree.find("size").find("height").text
print(f"width: {width} height: {height}")

# Frame start, end pairs of fall down action
action_ranges = []
action_frames = tree.find("object").find("action").findall("frame")
for action_frame in action_frames:
    action_ranges.append((int(action_frame.find("start").text), 
                          int(action_frame.find("end").text)))
print(action_ranges)

# Keyframe information, keyframe number and ground truth coordinate
keyframe = int(tree.find("object").find("position").find("keyframe").text)
keypoint = tree.find("object").find("position").find("keypoint")
keypoint_coord = (int(keypoint.find("x").text), int(keypoint.find("y").text))

print(f"keyframe {keyframe}, coord: {keypoint_coord}")

#     # Show keyframe
#     _, frame = get_frame(cap, keyframe)
#     print(frame.shape)
#     plot_frame(frame, keypoint_coord)

# # Show all frames in the first action range
# frames = get_frames_with_5fps(cap, action_ranges[0][0], action_ranges[0][1], fps=fps)
# print(len(frames))
# # input_frames = torch.Tensor(frames)
# for frame in frames[-100:]:
#     plot_frame(frame)

# Object detection from detectron
frames = get_frames_with_5fps(cap, action_ranges[0][0], action_ranges[0][1], fps=fps)
cfg = init_cfg()
demo = VisualizationDemo(cfg)
for idx, frame in enumerate(frames):
    predictions, visualized_output = demo.run_on_image(frame)
#     plot_frame(visualized_output.get_image()[:, :, ::-1], keypoint_coord)
    visualized_output.save(f"./output/keyframe_{video_filename}_{idx}_keypoints.jpg")
#     print(predictions)

103-6_cam01_swoon01_place04_day_summer.mp4
fps 30 frames 8568
width: 3840 height: 2160
[(6316, 6682), (6682, 6980)]
keyframe 6316, coord: (2276, 879)


model_final_997cc7.pkl:  79%|███████▉  | 248M/313M [00:18<00:11, 5.61MB/s]   

In [None]:
# def inference(path, event_range_tuple):
#     imgsz = 960
#     mot = MPT(
#         display=True,
#         detector_type='yolo',  # 'maskrcnn'
#         batch_size=1,
#         detection_threshold=0.7,
#         yolo_img_size=imgsz,
#     )
#     image_folder = video_to_images(path, event_range_tuple=event_range_tuple)
#     result = mot(image_folder, output_file='sample.mp4')
    
# #     shutil.rmtree(image_folder)
    
#     return result

In [None]:
# xml_filename = filenames[0]
# tree = elemTree.parse(os.path.join(PATH, xml_filename))
# video_filename = tree.find("filename").text
# ss = tree.find("event").find("starttime").text
# t = tree.find("event").find("duration").text
# image_folder = os.path.join(PATH, video_filename)
# result = inference(image_folder, (ss, t))