Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-RGB SPAN models #25

Closed
RunDevelopment opened this issue Jan 29, 2024 · 13 comments
Closed

Non-RGB SPAN models #25

RunDevelopment opened this issue Jan 29, 2024 · 13 comments
Labels
enhancement New feature or request solved issue has been solved

Comments

@RunDevelopment
Copy link
Contributor

RunDevelopment commented Jan 29, 2024

Hi @muslll!

I just read through the SPAN code again, and wondered whether span even supports anything other than RGB images as input. If I understand PyTorch tensors correctly, then this line:

x = (x - self.mean) * self.img_range

will fail for non-RGB inputs because self.mean is defined like this:

self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)

and will always have 3 channels.

So the torch.Tensor(rgb_mean).view(1, 3, 1, 1) should probably be changed to torch.Tensor(rgb_mean).view(1, in_channels, 1, 1) . Alternatively, we could also use the same approach as SwinIR:

https://github.com/muslll/neosr/blob/master/neosr/archs/swinir_arch.py#L775-L779

        if in_chans == 3:
            rgb_mean = (0.4488, 0.4371, 0.4040)
            self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
        else:
            self.mean = torch.zeros(1, 1, 1, 1)

Correction: The above suggested are not backwards compatible because of single-channel images are broadcasted. So we need to keep the current behavior for in_chans in (1, 3).

What do you think?

@umzi2
Copy link
Contributor

umzi2 commented Jan 29, 2024

It at least runs the training and gives the expected result, today I can run the training of a 1 channel model and show the result.

@RunDevelopment
Copy link
Contributor Author

You're right. It will work for 1-channel models:

>>> m = torch.zeros(1, 3, 1, 1)
>>> i = torch.zeros(1, 1, 200, 100)
>>> (i - m).shape
torch.Size([1, 3, 200, 100])

I mean, the first thing a 1-channel model does is to convert the input image to RGB, but the rest of the model seemingly has no issue with that. ¯\(ツ)

It doesn't work for RGBA models though:

>>> i = torch.zeros(1, 4, 200, 100)
>>> (i - m).shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 1

@muslll
Copy link
Owner

muslll commented Jan 29, 2024

Hi @RunDevelopment, this has been discussed a lot these few days in the EE discord. Consensus was that the x = (x - self.mean) * self.img_range forward affects stability. I thought about creating a bool for disabling it, which would also solve the grayscale training issue you mentioned. What do you think?

    def __init__(self,
                 num_in_ch=3,
                 num_out_ch=3,
                 feature_channels=48,
                 upscale=upscale,
                 bias=True,
                 norm=False, # new bool
                 img_range=1.0,
                 rgb_mean=(0.4488, 0.4371, 0.4040)
                 ):
        super(span, self).__init__()

        in_channels = num_in_ch
        out_channels = num_out_ch
        self.img_range = img_range
        self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
        self.norm = norm

        self.conv_x = [...]
        
    def forward(self, x):
        if self.norm:
            self.mean = self.mean.type_as(x)
            x = (x - self.mean) * self.img_range

        [...]

@RunDevelopment
Copy link
Contributor Author

Sounds good. I don't know much about AI, so I trust you and the others on EE when you say that removing this won't cause problems.

As for spandrel parameter detection: since this parameter is a boolean, this makes detection (in a backwards compatible way) easy. We can use the same trick I used Real-CUGAN pro models. Basically, we optionally register a small tensor as a buffer and the presence of the tensor determines the parameter value.
In this case, we register the tensor only if norm=False. So the tensor will only be present on SPAN models without normalization, so everything is backwards compatible.

What do you think?

@muslll
Copy link
Owner

muslll commented Jan 29, 2024

Looks good. Commited 👍

@muslll muslll added the solved issue has been solved label Jan 29, 2024
@RunDevelopment
Copy link
Contributor Author

Sorry for not making this clear enough @muslll, but the Real-CUGAN trick I mentioned has to be implemented by neosr as well. If neosr doesn't register the tensor to signify the value of norm, there is no way for us to detect this.

@RunDevelopment
Copy link
Contributor Author

RunDevelopment commented Jan 29, 2024

On that note: your CUGAN implementation for pro models has the same issue. Since the pro parameter isn't stored in the .pth, there is no way to detect them as pro models. They would even be incompatible with the official Real-CUGAN code.

Should I make a separate issue for this?

@muslll
Copy link
Owner

muslll commented Jan 29, 2024

I see, my bad. Commited here.

@muslll
Copy link
Owner

muslll commented Jan 29, 2024

Should I make a separate issue for this?

No need, I just commited here. I also changed the tensor to isnorm instead, to avoid conflict with the main function param.

edit: fix

@RunDevelopment
Copy link
Contributor Author

edit: fix

Still wrong. As I said here:

we register the tensor only if norm=False.

Right now, you are registering the tensor for the old behavior. But we need to register it for the new behavior and for it only.

@muslll
Copy link
Owner

muslll commented Jan 29, 2024

I see. Is this right now?.

@RunDevelopment
Copy link
Contributor Author

Yes, this works! Thank you @muslll!

Also, you don't have to make a property and stuff like I did with Real-CUGAN. I just did it there because both the hyper parameter and buffer had the same name...

@muslll
Copy link
Owner

muslll commented Jan 29, 2024

Great. Thanks 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request solved issue has been solved
Projects
None yet
Development

No branches or pull requests

3 participants