Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module.Load throws Mismatched state_dict sizes exception on BatchNorm1d #510

Closed
FusionCarcass opened this issue Feb 9, 2022 · 9 comments · Fixed by #511
Closed

Module.Load throws Mismatched state_dict sizes exception on BatchNorm1d #510

FusionCarcass opened this issue Feb 9, 2022 · 9 comments · Fixed by #511

Comments

@FusionCarcass
Copy link

FusionCarcass commented Feb 9, 2022

When loading a model from a .dat file exported from python, the Module.Load method throws the exception below. I printed out all of the registered parameters, and the only ones that didn't show up are BatchNorm1d parameters: running_mean, running_var, and num_batches_tracked.

I tried to work around this problem by registering those parameters with the register_parameter function. That eliminates the exception below, but I run into a different issue where the bias is not loaded correctly after registering the other parameters. The bias parameter is still set to torch.zeros(N).

System.ArgumentException
  HResult=0x80070057
  Message=Mismatched state_dict sizes: expected 200, but found 300 entries.
  Source=TorchSharp
  StackTrace:
   at TorchSharp.torch.nn.Module.load(BinaryReader reader, Boolean strict)
   at TorchSharp.torch.nn.Module.load(String location, Boolean strict)
   at OpenHips.Scanners.TorchScanner.LoadFromFile(String fileName) in C:\Users\helpdesk\Desktop\Workspace\repos\open-hips\open-hips-cortex\Scanners\TorchScanner.cs:line 15
   at OpenHips.Program.HandleScan(FileInfo target, FileInfo model) in C:\Users\helpdesk\Desktop\Workspace\repos\open-hips\open-hips-scanner\Program.cs:line 33
   at System.CommandLine.Handler.<>c__DisplayClass2_0`2.<SetHandler>b__0(InvocationContext context)
   at System.CommandLine.Invocation.AnonymousCommandHandler.<>c__DisplayClass2_0.<.ctor>g__Handle|0(InvocationContext context)
   at System.CommandLine.Invocation.AnonymousCommandHandler.<InvokeAsync>d__3.MoveNext()

The load method should probably take into consideration the registered buffers if we are not going to consider running_mean and running_var parameters.

Here are the fixes I tried.

internal class BasicConv1d : Module {
        private readonly Sequential stack;

        public BasicConv1d(int in_channels, int out_channels, int kernel_size, int stride = 1, int padding = 0) : base(String.Empty) {
            BatchNorm1d temp = BatchNorm1d(out_channels);
            //temp.reset_running_stats();
            temp.running_mean = new Parameter(torch.zeros(out_channels, requiresGrad:false), requires_grad:false);
            temp.running_var = new Parameter(torch.zeros(out_channels, requiresGrad:false), requires_grad:false);
            temp.register_parameter("num_batches_tracked", new Parameter(temp.state_dict()["num_batches_tracked"], requires_grad: false));
            temp.register_parameter("running_mean", temp.running_mean);
            temp.register_parameter("running_var", temp.running_var);
            temp.bias = temp.get_parameter("bias"); // Without this line, temp.bias is all zeros after load. With this line, temp.bias is equal to temp.running_var
            this.stack = Sequential(
                Conv1d(in_channels, out_channels, kernel_size, stride:stride, padding:padding, bias:false),
                temp,
                ReLU(inPlace:true)
            );

            this.RegisterComponents();
        }

        public override torch.Tensor forward(torch.Tensor t) {
            return this.stack.forward(t);
        }
    }
@NiklasGustafsson
Copy link
Contributor

Thanks for sharing that. Is there a small repro that exhibits the problem? I thought the repro I came up with for the previous BatchNorm1d bug covered it, but apparently not.

@FusionCarcass
Copy link
Author

FusionCarcass commented Feb 10, 2022

You fixed the issue from before where BatchNorm1d could not evaluate tensors with a single batch, so that's good! Just need to get my model loaded and I think I'll be good to go.

From my Python training code base:

class BasicConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
       
        self.stack = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, bias=False, **kwargs),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.stack(x)

if __name__ == '__main__':
    # Create model
    model = BasicConv1d(1, 32)
    
    #Export model to .dat file for ingestion into TorchSharp
    f = open("C:\model.dat", "wb")
    exportsd.save_state_dict(model.to("cpu").state_dict(), f)
    f.close()

Then import to TorchSharp.

#Create the model
BasicConv1d module = new BasicConv1d();

#Load the model from .dat file.
module.load("C:\model.dat"); #Exception occurs here

@NiklasGustafsson
Copy link
Contributor

Thanks for that. I'll take a look at it later today.

NiklasGustafsson added a commit to NiklasGustafsson/TorchSharp that referenced this issue Feb 10, 2022
Adding 'num_batches_tracked' to BatcNorm{123}d
@NiklasGustafsson NiklasGustafsson linked a pull request Feb 10, 2022 that will close this issue
@NiklasGustafsson
Copy link
Contributor

@FusionCarcass -- I have a fix in PR, and will release it with another important fix (parameter groups).

@GeorgeS2019
Copy link

@FusionCarcass Can you confirm that the bugs are addressed. If possible, do consider provide a simple example so others could expand the workflow of saving model from pytorch and loading that in TorchSharp

@FusionCarcass
Copy link
Author

@GeorgeS2019 I can take a look at it. Is this fix pushed out in a new Nuget version or do I need to build from source? So far I've just been waiting on the nuget releases.

@GeorgeS2019
Copy link

GeorgeS2019 commented Feb 19, 2022

@FusionCarcass Thank you for testing and if possible we need more use cases to load pytorch save dict states with Torchsharp .

@GeorgeS2019
Copy link

class BasicConv1d(torch.nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()

        self.stack = torch.nn.Sequential(
            torch.nn.Conv1d(in_channels, out_channels, kernel_size=3, bias=False, **kwargs),
            torch.nn.BatchNorm1d(out_channels),
            torch.nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.stack(x)
internal class BasicConv1d: Module
{
    private readonly Module stack;

    public BasicConv1d(int in_channels, int out_channels, int kernel_size=3, int stride = 1, int padding = 0) : base(String.Empty)
    {
        var temp = BatchNorm1d(out_channels);               
        this.stack = Sequential(
            Conv1d(in_channels, out_channels, 3, stride: stride, padding: padding, bias: false),
            temp,
            ReLU(inPlace: true)
        );

        temp.weight = Parameter(torch.randn(temp.weight.shape));
        temp.bias = Parameter(torch.randn(temp.bias.shape));
        if (temp.running_mean is not null) temp.running_mean = torch.randn(temp.running_mean.shape);
        if (temp.running_var is not null) temp.running_var = torch.randn(temp.running_var.shape);

        this.RegisterComponents();
    }

    public override torch.Tensor forward(torch.Tensor t)
    {
        return this.stack.forward(t);
    }
}

Loading save dict in pytorch

    # Create model
    model = BasicConv1d(1, 32)

Loading save dict in torchsharp

public void ValidateIssue510()
{
    // Create model
    var model = new BasicConv1d(1, 32);
    model.forward(torch.randn(16, 1, 32));

    var w0 = model.get_parameter("stack.0.weight").clone();
    var w1 = model.get_parameter("stack.1.weight").clone();
    var b1 = model.get_parameter("stack.1.bias").clone();
    var rm = model.get_buffer("stack.1.running_mean").clone();
    var rv = model.get_buffer("stack.1.running_var").clone();
    var nm = model.get_buffer("stack.1.num_batches_tracked").clone();

    model.load("bug510.dat");

    var w0_ = model.get_parameter("stack.0.weight");
    var w1_ = model.get_parameter("stack.1.weight");
    var b1_ = model.get_parameter("stack.1.bias");
    var rm_ = model.get_buffer("stack.1.running_mean");
    var rv_ = model.get_buffer("stack.1.running_var");
    var nm_ = model.get_buffer("stack.1.num_batches_tracked");

    Assert.NotEqual(w0, w0_);
    Assert.NotEqual(w1, w1_);
    Assert.NotEqual(b1, b1_);
    Assert.NotEqual(rm, rm_);
    Assert.NotEqual(rv, rv_);
    Assert.Equal(1, nm.item<long>());
    Assert.Equal(0, nm_.item<long>());
}

@GeorgeS2019
Copy link

GeorgeS2019 commented Feb 19, 2022

@FusionCarcass it seems num_batches_tracked was previously missing and now added.

num_batches_tracked is used to update running_mean and running_var

//I wonder it is possible to register a new parameter not defined internally?
temp.register_parameter("num_batches_tracked", new Parameter(temp.state_dict()["num_batches_tracked"], requires_grad: false));

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants