-
Notifications
You must be signed in to change notification settings - Fork 78
Description
Describe the bug
Our data sets currently accept a stage argument which is training by default. Having validation and a stateful standardize transform will results in an error due to:
if self.adapter is not None: batch = self.adapter(batch, stage=self.stage)
since the running means and standard deviations have never been computed.
Expected behavior
I have come to the realization that BatchNorm layers should be part of the approximators and applied to all inference_conditions, summary_variables, and inference_variables. This will have the advantage that adapters will remain stateless and users will not have to deal with standardizing things explicitly. Still, we should keep the standardize transform with static means and stds for special cases.
Let me know what you think and I will provide an implementation.