Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchNorm1d throws exception during eval with batch size of 1 #500

Closed
FusionCarcass opened this issue Jan 26, 2022 · 11 comments · Fixed by #501
Closed

BatchNorm1d throws exception during eval with batch size of 1 #500

FusionCarcass opened this issue Jan 26, 2022 · 11 comments · Fixed by #501
Assignees

Comments

@FusionCarcass
Copy link

I have a model that uses BatchNorm1d after a Linear layer that results in the exception below when using a batch size of 1 during eval. The same operation works fine in Python with a batch size of 1 during eval. I believe the correct behavior here should be to allow batch sizes of 1 during eval mode.

Unhandled exception: System.Runtime.InteropServices.ExternalException (0x80004005): Expected more than 1 value per channel when training, got input size [1, 1024]
Exception raised from batch_norm at ..\..\torch\csrc\api\include\torch/nn/functional/batchnorm.h:27 (most recent call first):
00007FFB94A5A29200007FFB94A5A230 c10.dll!c10::Error::Error [<unknown file> @ <unknown line number>]
00007FFB94A59D1E00007FFB94A59CD0 c10.dll!c10::detail::torchCheckFail [<unknown file> @ <unknown line number>]
00000296EC14AE2D00000296EC14AD40 torch_cpu.dll!torch::nn::RNNOptions::batch_first [<unknown file> @ <unknown line number>]
00000296EC19EE2A00000296EC19EC60 torch_cpu.dll!torch::nn::BatchNormImplBase<2,torch::nn::BatchNorm2dImpl>::forward [<unknown file> @ <unknown line number>]
00007FFB903790C400007FFB90379030 LibTorchSharp.DLL!THSNN_BatchNorm1d_forward [<unknown file> @ <unknown line number>]
00007FFAFBADD15D <unknown symbol address> !<unknown symbol> [<unknown file> @ <unknown line number>]

   at TorchSharp.torch.CheckForErrors()
   at TorchSharp.Modules.BatchNorm1d.forward(Tensor tensor)
   at TorchSharp.Modules.Sequential.forward(Tensor tensor)

Troubleshooting

I have verified that all BatchNorm1d modules are set to eval prior to prediction by running the following command:

InceptionResNetV4 module = new InceptionResNetV4();
module.load(fileName);
module.Eval();
module.apply((Module child) => { child.Eval(); });

Theories

I am training in python, then exporting the model to a .dat file and loading with the module.load method. I verified that the mean and var parameters of the BatchNorm1d are successfully loaded. I thought maybe the mean or var would be null, which might trigger a batching requirement during eval.

I think the issue is either related to loaded from a .dat file OR something not getting passed down to the native torch libraries correctly.

running_mean
[1024], type = Float32, device = cpu
running_var
[1024], type = Float32, device = cpu
num_batches_tracked
[], type = Int64, device = cpu
@NiklasGustafsson
Copy link
Contributor

NiklasGustafsson commented Jan 27, 2022

@FusionCarcass -- I'll take a look at it. If you have a smaller repro case, that would be great, otherwise I'll try to come up with one myself.

Could you check whether having a channel dimension of 1 makes any difference? BatchNorm1d is supposed to take either (N,L) or (N,C,L).

@NiklasGustafsson NiklasGustafsson self-assigned this Jan 27, 2022
@NiklasGustafsson
Copy link
Contributor

NiklasGustafsson commented Jan 27, 2022

@FusionCarcass -- here's what I'm seeing in Python:

bn1 = torch.nn.BatchNorm1d(28)
p = bn1(torch.randn(16,28))
p = bn1(torch.randn(1,28))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\Miniconda3\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Miniconda3\lib\site-packages\torch\nn\modules\batchnorm.py", line 168, in forward
    return F.batch_norm(
  File "D:\Miniconda3\lib\site-packages\torch\nn\functional.py", line 2280, in batch_norm
    _verify_batch_size(input.size())
  File "D:\Miniconda3\lib\site-packages\torch\nn\functional.py", line 2248, in _verify_batch_size
    raise ValueError("Expected more than 1 value per channel when training, got input size {}".format(size))
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 28])

Which seems to be the same error you're seeing with TorchSharp.

It only fails if the C dimension is missing. There's no error for this in either Python or .NET:

bn1 = torch.nn.BatchNorm1d(3)
p = bn1(torch.randn(1,3,28))

@FusionCarcass
Copy link
Author

BatchNorm1d should work with tensor of shape (1, L) when set to eval mode. Training would require (N, L) tensors where N > 1.

This code works for me in python with input tensors (1, C, L) which gets flattened to (1, L) before the batchnorm where the exception occurs.

class InceptionResNetV4(nn.Module):
    def __init__(self):
        super().__init__()
       
        self.stack = nn.Sequential(
            Stem(1),
            InceptionResNetA(384),
            InceptionResNetA(384),
            InceptionResNetA(384),
            InceptionResNetA(384),
            InceptionResNetA(384),
            nn.AdaptiveMaxPool1d(10),
            nn.Flatten(),
            nn.Dropout(p=0.5),
            nn.Linear(3840, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 2)
        )
        
    def forward(self, x):
        output = self.stack(x)
        print(output.shape)
        return output

I export my model and load into the following C# model:

public class InceptionResNetV4 : Module {
	private readonly Sequential stack;

	public InceptionResNetV4() : base(string.Empty) {
		this.stack = Sequential(
			new Stem(1),
			new InceptionResNetA(384),
			new InceptionResNetA(384),
			new InceptionResNetA(384),
			new InceptionResNetA(384),
			new InceptionResNetA(384),
			AdaptiveMaxPool1d(10),
			Flatten(),
			Dropout(probability: 0.5f),
			Linear(3840, 1024),
			BatchNorm1d(1024),
			ReLU(),
			Linear(1024, 2)
		);

		this.RegisterComponents();
	}

	public override torch.Tensor forward(torch.Tensor t) {
		return this.stack.forward(t);
	}
}

@FusionCarcass
Copy link
Author

I modified your example to make it work in python.

import torch
import torch.nn as nn
import torch.nn.functional as F


bn1 = torch.nn.BatchNorm1d(28)
bn1.eval()
p = bn1(torch.randn(16,28))
p = bn1(torch.randn(1,28))

@NiklasGustafsson
Copy link
Contributor

Thanks!

@NiklasGustafsson
Copy link
Contributor

Seems like the bug is in Sequential. This blows up in the second block, not the first:

            using (var pool = BatchNorm1d(28)) {
                pool.Eval();
                pool.forward(torch.ones(1, 28));
            }
            using (var pool = BatchNorm1d(28))
            using (var seq = Sequential(pool)) {
                seq.Eval();
                seq.forward(torch.ones(1, 28));
            }

@NiklasGustafsson
Copy link
Contributor

Eval() does not seem to be propagated properly to all submodules, whether in a Sequential or a custom module.

I have a fix, and I'm going to push it together with the fix for #499, which is a big one.

@NiklasGustafsson NiklasGustafsson linked a pull request Jan 31, 2022 that will close this issue
@PointerGuide
Copy link

I hope we will see last fixes in nuget soon :)

@NiklasGustafsson
Copy link
Contributor

It's coming... :-)

NiklasGustafsson added a commit that referenced this issue Feb 3, 2022
Allow register_parameter to take a null tensor.
NiklasGustafsson added a commit that referenced this issue Feb 3, 2022
NiklasGustafsson added a commit that referenced this issue Feb 3, 2022
@NiklasGustafsson
Copy link
Contributor

@241721, @FusionCarcass -- in case you didn't see, 0.96.0 was just released on NuGet with the fix for this bug in it.

@PointerGuide
Copy link

I saw that! Thank you so much :-) Great job

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants