Description
I have a model that uses BatchNorm1d after a Linear layer that results in the exception below when using a batch size of 1 during eval. The same operation works fine in Python with a batch size of 1 during eval. I believe the correct behavior here should be to allow batch sizes of 1 during eval mode.
Unhandled exception: System.Runtime.InteropServices.ExternalException (0x80004005): Expected more than 1 value per channel when training, got input size [1, 1024]
Exception raised from batch_norm at ..\..\torch\csrc\api\include\torch/nn/functional/batchnorm.h:27 (most recent call first):
00007FFB94A5A29200007FFB94A5A230 c10.dll!c10::Error::Error [<unknown file> @ <unknown line number>]
00007FFB94A59D1E00007FFB94A59CD0 c10.dll!c10::detail::torchCheckFail [<unknown file> @ <unknown line number>]
00000296EC14AE2D00000296EC14AD40 torch_cpu.dll!torch::nn::RNNOptions::batch_first [<unknown file> @ <unknown line number>]
00000296EC19EE2A00000296EC19EC60 torch_cpu.dll!torch::nn::BatchNormImplBase<2,torch::nn::BatchNorm2dImpl>::forward [<unknown file> @ <unknown line number>]
00007FFB903790C400007FFB90379030 LibTorchSharp.DLL!THSNN_BatchNorm1d_forward [<unknown file> @ <unknown line number>]
00007FFAFBADD15D <unknown symbol address> !<unknown symbol> [<unknown file> @ <unknown line number>]
at TorchSharp.torch.CheckForErrors()
at TorchSharp.Modules.BatchNorm1d.forward(Tensor tensor)
at TorchSharp.Modules.Sequential.forward(Tensor tensor)
Troubleshooting
I have verified that all BatchNorm1d modules are set to eval prior to prediction by running the following command:
InceptionResNetV4 module = new InceptionResNetV4();
module.load(fileName);
module.Eval();
module.apply((Module child) => { child.Eval(); });
Theories
I am training in python, then exporting the model to a .dat file and loading with the module.load method. I verified that the mean and var parameters of the BatchNorm1d are successfully loaded. I thought maybe the mean or var would be null, which might trigger a batching requirement during eval.
I think the issue is either related to loaded from a .dat file OR something not getting passed down to the native torch libraries correctly.
running_mean
[1024], type = Float32, device = cpu
running_var
[1024], type = Float32, device = cpu
num_batches_tracked
[], type = Int64, device = cpu