In [None]:
import torch
import torchvision
#import torch2trt

In [None]:
import torch
from torch2trt import torch2trt
from torchvision.models.alexnet import alexnet

# create some regular pytorch model...
model = alexnet(pretrained=True).eval().cuda()

# create example data
x = torch.ones((1, 3, 224, 224)).cuda()

# convert to TensorRT feeding sample data as input
model_trt = torch2trt(model, [x])

In [None]:
model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)

In [None]:
model = model.cuda().eval().half()

In [None]:
class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super(ModelWrapper, self).__init__()
        self.model = model
    def forward(self, x):
        return self.model(x)['out']

In [None]:
model_w = ModelWrapper(model).half()

In [None]:
data = torch.ones((1, 3, 224, 224)).cuda().half()

In [None]:
import torch2trt

model_trt = torch2trt.torch2trt(model_w, [data], fp16_mode=True)

In [None]:
torch.save(model_trt.state_dict(), 'segment_trt.pth')

# Live demo

In [10]:
from torch2trt import TRTModule
import torch

model_trt = TRTModule()

model_trt.load_state_dict(torch.load('segment2_trt.pth'))

<All keys matched successfully>

In [11]:

from jetbot import Camera
from IPython.display import display
import ipywidgets.widgets as widgets
from jetbot import bgr8_to_jpeg

camera = Camera.instance(width=224, height=224, usb=0)


In [12]:
from jetbot import bgr8_to_jpeg
import traitlets
import ipywidgets

#image_w = ipywidgets.Image()
image_w = widgets.Image(format='jpeg', width=224, height=224)

traitlets.dlink((camera, 'value'), (image_w, 'value'), transform=bgr8_to_jpeg)


<traitlets.traitlets.directional_link at 0x7fa6dd1198>

In [13]:
import cv2, PIL
import numpy as np
import torchvision
import torchvision.transforms as T

device = torch.device('cuda')
mean = 255.0 * np.array([0.485, 0.456, 0.406])
stdev = 255.0 * np.array([0.229, 0.224, 0.225])

normalize = torchvision.transforms.Normalize(mean, stdev)

def preprocess(camera_value):
    global device, normalize
    x = camera_value
    x = cv2.resize(x, (224, 224))
#    x = cv2.center_crop(x, (224, 224))
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = x.transpose((2, 0, 1))
    x = torch.from_numpy(x).float()
    x = normalize(x)
    x = x.to(device)
    x = x[None, ...]
    return x


In [14]:
import numpy as np
def decode_segmap(image, nc=21):
  
  label_colors = np.array([(0, 0, 0),  # 0=background
               # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
               (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
               # 6=bus, 7=car, 8=cat, 9=chair, 10=cow
               (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
               # 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
               (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
               # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
               (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])

  r = np.zeros_like(image).astype(np.uint8)
  g = np.zeros_like(image).astype(np.uint8)
  b = np.zeros_like(image).astype(np.uint8)
  
  for l in range(0, nc):
    idx = image == l
    r[idx] = label_colors[l, 0]
    g[idx] = label_colors[l, 1]
    b[idx] = label_colors[l, 2]
    
  rgb = np.stack([r, g, b], axis=2)
  return rgb

In [15]:
seg_image = ipywidgets.Image(format='jpeg', width=224, height=224)
display(widgets.VBox([
    widgets.HBox([image_w, seg_image]),
]))


VBox(children=(HBox(children=(Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x…

In [16]:
def execute(change):
    image = change['new']
    output = model_trt(preprocess(camera.value).half())[0].detach().cpu().float().numpy()
    mask=decode_segmap(output.argmax(0))
    #mask = 1.0 * (output.argmax(0) == 15)
    #m = mask[:, :, None]
    #print(m.shape)
    #print(image.shape)
 #   seg_image.value = bgr8_to_jpeg(mask[:, :, None] * image)
    seg_image.value = bgr8_to_jpeg(mask)
    
mask = execute({'new': camera.value})
camera.observe(execute, names='value')

In [8]:
camera.observe(execute, names='value')

In [17]:
camera.stop()
camera.unobserve(execute, names='value')


In [None]:
import time

torch.cuda.current_stream().synchronize()
t0 = time.time()
for i in range(100):
    output = model_w(preprocess(camera.value).half())
torch.cuda.current_stream().synchronize()
t1 = time.time()

print(100.0 / (t1 - t0))