Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

oneflow backend 对接 torch compile ,运行 faster rcnn #10438

Open
crazy-JiangDongHua opened this issue Feb 28, 2024 · 6 comments
Open

oneflow backend 对接 torch compile ,运行 faster rcnn #10438

crazy-JiangDongHua opened this issue Feb 28, 2024 · 6 comments

Comments

@crazy-JiangDongHua
Copy link
Contributor

Description

记录开发调试过程和遇到的问题。

@crazy-JiangDongHua
Copy link
Contributor Author

crazy-JiangDongHua commented Feb 28, 2024

首先遇到第一大问题,oneflow 编译前端传递的计算图失败了

[2024-02-29 03:15:49,828] torch._dynamo.convert_frame: [WARNING] WON'T CONVERT forward /data/home/jiangdonghua/miniconda3/envs/oneflow-dev-gcc9/lib/python3.9/site-packages/torchvision/models/detection/backbone_utils.py line 56 
[2024-02-29 03:15:49,828] torch._dynamo.convert_frame: [WARNING] due to: 
[2024-02-29 03:15:49,828] torch._dynamo.convert_frame: [WARNING] Traceback (most recent call last):
[2024-02-29 03:15:49,828] torch._dynamo.convert_frame: [WARNING]   File "/data/home/jiangdonghua/miniconda3/envs/oneflow-dev-gcc9/lib/python3.9/site-packages/onefx/graph.py", line 1176, in _target_to_str
[2024-02-29 03:15:49,828] torch._dynamo.convert_frame: [WARNING]     assert isinstance(target, str)
[2024-02-29 03:15:49,828] torch._dynamo.convert_frame: [WARNING] torch._dynamo.exc.BackendCompilerFailed: backend='oneflow_backend' raised:
[2024-02-29 03:15:49,828] torch._dynamo.convert_frame: [WARNING] AssertionError: 
[2024-02-29 03:15:49,828] torch._dynamo.convert_frame: [WARNING] 
[2024-02-29 03:15:49,828] torch._dynamo.convert_frame: [WARNING] Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
[2024-02-29 03:15:49,828] torch._dynamo.convert_frame: [WARNING] 
[2024-02-29 03:15:49,828] torch._dynamo.convert_frame: [WARNING]

通过调试发现,原因是 torch.nn.functional.maxpool2d 的实现采用了一个 boolean dispatch,这个函数返回闭包内实现的一个函数 fn,其主要作用为,如果参数 return_indices 值为 True,则调用 max_pool2d_with_indices 函数,否则调用 _max_pool2d 函数。所以 max_pool2d 在 fx 的计算图追踪过程中,记录下来调用的是 torch.nn.fuctional.boolean_dispatch.<locals>.fn 函数。

max_pool2d = boolean_dispatch(
    arg_name="return_indices",
    arg_index=6,
    default=False,
    if_true=max_pool2d_with_indices,
    if_false=_max_pool2d,
    module_name=__name__,
    func_name="max_pool2d",
)

在 torch 模型转 oneflow 模型的过程中,对 torch.nn.fuctional.boolean_dispatch.<local>.fn 函数的调用会替换成对 oneflow.nn.fuctional.boolean_dispatch.<local>.fn 函数的调用,但是 oneflow 不是这么实现的,没有这个函数,应该直接调用 oneflow.nn.fuctional.max_pool2d 函数。知道问题以后,通过在转过过程添加一个特判就可以解决了。

@crazy-JiangDongHua
Copy link
Contributor Author

然后出现了另外一个问题,在编译完成后,模型推理过程中,产生了 conv TypeError 如下:

TypeError: conv2d() received an invalid combination of arguments - got (Tensor, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (Tensor, Parameter, Parameter, tuple of (int, int), tuple of (int, int), tuple of (int, int), int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (Tensor, Parameter, Parameter, tuple of (int, int), tuple of (int, int), tuple of (int, int), int)

经过检查发现这个 conv 是一个普通的 torch.nn.Conv2d, 打印 stride ,padding ,dialtion 都是正常的,并且同样形式的 conv 在前面可以通过检查,

stride: (1, 1), padding: (1, 1), dilation: (1, 1)

这个问题卡了很久,依然没有头绪。。.

@mosout

This comment was marked as off-topic.

@crazy-JiangDongHua
Copy link
Contributor Author

经过调试发现,原来是 oneflow backend 在处理输出时,将 oneflow.Tensor 转为 torch.Tensor 的过程中漏判了输出为 flow._oneflow_internal.TensorTuple 的情况,导致将 oneflow.Tensor 传入 torch 执行导致了错误。通过添加对 flow._oneflow_internal.TensorTuple 的处理解决了这个问题。现在 faster rcnn 在关闭动态 shape 的情况下能跑通了。

@mosout
Copy link
Contributor

mosout commented Feb 29, 2024

这个报出的错误看起来和产生的原因差的有点远,后面在提PR修这个问题的时候,check一下如果不是tensor/tuple/将要增加的dict类型的话就显式抛出一个异常以便于完善

@crazy-JiangDongHua
Copy link
Contributor Author

这个报出的错误看起来和产生的原因差的有点远,后面在提PR修这个问题的时候,check一下如果不是tensor/tuple/将要增加的dict类型的话就显式抛出一个异常以便于完善

好的,我给添加一下。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants