In [1]:
'''
M2u-net implementation with tensorRT
'''

'\nM2u-net implementation with tensorRT\n'

In [11]:
import torchvision.models as models
import torch
import torch.onnx
import os
import bioimage

from interactive_m2unet import M2UnetInteractiveModel
import numpy as np
import imageio
import albumentations as A
from skimage.filters import threshold_otsu
from skimage.measure import label
# Uncomment to specify the gpu number
# os.environ['CUDA_VISIBLE_DEVICES'] = "1"
import torch
torch.backends.cudnn.benchmark = True
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

import torch
import torchvision
from glob import glob
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as transform
from torch.utils.data import DataLoader,Dataset
from torchvision.utils import make_grid
import math


# load the pretrained model
resnet50 = models.resnet50(pretrained=True, progress=False).eval()


  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "


In [12]:
'''
Modified m2unet for quantization
'''

def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU(inplace=True)
    )


def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU(inplace=True)
    )

class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        self.add = nn.quantized.FloatFunctional()

        
        assert stride in [1, 2]
        
        # self.quant = torch.quantization.QuantStub()
        # self.dequant = torch.quantization.DeQuantStub()


        hidden_dim = round(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            # depthwise separable convolution block
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            # Bottleneck with expansion layer
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU(inplace=True),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            # return x + self.conv(x)
            return self.add.add(x, self.conv(x))
        else:
            return self.conv(x)
        
class Encoder(nn.Module):
    """
    14 layers of MobileNetv2 as encoder part
    """
    def __init__(self):
        super(Encoder, self).__init__()
        block = InvertedResidual
        interverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
        ]
        # Encoder Part
        input_channel = 32 # number of input channels to first inverted (residual) block
        self.layers = [conv_bn(3, 32, 2)]
        # building inverted residual blocks
        for t, c, n, s in interverted_residual_setting:
            output_channel = c
            for i in range(n):
                if i == 0:
                    self.layers.append(block(input_channel, output_channel, s, expand_ratio=t))
                else:
                    self.layers.append(block(input_channel, output_channel, 1, expand_ratio=t))
                input_channel = output_channel
        # make it nn.Sequential
        self.layers = nn.Sequential(*self.layers)
                
class DecoderBlock(nn.Module):
    """
    Decoder block: upsample and concatenate with features maps from the encoder part
    """
    def __init__(self,up_in_c,x_in_c,upsamplemode='bilinear',expand_ratio=0.15):
        super().__init__()
        self.cat = nn.quantized.FloatFunctional()

        self.upsample = nn.Upsample(scale_factor=2,mode=upsamplemode,align_corners=False) # H, W -> 2H, 2W
        self.ir1 = InvertedResidual(up_in_c+x_in_c,(x_in_c + up_in_c) // 2,stride=1,expand_ratio=expand_ratio)

    def forward(self,up_in,x_in):
        up_out = self.upsample(up_in)
        # cat_x = torch.cat([up_out, x_in] , dim=1)
        cat_x = self.cat.cat([up_out, x_in] , dim=1)
        x = self.ir1(cat_x)
        return x
    
class LastDecoderBlock(nn.Module):
    def __init__(self,x_in_c,upsamplemode='bilinear',expand_ratio=0.15, output_channels=1, activation='linear'):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2,mode=upsamplemode,align_corners=False) # H, W -> 2H, 2W
        self.ir1 = InvertedResidual(x_in_c,16,stride=1,expand_ratio=expand_ratio)
        layers =  [
            nn.Conv2d(16, output_channels, 1, 1, 0, bias=True),
        ]
        self.cat = nn.quantized.FloatFunctional()

        if activation == 'sigmoid':
            layers.append(nn.Sigmoid())
        elif activation == 'softmax':
            layers.append(nn.Softmax(dim=1))
        elif activation == 'linear' or activation is None:
            pass
        else:
            raise NotImplementedError('Activation {} not implemented'.format(activation))
        self.conv = nn.Sequential(
           *layers
        )

    def forward(self,up_in,x_in):
        up_out = self.upsample(up_in)
        # cat_x = torch.cat([up_out, x_in] , dim=1)
        cat_x = self.cat.cat([up_out, x_in] , dim=1)
        x = self.ir1(cat_x)
        x = self.conv(x)
        return x
    
class M2UNet(nn.Module):
        def __init__(self,encoder,upsamplemode='bilinear',output_channels=1, activation="linear", expand_ratio=0.15):
            super(M2UNet,self).__init__()
            encoder = list(encoder.children())[0]
            # Encoder
            self.conv1 = encoder[0:2]
            self.conv2 = encoder[2:4]
            self.conv3 = encoder[4:7]
            self.conv4 = encoder[7:14]
            # Decoder
            self.decode4 = DecoderBlock(96,32,upsamplemode,expand_ratio)
            self.decode3 = DecoderBlock(64,24,upsamplemode,expand_ratio)
            self.decode2 = DecoderBlock(44,16,upsamplemode,expand_ratio)
            self.decode1 = LastDecoderBlock(33,upsamplemode,expand_ratio, output_channels=output_channels, activation=activation)
            # initilaize weights 
            self._init_params()

        def _init_params(self):
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(0, math.sqrt(2. / n))
                    if m.bias is not None:
                        m.bias.data.zero_()
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()
                elif isinstance(m, nn.Linear):
                    n = m.weight.size(1)
                    m.weight.data.normal_(0, 0.01)
                    m.bias.data.zero_()
            
        
        
        def forward(self,x):
            conv1 = self.conv1(x)
            conv2 = self.conv2(conv1)
            conv3 = self.conv3(conv2)
            conv4 = self.conv4(conv3)
            decode4 = self.decode4(conv4,conv3)
            decode3 = self.decode3(decode4,conv2)
            decode2 = self.decode2(decode3,conv1)
            decode1 = self.decode1(decode2,x)
            return decode1
        
class M2UNet_q2(nn.Module):
        def __init__(self,encoder,upsamplemode='bilinear',output_channels=1, activation="linear", expand_ratio=0.15):
            super(M2UNet_q2,self).__init__()
            encoder = list(encoder.children())[0]
            # Encoder
            self.quant = torch.quantization.QuantStub()

            self.conv1 = encoder[0:2]
            self.conv2 = encoder[2:4]
            self.conv3 = encoder[4:7]
            self.conv4 = encoder[7:14]
            # Decoder
            self.decode4 = DecoderBlock(96,32,upsamplemode,expand_ratio)
            self.decode3 = DecoderBlock(64,24,upsamplemode,expand_ratio)
            self.decode2 = DecoderBlock(44,16,upsamplemode,expand_ratio)
            self.decode1 = LastDecoderBlock(33,upsamplemode,expand_ratio, output_channels=output_channels, activation=activation)
            
            self.dequant = torch.quantization.DeQuantStub()

            # initilaize weights 
            self._init_params()

        def _init_params(self):
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(0, math.sqrt(2. / n))
                    if m.bias is not None:
                        m.bias.data.zero_()
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()
                elif isinstance(m, nn.Linear):
                    n = m.weight.size(1)
                    m.weight.data.normal_(0, 0.01)
                    m.bias.data.zero_()
            
        
        
        def forward(self,x):
            x = self.quant(x)

            conv1 = self.conv1(x)
            conv2 = self.conv2(conv1)
            conv3 = self.conv3(conv2)
            conv4 = self.conv4(conv3)
            decode4 = self.decode4(conv4,conv3)
            decode3 = self.decode3(decode4,conv2)
            decode2 = self.decode2(decode3,conv1)
            decode1 = self.decode1(decode2,x)
            
            res = self.dequant(decode1)

            
            return res
        
def m2unet(output_channels=1,expand_ratio=0.15, activation="linear", **kwargs):
    encoder = Encoder()
    model = M2UNet(encoder,upsamplemode='bilinear',expand_ratio=expand_ratio, output_channels=output_channels, activation=activation)
    return model

def m2unet_q2(output_channels=1,expand_ratio=0.15, activation="linear", **kwargs):
    encoder = Encoder()
    model = M2UNet_q2(encoder,upsamplemode='bilinear',expand_ratio=expand_ratio, output_channels=output_channels, activation=activation)
    return model



In [39]:
# Load model 
model = m2unet_q2()
model_name = 'm2unet_50epochs_v2_pre_quant'
PATH = './models/' + model_name

model.load_state_dict(torch.load(PATH))
model.eval()


