# Dataset

In [1]:
import numpy as np
import math

from torch.utils.data import random_split

## Calculating Mean & Std

Calculates mean and std of dataset.

In [2]:
def get_norm(dataset):
    mean = dataset.data.mean(axis=(0, 1, 2)) / 255.
    std = dataset.data.std(axis=(0, 1, 2)) / 255.

    return mean, std

## Split Dataset

Splits dataset into multiple subsets.

### TODO

* bias

In [3]:
def random_split_by_dist(dataset, size: int, params: dict = {}):
    """Returns subsets of `dataset`

    Parameters
    ----------
    dataset: datasets
        By torchvision.datasets.
    size: int
        Number (Length) of subsets.
    params: dict
        Contains `distFunc` which returns np.array.
        Sum of returned array SHOULD be 1.
    """

    assert size != 0, "`size` > 0"

    if 'distFunc' not in params:
        params['distFunc'] = uniform

    distFunc = params['distFunc']

    # calculates distribution
    dist = distFunc(size, params)  # dist: np.array
    assert math.isclose(sum(dist), 1.), "sum of `dist` shoule be 1."

    N = len(dataset)
    result = np.full(size, N) * dist
    result = result.astype('int')  # to integers
    # adjustment for that summation of `result` SHOULD be `N`
    result[-1] = N - sum(result[:-1])

    return random_split(dataset, result)

In [4]:
def uniform(size: int, params: dict = {}):
    return np.ones(size) / size

In [5]:
def normal(size: int, params: dict = {}):
    """Returns normal (Gaussian) distribution

    Uses `abs` to restrict to non-zeros.

    In fact, it is not a normal distribution because there are only
    positive elements in `result`.

    See https://numpy.org/doc/stable/reference/random/generated/numpy.random.normal.html .

    Parameters
    ----------
    size: int
        Number (Length) of chunks.
        Same as length of returned np.array.
    params: dict
        Contains 'loc', 'scale', 'lower' and 'upper'.
        The latter two are lower-bound and upper-bound respectively.
    """

    if 'loc' not in params:
        params['loc'] = 0.
    if 'scale' not in params:
        params['scale'] = 1.
    if 'lower' not in params:
        params['lower'] = 0.
    if 'upper' not in params:
        params['upper'] = None

    loc, scale, lower, upper = params['loc'], params['scale'], params['lower'], params['upper']

    result = np.random.normal(loc, scale, size)
    result = abs(result)  # `result` SHOULD be only positive.
    result = result.clip(lower, upper)
    return result / sum(result)

In [6]:
def pareto(size: int, params: dict = {}):
    """Returns Pareto distribution

    See https://numpy.org/doc/stable/reference/random/generated/numpy.random.pareto.html .

    Parameters
    ----------
    size: int
        Number (Length) of chunks.
        Same as length of returned np.array.
    params: dict
        contains 'alpha', 'lower' and 'upper'.
        The latter two are lower-bound and upper-bound respectively.
    """

    if 'alpha' not in params:
        params['alpha'] = 1.16  # by 80-20 rule, log(5)/log(4)
    if 'lower' not in params:
        params['lower'] = 0.
    if 'upper' not in params:
        params['upper'] = None

    alpha, lower, upper = params['alpha'], params['lower'], params['upper']

    result = np.random.pareto(alpha, size)
    result = result.clip(lower, upper)
    return result / sum(result)

# main

In [7]:
if __name__ == "__main__":
    from pprint import pprint

    import torchvision.datasets as dset
    import torchvision.transforms as transforms

    """Test `get_norm`"""
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    trainDataset = dset.CIFAR10(root='cifar', train=True, download=True, transform=transform)
    pprint(get_norm(trainDataset))

    """Test `adv_random_split`"""
    pprint(random_split_by_dist(trainDataset, 10, params={}))

Files already downloaded and verified
(array([0.49139968, 0.48215841, 0.44653091]),
 array([0.24703223, 0.24348513, 0.26158784]))
[<torch.utils.data.dataset.Subset object at 0x7febad4b5f50>,
 <torch.utils.data.dataset.Subset object at 0x7febae6c9d50>,
 <torch.utils.data.dataset.Subset object at 0x7febae5f9190>,
 <torch.utils.data.dataset.Subset object at 0x7febae5f9c90>,
 <torch.utils.data.dataset.Subset object at 0x7febae5f9dd0>,
 <torch.utils.data.dataset.Subset object at 0x7fec6e53bed0>,
 <torch.utils.data.dataset.Subset object at 0x7fec6e53bf10>,
 <torch.utils.data.dataset.Subset object at 0x7fec6e53bcd0>,
 <torch.utils.data.dataset.Subset object at 0x7fec6e53b990>,
 <torch.utils.data.dataset.Subset object at 0x7fec6e53b9d0>]
