Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why does using the BatchNorm2d() layer break model loading? #538

Closed
Metritutus opened this issue Feb 27, 2022 · 6 comments · Fixed by #540
Closed

Why does using the BatchNorm2d() layer break model loading? #538

Metritutus opened this issue Feb 27, 2022 · 6 comments · Fixed by #540

Comments

@Metritutus
Copy link

Metritutus commented Feb 27, 2022

Model generation appears to work without error. When I get to the point of trying to load the model afterwards however, an ArgumentException is thrown.

I've tried to find examples online, but am unable to see what I am doing wrong. The value being passed to BatchNorm2d matches the number of output channels in the preceding Conv2d layer.

I've removed most of my other layers for the purposes of simplifying the example:

 Layers = Sequential(Conv2d(1, 32, 3),
                     BatchNorm2d(32),
                     ReLU(),
                     Flatten(),
                     LogSoftmax(1));

When calling model.load() for a model generated with the above layers, the following exception is thrown:
System.ArgumentException: 'Mismatched state_dict sizes: expected 7, but found 4 entries.'

Stacktrace:

   at TorchSharp.torch.nn.Module.load(BinaryReader reader, Boolean strict)
   at TorchSharp.torch.nn.Module.load(String location, Boolean strict)
   at <REDACTED>

If I comment out the BatchNorm2d() line and generate a new model, that one loads without issue.

What am I doing wrong?

@NiklasGustafsson
Copy link
Contributor

NiklasGustafsson commented Feb 28, 2022

@Metritutus -- it's hard to know exactly what is going wrong. I'll take a look at this later today.

Just to clarify -- you are both saving and loading the model using TorchSharp, right?

@Metritutus
Copy link
Author

Metritutus commented Feb 28, 2022

@Metritutus -- it's hard to know exactly what is going wrong. I'll take a look at this later today.

Just to clarify -- you are both saving and loading the model using TorchSharp, right?

Indeed. I did find an issue which appeared similar (#510), but it related to loading models exported from Python, whereas I am working solely with TorchSharp.

@NiklasGustafsson
Copy link
Contributor

I'll take a look at it. BatchNorm has been troublesome...

@NiklasGustafsson
Copy link
Contributor

Here's the repro code I put together:

        [Fact]
        public void Validate538()
        {
            var module = new Module538(1000, 100);

            var sd = module.state_dict();

            Assert.Equal(7, sd.Count);

            if (File.Exists("bug538.dat")) File.Delete("bug538.dat");

            module.save("bug538.dat");
            module.load("bug538.dat");

            File.Delete("bug538.dat");
        }

        internal class Module538 : Module
        {
            private Module seq;

            public Module538(int in_channels, int out_channels) : base(String.Empty)
            {
                seq = Sequential(Conv2d(1, 32, 3),
                     BatchNorm2d(32),
                     ReLU(),
                     Flatten(),
                     LogSoftmax(1)
                );

                RegisterComponents();
            }

            public override torch.Tensor forward(torch.Tensor t)
            {
                return this.seq.forward(t);
            }
        }

This test runs fine for me.

That said, this is with post-release code. I don't think we've fixed any bugs related to BatchNorm since the last release, though.

@Metritutus
Copy link
Author

Metritutus commented Feb 28, 2022

Looks like the fault was mine. I thought I was on the latest version, however it seems I was only on 0.96.0. Upgrading to 0.96.3 appears to have fixed the issue!

For your information, after updating I did need to generate a new model. Models saved on the older version still produced the error when calling model.load() on the newer version. From this, I'm assuming something must have previously been going wrong when saving the model. It may be that this was in fact fixed by the fix for #510, which I've noticed was in 0.96.1, which I should have noticed earlier, sorry!.

@NiklasGustafsson
Copy link
Contributor

@Metritutus -- yes, that is right. As I mentioned, BatchNorm{123} surfaced several issues, and it took a couple of releases to get to the bottom of them all (which I hope we have done at this point.)

NiklasGustafsson added a commit to NiklasGustafsson/TorchSharp that referenced this issue Feb 28, 2022
@NiklasGustafsson NiklasGustafsson linked a pull request Feb 28, 2022 that will close this issue
NiklasGustafsson added a commit that referenced this issue Feb 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants