In [1]:
from rsna_retro.imports import *
from rsna_retro.metadata import *
from rsna_retro.preprocess import *

Loading imports


In [2]:
# https://forums.fast.ai/t/calculating-our-own-image-stats-imagenet-stats-cifar-stats-etc/40355/4
# https://gist.github.com/thomasbrandon/ad5b1218fc573c10ea4e1f0c63658469#file-running_stats-py-L31

import torch
from torch import Tensor
from typing import Iterable
from fastprogress import progress_bar

class RunningStatistics:
    def __init__(self, n_dims:int=2, record_range=False):
        self._n_dims,self._range = n_dims,record_range
        self.n,self.sum,self.min,self.max = 0,None,None,None
    
    def update(self, data:Tensor):
        data = data.view(*list(data.shape[:-self._n_dims]) + [-1])
        with torch.no_grad():
            new_n,new_var,new_sum = data.shape[-1],data.var(-1),data.sum(-1)
            if self.n == 0:
                self.n = new_n
                self._shape = data.shape[:-1]
                self.sum = new_sum
                self._nvar = new_var.mul_(new_n)
#                 self._nvar = new_var.mul_(new_n-1)
                if self._range:
                    self.min = data.min(-1)[0]
                    self.max = data.max(-1)[0]
            else:
                assert data.shape[:-1] == self._shape, f"Mismatched shapes, expected {self._shape} but got {data.shape[:-1]}."
                ratio = self.n / new_n
                t = (self.sum / ratio).sub_(new_sum).pow_(2)
                self._nvar.add_(new_n, new_var).add_(ratio / (self.n + new_n), t)
#                 self._nvar.add_(new_n-1, new_var).add_(ratio / (self.n + new_n), t)
                self.sum.add_(new_sum)
                self.n += new_n
                if self._range:
                    self.min = torch.min(self.min, data.min(-1)[0])
                    self.max = torch.max(self.max, data.max(-1)[0])

    @property
    def mean(self): return self.sum / self.n if self.n > 0 else None
    @property
    def var(self): return self._nvar / self.n if self.n > 0 else None
    @property
    def std(self): return self.var.sqrt() if self.n > 0 else None

    def __repr__(self):
        def _fmt_t(t:Tensor):
            if t.numel() > 5: return f"tensor of ({','.join(map(str,t.shape))})"
            def __fmt_t(t:Tensor):
                return '[' + ','.join([f"{v:.3g}" if v.ndim==0 else __fmt_t(v) for v in t]) + ']'
            return __fmt_t(t)
        rng_str = f", min={_fmt_t(self.min)}, max={_fmt_t(self.max)}" if self._range else ""
        return f"RunningStatistics(n={self.n}, mean={_fmt_t(self.mean)}, std={_fmt_t(self.std)}{rng_str})"

In [3]:
train_fns = path_jpg256.ls()
random.shuffle(train_fns)

In [4]:
pipe = Pipeline(funcs=[PILCTScan.create, ToTensor, IntToFloatTensor])

In [5]:
pipe(train_fns[0]).mean()

tensor(0.1248)

In [6]:
rs = RunningStatistics(n_dims=2)

In [None]:
for fn in progress_bar(train_fns[:5000]):
    try:
        rs.update(pipe(fn))
    except Exception as e:
        print(e)

In [8]:
rs.mean, rs.std

(tensor([0.1594, 0.0766, 0.0605]), tensor([0.3011, 0.2521, 0.2160]))