Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

about test #35

Closed
jiaaihhy opened this issue Oct 2, 2021 · 15 comments
Closed

about test #35

jiaaihhy opened this issue Oct 2, 2021 · 15 comments

Comments

@jiaaihhy
Copy link

jiaaihhy commented Oct 2, 2021

在swinIR模型中,有img_size这个参数,例如为128, 在SwinLayer时,是input_resolution=(128, 128), 比如我在测试的时候,我的输入图像不是(128, 128) 那么计算attention的时候 有一个判断, if self.input_resolution == x_size, else attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))。我想请问一下如果图片大小不等于self.input_resolution=(128,128)时, 加入的参数 mask 这个是什么mask

@JingyunLiang
Copy link
Owner

See #9 for a discussion

@jiaaihhy
Copy link
Author

jiaaihhy commented Oct 2, 2021

可是还是没有解决这个问题,我还是不理解如果input_resolution != x_size加的这个mask是什么 以及为什么要加这个mask
if self.input_resolution == x_size:
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nWB, window_sizewindow_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))

attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))

@jiaaihhy
Copy link
Author

jiaaihhy commented Oct 2, 2021

verythanks

@JingyunLiang
Copy link
Owner

This line is for training. We initialize the mask as self.attn_mask so that we don't need to calculate the mask for every iteration. Note that the training image size is fixed, e.g., 64x64.

if self.input_resolution == x_size:

This line is for testing. We calculate the mask for a given testing image. Note that the testing image is generally not 64x64.

@jiaaihhy
Copy link
Author

jiaaihhy commented Oct 2, 2021

when test, what is the meaning of the mask? thankyou

@JingyunLiang
Copy link
Owner

Similar to training, we pad the image after shifting it. You can try not using padding during testing and share the results with me. Thank you.

@paragon1234
Copy link

Even I had the same issue. Can you please elaborate on why we require attn_mask?
From the code, it seems that it is required only for those transformers that operate on shifted window.

if self.shift_size > 0:
            attn_mask = self.calculate_mask(self.input_resolution)
        else:
            attn_mask = None 

@JingyunLiang
Copy link
Owner

Yes, attn_mask is only used for those transformers that operate on shifted window. Imagine that for a 64x64 input, after shifting 4x4 pixels towards top-left corner (by using torch.roll), pixels within [0:4,:] and [:,0:4] are shifted to [60:64,:] and [:,60:64], respectively. In such a case, pixels within [56:60,:] and [:,56:60] will be forced to attend to above unrelated pixels after the new window partition. This is not we want, so we use attn_mask to mask them out.

@paragon1234
Copy link

paragon1234 commented Oct 9, 2021

Thank you for the response. This is an interesting point, from an implementation perspective. I am wondering why should transformer operate only on [56:60,:] and [:,56:60], but not on [0:4,:] and [:,0:4]? Either:

  1. what if we simply not do self-attention on pixels within [56:64,:] and [:,56:64] for shifted_windows ? OR
  2. What if transformer operate on both, without mixing them ie even smaller window size for this corner-case?

In the curent implementation, top 4 rows and last 4 columns are operated only 50% of times.

@JingyunLiang
Copy link
Owner

I didn't test these two cases, but I guess the first case may lead to slightly worse performance (this part of data is discarded), while the second one may leads to slightly better performance (making full use of this part of data). The current implementation is just for simplicity and efficiency.

@JingyunLiang
Copy link
Owner

Feel free to open it if you have more questions.

@paragon1234
Copy link

Can you please explain why we need mask for testing, when input resolution is not 48x48?
I tried to change the attention mechanism in the code. However, my method did not require mask. I removed calculate_mask from code and completed the training phase. However, during testing it gave error.
My concern is: 1)Do I require mask in my attention? 2)How to change my attention mechanism to incorporate mask?

@JingyunLiang
Copy link
Owner

SwinIR needs mask for either input resolution of 48x48 or not. The difference is that we use pre-calculated mask for 48x48 images because we store it in the .pth file during training. For non 48x48 images, we need to recompute it.

For the second concern, what is your error in testing? If you don't need mask in your own attention, there is no need for masking in testing as well. Sorry that I cannot give more help because I don't you what is your changed attention mechanism.

@paragon1234
Copy link

paragon1234 commented Nov 20, 2021

I am using the model: efficient attention(https://github.com/cmsflash/efficient-attention).
I do not understand why we require mask? Even for 48x48 patch, as per your last reply?
Also, the training-time (using efficient attention) is more compared to transformer attention (of swinIR). Maybe the window_size of 8x8 is super efficient for low-level image processing.

@JingyunLiang
Copy link
Owner

JingyunLiang commented Nov 21, 2021

It seems that the position encoding is not very important for SR from my experience. You can try to remove it and compare their results. Note that there are two problems you need to address for efficient attention(https://github.com/cmsflash/efficient-attention): 1) Using softmax for q and k separately may reduce the representation power of the attention matrix significantly as the rank is matrix is smaller. 2) It may be trick to apply masks for it (see cmsflash/efficient-attention#4)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants