In [1]:
from pcdet.config import cfg, cfg_from_yaml_file
from pcdet.models import build_network
from pcdet.datasets import build_dataloader
from pcdet.utils import common_utils
import os
import numpy as np
import torch
import onnx
import onnxruntime as ort
import torch.nn as nn

from typing import Sequence, NamedTuple


####### load model #######
# cfg_file = "./cfgs/dsvt_models/dsvt_plain_1f_onestage.yaml"
# cfg_from_yaml_file(cfg_file, cfg)
# if os.path.exists('./deploy_files')==False:
#     os.mkdir('./deploy_files')
# log_file = './deploy_files/log_trt.log'

cfg_file = "./pillar/config.yaml"
cfg_from_yaml_file(cfg_file, cfg)
if os.path.exists('./deploy_pillar_sfaw_3d_origin')==False:
    os.mkdir('./deploy_pillar_sfaw_3d_origin')
log_file = './deploy_pillar_sfaw_3d_origin/log_trt.log'

logger = common_utils.create_logger(log_file, rank=0)
test_set, test_loader, sampler = build_dataloader(
    dataset_cfg=cfg.DATA_CONFIG,
    class_names=cfg.CLASS_NAMES,
    batch_size=1,
    dist=False, workers=8, logger=logger, training=False
)

model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=test_set)
ckpt = "./pillar/model.pth"

model.load_params_from_file(filename=ckpt, logger=logger, to_cpu=False, pre_trained_path=None)
model.eval()
model.cuda()
####### load model #######

####### read input #######
batch_dict = torch.load("/mnt/nas2/users/eslim/onnx/sfaw/input_dict.pth", map_location="cuda")
inputs = batch_dict
####### read input #######

####### DSVT #######
class AllDSVTBlocksTRT(nn.Module):
    def __init__(self, dsvtblocks_list, layer_norms_list):
        super().__init__()
        self.layer_norms_list = layer_norms_list
        self.dsvtblocks_list = dsvtblocks_list
    def forward(
        self,
        pillar_features, 
        set_voxel_inds_tensor_shift_0,
        set_voxel_inds_tensor_shift_1,
        set_voxel_masks_tensor_shift_0, 
        set_voxel_masks_tensor_shift_1,
        pos_embed_tensor,
    ):
        outputs = pillar_features

        residual = outputs
        blc_id = 0
        set_id = 0
        set_voxel_inds = set_voxel_inds_tensor_shift_0[set_id:set_id+1].squeeze(0)
        set_voxel_masks = set_voxel_masks_tensor_shift_0[set_id:set_id+1].squeeze(0)
        pos_embed = pos_embed_tensor[blc_id:blc_id+1, set_id:set_id+1].squeeze(0).squeeze(0)
        inputs = (outputs, set_voxel_inds, set_voxel_masks, pos_embed, True)
        outputs = self.dsvtblocks_list[blc_id].encoder_list[set_id](*inputs)
        set_id = 1
        set_voxel_inds = set_voxel_inds_tensor_shift_0[set_id:set_id+1].squeeze(0)
        set_voxel_masks = set_voxel_masks_tensor_shift_0[set_id:set_id+1].squeeze(0)
        pos_embed = pos_embed_tensor[blc_id:blc_id+1, set_id:set_id+1].squeeze(0).squeeze(0)
        inputs = (outputs, set_voxel_inds, set_voxel_masks, pos_embed, True)
        outputs = self.dsvtblocks_list[blc_id].encoder_list[set_id](*inputs)
        
        outputs = self.layer_norms_list[blc_id](residual + outputs)

        residual = outputs
        blc_id = 1
        set_id = 0
        set_voxel_inds = set_voxel_inds_tensor_shift_1[set_id:set_id+1].squeeze(0)
        set_voxel_masks = set_voxel_masks_tensor_shift_1[set_id:set_id+1].squeeze(0)
        pos_embed = pos_embed_tensor[blc_id:blc_id+1, set_id:set_id+1].squeeze(0).squeeze(0)
        inputs = (outputs, set_voxel_inds, set_voxel_masks, pos_embed, True)
        outputs = self.dsvtblocks_list[blc_id].encoder_list[set_id](*inputs)
        set_id = 1
        set_voxel_inds = set_voxel_inds_tensor_shift_1[set_id:set_id+1].squeeze(0)
        set_voxel_masks = set_voxel_masks_tensor_shift_1[set_id:set_id+1].squeeze(0)
        pos_embed = pos_embed_tensor[blc_id:blc_id+1, set_id:set_id+1].squeeze(0).squeeze(0)
        inputs = (outputs, set_voxel_inds, set_voxel_masks, pos_embed, True)
        outputs = self.dsvtblocks_list[blc_id].encoder_list[set_id](*inputs)
        
        outputs = self.layer_norms_list[blc_id](residual + outputs)

        residual = outputs
        blc_id = 2
        set_id = 0
        set_voxel_inds = set_voxel_inds_tensor_shift_0[set_id:set_id+1].squeeze(0)
        set_voxel_masks = set_voxel_masks_tensor_shift_0[set_id:set_id+1].squeeze(0)
        pos_embed = pos_embed_tensor[blc_id:blc_id+1, set_id:set_id+1].squeeze(0).squeeze(0)
        inputs = (outputs, set_voxel_inds, set_voxel_masks, pos_embed, True)
        outputs = self.dsvtblocks_list[blc_id].encoder_list[set_id](*inputs)
        set_id = 1
        set_voxel_inds = set_voxel_inds_tensor_shift_0[set_id:set_id+1].squeeze(0)
        set_voxel_masks = set_voxel_masks_tensor_shift_0[set_id:set_id+1].squeeze(0)
        pos_embed = pos_embed_tensor[blc_id:blc_id+1, set_id:set_id+1].squeeze(0).squeeze(0)
        inputs = (outputs, set_voxel_inds, set_voxel_masks, pos_embed, True)
        outputs = self.dsvtblocks_list[blc_id].encoder_list[set_id](*inputs)
        
        outputs = self.layer_norms_list[blc_id](residual + outputs)

        residual = outputs
        blc_id = 3
        set_id = 0
        set_voxel_inds = set_voxel_inds_tensor_shift_1[set_id:set_id+1].squeeze(0)
        set_voxel_masks = set_voxel_masks_tensor_shift_1[set_id:set_id+1].squeeze(0)
        pos_embed = pos_embed_tensor[blc_id:blc_id+1, set_id:set_id+1].squeeze(0).squeeze(0)
        inputs = (outputs, set_voxel_inds, set_voxel_masks, pos_embed, True)
        outputs = self.dsvtblocks_list[blc_id].encoder_list[set_id](*inputs)
        set_id = 1
        set_voxel_inds = set_voxel_inds_tensor_shift_1[set_id:set_id+1].squeeze(0)
        set_voxel_masks = set_voxel_masks_tensor_shift_1[set_id:set_id+1].squeeze(0)
        pos_embed = pos_embed_tensor[blc_id:blc_id+1, set_id:set_id+1].squeeze(0).squeeze(0)
        inputs = (outputs, set_voxel_inds, set_voxel_masks, pos_embed, True)
        outputs = self.dsvtblocks_list[blc_id].encoder_list[set_id](*inputs)
        
        outputs = self.layer_norms_list[blc_id](residual + outputs)

        return outputs
####### DSVT #######

####### torch to onnx #######
with torch.no_grad():
    DSVT_Backbone = model.backbone_3d
    dsvtblocks_list = DSVT_Backbone.stage_0
    layer_norms_list = DSVT_Backbone.residual_norm_stage_0
    inputs = model.vfe(inputs)
    voxel_info = DSVT_Backbone.input_layer(inputs)
    set_voxel_inds_list = [[voxel_info[f'set_voxel_inds_stage{s}_shift{i}'] for i in range(2)] for s in range(1)]
    set_voxel_masks_list = [[voxel_info[f'set_voxel_mask_stage{s}_shift{i}'] for i in range(2)] for s in range(1)]
    pos_embed_list = [[[voxel_info[f'pos_embed_stage{s}_block{b}_shift{i}'] for i in range(2)] for b in range(4)] for s in range(1)]

    pillar_features = inputs['voxel_features']
    alldsvtblockstrt_inputs = (
        pillar_features,
        set_voxel_inds_list[0][0],
        set_voxel_inds_list[0][1],
        set_voxel_masks_list[0][0],
        set_voxel_masks_list[0][1],
        torch.stack([torch.stack(v, dim=0) for v in pos_embed_list[0]], dim=0),
    )

    jit_mode = "trace"
    input_names = [
        'src',
        'set_voxel_inds_tensor_shift_0', 
        'set_voxel_inds_tensor_shift_1', 
        'set_voxel_masks_tensor_shift_0', 
        'set_voxel_masks_tensor_shift_1',
        'pos_embed_tensor'
    ]
    output_names = ["output",]
    input_shapes = {
        "src": {
            "min_shape": [24629, 192],
            "opt_shape": [24629, 192],
            "max_shape": [24629, 192],
        },
        "set_voxel_inds_tensor_shift_0": {
            "min_shape": [2, 1156, 36],
            "opt_shape": [2, 1156, 36],
            "max_shape": [2, 1156, 36],
        },
        "set_voxel_inds_tensor_shift_1": {
            "min_shape": [2, 834, 36],
            "opt_shape": [2, 834, 36],
            "max_shape": [2, 834, 36],
        },
        "set_voxel_masks_tensor_shift_0": {
            "min_shape": [2, 1156, 36],
            "opt_shape": [2, 1156, 36],
            "max_shape": [2, 1156, 36],
        },
        "set_voxel_masks_tensor_shift_1": {
            "min_shape": [2, 834, 36],
            "opt_shape": [2, 834, 36],
            "max_shape": [2, 834, 36],
        },
        "pos_embed_tensor": {
            "min_shape": [4, 2, 24629, 192],
            "opt_shape": [4, 2, 24629, 192],
            "max_shape": [4, 2, 24629, 192],
        },
    }


    dynamic_axes = {
        "src": {
            0: "voxel_number",
        },
        "set_voxel_inds_tensor_shift_0": {
            1: "set_number_shift_0",
        },
        "set_voxel_inds_tensor_shift_1": {
            1: "set_number_shift_1",
        },
        "set_voxel_masks_tensor_shift_0": {
            1: "set_number_shift_0",
        },
        "set_voxel_masks_tensor_shift_1": {
            1: "set_number_shift_1",
        },
        "pos_embed_tensor": {
            2: "voxel_number",
        },
        "output": {
            0: "voxel_number",
        }
    }

    base_name = "./deploy_pillar_sfaw_3d_origin/deploy_pillar_sfaw_3d_origin"
    ts_path = f"{base_name}.ts"
    onnx_path = f"{base_name}.onnx"

    allptransblocktrt = AllDSVTBlocksTRT(dsvtblocks_list, layer_norms_list).eval().cuda()
    torch.onnx.export(
        allptransblocktrt,
        alldsvtblockstrt_inputs,
        onnx_path, input_names=input_names,
        output_names=output_names, dynamic_axes=dynamic_axes,
        opset_version=14,
    )
    # test onnx
    ort_session = ort.InferenceSession(onnx_path)
    def to_numpy(tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
    
    # compute ONNX Runtime output prediction
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(pillar_features),
                  ort_session.get_inputs()[1].name: to_numpy(set_voxel_inds_list[0][0]),
                  ort_session.get_inputs()[2].name: to_numpy(set_voxel_inds_list[0][1]),
                  ort_session.get_inputs()[3].name: to_numpy(set_voxel_masks_list[0][0]),
                  ort_session.get_inputs()[4].name: to_numpy(set_voxel_masks_list[0][1]),
                  ort_session.get_inputs()[5].name: to_numpy(torch.stack([torch.stack(v, dim=0) for v in pos_embed_list[0]], dim=0)),}
    import time
    from onnxconverter_common import float16
    
    start_time = time.time()
    ort_outs = ort_session.run(None, ort_inputs) 
    print(time.time()-start_time)
    base_model = onnx.load(onnx_path)
    model_fp16 = float16.convert_float_to_float16(base_model)
    onnx.save(model_fp16, "./deploy_pillar_sfaw_bm/dsvt_3d_onnx_16.onnx")
    onnx_path = "./deploy_pillar_sfaw_bm/dsvt_3d_onnx_16.onnx"
    print(ort.get_device())
    ort_session = ort.InferenceSession(onnx_path)
    
    
    # # compute ONNX Runtime output prediction
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(pillar_features).astype(np.float16),
                  ort_session.get_inputs()[1].name: to_numpy(set_voxel_inds_list[0][0]).astype(np.int64),
                  ort_session.get_inputs()[2].name: to_numpy(set_voxel_inds_list[0][1]).astype(np.int64),
                  ort_session.get_inputs()[3].name: to_numpy(set_voxel_masks_list[0][0]),
                  ort_session.get_inputs()[4].name: to_numpy(set_voxel_masks_list[0][1]),
                  ort_session.get_inputs()[5].name: to_numpy(torch.stack([torch.stack(v, dim=0) for v in pos_embed_list[0]], dim=0)).astype(np.float16),}
    
    start_time = time.time()
    ort_outs = ort_session.run(None, ort_inputs) 
    print(time.time()-start_time)
    # model_fp16 = float16.convert_float_to_float16(ort_session)
