Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert the model to onnx and torchscript #77

Open
sevocrear opened this issue Oct 6, 2021 · 21 comments
Open

convert the model to onnx and torchscript #77

sevocrear opened this issue Oct 6, 2021 · 21 comments

Comments

@sevocrear
Copy link

sevocrear commented Oct 6, 2021

Hello, I have some difficulties converting your model both to onnx and to torchscript. I've read closed issues already but there isn't any help. Could you help me? Or, may be, someone succeeded to do that. Below I'll show code and errors I get after I try to convert the model.

2nd_model.pt is the full model saved in pytorch. Works fine after loading in python.

Converting the model to onnx

model = torch.load("2nd_model.pt").to(device)
model.eval()
torch.onnx.export(
    model,
    sample,
    "model.onnx",
    opset_version=11,
    export_params=True)

Error:

/content/Regression-model/lib/models/laneatt.py:230: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  n_proposals = len(self.anchors)
/content/Regression-model/lib/models/laneatt.py:235: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  for batch_idx, img_features in enumerate(features):
/content/Regression-model/lib/models/laneatt.py:82: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  attention = softmax(scores).reshape(x.shape[0], len(self.anchors), -1)
/content/Regression-model/lib/models/laneatt.py:83: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  attention_matrix = torch.eye(attention.shape[1], device=x.device).repeat(x.shape[0], 1, 1)
/content/Regression-model/lib/models/laneatt.py:87: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  batch_anchor_features = batch_anchor_features.reshape(x.shape[0], len(self.anchors), -1)
/content/Regression-model/lib/models/laneatt.py:116: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  for proposals, attention_matrix in zip(batch_proposals, batch_attention_matrix):
/content/Regression-model/lib/models/laneatt.py:127: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if proposals.shape[0] == 0:
/usr/local/lib/python3.7/dist-packages/torch/onnx/symbolic_opset9.py:2766: UserWarning: Exporting aten::index operator of advanced indexing in opset 11 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.
  "If indices include negative values, the exported graph will produce incorrect results.")
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-18-d8ff0aee41a0> in <module>()
      6     "model.onnx",
      7     opset_version=11,
----> 8     export_params=True)

8 frames
/usr/local/lib/python3.7/dist-packages/torch/onnx/utils.py in _graph_op(g, opname, *raw_args, **kwargs)
    888     if _onnx_shape_inference:
    889         from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
--> 890         torch._C._jit_pass_onnx_node_shape_type_inference(n, _params_dict, opset_version)
    891 
    892     if outputs == 1:

RuntimeError: input_shape_value == reshape_value || input_shape_value == 1 || reshape_value == 1INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/jit/passes/onnx/shape_type_inference.cpp":520, please report a bug to PyTorch. ONNX Expand input shape constraint not satisfied.

Error 2 (different):

