-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why does using the BatchNorm2d() layer break model loading? #538
Comments
@Metritutus -- it's hard to know exactly what is going wrong. I'll take a look at this later today. Just to clarify -- you are both saving and loading the model using TorchSharp, right? |
Indeed. I did find an issue which appeared similar (#510), but it related to loading models exported from Python, whereas I am working solely with TorchSharp. |
I'll take a look at it. BatchNorm has been troublesome... |
Here's the repro code I put together: [Fact]
public void Validate538()
{
var module = new Module538(1000, 100);
var sd = module.state_dict();
Assert.Equal(7, sd.Count);
if (File.Exists("bug538.dat")) File.Delete("bug538.dat");
module.save("bug538.dat");
module.load("bug538.dat");
File.Delete("bug538.dat");
}
internal class Module538 : Module
{
private Module seq;
public Module538(int in_channels, int out_channels) : base(String.Empty)
{
seq = Sequential(Conv2d(1, 32, 3),
BatchNorm2d(32),
ReLU(),
Flatten(),
LogSoftmax(1)
);
RegisterComponents();
}
public override torch.Tensor forward(torch.Tensor t)
{
return this.seq.forward(t);
}
} This test runs fine for me. That said, this is with post-release code. I don't think we've fixed any bugs related to BatchNorm since the last release, though. |
Looks like the fault was mine. I thought I was on the latest version, however it seems I was only on 0.96.0. Upgrading to 0.96.3 appears to have fixed the issue! For your information, after updating I did need to generate a new model. Models saved on the older version still produced the error when calling |
@Metritutus -- yes, that is right. As I mentioned, BatchNorm{123} surfaced several issues, and it took a couple of releases to get to the bottom of them all (which I hope we have done at this point.) |
Added regression test for issue #538
Model generation appears to work without error. When I get to the point of trying to load the model afterwards however, an
ArgumentException
is thrown.I've tried to find examples online, but am unable to see what I am doing wrong. The value being passed to
BatchNorm2d
matches the number of output channels in the precedingConv2d
layer.I've removed most of my other layers for the purposes of simplifying the example:
When calling
model.load()
for a model generated with the above layers, the following exception is thrown:System.ArgumentException: 'Mismatched state_dict sizes: expected 7, but found 4 entries.'
Stacktrace:
If I comment out the
BatchNorm2d()
line and generate a new model, that one loads without issue.What am I doing wrong?
The text was updated successfully, but these errors were encountered: