Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

训练这个模型在其他数据集 #17

Closed
imcjx opened this issue Mar 30, 2022 · 8 comments
Closed

训练这个模型在其他数据集 #17

imcjx opened this issue Mar 30, 2022 · 8 comments

Comments

@imcjx
Copy link

imcjx commented Mar 30, 2022

你好,我将这个模型用于别的数据集(19类),并更改了配置文件里的类别个数,但似乎并不能直接运行起来,出现了下面的错误:

Traceback (most recent call last):
  File "/mnt/csip-102/CaiJiaXin/MCTrans-master/train.py", line 131, in <module>
    main()
  File "/mnt/csip-102/CaiJiaXin/MCTrans-master/train.py", line 127, in main
    SegPipline(cfg, distributed, not args.no_validate).run()
  File "/mnt/csip-102/CaiJiaXin/MCTrans-master/mctrans/pipline/segpipline.py", line 99, in run
    self.runner.run(self.data_loaders, self.cfg.workflow, self.cfg.max_epochs)
  File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 50, in train
    self.run_iter(data_batch, train_mode=True, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 30, in run_iter
    **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/mmcv/parallel/data_parallel.py", line 75, in train_step
    return self.module.train_step(*inputs[0], **kwargs[0])
  File "/mnt/csip-102/CaiJiaXin/MCTrans-master/mctrans/models/segmentors/base.py", line 152, in train_step
    losses = self(**data_batch)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 109, in new_func
    return old_func(*args, **kwargs)
  File "/mnt/csip-102/CaiJiaXin/MCTrans-master/mctrans/models/segmentors/base.py", line 122, in forward
    return self.forward_train(img, img_metas, **kwargs)
  File "/mnt/csip-102/CaiJiaXin/MCTrans-master/mctrans/models/segmentors/encoder_decoder.py", line 92, in forward_train
    loss_aux = self._auxiliary_head_forward(x, seg_label, return_loss=True)
  File "/mnt/csip-102/CaiJiaXin/MCTrans-master/mctrans/models/segmentors/encoder_decoder.py", line 81, in _auxiliary_head_forward
    return self.aux_head.forward_train(x, seg_label)
  File "/mnt/csip-102/CaiJiaXin/MCTrans-master/mctrans/models/heads/mctrans_aux_head.py", line 49, in forward_train
    outputs = self.ca(inputs_flatten)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/csip-102/CaiJiaXin/MCTrans-master/mctrans/models/trans/transformer.py", line 145, in forward
    query = layer(query, src)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/csip-102/CaiJiaXin/MCTrans-master/mctrans/models/trans/transformer.py", line 122, in forward
    tgt2 = self.cross_attn(tgt, src, src)[0]
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 1038, in forward
    attn_mask=attn_mask)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py", line 5025, in multi_head_attention_forward
    k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
RuntimeError: shape '[-1, 152, 16]' is invalid for input of size 2752512 

我看到关闭的问题中也有类似问题,但我按照他的修改方式也并不能成功,期待你的解答

@JiYuanFeng
Copy link
Owner

JiYuanFeng commented Mar 30, 2022 via email

@imcjx
Copy link
Author

imcjx commented Mar 30, 2022

我使用的配置文件是mctrans_vgg32_d5.py

model = dict(
    type='EncoderDecoder',
    pretrained=None,
    encoder=dict(
        type="VGG",
        in_channel=3,
        init_channels=32,
        num_blocks=2),
    center=dict(
        type="MCTrans",
        d_model=128,
        nhead=8,
        d_ffn=512,
        dropout=0.1,
        act="relu",
        n_levels=3,
        n_points=4,
        n_sa_layers=6),
    decoder=dict(
        type="UNetDecoder",
        in_channels=[32, 64, 128, 128, 128],
    ),
    seg_head=dict(
        type="BasicSegHead",
        in_channels=32,
        num_classes=19, # 6
        post_trans=[dict(type="Activations", softmax=True),
                    dict(type="AsDiscrete", argmax=True)],
        losses=[dict(type="DiceLoss", softmax=True, to_onehot_y=True),
                dict(type="CrossEntropyLoss"),
                dict(type="FLoss")]),
    aux_head=dict(
        type="MCTransAuxHead",
        d_model=128,
        d_ffn=512,
        act="relu",
        num_classes=19, # 6
        in_channles=[32, 64, 128, 128, 128],
        losses=[dict(type="MCTransAuxLoss", sigmoid=True, loss_weight=0.1)]),
)

@JiYuanFeng
Copy link
Owner

JiYuanFeng commented Mar 30, 2022 via email

@imcjx
Copy link
Author

imcjx commented Mar 30, 2022

我使用的monai版本就是0.5.3
image

@JiYuanFeng
Copy link
Owner

hi,你可以提供你的数据在进入mctrans之前的shape吗?可能的话,提供下一张data还有mask,我这边debug一下

@imcjx
Copy link
Author

imcjx commented Mar 31, 2022

你好,非常感谢你热心的解答,这个问题已经解决了,是维度的问题。但我自己的数据集里的GT中的像素有的值为255,是不需要计算指标的像素,这使我在使用损失时会遇到问题,只有交叉熵的损失有ignore_index的参数来让我忽略这些不需要计算的参数,Dice_Loss、FLoss、MCTransAuxLoss并没有这个参数,请问有什么办法可以解决这个问题么。

@JiYuanFeng
Copy link
Owner

Hi, 你那边可以对这几个loss进行更改,在计算loss的时候增加一个mask。或者你也可以在数据集里面简单对label进行remap, 例如将255 转为0,虽然简单,但是这会参与计算。

if self.label_map is not None:

@imcjx
Copy link
Author

imcjx commented Apr 1, 2022

非常感谢你的解答

@imcjx imcjx closed this as completed Apr 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants