Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Currently creates a lot of NaNs #9

Closed
p-sodmann opened this issue Feb 28, 2022 · 1 comment
Closed

Currently creates a lot of NaNs #9

p-sodmann opened this issue Feb 28, 2022 · 1 comment

Comments

@p-sodmann
Copy link
Contributor

I am a physician and don't know what I am doing, but this helps to prevent the NaNs, maybe it helps you to create a more sophisticated solution.

        out = inp.square()
        if torch.any(torch.isnan(out)):
            raise Exception("out")
            
        if self.groups == 1:
            out = out.sum(1, keepdim=True)

        norm = F.conv2d(
            out,
            torch.ones_like(self.weight[:1, :1]),
            None,
            self.stride,
            self.padding,
            self.dilation) + 1e-6
        
        if torch.any(torch.isnan(norm)):
            raise Exception("norm")
            
        # prevent 0 and inf
        q = torch.exp(-self.q / (self.q_scale**2 + 0.1))
        if torch.any(torch.isnan(q)):
            raise Exception("q")
            

        weight = self.weight / (self.weight.square().sum(dim=(1, 2, 3), keepdim=True).sqrt() + q + 1e-6)
        if torch.any(torch.isnan(weight)):
            raise Exception("weight")
            

        out = F.conv2d(
            inp,
            weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups
        ) / ((norm**2).sqrt() + 1e-1)

        if torch.any(torch.isnan(out)):
            raise Exception("out2")

        # Comment these lines out for vanilla cosine similarity.
        # It's ~200x faster.
        abs = (out.square() + 1e-6).sqrt()
        sign = out / abs
        # prevent 0 and inf
        p = torch.exp(self.p / (self.p_scale**2 +0.1))
        out = abs ** p
        out = out * sign
        if torch.any(torch.isnan(out)):
            raise Exception("out3")
        return out```
@brohrer
Copy link
Owner

brohrer commented Mar 10, 2022

Thanks for taking a look at this Philipp. I appreciate it since I know that it is out of your comfort zone. The presence of NaNs is a problem. Checking for them as you're doing here is a good way to make the code more robust, but also can slow it down a little bit, so I'm hoping to identify the root cause and nip them in the bud rather than having to check each time through.

I haven't run into any NaNs yet myself, but suspect that's because we're using different data sets. Do you remember which data set you were working with when this popped up?

@brohrer brohrer closed this as completed Apr 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants