# data

> Fill in a module description here

In [1]:
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.2

env: XLA_PYTHON_CLIENT_MEM_FRACTION=0.2


In [2]:
#| default_exp data

In [3]:
#|export

from functools import partial
from operator import itemgetter
from typing import NamedTuple, Union

import jax
import jax.numpy as jnp
import lovely_jax as lj
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, default_collate
import torch

In [4]:
lj.monkey_patch()
jax.default_backend()

'gpu'

#### Data

In [5]:
#|export

def get_dls(train_ds, valid_ds, bs, **kwargs):
    return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, batch_size=bs*2, **kwargs))

def collate_dict(ds):
    get = itemgetter(*ds.features)
    def _f(b): return get(default_collate(b))
    return _f

class DataLoaders:
    def __init__(self, *dls): self.train,self.valid = dls[:2]

    @classmethod
    def from_dd(cls, dd, batch_size, as_tuple=True, **kwargs):
        f = collate_dict(dd['train'])
        return cls(*get_dls(*dd.values(), bs=batch_size, collate_fn=f, **kwargs))

In [6]:
XMEAN,XSTD, BATCH_SIZE, NUM_CLASSES = 0.28,0.35, 500, 10

tfm = transforms.Compose([transforms.PILToTensor(), transforms.Lambda(lambda x: x/255), transforms.Normalize(XMEAN, XSTD), transforms.Lambda(lambda x: torch.flatten(x))])
ds = partial(torchvision.datasets.FashionMNIST,root="data",download=True, transform = tfm)
train_ds, valid_ds = ds(train=True), ds(train=False)
tdl = DataLoader(train_ds, batch_size=BATCH_SIZE)
vdl = DataLoader(valid_ds, batch_size=BATCH_SIZE)
dls = DataLoaders(tdl, vdl)

In [7]:
#|export
Tensor = Union[jax.Array, jnp.ndarray, np.ndarray] # should include torch.Tensor?

class Batch(NamedTuple):
  input: Tensor   # [B, H, W, C]
  target: Tensor  # [B]

In [8]:
batch = Batch(*map(jnp.array, next(iter(dls.train))))
batch

Batch(input=Array[500, 784] n=392000 x∈[-0.800, 2.057] μ=0.011 σ=1.006 gpu:0, target=Array[500] i32 x∈[0, 9] μ=4.402 σ=2.838 gpu:0)

In [9]:
#| hide
import nbdev; nbdev.nbdev_export()