/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [110,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [111,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [112,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [113,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [114,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [115,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [116,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [117,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [118,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [119,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [120,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [121,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [122,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [123,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [124,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [125,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [126,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:84: operator(): block: [1,0,0], thread: [127,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/home/sevocrear/.local/lib/python3.6/site-packages/torch/onnx/symbolic_opset9.py:2225: UserWarning: Exporting aten::index operator of advanced indexing in opset 11 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.
  "If indices include negative values, the exported graph will produce incorrect results.")
Traceback (most recent call last):
  File "converting_model.py", line 29, in <module>
    export_params=True)
  File "/home/sevocrear/.local/lib/python3.6/site-packages/torch/onnx/__init__.py", line 208, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/home/sevocrear/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 92, in export
    use_external_data_format=use_external_data_format)
  File "/home/sevocrear/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 530, in _export
    fixed_batch_size=fixed_batch_size)
  File "/home/sevocrear/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 384, in _model_to_graph
    fixed_batch_size=fixed_batch_size, params_dict=params_dict)
  File "/home/sevocrear/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 188, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/home/sevocrear/.local/lib/python3.6/site-packages/torch/onnx/__init__.py", line 241, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/home/sevocrear/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 787, in _run_symbolic_function
    symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type)
  File "/home/sevocrear/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 745, in _find_symbolic_in_registry
    return sym_registry.get_registered_op(op_name, domain, opset_version)
  File "/home/sevocrear/.local/lib/python3.6/site-packages/torch/onnx/symbolic_registry.py", line 109, in get_registered_op
    raise RuntimeError(msg)
RuntimeError: Exporting the operator eye to ONNX opset version 11 is not supported. Please open a bug to request ONNX export support for the missing operator.

Scripting the model

model = torch.load("2nd_model.pt").to(device)
model.eval()
traced_script_module = torch.jit.script(model)

Error:

---------------------------------------------------------------------------
FrontendError                             Traceback (most recent call last)
<ipython-input-17-59a3d0112d6d> in <module>()
      1 model = torch.load("2nd_model.pt").to(device)
      2 model.eval()
----> 3 traced_script_module = torch.jit.script(model)

4 frames
/usr/local/lib/python3.7/dist-packages/torch/jit/annotations.py in check_fn(fn, loc)
    132     if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
    133         raise torch.jit.frontend.FrontendError(
--> 134             loc, f"Cannot instantiate class '{py_ast.body[0].name}' in a script function")
    135     if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
    136         raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function")

FrontendError: Cannot instantiate class 'Softmax' in a script function:
  File "/content/Regression-model/lib/models/laneatt.py", line 80
    
        # Add attention features
        softmax = nn.Softmax(dim=1)
                  ~~~~~~~~~~ <--- HERE
        scores = self.attention_layer(batch_anchor_features)
        attention = softmax(scores).reshape(x.shape[0], len(self.anchors), -1)
  • Here is the colab link also of the above code link
@sevocrear
Copy link
Author

Good afternoon, Could anyone help me? I would really like to know how I could optimize that model. I'd like to use it on Jetson AGX so optimization needed for higher inference time.

@Lllllp93
Copy link

Maybe you should convert the NMS function to python code first. I can convert the model to onnx after rewriting that function.

My problem is I get this error when I use Onnxruntime to inference onnx model:
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node (Reshape_914) Op (Reshape) [ShapeInferenceError] Invalid position of 0

Also I'm new to ONNX, so I have no idea about how to fix this.==

@sevocrear
Copy link
Author

sevocrear commented Oct 22, 2021

@Lllllp93, could you, please, share the code how you converted nms function and how did you convert the model to onnx?

I've also found some website about converting complex models. Might be useful: https://www.fatalerrors.org/a/pytoch-to-tensorrt-transformation-of-complex-models.html

@Lllllp93
Copy link

@Lllllp93, could you, please, share the code how you converted nms function and how did you convert the model to onnx?

I've also found some website about converting complex models. Might be useful: https://www.fatalerrors.org/a/pytoch-to-tensorrt-transformation-of-complex-models.html

That function basically refers to @lucastabelini CUDA source code, I checked the output and seems like the accuracy is aligned.

`

def Lane_nms(self,proposals,scores,overlap=50, top_k=4):
    '''
    re-write lane-nms with python
    reference: https://arxiv.org/pdf/2010.12035.pdf
    '''
    keep_index = []
    sorted_score, indices = torch.sort(scores, descending=True) # from big to small 
    r_filters = np.zeros(len(scores))

    for i,indice in enumerate(indices):
        if r_filters[i]==1: # continue if this proposal is filted by nms before
            continue
        keep_index.append(indice)
        if len(keep_index)>top_k: # break if more than top_k
            break
        if i == (len(scores)-1):# break if indice is the last one
            break
        sub_indices = indices[i+1:]
        for sub_i,sub_indice in enumerate(sub_indices):
            r_filter = self.Lane_IOU(proposals[indice,:],proposals[sub_indice,:],overlap)
            if r_filter: r_filters[i+1+sub_i]=1 
    num_to_keep = len(keep_index)

def Lane_IOU(self,parent_box, compared_box, threshold):
    '''
    calculate distance one pair of proposal lines
    return True if distance less than threshold 
    '''
    start_a = (parent_box[2] * self.n_strips + 0.5).int() # add 0.5 trick to make int() like round  
    start_b = (compared_box[2] * self.n_strips + 0.5).int()
    start = torch.max(start_a,start_b)
    end_a = start_a + parent_box[4] - 1 + 0.5 - (((parent_box[4] - 1)<0).int())
    end_b = start_b + compared_box[4] - 1 + 0.5 - (((compared_box[4] - 1)<0).int())
    end = torch.min(torch.min(end_a,end_b),torch.tensor(self.n_offsets-1))
    if (end - start)<0:
        return False
    dist = 0
    for i in range(5+start,77):
        if i>(5+end):
            break
        if parent_box[i] < compared_box[i]:
            dist += compared_box[i] - parent_box[i]
        else:
            dist += parent_box[i] - compared_box[i]
    return dist < (threshold * (end - start + 1))`

Note: this implementation is slower than origin implementation.

@sevocrear
Copy link
Author

sevocrear commented Oct 25, 2021

@Lllllp93 , thanks for the code! But I noticed that one of your function didn't have return so I added it.
Finally, I could convert the model to onnx. And it runs with onnxruntime. I use:

  • python3.8
  • torch1.8.2
  • cuda10.2

Here is the related part of laneatt.py file code

<\>
    def forward(self, x, conf_threshold=0.2, nms_thres=50.0, nms_topk=5):
        print(nms_thres, conf_threshold)
        batch_features = self.feature_extractor(x)

        

        batch_features = self.conv1(batch_features)
        batch_anchor_features = self.cut_anchor_features(batch_features)
        # Join proposals from all images into a single proposals features batch
        batch_anchor_features = batch_anchor_features.view(-1, self.anchor_feat_channels * self.fmap_h)


        # Add attention features
        scores = self.attention_layer(batch_anchor_features)
        attention = self.softmax(scores).reshape(x.shape[0], len(self.anchors), -1)
        # attention_matrix = torch.diag(torch.ones(attention.shape[1], device=x.device)).repeat(x.shape[0], 1, 1)
        attention_matrix = torch.eye(attention.shape[1], device=x.device).repeat(x.shape[0], 1, 1)
        non_diag_inds = torch.nonzero(attention_matrix == 0., as_tuple=False)
        attention_matrix[:] = 0
        attention_matrix[non_diag_inds[:, 0], non_diag_inds[:, 1], non_diag_inds[:, 2]] = attention.flatten()
        batch_anchor_features = batch_anchor_features.reshape(x.shape[0], len(self.anchors), -1)
        attention_features = torch.bmm(torch.transpose(batch_anchor_features, 1, 2),
                                       torch.transpose(attention_matrix, 1, 2)).transpose(1, 2)
        attention_features = attention_features.reshape(-1, self.anchor_feat_channels * self.fmap_h)
        batch_anchor_features = batch_anchor_features.reshape(-1, self.anchor_feat_channels * self.fmap_h)
        batch_anchor_features = torch.cat((attention_features, batch_anchor_features), dim=1)



        # Predict
        cls_logits = self.cls_layer(batch_anchor_features)
        reg = self.reg_layer(batch_anchor_features)

        # Undo joining
        cls_logits = cls_logits.reshape(x.shape[0], -1, cls_logits.shape[1])
        reg = reg.reshape(x.shape[0], -1, reg.shape[1])

        # Add offsets to anchors
        reg_proposals = torch.zeros((*cls_logits.shape[:2], 5 + self.n_offsets), device=x.device)
        reg_proposals += self.anchors
        reg_proposals[:, :, :2] = cls_logits
        reg_proposals[:, :, 4:] += reg



        # Apply nms 
        proposals_list = self.nms(reg_proposals, attention_matrix, nms_thres, nms_topk, conf_threshold)
        return proposals_list

    def nms(self, batch_proposals, batch_attention_matrix, nms_thres, nms_topk, conf_threshold):
        softmax = nn.Softmax(dim=1)
        proposals_list = []
        for proposals, attention_matrix in zip(batch_proposals, batch_attention_matrix):
            anchor_inds = torch.arange(batch_proposals.shape[1], device=proposals.device)
            # The gradients do not have to (and can't) be calculated for the NMS procedure
            with torch.no_grad():
                scores = softmax(proposals[:, :2])[:, 1]
                if conf_threshold is not None:
                    # apply confidence threshold
                    above_threshold = scores > conf_threshold
                    proposals = proposals[above_threshold]
                    scores = scores[above_threshold]
                    anchor_inds = anchor_inds[above_threshold]
                if proposals.shape[0] == 0:
                    proposals_list.append((proposals[[]], self.anchors[[]], attention_matrix[[]], None))
                    continue
                # keep, num_to_keep, _ = nms(proposals, scores, overlap=nms_thres, top_k=nms_topk)
                keep, num_to_keep, _ = self.Lane_nms(proposals, scores, overlap=nms_thres, top_k=nms_topk)
                keep = keep[:num_to_keep]
            proposals = proposals[keep]
            anchor_inds = anchor_inds[keep]
            attention_matrix = attention_matrix[anchor_inds]
            proposals_list.append((proposals, self.anchors[keep], attention_matrix, anchor_inds))

        return proposals_list


    def Lane_nms(self,proposals,scores,overlap=50, top_k=4):
        keep_index = []
        print(scores)
        sorted_score, indices = torch.sort(scores, descending=True) # from big to small 
        r_filters = np.zeros(len(scores))

        for i,indice in enumerate(indices):
            if r_filters[i]==1: # continue if this proposal is filted by nms before
                continue
            keep_index.append(indice)
            if len(keep_index)>top_k: # break if more than top_k
                break
            if i == (len(scores)-1):# break if indice is the last one
                break
            sub_indices = indices[i+1:]
            for sub_i,sub_indice in enumerate(sub_indices):
                r_filter = self.Lane_IOU(proposals[indice,:],proposals[sub_indice,:],overlap)
                if r_filter: r_filters[i+1+sub_i]=1 
        num_to_keep = len(keep_index)
        keep_index = list(map(lambda x: x.item(), keep_index))
        return keep_index, num_to_keep, None
        
    def Lane_IOU(self,parent_box, compared_box, threshold):
        '''
        calculate distance one pair of proposal lines
        return True if distance less than threshold 
        '''
        start_a = (parent_box[2] * self.n_strips + 0.5).int() # add 0.5 trick to make int() like round  
        start_b = (compared_box[2] * self.n_strips + 0.5).int()
        start = torch.max(start_a,start_b)
        end_a = start_a + parent_box[4] - 1 + 0.5 - (((parent_box[4] - 1)<0).int())
        end_b = start_b + compared_box[4] - 1 + 0.5 - (((compared_box[4] - 1)<0).int())
        end = torch.min(torch.min(end_a,end_b),torch.tensor(self.n_offsets-1))
        # end = torch.min(torch.min(end_a,end_b),torch.FloatTensor(self.n_offsets-1, device = torch.device('cpu')))
        if (end - start)<0:
            return False
        dist = 0
        for i in range(5+start,5 + end.int()):
            # if i>(5+end):
            #     break
            if parent_box[i] < compared_box[i]:
                dist += compared_box[i] - parent_box[i]
            else:
                dist += parent_box[i] - compared_box[i]
        return dist < (threshold * (end - start + 1))

    def loss(self, proposals_list, targets, cls_loss_weight=10):
      <\>

Here is the part of code how I export it and run with onnx

torch.onnx.export(model,               # model being run
                images,                         # model input (or a tuple for multiple inputs)
                "laneatt.onnx",   # where to save the model (can be a file or file-like object)
                export_params=True,        # store the trained parameter weights inside the model file
                opset_version=11)


onnx_model = onnx.load("laneatt.onnx")
onnx.checker.check_model(onnx_model)

ort_session = onnxruntime.InferenceSession("laneatt.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(images)}
ort_outs = ort_session.run(None, ort_inputs)
# compare ONNX Runtime and PyTorch results
onnx_output = [tuple(map(lambda x: torch.tensor(x), ort_outs))]
prediction = model.decode(onnx_output, as_lanes=True)

But the results are different from the original nms implementation

This is the visualization of the model output with original NMS: img
This is the visualization of the model output with custom NMS given by @Lllllp93 : img
conf_thrresh and nms_thresh are the same.

So, now I'm investigating the code why results are so different (but actually 3 lines are almost totally the same). Would be great if you point on differences between your nms implementation and @lucastabelini one. Also, it's 20 times slower than original one (just for notion). Tested on my pc with NVIDIA RTX 2080Ti 11Gb, Intel® Core™ i9-9820X CPU @ 3.30GHz × 20, 62,5 GiB RAM

@Lllllp93
Copy link

@sevocrear, The return should be:
return torch.tensor(keep_index), num_to_keep

So for the difference between these two implementations, I guess maybe you visualized all lanes after nms. Notice the keep index and num_to_keep in Lane_nms() can be used to limit number of lanes, normally I just keep top 4 lanes in my task.

As I mentioned before, this NMS implementation didn't take advantage of GPU, which means the time complexity is linearly related to the num of lanes. So may be you want to change to CUDA NMS for training. But for inference by CPU or ARM, I guess there is no big difference, it depends on which hardware platform the NN is deployed.

In addition,I just found that the post-processing(decode) was included in the onnx model. And I can run it with onnxruntime after remove post-processing part.

The comparison of two NMS implementation is showed below, FYI.

Origin NMS:
torch_00540

My inplementation:
onnx_00540

@sevocrear
Copy link
Author

sevocrear commented Oct 26, 2021

In addition,I just found that the post-processing(decode) was included in the onnx model. And I can run it with onnxruntime after remove post-processing part.

@Lllllp93 How did you do that? Running decode with onnxruntime

@Lllllp93
Copy link

In addition,I just found that the post-processing(decode) was included in the onnx model. And I can run it with onnxruntime after remove post-processing part.

@Lllllp93 How did you do that? Running decode with onnxruntime

@sevocrear I mean the reason that I cannot run with onnnxruntime before is because decode was also included in onnx model. I can run it with onnxruntime after I remove decode from onnx. Sorry for the misunderstanding.

@tersekmatija
Copy link

Hey @sevocrear, how exactly have you managed to export the model to ONNX? I've tried removing NMS and exporting the model without it (I'd do it after computing the output), but I keep getting the error:

RuntimeError: input_shape_value == reshape_value || input_shape_value == 1 || reshape_value == 1INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/jit/passes/onnx/shape_type_inference.cpp":520, please report a bug to PyTorch. ONNX Expand input shape constraint not satisfied.

@sevocrear
Copy link
Author

Hey @sevocrear, how exactly have you managed to export the model to ONNX? I've tried removing NMS and exporting the model without it (I'd do it after computing the output), but I keep getting the error:

RuntimeError: input_shape_value == reshape_value || input_shape_value == 1 || reshape_value == 1INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/jit/passes/onnx/shape_type_inference.cpp":520, please report a bug to PyTorch. ONNX Expand input shape constraint not satisfied.

Hi @tersekmatija , have you removed NMS totally? In my implementation, I just changed the CUDA NMS to python NMS function (which, of course, is much slower). And then I could convert it to ONNX. But the inference speed was awful so I use python original code with some changes regarding my project. Me and @Lllllp93 described the changed parts and requirements above

@hwang12345
Copy link

Hey @sevocrear, how exactly have you managed to export the model to ONNX? I've tried removing NMS and exporting the model without it (I'd do it after computing the output), but I keep getting the error:

RuntimeError: input_shape_value == reshape_value || input_shape_value == 1 || reshape_value == 1INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/jit/passes/onnx/shape_type_inference.cpp":520, please report a bug to PyTorch. ONNX Expand input shape constraint not satisfied.

have you sloved this problem?

@hwang12345
Copy link

hwang12345 commented Feb 9, 2022

@Lllllp93 , thanks for the code! But I noticed that one of your function didn't have return so I added it. Finally, I could convert the model to onnx. And it runs with onnxruntime. I use:

  • python3.8
  • torch1.8.2
  • cuda10.2

Here is the related part of laneatt.py file code

<\>
    def forward(self, x, conf_threshold=0.2, nms_thres=50.0, nms_topk=5):
        print(nms_thres, conf_threshold)
        batch_features = self.feature_extractor(x)

        

        batch_features = self.conv1(batch_features)
        batch_anchor_features = self.cut_anchor_features(batch_features)
        # Join proposals from all images into a single proposals features batch
        batch_anchor_features = batch_anchor_features.view(-1, self.anchor_feat_channels * self.fmap_h)


        # Add attention features
        scores = self.attention_layer(batch_anchor_features)
        attention = self.softmax(scores).reshape(x.shape[0], len(self.anchors), -1)
        # attention_matrix = torch.diag(torch.ones(attention.shape[1], device=x.device)).repeat(x.shape[0], 1, 1)
        attention_matrix = torch.eye(attention.shape[1], device=x.device).repeat(x.shape[0], 1, 1)
        non_diag_inds = torch.nonzero(attention_matrix == 0., as_tuple=False)
        attention_matrix[:] = 0
        attention_matrix[non_diag_inds[:, 0], non_diag_inds[:, 1], non_diag_inds[:, 2]] = attention.flatten()
        batch_anchor_features = batch_anchor_features.reshape(x.shape[0], len(self.anchors), -1)
        attention_features = torch.bmm(torch.transpose(batch_anchor_features, 1, 2),
                                       torch.transpose(attention_matrix, 1, 2)).transpose(1, 2)
        attention_features = attention_features.reshape(-1, self.anchor_feat_channels * self.fmap_h)
        batch_anchor_features = batch_anchor_features.reshape(-1, self.anchor_feat_channels * self.fmap_h)
        batch_anchor_features = torch.cat((attention_features, batch_anchor_features), dim=1)



        # Predict
        cls_logits = self.cls_layer(batch_anchor_features)
        reg = self.reg_layer(batch_anchor_features)

        # Undo joining
        cls_logits = cls_logits.reshape(x.shape[0], -1, cls_logits.shape[1])
        reg = reg.reshape(x.shape[0], -1, reg.shape[1])

        # Add offsets to anchors
        reg_proposals = torch.zeros((*cls_logits.shape[:2], 5 + self.n_offsets), device=x.device)
        reg_proposals += self.anchors
        reg_proposals[:, :, :2] = cls_logits
        reg_proposals[:, :, 4:] += reg



        # Apply nms 
        proposals_list = self.nms(reg_proposals, attention_matrix, nms_thres, nms_topk, conf_threshold)
        return proposals_list

    def nms(self, batch_proposals, batch_attention_matrix, nms_thres, nms_topk, conf_threshold):
        softmax = nn.Softmax(dim=1)
        proposals_list = []
        for proposals, attention_matrix in zip(batch_proposals, batch_attention_matrix):
            anchor_inds = torch.arange(batch_proposals.shape[1], device=proposals.device)
            # The gradients do not have to (and can't) be calculated for the NMS procedure
            with torch.no_grad():
                scores = softmax(proposals[:, :2])[:, 1]
                if conf_threshold is not None:
                    # apply confidence threshold
                    above_threshold = scores > conf_threshold
                    proposals = proposals[above_threshold]
                    scores = scores[above_threshold]
                    anchor_inds = anchor_inds[above_threshold]
                if proposals.shape[0] == 0:
                    proposals_list.append((proposals[[]], self.anchors[[]], attention_matrix[[]], None))
                    continue
                # keep, num_to_keep, _ = nms(proposals, scores, overlap=nms_thres, top_k=nms_topk)
                keep, num_to_keep, _ = self.Lane_nms(proposals, scores, overlap=nms_thres, top_k=nms_topk)
                keep = keep[:num_to_keep]
            proposals = proposals[keep]
            anchor_inds = anchor_inds[keep]
            attention_matrix = attention_matrix[anchor_inds]
            proposals_list.append((proposals, self.anchors[keep], attention_matrix, anchor_inds))

        return proposals_list


    def Lane_nms(self,proposals,scores,overlap=50, top_k=4):
        keep_index = []
        print(scores)
        sorted_score, indices = torch.sort(scores, descending=True) # from big to small 
        r_filters = np.zeros(len(scores))

        for i,indice in enumerate(indices):
            if r_filters[i]==1: # continue if this proposal is filted by nms before
                continue
            keep_index.append(indice)
            if len(keep_index)>top_k: # break if more than top_k
                break
            if i == (len(scores)-1):# break if indice is the last one
                break
            sub_indices = indices[i+1:]
            for sub_i,sub_indice in enumerate(sub_indices):
                r_filter = self.Lane_IOU(proposals[indice,:],proposals[sub_indice,:],overlap)
                if r_filter: r_filters[i+1+sub_i]=1 
        num_to_keep = len(keep_index)
        keep_index = list(map(lambda x: x.item(), keep_index))
        return keep_index, num_to_keep, None
        
    def Lane_IOU(self,parent_box, compared_box, threshold):
        '''
        calculate distance one pair of proposal lines
        return True if distance less than threshold 
        '''
        start_a = (parent_box[2] * self.n_strips + 0.5).int() # add 0.5 trick to make int() like round  
        start_b = (compared_box[2] * self.n_strips + 0.5).int()
        start = torch.max(start_a,start_b)
        end_a = start_a + parent_box[4] - 1 + 0.5 - (((parent_box[4] - 1)<0).int())
        end_b = start_b + compared_box[4] - 1 + 0.5 - (((compared_box[4] - 1)<0).int())
        end = torch.min(torch.min(end_a,end_b),torch.tensor(self.n_offsets-1))
        # end = torch.min(torch.min(end_a,end_b),torch.FloatTensor(self.n_offsets-1, device = torch.device('cpu')))
        if (end - start)<0:
            return False
        dist = 0
        for i in range(5+start,5 + end.int()):
            # if i>(5+end):
            #     break
            if parent_box[i] < compared_box[i]:
                dist += compared_box[i] - parent_box[i]
            else:
                dist += parent_box[i] - compared_box[i]
        return dist < (threshold * (end - start + 1))

    def loss(self, proposals_list, targets, cls_loss_weight=10):
      <\>

Here is the part of code how I export it and run with onnx

torch.onnx.export(model,               # model being run
                images,                         # model input (or a tuple for multiple inputs)
                "laneatt.onnx",   # where to save the model (can be a file or file-like object)
                export_params=True,        # store the trained parameter weights inside the model file
                opset_version=11)


onnx_model = onnx.load("laneatt.onnx")
onnx.checker.check_model(onnx_model)

ort_session = onnxruntime.InferenceSession("laneatt.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(images)}
ort_outs = ort_session.run(None, ort_inputs)
# compare ONNX Runtime and PyTorch results
onnx_output = [tuple(map(lambda x: torch.tensor(x), ort_outs))]
prediction = model.decode(onnx_output, as_lanes=True)

But the results are different from the original nms implementation

This is the visualization of the model output with original NMS: img This is the visualization of the model output with custom NMS given by @Lllllp93 : img conf_thrresh and nms_thresh are the same.

So, now I'm investigating the code why results are so different (but actually 3 lines are almost totally the same). Would be great if you point on differences between your nms implementation and @lucastabelini one. Also, it's 20 times slower than original one (just for notion). Tested on my pc with NVIDIA RTX 2080Ti 11Gb, Intel® Core™ i9-9820X CPU @ 3.30GHz × 20, 62,5 GiB RAM

hey, @Lllllp93 @sevocrear ,i used you code try to convert ,but i got some error,and i can not find the reason,please help!

Traceback (most recent call last):                                              
  File "main.py", line 63, in <module>
    main()
  File "main.py", line 52, in main
    runner.train()
  File "/home/hw/project/LaneDetection/lane-att/lib/runner.py", line 59, in train
    outputs = model(images, **self.cfg.get_train_parameters())
  File "/home/hw/anaconda3/envs/laneatt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/hw/project/LaneDetection/lane-att/lib/models/laneatt.py", line 126, in forward
    proposals_list = self.nms(
  File "/home/hw/project/LaneDetection/lane-att/lib/models/laneatt.py", line 150, in nms
    keep, num_to_keep, _ = self.Lane_nms(
  File "/home/hw/project/LaneDetection/lane-att/lib/models/laneatt.py", line 180, in Lane_nms
    r_filter = self.Lane_IOU(proposals[indice, :], proposals[sub_indice, :], overlap)
  File "/home/hw/project/LaneDetection/lane-att/lib/models/laneatt.py", line 195, in Lane_IOU
    end = torch.min(torch.min(end_a, end_b), torch.tensor(self.n_offsets - 1))
RuntimeError: Expected object of scalar type float but got scalar type long int for argument 'other'

@hwang12345
Copy link

hey @lucastabelini sorry to bother you, i have some difficulties in covert the model to onnx,so i have an idea that remove the nms program in the project,because in my project just need detect one lane . do you think this idea is ok ?

@tersekmatija
Copy link

Hi @tersekmatija , have you removed NMS totally? In my implementation, I just changed the CUDA NMS to python NMS function (which, of course, is much slower). And then I could convert it to ONNX. But the inference speed was awful so I use python original code with some changes regarding my project. Me and @Lllllp93 described the changed parts and requirements above

Yes @sevocrear . I just tried removing the NMS in the forwards pass, this is the snapshot:

        # Add offsets to anchors
        reg_proposals = torch.zeros((*cls_logits.shape[:2], 5 + self.n_offsets), device=x.device)
        reg_proposals += self.anchors
        reg_proposals[:, :, :2] = cls_logits
        reg_proposals[:, :, 4:] += reg
        # nms removed here
        return reg_proposals

But I am still getting the error. Do you know which version of Python and Torch did you use if it might depend on this?

@lucastabelini
Copy link
Owner

hey @lucastabelini sorry to bother you, i have some difficulties in covert the model to onnx,so i have an idea that remove the nms program in the project,because in my project just need detect one lane . do you think this idea is ok ?

If you only need to detect one lane boundary (i.e., one line), then you can do that. Just get the one with the highest confidence score.

@hwang12345
Copy link

hey @lucastabelini sorry to bother you, i have some difficulties in covert the model to onnx,so i have an idea that remove the nms program in the project,because in my project just need detect one lane . do you think this idea is ok ?

If you only need to detect one lane boundary (i.e., one line), then you can do that. Just get the one with the highest confidence score.

@lucastabelini ok,i will try in my project,thanks!

@hwang12345
Copy link

Yes @sevocrear . I just tried removing the NMS in the forwards pass, this is the snapshot:

        # Add offsets to anchors
        reg_proposals = torch.zeros((*cls_logits.shape[:2], 5 + self.n_offsets), device=x.device)
        reg_proposals += self.anchors
        reg_proposals[:, :, :2] = cls_logits
        reg_proposals[:, :, 4:] += reg
        # nms removed here
        return reg_proposals

But I am still getting the error. Do you know which version of Python and Torch did you use if it might depend on this?
@tersekmatija if you try to convert model to onnx,as far as i know,just remove nms is not enough.the 'torch.eye' not support yet.

@tersekmatija
Copy link

@hwang12345 It seems that @sevocrear uses it in his forward method and he managed to export it.

@hwang12345
Copy link

@tersekmatija i had try @sevocrear push code ,but i did not succeed.

@Lllllp93
Copy link

@hwang12345 It seems like you use torch.min() to compared different data type tensors as the error message shows, self.n_offsets in your code may be a long tensor. You can just convert it to Float or just use torch.FloatTensor(self.n_offsets-1).

@zpge
Copy link

zpge commented May 15, 2024

@Lllllp93 , thanks for the code! But I noticed that one of your function didn't have return so I added it. Finally, I could convert the model to onnx. And it runs with onnxruntime. I use:

  • python3.8
  • torch1.8.2
  • cuda10.2

Here is the related part of laneatt.py file code

<\>
    def forward(self, x, conf_threshold=0.2, nms_thres=50.0, nms_topk=5):
        print(nms_thres, conf_threshold)
        batch_features = self.feature_extractor(x)

        

        batch_features = self.conv1(batch_features)
        batch_anchor_features = self.cut_anchor_features(batch_features)
        # Join proposals from all images into a single proposals features batch
        batch_anchor_features = batch_anchor_features.view(-1, self.anchor_feat_channels * self.fmap_h)


        # Add attention features
        scores = self.attention_layer(batch_anchor_features)
        attention = self.softmax(scores).reshape(x.shape[0], len(self.anchors), -1)
        # attention_matrix = torch.diag(torch.ones(attention.shape[1], device=x.device)).repeat(x.shape[0], 1, 1)
        attention_matrix = torch.eye(attention.shape[1], device=x.device).repeat(x.shape[0], 1, 1)
        non_diag_inds = torch.nonzero(attention_matrix == 0., as_tuple=False)
        attention_matrix[:] = 0
        attention_matrix[non_diag_inds[:, 0], non_diag_inds[:, 1], non_diag_inds[:, 2]] = attention.flatten()
        batch_anchor_features = batch_anchor_features.reshape(x.shape[0], len(self.anchors), -1)
        attention_features = torch.bmm(torch.transpose(batch_anchor_features, 1, 2),
                                       torch.transpose(attention_matrix, 1, 2)).transpose(1, 2)
        attention_features = attention_features.reshape(-1, self.anchor_feat_channels * self.fmap_h)
        batch_anchor_features = batch_anchor_features.reshape(-1, self.anchor_feat_channels * self.fmap_h)
        batch_anchor_features = torch.cat((attention_features, batch_anchor_features), dim=1)



        # Predict
        cls_logits = self.cls_layer(batch_anchor_features)
        reg = self.reg_layer(batch_anchor_features)

        # Undo joining
        cls_logits = cls_logits.reshape(x.shape[0], -1, cls_logits.shape[1])
        reg = reg.reshape(x.shape[0], -1, reg.shape[1])

        # Add offsets to anchors
        reg_proposals = torch.zeros((*cls_logits.shape[:2], 5 + self.n_offsets), device=x.device)
        reg_proposals += self.anchors
        reg_proposals[:, :, :2] = cls_logits
        reg_proposals[:, :, 4:] += reg



        # Apply nms 
        proposals_list = self.nms(reg_proposals, attention_matrix, nms_thres, nms_topk, conf_threshold)
        return proposals_list

    def nms(self, batch_proposals, batch_attention_matrix, nms_thres, nms_topk, conf_threshold):
        softmax = nn.Softmax(dim=1)
        proposals_list = []
        for proposals, attention_matrix in zip(batch_proposals, batch_attention_matrix):
            anchor_inds = torch.arange(batch_proposals.shape[1], device=proposals.device)
            # The gradients do not have to (and can't) be calculated for the NMS procedure
            with torch.no_grad():
                scores = softmax(proposals[:, :2])[:, 1]
                if conf_threshold is not None:
                    # apply confidence threshold
                    above_threshold = scores > conf_threshold
                    proposals = proposals[above_threshold]
                    scores = scores[above_threshold]
                    anchor_inds = anchor_inds[above_threshold]
                if proposals.shape[0] == 0:
                    proposals_list.append((proposals[[]], self.anchors[[]], attention_matrix[[]], None))
                    continue
                # keep, num_to_keep, _ = nms(proposals, scores, overlap=nms_thres, top_k=nms_topk)
                keep, num_to_keep, _ = self.Lane_nms(proposals, scores, overlap=nms_thres, top_k=nms_topk)
                keep = keep[:num_to_keep]
            proposals = proposals[keep]
            anchor_inds = anchor_inds[keep]
            attention_matrix = attention_matrix[anchor_inds]
            proposals_list.append((proposals, self.anchors[keep], attention_matrix, anchor_inds))

        return proposals_list


    def Lane_nms(self,proposals,scores,overlap=50, top_k=4):
        keep_index = []
        print(scores)
        sorted_score, indices = torch.sort(scores, descending=True) # from big to small 
        r_filters = np.zeros(len(scores))

        for i,indice in enumerate(indices):
            if r_filters[i]==1: # continue if this proposal is filted by nms before
                continue
            keep_index.append(indice)
            if len(keep_index)>top_k: # break if more than top_k
                break
            if i == (len(scores)-1):# break if indice is the last one
                break
            sub_indices = indices[i+1:]
            for sub_i,sub_indice in enumerate(sub_indices):
                r_filter = self.Lane_IOU(proposals[indice,:],proposals[sub_indice,:],overlap)
                if r_filter: r_filters[i+1+sub_i]=1 
        num_to_keep = len(keep_index)
        keep_index = list(map(lambda x: x.item(), keep_index))
        return keep_index, num_to_keep, None
        
    def Lane_IOU(self,parent_box, compared_box, threshold):
        '''
        calculate distance one pair of proposal lines
        return True if distance less than threshold 
        '''
        start_a = (parent_box[2] * self.n_strips + 0.5).int() # add 0.5 trick to make int() like round  
        start_b = (compared_box[2] * self.n_strips + 0.5).int()
        start = torch.max(start_a,start_b)
        end_a = start_a + parent_box[4] - 1 + 0.5 - (((parent_box[4] - 1)<0).int())
        end_b = start_b + compared_box[4] - 1 + 0.5 - (((compared_box[4] - 1)<0).int())
        end = torch.min(torch.min(end_a,end_b),torch.tensor(self.n_offsets-1))
        # end = torch.min(torch.min(end_a,end_b),torch.FloatTensor(self.n_offsets-1, device = torch.device('cpu')))
        if (end - start)<0:
            return False
        dist = 0
        for i in range(5+start,5 + end.int()):
            # if i>(5+end):
            #     break
            if parent_box[i] < compared_box[i]:
                dist += compared_box[i] - parent_box[i]
            else:
                dist += parent_box[i] - compared_box[i]
        return dist < (threshold * (end - start + 1))

    def loss(self, proposals_list, targets, cls_loss_weight=10):
      <\>

Here is the part of code how I export it and run with onnx

torch.onnx.export(model,               # model being run
                images,                         # model input (or a tuple for multiple inputs)
                "laneatt.onnx",   # where to save the model (can be a file or file-like object)
                export_params=True,        # store the trained parameter weights inside the model file
                opset_version=11)


onnx_model = onnx.load("laneatt.onnx")
onnx.checker.check_model(onnx_model)

ort_session = onnxruntime.InferenceSession("laneatt.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(images)}
ort_outs = ort_session.run(None, ort_inputs)
# compare ONNX Runtime and PyTorch results
onnx_output = [tuple(map(lambda x: torch.tensor(x), ort_outs))]
prediction = model.decode(onnx_output, as_lanes=True)

But the results are different from the original nms implementation

This is the visualization of the model output with original NMS: img This is the visualization of the model output with custom NMS given by @Lllllp93 : img conf_thrresh and nms_thresh are the same.

So, now I'm investigating the code why results are so different (but actually 3 lines are almost totally the same). Would be great if you point on differences between your nms implementation and @lucastabelini one. Also, it's 20 times slower than original one (just for notion). Tested on my pc with NVIDIA RTX 2080Ti 11Gb, Intel® Core™ i9-9820X CPU @ 3.30GHz × 20, 62,5 GiB RAM

Hello. Can you provide the complete code?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants