### Batch-wise statistics

In [1]:
import numpy as np
import torch

def std_(data, mean):
    return torch.sqrt( torch.mean( (data - mean)**2, dim=0) )

class StatsRecorder:
    def __init__(self, data=None):
        """
        data: ndarray, shape (nobservations, ndimensions)
        """
        if data is not None:
            data = torch.atleast_2d(data)
            self.mean = torch.mean(data, dim=0)
            self.std = std_(data, self.mean) #torch.std(data, dim=0)
            self.nobservations = data.shape[0]
            self.ndimensions   = data.shape[1]
        else:
            self.nobservations = 0

    def update(self, data):
        """
        data: ndarray, shape (nobservations, ndimensions)
        """
        if self.nobservations == 0:
            self.__init__(data)
        else:
            data = torch.atleast_2d(data)
            if data.shape[1] != self.ndimensions:
                raise ValueError("Data dims don't match prev observations.")

            newmean = torch.mean(data, dim=0)
            newstd  = std_(data, newmean) #torch.std(data, dim=0)

            m = self.nobservations * 1.0
            n = data.shape[0] 

            tmp = self.mean

            self.mean = m/(m+n)*tmp + n/(m+n)*newmean
            self.std  = m/(m+n)*self.std**2 + n/(m+n)*newstd**2 + m*n/(m+n)**2 * (tmp - newmean)**2
            self.std  = np.sqrt(self.std)

            self.nobservations += n

In [2]:
# import numpy as np
# import statsrecorder as sr

# rs = torch.random.RandomState(323)

mystats = StatsRecorder()

# Hold all observations in "data" to check for correctness.
ndims = 42
data = torch.empty((0, ndims))

for i in range(10):
    nobserv = torch.randint(10, 103, (1,))
    newdata = torch.randn(nobserv, ndims)
    data = torch.vstack((data, newdata))

    # Update stats recorder object
    mystats.update(newdata)

    # Check stats recorder object is doing its business right.
    assert torch.allclose(mystats.mean, torch.mean(data, dim=0))
    assert torch.allclose(mystats.std, std_(data, torch.mean(data, dim=0)))
    print( torch.max(torch.abs(mystats.mean-torch.mean(data, dim=0))))
    # print(i+1, (mystats.std - std_(data, torch.mean(data, dim=0)).numpy()))

tensor(0.)
tensor(1.8626e-08)
tensor(1.4901e-08)
tensor(2.9802e-08)
tensor(1.8626e-08)
tensor(1.8626e-08)
tensor(2.2352e-08)
tensor(1.4901e-08)
tensor(1.4901e-08)
tensor(1.4435e-08)
