In [None]:
class MADBatchNorm1d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False, track_running_stats=True):
        super(MADBatchNorm1d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.track_running_stats = track_running_stats

        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_mad', torch.ones(num_features))
            self.running_mean: torch.Tensor
            self.running_mad: torch.Tensor
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))

    def forward(self, x):
        if x.dim() != 2:
            raise ValueError("Input must be 2D (batch, features)")

        if self.training or not self.track_running_stats:
            mean = x.mean(dim=0)
            mad = (x - mean).abs().mean(dim=0)

            if self.track_running_stats:
                with torch.no_grad():
                    self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean)
                    self.running_mad.mul_(1 - self.momentum).add_(self.momentum * mad)
                    self.num_batches_tracked += 1
            normed = (x - mean) / (mad + self.eps)

        else:
            normed = (x - self.running_mean) / (self.running_mad + self.eps)

        return normed