Ideas:

* Con random ghost batches
* cuanto tarda en alcanzar la running_mean for pytorch BatchNorm, pytorch GhostBatchnorm, custom GBN

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

In [656]:
torch.manual_seed(42)
dist = torch.distributions.Uniform(-3, 5)
data = dist.sample([10000, 2])
data = data[:]
data_mean, data_std = data.mean(dim=0), data.std(dim=0)
data_mean, data_std

(tensor([1.0004, 0.9691]), tensor([2.3150, 2.3147]))

In [657]:
class GBN(nn.Module):
    def __init__(self, n_in, vbs, momentum=0.1, eps=1e-5):
        super().__init__()
        self.vbs = vbs
        self.eps = eps
        self.mm = momentum
        
        gamma = nn.Parameter(torch.ones(n_in))
        self.register_parameter("gamma", gamma)
        
        beta = nn.Parameter(torch.zeros(n_in))
        self.register_parameter("beta", beta)
        
        self.register_buffer("running_mean", torch.zeros(n_in))
        self.register_buffer("running_std", torch.ones(n_in))

    
    def forward(self, X):
        num_ghost_batches = np.ceil(X.size(0)/self.vbs).astype(int)
        ghost_batches = X.view(num_ghost_batches, -1, X.size(-1))
        
        ghost_mean = ghost_batches.mean(dim=1).unsqueeze(1)
        ghost_std = ghost_batches.std(dim=1).unsqueeze(1)
        
        normalized_ghost_batches = (ghost_batches - ghost_mean) / ghost_std
        normalized_batch = normalized_ghost_batches.view(X.size())
        
        self.running_mean = self._calculate_running_metric(self.running_mean, ghost_mean, num_ghost_batches)
        self.running_std = self._calculate_running_metric(self.running_std, ghost_std, num_ghost_batches)
        
        return self.gamma * normalized_batch + self.beta
    
    def _calculate_running_metric(self, running_metric, ghost_metric, num_ghost_batches):
        weighted_prev = ((1-self.mm)**num_ghost_batches) * running_metric
        
        exp_idxs = torch.arange(0, num_ghost_batches).flip(dims=(0,))
        weighted_new = (
            (self.mm * (1-self.mm)**exp_idxs)[..., None] * ghost_metric.squeeze(1)
        ).sum(dim=0)
        
        return weighted_prev + weighted_new
    

In [646]:
n_in = 2
vbs = 10
eps = 1e-5
mm = momentum = 0.1

g = GBN(n_in, vbs, eps=eps, momentum=momentum)
g(data).mean(dim=0)

tensor([-3.9861e-10,  1.9193e-09], grad_fn=<MeanBackward1>)

In [659]:
%%timeit
for i in range(100): g(data)

205 ms ± 16.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [673]:
g.running_std.pow(2)

tensor([5.5337, 5.0845])

In [661]:
b = nn.BatchNorm1d(2, momentum=0.1, eps=1e-5)

In [662]:
num_ghost_batches = np.ceil(data.size(0)/vbs).astype(int)
ghost_batches = torch.stack([
        data[i*vbs : i*vbs + vbs]
        for i in range(num_ghost_batches)], dim=0)

In [663]:
%%timeit
for i in range(100):
    for i in range(len(ghost_batches)): b(ghost_batches[i])

6.73 s ± 618 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [670]:
b.running_var

tensor([5.6776, 5.1574])

In [665]:
b2 = nn.BatchNorm1d(2, momentum=0.1, eps=1e-5)

In [666]:
%%timeit
for i in range(100): b2(data)

22 ms ± 967 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [667]:
b2.running_mean

tensor([1.0004, 0.9691])