Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(InvalidArgument) Sum of Attr(num_or_sections) must be equal to the input's size along the split dimension. #12

Closed
guoqsGary opened this issue Jul 20, 2023 · 8 comments

Comments

@guoqsGary
Copy link

Traceback (most recent call last):
File "tools/train.py", line 183, in
main()
File "tools/train.py", line 179, in main
run(FLAGS, cfg)
File "tools/train.py", line 135, in run
trainer.train(FLAGS.eval)
File "/home/gy/workspace/work/RT-DETR/rtdetr_paddle/ppdet/engine/trainer.py", line 377, in train
outputs = model(data)
File "/home/gy/miniconda3/envs/detr-like/lib/python3.8/site-packages/paddle/fluid/dygraph/layers.py", line 1012, in call
return self.forward(*inputs, **kwargs)
File "/home/gy/workspace/work/RT-DETR/rtdetr_paddle/ppdet/modeling/architectures/meta_arch.py", line 60, in forward
out = self.get_loss()
File "/home/gy/workspace/work/RT-DETR/rtdetr_paddle/ppdet/modeling/architectures/detr.py", line 113, in get_loss
return self._forward()
File "/home/gy/workspace/work/RT-DETR/rtdetr_paddle/ppdet/modeling/architectures/detr.py", line 87, in _forward
out_transformer = self.transformer(body_feats, pad_mask, self.inputs)
File "/home/gy/miniconda3/envs/detr-like/lib/python3.8/site-packages/paddle/fluid/dygraph/layers.py", line 1012, in call
return self.forward(*inputs, **kwargs)
File "/home/gy/workspace/work/RT-DETR/rtdetr_paddle/ppdet/modeling/transformers/rtdetr_transformer.py", line 419, in forward
get_contrastive_denoising_training_group(gt_meta,
File "/home/gy/workspace/work/RT-DETR/rtdetr_paddle/ppdet/modeling/transformers/utils.py", line 296, in get_contrastive_denoising_training_group
dn_positive_idx = paddle.split(dn_positive_idx,
File "/home/gy/miniconda3/envs/detr-like/lib/python3.8/site-packages/paddle/tensor/manipulation.py", line 1982, in split
return _C_ops.split(input, num_or_sections, dim)
ValueError: (InvalidArgument) Sum of Attr(num_or_sections) must be equal to the input's size along the split dimension. But received Attr(num_or_sections) = [84], input(X)'s shape = [2166784], Attr(dim) = 0.
[Hint: Expected sum_of_section == input_axis_dim, but received sum_of_section:84 != input_axis_dim:2166784.] (at /paddle/paddle/phi/infermeta/unary.cc:3285)

print(dn_positive_idx.shape) 的输出为 [2166784],
print([n * num_group for n in num_gts]) 的输出为 [84]

@lyuwenyu
Copy link
Owner

dn_positive_idx有点奇怪,他是由pad_gt_mask得到的,在18 , 192, 194行后面都print一下shape看下

@guoqsGary
Copy link
Author

guoqsGary commented Jul 21, 2023

dn_positive_idx有点奇怪,他是由pad_gt_mask得到的,在18 , 192, 194行后面都print一下shape看下

# pad gt to max_num of a batch
bs = len(targets["gt_class"])
input_query_class = paddle.full(
    [bs, max_gt_num], num_classes, dtype='int32')
input_query_bbox = paddle.zeros([bs, max_gt_num, 4])
pad_gt_mask = paddle.zeros([bs, max_gt_num])
print("280的pad_gt_mask", pad_gt_mask.shape)

for i in range(bs):
    num_gt = num_gts[i]
    if num_gt > 0:
        input_query_class[i, :num_gt] = targets["gt_class"][i].squeeze(-1)
        input_query_bbox[i, :num_gt] = targets["gt_bbox"][i]
        pad_gt_mask[i, :num_gt] = 1
# each group has positive and negative queries.
print("289的pad_gt_mask", pad_gt_mask.shape)

input_query_class = input_query_class.tile([1, 2 * num_group])
input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1])
pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group])
print("293的pad_gt_mask", pad_gt_mask.shape)

# positive and negative mask
negative_gt_mask = paddle.zeros([bs, max_gt_num * 2, 1])
negative_gt_mask[:, max_gt_num:] = 1
negative_gt_mask = negative_gt_mask.tile([1, num_group, 1])
positive_gt_mask = 1 - negative_gt_mask

# contrastive denoising training positive index
positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
dn_positive_idx = paddle.nonzero(positive_gt_mask)[:, 1]
print("304行的dn_positive_idx", dn_positive_idx.shape)
print("308", [n * num_group for n in num_gts])

dn_positive_idx = paddle.split(dn_positive_idx,
                               [n * num_group for n in num_gts])
# total denoising queries
num_denoising = int(max_gt_num * 2 * num_group)

我在get_contrastive_denoising_training_group里输出和pad_gt_mask、dn_positive_idx相关的shape,如下==>
280的pad_gt_mask [1, 42]
289的pad_gt_mask [1, 42]
293的pad_gt_mask [1, 168]
304行的dn_positive_idx [1478656]
308 [84]

@lyuwenyu
Copy link
Owner

nonzero

dn_positive_idx = paddle.nonzero(positive_gt_mask)[:, 1] 这一行的结果好像有异常,

直接把positive_gt_maskdn_positive_idx的结果打印出来看看

@guoqsGary
Copy link
Author

280的pad_gt_mask [1, 11]
289的pad_gt_mask [1, 11]
293的pad_gt_mask [1, 198]
303行的positive_gt_mask [1, 198]
具体结果:
Tensor(shape=[1, 198], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[[ 0.00310632, 0.14101093, 0.12103835, 0.02154503, -0.07726379,
-0.14637139, -0.03424893, 0.08380885, -0.09107973, 0.06767586,
-0.00615196, 0.05792801, 0.09927218, 0.16896574, 0.01404864,
0.04292662, 0.01954217, 0.09794760, 0.07740788, 0.06004503,
-0.35705784, 0.07120756, 0.02380743, -0.14642087, 0.13815553,
-0.16340810, -0.08960244, 0.04147810, -0.33778837, 0.01930248,
0.01650336, -0.03779846, 0.13665493, -0.04679525, -0.00215192,
-0.12053833, 0.00453085, 0.00240263, 0.04109331, -0.06165434,
0.01337579, 0.10240862, 0.05134562, -0.06095485, -0.04340591,
0.10876471, 0.00403623, 0.11194342, -0.11809937, -0.00234561,
0.03818209, -0.13671142, -0.02147811, 0.06945695, 0.10818230,
-0.01067771, 0.10701979, -0.05308023, 0.08857229, 0.10946196,
-0.07160667, 0.02844960, 0.10553841, 0.02046511, -0.04444174,
0.08384745, 0.02021224, -0.01091145, 0.00044059, -0.01279152,
-0.12776904, 0.12965322, -0.12100679, 0.13552734, 0.06030656,
-0.08390035, -0.10023104, -0.01297596, 0.07216468, -0.01284016,
-0.11601792, 0.22335936, 0.06061524, -0.01920354, -0.20133021,
-0.04863635, 0.01202199, 0.06056166, 0.06536793, -0.09030059,
0.13137969, 0.13094848, -0.06637057, -0.01327492, 0.16215605,
-0.07760047, 0.22065499, 0.16439630, 0.09158885, 0.00475553,
-0.00816963, -0.01253985, 0.06993420, -0.10915629, -0.01680921,
0.00923347, 0.00042252, -0.10916576, 0.09148672, 0.12017623,
0.05146683, -0.04138303, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, -0.00386819, 0.05617883,
-0.13188867, 0.06380036, 0.08549953, -0.07800592, 0.03591472,
0.04316943, 0.08014700, 0.02599740, 0.06113206, 0.01989053,
0.11121507, -0.06143811, 0.01651115, 0.19473866, -0.08239598,
0.00521296, -0.12147374, 0.03743923, 0.12685654, 0.00977078,
0.18656254, 0.16259746, -0.04821961, -0.06283052, 0.02976335,
0.01802769, 0.00984512, -0.09385370, -0.04777906, -0.05843711,
-0.12113749, 0.15310720, 0.16464706, 0.01821084, 0.03444537,
0.11447179, -0.14957972, 0.16943902, -0.10179990, 0.20357345,
0.21170834, 0.05378019, -0.06139834, -0.12831408, 0.08167726,
-0.02788442, -0.00992010, -0.02202041, 0.02871853, -0.07218811,
0.05534431, 0.17242326, 0.16148648, 0.05731316, -0.00394048,
0.02770376, -0.00773144, 0.00052531, -0.08097828, -0.04605101,
0.07508193, 0.04552646, -0.00618507, 0.02365397, 0.02251674,
0.10373530, -0.13113993, 0.03352785]])
305行的dn_positive_idx [1478656]
具体结果:
Tensor(shape=[1478656], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[0, 0, 0, ..., 0, 0, 0])
309 [99]

@lyuwenyu
Copy link
Owner

lyuwenyu commented Jul 21, 2023

positive_gt_mask就是上边你打印出来的tensor嘛,

执行paddle.nonzero(positive_gt_mask)[:, 1]的结果看着没啥问题啊

image
# pip list | grep "paddlepaddle"
paddlepaddle-gpu                   2.4.2.post116


# nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_May__3_18:49:52_PDT_2022
Cuda compilation tools, release 11.7, V11.7.64
Build cuda_11.7.r11.7/compiler.31294372_0

@guoqsGary
Copy link
Author

guoqsGary commented Jul 21, 2023

positive_gt_mask就是上边你打印出来的tensor嘛,

执行paddle.nonzero(positive_gt_mask)[:, 1]的结果看着没啥问题啊

image ``` # paddle paddlepaddle-gpu 2.4.2

cuda

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_May__3_18:49:52_PDT_2022
Cuda compilation tools, release 11.7, V11.7.64
Build cuda_11.7.r11.7/compiler.31294372_0

对~
但是在原来的代码里positive_gt_mask的输出是这个
Tensor(shape=[1478656], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[0, 0, 0, ..., 0, 0, 0])

我在测试代码中进行实验发现无法得到您的输出结果:

aa = paddle.to_tensor([[ 0.00310632, 0.14101093, 0.12103835, 0.02154503, -0.07726379,
-0.14637139, -0.03424893, 0.08380885, -0.09107973, 0.06767586,
-0.00615196, 0.05792801, 0.09927218, 0.16896574, 0.01404864,
0.04292662, 0.01954217, 0.09794760, 0.07740788, 0.06004503,
-0.35705784, 0.07120756, 0.02380743, -0.14642087, 0.13815553,
-0.16340810, -0.08960244, 0.04147810, -0.33778837, 0.01930248,
0.01650336, -0.03779846, 0.13665493, -0.04679525, -0.00215192,
-0.12053833, 0.00453085, 0.00240263, 0.04109331, -0.06165434,
0.01337579, 0.10240862, 0.05134562, -0.06095485, -0.04340591,
0.10876471, 0.00403623, 0.11194342, -0.11809937, -0.00234561,
0.03818209, -0.13671142, -0.02147811, 0.06945695, 0.10818230,
-0.01067771, 0.10701979, -0.05308023, 0.08857229, 0.10946196,
-0.07160667, 0.02844960, 0.10553841, 0.02046511, -0.04444174,
0.08384745, 0.02021224, -0.01091145, 0.00044059, -0.01279152,
-0.12776904, 0.12965322, -0.12100679, 0.13552734, 0.06030656,
-0.08390035, -0.10023104, -0.01297596, 0.07216468, -0.01284016,
-0.11601792, 0.22335936, 0.06061524, -0.01920354, -0.20133021,
-0.04863635, 0.01202199, 0.06056166, 0.06536793, -0.09030059,
0.13137969, 0.13094848, -0.06637057, -0.01327492, 0.16215605,
-0.07760047, 0.22065499, 0.16439630, 0.09158885, 0.00475553,
-0.00816963, -0.01253985, 0.06993420, -0.10915629, -0.01680921,
0.00923347, 0.00042252, -0.10916576, 0.09148672, 0.12017623,
0.05146683, -0.04138303, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, -0.00386819, 0.05617883,
-0.13188867, 0.06380036, 0.08549953, -0.07800592, 0.03591472,
0.04316943, 0.08014700, 0.02599740, 0.06113206, 0.01989053,
0.11121507, -0.06143811, 0.01651115, 0.19473866, -0.08239598,
0.00521296, -0.12147374, 0.03743923, 0.12685654, 0.00977078,
0.18656254, 0.16259746, -0.04821961, -0.06283052, 0.02976335,
0.01802769, 0.00984512, -0.09385370, -0.04777906, -0.05843711,
-0.12113749, 0.15310720, 0.16464706, 0.01821084, 0.03444537,
0.11447179, -0.14957972, 0.16943902, -0.10179990, 0.20357345,
0.21170834, 0.05378019, -0.06139834, -0.12831408, 0.08167726,
-0.02788442, -0.00992010, -0.02202041, 0.02871853, -0.07218811,
0.05534431, 0.17242326, 0.16148648, 0.05731316, -0.00394048,
0.02770376, -0.00773144, 0.00052531, -0.08097828, -0.04605101,
0.07508193, 0.04552646, -0.00618507, 0.02365397, 0.02251674,
0.10373530, -0.13113993, 0.03352785]])
    

bb = paddle.nonzero(aa)[:, 1]

print(bb)

输出:
Tensor(shape=[0], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[])

# pip list | grep "paddlepaddle"
paddlepaddle-gpu              2.4.2
# nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Jun__8_16:49:14_PDT_2022
Cuda compilation tools, release 11.7, V11.7.99
Build cuda_11.7.r11.7/compiler.31442593_0

@lyuwenyu
Copy link
Owner

lyuwenyu commented Jul 21, 2023

如果你也是cuda11.7 试一下装这个版本 paddlepaddle-gpu 2.4.2.post116

还不行的话,把你这个代码片段和各种版本信息贴到这问一下吧 https://github.com/PaddlePaddle/Paddle/issues

@guoqsGary
Copy link
Author

如果你也是cuda11.7 试一下装这个版本 paddlepaddle-gpu 2.4.2.post116

还不行的话,把你这个代码片段和各种版本信息贴到这问一下吧 https://github.com/PaddlePaddle/Paddle/issues

谢谢大佬,更新了paddle版本后问题解决了。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants