In [1]:
import json
import trt_pose.coco
import torch
from torch2trt import TRTModule
import cv2
import torchvision.transforms as transforms
import PIL.Image
import time
from trt_pose.draw_objects import DrawObjects
from trt_pose.parse_objects import ParseObjects


#update model path if need be
with open('trt_pose/tasks/human_pose/human_pose.json', 'r') as f:
    human_pose = json.load(f)

topology = trt_pose.coco.coco_category_to_topology(human_pose)

In [2]:
#update model path if need be
OPTIMIZED_MODEL = 'resnet18_baseline_att_224x224_A_epoch_249_trt.pth'
#OPTIMIZED_MODEL = 'densenet121_baseline_att_256x256_B_epoch_160_trt.pth'

model_trt = TRTModule()
model_trt.load_state_dict(torch.load(OPTIMIZED_MODEL))

<All keys matched successfully>

In [None]:
mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()
std = torch.Tensor([0.229, 0.224, 0.225]).cuda()
device = torch.device('cuda')

def preprocess(image):
    global device
    device = torch.device('cuda')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = PIL.Image.fromarray(image)
    image = transforms.functional.to_tensor(image).to(device)
    image.sub_(mean[:, None, None]).div_(std[:, None, None])
    return image[None, ...]

In [None]:
parse_objects = ParseObjects(topology)
draw_objects = DrawObjects(topology)

In [None]:
def poseProc(video):
    cap = cv2.VideoCapture(video)
    
    WIDTH = 224
    HEIGHT = 224
    #WIDTH = 256
    #HEIGHT = 256
    
    numFrame = 0
    if cap.isOpened():
        t0 = time.time()
        torch.cuda.current_stream().synchronize()
        while(cap.isOpened()):
            ret, frame = cap.read()
            if ret == False:
                break
            else:
                frame = cv2.resize(frame, (WIDTH,HEIGHT), interpolation =cv2.INTER_LINEAR)
                data = preprocess(frame)
                cmap, paf = model_trt(data)
                cmap, paf = cmap.detach().cpu(), paf.detach().cpu()
                counts, objects, peaks = parse_objects(cmap, paf)
                draw_objects(frame, counts, objects, peaks)
                #image_w.value = bgr8_to_jpeg(frame[:, ::-1, :])
                cv2.imshow('frame', frame)
                numFrame += 1
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break 
        torch.cuda.current_stream().synchronize()
        t1 = time.time()
    
    print("FPS: " + str(numFrame / (t1 - t0))) 
    cap.release()
    cv2.destroyAllWindows()
    

In [None]:
#update model path if need be
video = 'pedestrian.mp4'
poseProc(video)