### Batch-wise statistics

In [None]:
import numpy as np
import torch

# def std_(data, mean):
#     return torch.sqrt( torch.mean( (data - mean)**2, dim=0) )

class StatsRecorder:
    def __init__(self, data=None):
        """
        data: ndarray, shape (nobservations, ndimensions)
        """
        if data is not None:
            data = torch.atleast_2d(data)
            self.mean = torch.mean(data, dim=0)
            self.std = torch.std(data, dim=0, unbiased=False)
            self.nobservations = data.shape[0]
            self.ndimensions   = data.shape[1]
        else:
            self.nobservations = 0

    def update(self, data):
        """
        data: ndarray, shape (nobservations, ndimensions)
        """
        if self.nobservations == 0:
            self.__init__(data)
        else:
            data = torch.atleast_2d(data)
            if data.shape[1] != self.ndimensions:
                raise ValueError("Data dims don't match prev observations.")

            newmean = torch.mean(data, dim=0)
            newstd  = torch.std(data, dim=0, unbiased=False)

            m = self.nobservations * 1.0
            n = data.shape[0] 

            # tmp = self.mean
            tmp = self.mean.clone()

            self.mean = m/(m+n)*tmp + n/(m+n)*newmean
            self.std  = m/(m+n)*self.std**2 + n/(m+n)*newstd**2 + m*n/(m+n)**2 * (tmp - newmean)**2
            self.std  = np.sqrt(self.std)

            self.nobservations += n

In [None]:
# import numpy as np
# import statsrecorder as sr
# rs = torch.random.RandomState(323)

mystats = StatsRecorder()

# Hold all observations in "data" to check for correctness.
ndims = 42
data = torch.empty((0, ndims))

for i in range(100):
    nobserv = torch.randint(10, 103, (1,))
    newdata = torch.randn(nobserv, ndims)
    data = torch.vstack((data, newdata))

    # Update stats recorder object
    mystats.update(newdata)

    # Check stats recorder object is doing its business right.
    assert torch.allclose(mystats.mean, torch.mean(data, dim=0))
    assert torch.allclose(mystats.std, torch.std(data, dim=0, unbiased=False))

### Test: Scaler statistics in batch

In [None]:
import sys
sys.path.append('../')
import torch 
from torchip.descriptors import DescriptorScaler

In [None]:
scaler = DescriptorScaler(scale_type='scale', scale_min=-1)
x = []
for _ in range(10):
    batch = torch.normal(mean=1.0, std=2.0, size=(5, 4))
    scaler.fit(batch)
    x.append(batch)

print(scaler.sigma)    
print(torch.concat(x).std(dim=0, unbiased=False))