M2UNet_q2(
  (quant): QuantStub()
  (conv1): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): InvertedResidual(
      (add): FloatFunctional(
        (activation_post_process): Identity()
      )
      (conv): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (conv2): Sequential(
    (2): InvertedResidual(
      (add): FloatFunctional(
        (activation_post_process): Identity()
      )
      (conv): Sequentia

In [40]:
BATCH_SIZE=32

dummy_input=torch.randn(BATCH_SIZE, 3, 512, 512)

In [12]:
torch.onnx.export(model, dummy_input, "m2unet_pre_quant_pytorch.onnx", verbose=False)




In [13]:
torch.onnx.export(resnet50, dummy_input, "resnet50_pytorch.onnx", verbose=False)



In [53]:
# Verify we can run the model with dummy input
BATCH_SIZE = 1
dummy_input=torch.randn(BATCH_SIZE, 3, 1024, 1024)

model = m2unet_q2()
model_cuda = model.to("cuda").eval()


In [54]:
# Write dummy input to torch cuda
test_input = dummy_input.to("cuda")

In [55]:
with torch.no_grad():
    predictions = np.array(model_cuda(test_input).cpu())

predictions.shape

(1, 1, 1024, 1024)

In [56]:
%%timeit

with torch.no_grad():
    predictions = np.array(model_cuda(test_input).cpu())

12.1 ms ± 277 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [57]:


model_cuda_half = model_cuda.half()
input_half = test_input.half()

with torch.no_grad():
    preds = np.array(model_cuda_half(input_half).cpu()) # Warm Up

    

In [58]:
%%timeit

with torch.no_grad():
    preds = np.array(model_cuda_half(input_half).cpu())



10.6 ms ± 69 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
import os

os._exit(0) # Shut down all kernels so TRT doesn't fight with PyTorch for GPU memory

In [3]:

BATCH_SIZE = 32

import numpy as np

USE_FP16 = True
target_dtype = np.float16 if USE_FP16 else np.float32

In [4]:
from skimage import io
from skimage.transform import resize
from matplotlib import pyplot as plt
import numpy as np

url='https://images.dog.ceo/breeds/retriever-golden/n02099601_3004.jpg'
img = resize(io.imread(url), (224, 224))
input_batch = np.array(np.repeat(np.expand_dims(np.array(img, dtype=np.float32), axis=0), BATCH_SIZE, axis=0), dtype=np.float32)

input_batch.shape

(32, 224, 224, 3)

In [5]:
import torch
from torchvision.transforms import Normalize

def preprocess_image(img):
    norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    result = norm(torch.from_numpy(img).transpose(0,2).transpose(1,2))
    return np.array(result, dtype=np.float16)

preprocessed_images = np.array([preprocess_image(image) for image in input_batch])

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
import tensorrt

In [8]:
# # step out of Python for a moment to convert the ONNX model to a TRT engine using trtexec
# if USE_FP16:
#     !trtexec --onnx=resnet50_pytorch.onnx --saveEngine=resnet_engine_pytorch.trt  --explicitBatch --inputIOFormats=fp16:chw --outputIOFormats=fp16:chw --fp16
# else:
#     !trtexec --onnx=resnet50_pytorch.onnx --saveEngine=resnet_engine_pytorch.trt  --explicitBatch


# step out of Python for a moment to convert the ONNX model to a TRT engine using trtexec
if USE_FP16:
    !trtexec --onnx=resnet50_pytorch.onnx --saveEngine=resnet_engine_pytorch.trt  --explicitBatch --inputIOFormats=fp16:chw --outputIOFormats=fp16:chw --fp16
else:
    !trtexec --onnx=resnet50_pytorch.onnx --saveEngine=resnet_engine_pytorch.trt  --explicitBatch

&&&& RUNNING TensorRT.trtexec [TensorRT v8003] # trtexec --onnx=resnet50_pytorch.onnx --saveEngine=resnet_engine_pytorch.trt --explicitBatch --inputIOFormats=fp16:chw --outputIOFormats=fp16:chw --fp16
[12/15/2022-10:05:02] [I] === Model Options ===
[12/15/2022-10:05:02] [I] Format: ONNX
[12/15/2022-10:05:02] [I] Model: resnet50_pytorch.onnx
[12/15/2022-10:05:02] [I] Output:
[12/15/2022-10:05:02] [I] === Build Options ===
[12/15/2022-10:05:02] [I] Max batch: explicit
[12/15/2022-10:05:02] [I] Workspace: 16 MiB
[12/15/2022-10:05:02] [I] minTiming: 1
[12/15/2022-10:05:02] [I] avgTiming: 8
[12/15/2022-10:05:02] [I] Precision: FP32+FP16
[12/15/2022-10:05:02] [I] Calibration: 
[12/15/2022-10:05:02] [I] Refit: Disabled
[12/15/2022-10:05:02] [I] Sparsity: Disabled
[12/15/2022-10:05:02] [I] Safe mode: Disabled
[12/15/2022-10:05:02] [I] Restricted mode: Disabled
[12/15/2022-10:05:02] [I] Save engine: resnet_engine_pytorch.trt
[12/15/2022-10:05:02] [I] Load engine: 
[12/15/2022-10:05:02] [I] NVTX

In [13]:
# Load model and convert to onnx

# Load model 
model = m2unet_q2()
model_name = 'm2unet_pre_quant_12_22'
PATH = './models/' + model_name

model.load_state_dict(torch.load(PATH))
model.eval()

M2UNet_q2(
  (quant): QuantStub()
  (conv1): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): InvertedResidual(
      (add): FloatFunctional(
        (activation_post_process): Identity()
      )
      (conv): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (conv2): Sequential(
    (2): InvertedResidual(
      (add): FloatFunctional(
        (activation_post_process): Identity()
      )
      (conv): Sequentia

In [14]:
BATCH_SIZE=32

dummy_input=torch.randn(BATCH_SIZE, 3, 224, 224)

In [21]:
# export the model to ONNX
torch.onnx.export(model, dummy_input, "m2unet.onnx", verbose=True, opset_version=12)

Exported graph: graph(%input.1 : Float(32, 3, 224, 224, strides=[150528, 50176, 224, 1], requires_grad=0, device=cpu),
      %decode1.conv.0.weight : Float(1, 16, 1, 1, strides=[16, 1, 1, 1], requires_grad=1, device=cpu),
      %decode1.conv.0.bias : Float(1, strides=[1], requires_grad=1, device=cpu),
      %onnx::Conv_479 : Float(32, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu),
      %onnx::Conv_480 : Float(32, strides=[1], requires_grad=0, device=cpu),
      %onnx::Conv_482 : Float(32, 1, 3, 3, strides=[9, 9, 3, 1], requires_grad=0, device=cpu),
      %onnx::Conv_483 : Float(32, strides=[1], requires_grad=0, device=cpu),
      %onnx::Conv_485 : Float(16, 32, 1, 1, strides=[32, 1, 1, 1], requires_grad=0, device=cpu),
      %onnx::Conv_486 : Float(16, strides=[1], requires_grad=0, device=cpu),
      %onnx::Conv_488 : Float(96, 16, 1, 1, strides=[16, 1, 1, 1], requires_grad=0, device=cpu),
      %onnx::Conv_489 : Float(96, strides=[1], requires_grad=0, device=cpu),
    

In [22]:
!trtexec --onnx=m2unet.onnx --saveEngine=m2unet_engine_pytorch.trt  --explicitBatch

&&&& RUNNING TensorRT.trtexec [TensorRT v8003] # trtexec --onnx=m2unet.onnx --saveEngine=m2unet_engine_pytorch.trt --explicitBatch
[12/15/2022-11:30:39] [I] === Model Options ===
[12/15/2022-11:30:39] [I] Format: ONNX
[12/15/2022-11:30:39] [I] Model: m2unet.onnx
[12/15/2022-11:30:39] [I] Output:
[12/15/2022-11:30:39] [I] === Build Options ===
[12/15/2022-11:30:39] [I] Max batch: explicit
[12/15/2022-11:30:39] [I] Workspace: 16 MiB
[12/15/2022-11:30:39] [I] minTiming: 1
[12/15/2022-11:30:39] [I] avgTiming: 8
[12/15/2022-11:30:39] [I] Precision: FP32
[12/15/2022-11:30:39] [I] Calibration: 
[12/15/2022-11:30:39] [I] Refit: Disabled
[12/15/2022-11:30:39] [I] Sparsity: Disabled
[12/15/2022-11:30:39] [I] Safe mode: Disabled
[12/15/2022-11:30:39] [I] Restricted mode: Disabled
[12/15/2022-11:30:39] [I] Save engine: m2unet_engine_pytorch.trt
[12/15/2022-11:30:39] [I] Load engine: 
[12/15/2022-11:30:39] [I] NVTX verbosity: 0
[12/15/2022-11:30:39] [I] Tactic sources: Using default tactic sources