####### torch to onnx #######
ort_outs

####### torch to trt engine #######
# trtexec --onnx={path to onnx} --saveEngine={path to save trtengine} \
# --memPoolSize=workspace:4096 --verbose --buildOnly --device=1 --fp16 \
# --tacticSources=+CUDNN,+CUBLAS,-CUBLAS_LT,+EDGE_MASK_CONVOLUTIONS \
# --minShapes=src:3000x192,set_voxel_inds_tensor_shift_0:2x170x36,set_voxel_inds_tensor_shift_1:2x100x36,set_voxel_masks_tensor_shift_0:2x170x36,set_voxel_masks_tensor_shift_1:2x100x36,pos_embed_tensor:4x2x3000x192 \
# --optShapes=src:20000x192,set_voxel_inds_tensor_shift_0:2x1000x36,set_voxel_inds_tensor_shift_1:2x700x36,set_voxel_masks_tensor_shift_0:2x1000x36,set_voxel_masks_tensor_shift_1:2x700x36,pos_embed_tensor:4x2x20000x192 \
# --maxShapes=src:35000x192,set_voxel_inds_tensor_shift_0:2x1500x36,set_voxel_inds_tensor_shift_1:2x1200x36,set_voxel_masks_tensor_shift_0:2x1500x36,set_voxel_masks_tensor_shift_1:2x1200x36,pos_embed_tensor:4x2x35000x192 \
####### torch to trt engine #######

fatal: detected dubious ownership in repository at '/mnt/nas2/users/eslim/workspace/OpenPCDet'
To add an exception for this directory, call:

	git config --global --add safe.directory /mnt/nas2/users/eslim/workspace/OpenPCDet
2024-04-16 13:20:19,989   INFO  Loading Custom dataset.
2024-04-16 13:20:19,991   INFO  Total samples for CUSTOM dataset: 0
2024-04-16 13:20:22,342   INFO  ==> Loading parameters from checkpoint ./pillar/model.pth to GPU
2024-04-16 13:20:22,493   INFO  ==> Checkpoint trained from version: pcdet+0.6.0+255db8f
2024-04-16 13:20:22,726   INFO  ==> Done (loaded 391/391)


verbose: False, log level: Level.ERROR

0.05576634407043457




CPU
0.23626136779785156


[array([[-1.098e+00, -8.893e-02,  1.570e-01, ..., -1.048e+00,  3.809e-02,
         -1.986e-01],
        [-1.097e+00, -7.098e-02,  1.605e-01, ..., -1.083e+00,  4.245e-02,
         -1.924e-01],
        [-1.087e+00, -1.011e-01,  1.786e-01, ..., -1.056e+00,  2.637e-02,
         -1.595e-01],
        ...,
        [-3.608e-01,  2.130e-01, -1.586e-01, ..., -1.233e+00, -1.638e-01,
          1.677e-01],
        [-6.187e-01,  3.523e-01, -1.115e-01, ..., -1.372e+00, -1.095e-01,
         -5.441e-02],
        [-6.133e-01,  3.215e-01, -1.253e-02, ..., -1.245e+00,  6.313e-04,
         -1.378e-01]], dtype=float16)]

In [7]:
# to_numpy(torch.stack([torch.stack(v, dim=0) for v in pos_embed_list[0]], dim=0))

ModuleNotFoundError: No module named 'onnxruntime'

In [4]:
pillar_features.size()

torch.Size([962, 192])

In [6]:
set_voxel_inds_list[0][0].size()

torch.Size([2, 40, 36])

In [7]:
set_voxel_inds_list[0][1].size()

torch.Size([2, 34, 36])

In [10]:
print("ONNX version:", onnx.__version__)


ONNX version: 1.16.0


In [11]:
print("PyTorch version:", torch.__version__)


PyTorch version: 2.0.1+cu118
