# Real-time hand pose estimation with NVIDIA TensorRT model
This is adaptation of human pose estimatoin with new [model for hand pose](https://github.com/NVIDIA-AI-IOT/trt_pose_hand) NVIDIA
published just recently (17.12.2020). Get the new repository and [hand model weights](https://drive.google.com/file/d/1NCVo0FiooWccDzY7hCc5MAKaoUpts3mo)

In [1]:
import os
import cv2
import json
import ipywidgets
import jetson.utils
import PIL.Image
import numpy as np
import torch
import time
import trt_pose.coco

from jetutils import SimpleTimer
from IPython.display import display
from sidecar import Sidecar
from torch2trt import TRTModule
import torchvision.transforms as transforms

In [13]:
MODEL_PATH = 'models/hand_pose_resnet18_att_244_244_trt.pth'
MODEL_W = 224
MODEL_H = 224

with open('trt_pose_hand/preprocess/hand_pose.json', 'r') as f:
    hand_pose = json.load(f)

topology = trt_pose.coco.coco_category_to_topology(hand_pose)

# colors for links starting from named keypoint (BGR)
COLOR_MAP = {
    'palm': (255,255,255),
    'thumb': (180,180,255),
    'index': (180,255,180),
    'middle': (255,180,180),
    'ring': (255,180,255),
    'baby': (180,180,180),
}

For the first time only we need to convert the downloaded model weights into TensorRT.

In [3]:
if not os.path.exists(MODEL_PATH):
    import trt_pose.models
    import torch2trt
    totals = SimpleTimer()
    timer = SimpleTimer()
    with totals:
        MODEL_WEIGHTS = 'models/hand_pose_resnet18_att_244_244.pth'
        num_parts = len(hand_pose['keypoints'])
        num_links = len(hand_pose['skeleton'])
        with timer:
            model = trt_pose.models.resnet18_baseline_att(num_parts, 2 * num_links).cuda().eval()
        print('model with {} parts and {} links generated in {} seconds'.format(num_parts,
                                                                                num_links,
                                                                                timer.time))
        with timer:
            model.load_state_dict(torch.load(MODEL_WEIGHTS))
        print('model weights loaded in {} secods'.format(timer.time))
        with timer:
            trt_model = torch2trt.torch2trt(model,
                                            [torch.zeros((1, 3, MODEL_H, MODEL_W)).cuda()],
                                            fp16_mode=True,
                                            max_workspace_size=1<<20)
        print('model converted to TensorRT in {} seconds'.format(timer.time))
        torch.save(trt_model.state_dict(), MODEL_PATH)
    print('total time {} seconds'.format(totals.time))
else:
    timer = SimpleTimer()
    with timer:
        trt_model = TRTModule()
        trt_model.load_state_dict(torch.load(MODEL_PATH))
    print('model loaded in {} seconds'.format(timer.time))

model loaded in 29.26948960100708 seconds


In [14]:
import threading
import traitlets
from trt_pose.parse_objects import ParseObjects

from collections import namedtuple
Keypoint = namedtuple('Keypoint', 'name x y')
KPLink = namedtuple('KPLink', 'x0 y0 x1 y1 color')
_timer = SimpleTimer()

class HandPose(traitlets.HasTraits):
    input_frame = traitlets.Any()
    output_frame = traitlets.Any()
    
    image_width = traitlets.Integer()
    image_height = traitlets.Integer()
    width = traitlets.Integer(default_value=MODEL_W)
    height = traitlets.Integer(default_value=MODEL_H)
    
    draw_background = traitlets.Bool(default_value=True)
    draw_skeleton = traitlets.Bool(default_value=True)
    draw_labels = traitlets.Bool(default_value=False)
    
    _X = 1
    _Y = 0
    
    def __init__(self, model, topology, keypoints, 
                 *args, **kwargs):
        super(HandPose, self).__init__(*args, **kwargs)
        self.input_frame = np.empty((self.image_height, self.image_width, 3), dtype=np.uint8)
        self.output_frame = np.empty((self.image_height, self.image_width, 3), dtype=np.uint8)
        self._running = False 
        self._model = model
        self._parse = ParseObjects(topology)
        self._topology = topology
        self._keypoints = keypoints
        self._link_colors = dict()
        for link in sorted([(int(topology[a][2]),
                             int(topology[a][3]))
                            for a in range(topology.shape[0])]):
            self._link_colors[link] = COLOR_MAP[keypoints[link[0]].split('_')[0]]
        self._mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()
        self._std = torch.Tensor([0.229, 0.224, 0.225]).cuda()
        self._device = torch.device('cuda')

    def _get_hand_keypoints(self, hands, index, peaks):
        keypoints = []
        hand = hands[0][index]
        C = hand.shape[0]
        for j in range(C):
            k = int(hand[j])
            if k >= 0:
                peak = peaks[0][j][k]
                keypoints.append(Keypoint(
                                          name=self._keypoints[j],
                                          x=round(float(peak[HandPose._X]) * self.image_width),
                                          y=round(float(peak[HandPose._Y]) * self.image_height)
                                         ))
        return keypoints

    def _get_keypoint_links(self, hands, index, peaks):
        links = []
        hand = hands[0][index]
        K = self._topology.shape[0]
        for k in range(K):
            c_a = self._topology[k][2]
            c_b = self._topology[k][3]
            if hand[c_a] >= 0 and hand[c_b] >= 0:
                peak0 = peaks[0][c_a][hand[c_a]]
                peak1 = peaks[0][c_b][hand[c_b]]
                clink = (int(c_a), int(c_b))
                links.append(KPLink(
                    x0=round(float(peak0[HandPose._X]) * self.image_width),
                    y0=round(float(peak0[HandPose._Y]) * self.image_height),
                    x1=round(float(peak1[HandPose._X]) * self.image_width),
                    y1=round(float(peak1[HandPose._Y]) * self.image_height),
                    color=self._link_colors[clink] if clink in self._link_colors else (255,255,255)
                ))
        return links

    def _draw(self, image, counts, hands, peaks):
        dbg_str = ''
        for i in range (int(counts[0])):
            dbg_str = '\n\n'.join([dbg_str,'hand {}:'.format(i)])
            keypoints = self._get_hand_keypoints(hands, i, peaks)
            kplinks = self._get_keypoint_links(hands, i, peaks)
            if self.draw_skeleton:
                for link in kplinks:
                    cv2.line(image, (link.x0, link.y0), (link.x1, link.y1), link.color, 2)
            for keypoint in keypoints:
                dbg_str = '\n'.join([dbg_str,str(keypoint)])
                if self.draw_labels:
                    cv2.putText(image , '{}'.format(keypoint.name),
                                (keypoint.x + 5 + 2, keypoint.y + 5), 
                                cv2.FONT_HERSHEY_PLAIN, 1, (150, 150, 150), 1)
                cv2.circle(image, (keypoint.x, keypoint.y), 5, (240,240,240), 1)
        return dbg_str
                         
    def run_model(self):
        self._device = torch.device('cuda')
        with _timer:
            image = self.input_frame.copy()
            if self.draw_background:
                output_frame = self.input_frame.copy()
            else:
                output_frame = np.zeros((self.image_height, self.image_width, 3), dtype=np.uint8)

             # first scale input into model dimensions
            image = cv2.resize(image, (self.height, self.width))
            image = PIL.Image.fromarray(image.astype(np.uint8)).convert('RGB')
        
            # transform to torch tensor on CUDA device
            image_tensor = transforms.functional.to_tensor(image).to(self._device)
            image_tensor.sub_(self._mean[:, None, None]).div_(self._std[:, None, None])
            input_data = image_tensor[None, ...]

            # then use the ML model
            cmap, paf = self._model(input_data)
            cmap, paf = cmap.detach().cpu(), paf.detach().cpu()
            h_counts, hands, peaks = self._parse(cmap, paf)

            # finally draw object skeleton over image
            dbg_str = self._draw(output_frame, h_counts, hands, peaks)
            self.output_frame = output_frame
        return 'fps {}{}'.format(_timer.fps, dbg_str)

    @classmethod
    def validate(cls, model, topo, kp,img_path):
        data  = cv2.imread(img_path)
        instance = cls(model, topo, kp, image_height=data.shape[cls._Y], image_width=data.shape[cls._X])
        instance.input_frame = data
        print(instance.run_model())
        display(PIL.Image.fromarray(cv2.cvtColor(instance.output_frame, cv2.COLOR_BGR2RGB)))
        del instance
        


It is good practice to validate model and above HumanPose abstraction before actual real-time usage. Makes changes and improvements a lot of easier.

In [None]:
HandPose.validate(trt_model, topology, hand_pose['keypoints'], 'hand.png')
print(hand_pose['keypoints'])

And only then initialize camera steam so in case of errors it is not needed to clean gstreamer.

In [20]:
from jetutils import GstCamera, bgr8_to_jpeg
camera = GstCamera()

In [21]:
pose_model = HandPose(trt_model, topology, hand_pose['keypoints'],
                       image_height=camera.height, image_width=camera.width)

Then we make our common output display sidecar again

In [22]:
image_original = ipywidgets.Image(format='jpeg', width=camera.width, height=camera.height)
image_processed = ipywidgets.Image(format='jpeg', width=camera.width, height=camera.height)
image_original.value = bgr8_to_jpeg(np.zeros((camera.height, camera.width, 3), dtype=np.uint8))
image_processed.value = bgr8_to_jpeg(np.zeros((camera.height, camera.width, 3), dtype=np.uint8))


debug_out = ipywidgets.Textarea(value='',
                                disabled=True,
                                layout=ipywidgets.Layout(width='640px', height='520px'))
images_out = ipywidgets.HBox([image_original, image_processed])

select_background = ipywidgets.ToggleButton(value=True, description='background')
select_skeleton = ipywidgets.ToggleButton(value=True, description='skeleton')
select_labels = ipywidgets.ToggleButton(value=False, description='labels')
control_box = ipywidgets.HBox([ipywidgets.Label(value='Draw '),
                               select_background,
                               select_skeleton,
                               select_labels])

all_box = ipywidgets.VBox([images_out, control_box, debug_out])
_sidecar = Sidecar(title='output')
with _sidecar:
    display(all_box)

With traitlest it is so convenient to link traits of objects together.

In [23]:
traitlets.dlink((camera, 'value'), (image_original, 'value'), transform=bgr8_to_jpeg)
traitlets.dlink((pose_model,'output_frame'), (image_processed, 'value'), transform=bgr8_to_jpeg)

traitlets.dlink((select_background, 'value'), (pose_model, 'draw_background'))
traitlets.dlink((select_skeleton, 'value'), (pose_model, 'draw_skeleton'))
traitlets.dlink((select_labels, 'value'), (pose_model, 'draw_labels'))

def process(change):
    pose_model.input_frame = change['new']
    debug_out.value = pose_model.run_model()
    
camera.observe(process, names='value')
camera.running = True

In [24]:
camera.running = False
camera.unobserve_all()

In [25]:
del camera
del trt_model