Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SwinIR classical Image SR Training using dataparallel #84

Closed
paragon1234 opened this issue Sep 24, 2021 · 11 comments
Closed

SwinIR classical Image SR Training using dataparallel #84

paragon1234 opened this issue Sep 24, 2021 · 11 comments

Comments

@paragon1234
Copy link

paragon1234 commented Sep 24, 2021

While training the SwinIR classical Image SR, using dataparallel :
**python main_train_psnr.py --opt options/swinir/train_swinir_sr_classical.json**

file::main_train_psnr.py has training loop. Call to model.test(), gives error:

Traceback (most recent call last):
  File "main_train_psnr.py", line 256, in <module>
    main()
  File "main_train_psnr.py", line 228, in main
    model.test()
  File "F:\models\model_plain.py", line 186, in test
    self.netG_forward()
  File "F:\models\model_plain.py", line 143, in netG_forward
    self.E = self.netG(self.L)
  File "C:\Users\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\parallel\data_parallel.py", line 154, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "C:\Users\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "F:\models\network_swinir.py", line 807, in forward
    x = self.conv_after_body(self.forward_features(x)) + x
  File "F:\models\network_swinir.py", line 793, in forward_features
    x = layer(x, x_size)
  File "C:\Users\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "F:\models\network_swinir.py", line 485, in forward
    return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
  File "C:\Users\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "F:\models\network_swinir.py", line 405, in forward
    x = blk(x, x_size)
  File "C:\Users\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "F:\models\network_swinir.py", line 258, in forward
    x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
  File "F:\models\network_swinir.py", line 43, in window_partition
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
RuntimeError: shape '[1, 84, 8, 127, 8, 180]' is invalid for input of size 124480800

The header of file has comment 'training code for MSRResNet'. Seems like the file needs to be modified.

@paragon1234 paragon1234 changed the title SwinIR classical Image ST Training using dataparallel SwinIR classical Image SR Training using dataparallel Sep 24, 2021
@JingyunLiang
Copy link
Collaborator

JingyunLiang commented Sep 29, 2021

It's not the problem of training code for MSRResNet because they share the same training scheme. Does DistributedDataParallel work for you? What are your input image sizes? You can print the image size at the beginning of the network. It should always be a multiple of window_size.

@paragon1234
Copy link
Author

paragon1234 commented Sep 29, 2021

The system I use, do not support DistributedDataparallel.
I had split Div2K dataset into training (800 images) and testing (100 images). The training code runs fine. Issue is in testing code. Why the testing code is giving error?
If I comment out the testing part (loop containing model.test()), the training code runs fine.
Images for testing is from same dataset as training. Why should it give invalid dimension during testing?

@JingyunLiang
Copy link
Collaborator

Thanks for your question. It's a bug because the testing image is not padded to be a multiple of window_size. Please use the latest swinir_network.py. See comit 81a6547.

@paragon1234
Copy link
Author

Thank you sir for the update.
Can you please elaborate on what is window_size and why it is needed?

@JingyunLiang
Copy link
Collaborator

See JingyunLiang/SwinIR#9 for the discussion.

@paragon1234
Copy link
Author

Thank you

@paragon1234
Copy link
Author

While going through the code changes, I failed to understand following things:

  1. The code changes is done in forward() function of SwinIR class. This should impact both testing and training code. However, the problem was only with testing code, training code was working fine. Why was training code working fine without the fix? If my thinking is correct, then the fix should be in some other place, where only testing data is operated on.

  2. (with batch_size=4), When I checked dimensions of x in forward() of SwinIR, it is torch.Size([4, 3, 48, 48]). 48x48 indicate its a patch. Where is this patch created? My understanding is: forward() of SwinIR is the first function called, and patch should be created here.
    Also, does it indicate that a batch consist of 4 patches of size 48x48. Should not it be 4 full images from training data?

  3. The paper "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" talks about: hierarchical representation, by patch merging layers. I could see a class PatchMerging, but it is never called. Does your architecture uses hierarchical representation?

@paragon1234 paragon1234 reopened this Oct 2, 2021
@JingyunLiang
Copy link
Collaborator

1, Training stage only gets 48x48 image patches as input. There is no image padding problems.

2, The patch is created in the dataloader during training. We use 48x48 patches for batch training.

3, No. We don't use it but we still keep it in the code for possible future usages.

@paragon1234
Copy link
Author

  1. Can you please point to which part (of code) of training dataloader, creates patch? There is a function patches_from_image() in utils_image.py, but it is never called.

  2. Is there something else you do to get attention across scale, or cascading of transformer blocks automatically does it, just like convolution increases receptive field?

@JingyunLiang
Copy link
Collaborator

2, See

# randomly crop the L patch

3, No. We just stack layers to increase the receptive field.

@paragon1234
Copy link
Author

Thank you sir.

3> Seems like others have done over-engineering to do hierarchical representation, which happens by default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants