In [1]:
# Sets how much GPU memory JAX preallocate
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.8

env: XLA_PYTHON_CLIENT_MEM_FRACTION=0.8


In [37]:
#| default_exp layers

In [3]:
#|export
import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
from pathlib import Path

from typing import NamedTuple, Dict, Callable
import dataclasses # adds some dunder methods like __init__ and __repr__
import collections
from operator import attrgetter,itemgetter
from functools import partial

import fastcore.all as fc

import jax
from jax import numpy as jnp
from jax import random as jrand
from jax import grad, value_and_grad, jit, vmap
import lovely_jax as lj
lj.monkey_patch()

In [4]:
import pynvml
def get_memory_free_MiB(gpu_index):
    pynvml.nvmlInit()
    handle = pynvml.nvmlDeviceGetHandleByIndex(int(gpu_index))
    mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
    return mem_info.free // 1024 ** 2

get_memory_free_MiB(0)

11176

In [5]:
path_data = Path('data')
path_gz = path_data/'mnist.pkl.gz'
with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
x_train, y_train, x_valid, y_valid = map(jnp.array, [x_train, y_train, x_valid, y_valid])

## Layers and Models

Our strategy is to based on `currying`. 

In [6]:
# def linear(params, z)
# arch= [linear(*), relu(*), linear(*)]
# params=[(w1,b1),(), (w2,b2)]
# model = arch(params)
# model = [linear(params,*), relu, linear(params, *)]
# preds = model(inputs)
# Learning: preds -> evaluate -> update
# for each epoch, for each batch
# z = inputs
# z = linear(W1,z)
# z = relu(z)
# z = linear(W2,z)
# outputs = z
# ...
# model = linear(W2,(relu(linear(W1,*))))
# update(msg, model)-> model
# 


In [7]:
class Adder(fc.DisplayedTransform):
    def __init__(self, num): fc.store_attr()
    def encodes(self, x): return x + self.num

class Multiplier(fc.DisplayedTransform):
    def __init__(self, num): fc.store_attr()
    def encodes(self, x): return x * self.num

In [8]:
fc.Pipeline([Adder(3), Multiplier(2)])(4), fc.Pipeline([Multiplier(2),Adder(3)])(4)

(14, 11)

In [39]:
w = jnp.ones((2,3))*4.
b = jnp.ones((3,))*1
x = jnp.array([[4.,4.], [3.,1.], [-1.,1]])

In [35]:
p = fc.Pipeline([Adder(3), Multiplier(2)])
# [m for m in dir(p) if m[0]!="_"]

['add',
 'adder',
 'decode',
 'decode',
 'decodes',
 'default',
 'encodes',
 'fs',
 'init_enc',
 'multiplier',
 'name',
 'num',
 'order',
 'setup',
 'setup',
 'setups',
 'show',
 'split_idx',
 'split_idx',
 'train_setup']

In [10]:
class Layer(fc.DisplayedTransform):
    def __init__(self, params): 
        super().__init__()
        fc.store_attr()
    def get_params(self): return self.params

In [20]:
class Linear(Layer):
    def __init__(self, params): 
        super().__init__(params)
        self.w, self.b = self.params
    @jax.jit
    def encodes(self, x): return x@self.w + self.b

In [21]:
w = jnp.ones((2,3))*4.
b = jnp.ones((3,))*1
x = jnp.array([[4.,4.], [3.,1.], [-1.,1]])
w,b,x

(Array[2, 3] n=6 x∈[4.000, 4.000] μ=4.000 σ=0. gpu:0 [[4.000, 4.000, 4.000], [4.000, 4.000, 4.000]],
 Array[3] x∈[1.000, 1.000] μ=1.000 σ=0. gpu:0 [1.000, 1.000, 1.000],
 Array[3, 2] n=6 x∈[-1.000, 4.000] μ=2.000 σ=1.826 gpu:0 [[4.000, 4.000], [3.000, 1.000], [-1.000, 1.000]])

In [22]:
pipe = fc.Pipeline([Linear([w,b])])
pipe[0].get_params()

[Array[2, 3] n=6 x∈[4.000, 4.000] μ=4.000 σ=0. gpu:0 [[4.000, 4.000, 4.000], [4.000, 4.000, 4.000]],
 Array[3] x∈[1.000, 1.000] μ=1.000 σ=0. gpu:0 [1.000, 1.000, 1.000]]

In [24]:
pipe

Pipeline: Linear -- {'params': [Array[2, 3] n=6 x∈[4.000, 4.000] μ=4.000 σ=0. gpu:0 [[4.000, 4.000, 4.000], [4.000, 4.000, 4.000]], Array[3] x∈[1.000, 1.000] μ=1.000 σ=0. gpu:0 [1.000, 1.000, 1.000]]}

#### Dead end

jax.jit does not understand pipeline

In [26]:
jax.jit(pipe)(x)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
The error occurred while tracing the function <unknown> for jit. This value became a tracer due to JAX operations on these lines:

  operation a[35m:f32[][39m = reduce_min[axes=(0, 1)] b
    from line /tmp/ipykernel_408654/2722338522.py:1 (<module>)

  operation a[35m:f32[][39m = reduce_max[axes=(0, 1)] b
    from line /tmp/ipykernel_408654/2722338522.py:1 (<module>)

  operation a[35m:f32[][39m = reduce_min[axes=(0, 1)] b
    from line /tmp/ipykernel_408654/2722338522.py:1 (<module>)

  operation a[35m:f32[][39m = reduce_max[axes=(0, 1)] b
    from line /tmp/ipykernel_408654/2722338522.py:1 (<module>)

  operation a[35m:f32[][39m = reduce_min[axes=(0, 1)] b
    from line /tmp/ipykernel_408654/2722338522.py:1 (<module>)

(Additional originating lines are not shown.)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

## With pure functions

In [None]:
#|export
@jax.jit
def linear(params:jnp.array, x:jnp.array) -> jnp.array:
    w,b = params
    return x@w + b

In [None]:
#|export
@jax.jit
def relu(x:jnp.array) -> jnp.array: return jnp.maximum(0,x)

In [None]:
a@Multiplier
def setups(self, num): self.num = num
def encodes(self, x): return x * self.num

In [None]:
pipe = fc.Pipeline([Adder(2), Multiplier(3)])
pipe(4)

In [None]:
def f(*a, **b): print(f"args: {a}; kwargs: {b}")
f(1,n="out")

In [None]:
# GLOBAL variables --------------aa ¯\_(ツ)_/¯
frame_stack = [] 
key = jax.random.PRNGKey(42)
# ------------------------------------------

### A random key generator

In [None]:
def _with_key(func):
    def method(self, *args, **kwargs):
        return func(self.key, *args, **kwargs)
    return method

class RNG:
    def __init__(self, key:jax.random.PRNGKey):    self.key = key
    def __repr__(self):         return f'{type(self).__name__}({self.key!r})'
    def split(self, num=2):     return [RNG(k) for k in jax.random.split(self.key, num)]
    def __next__(self): return self.next()
    def next(self):
        self.key, k = jax.random.split(self.key)
        return k
    uniform = _with_key(random.uniform)
    normal = _with_key(random.normal)

jax.tree_util.register_pytree_node(
    RNG,
    lambda rng: ([rng.key], None),
    lambda aux, values: RNG(values[0]),
)

def rng(seed):  return RNG(random.PRNGKey(seed))
rng(42).uniform((2,2))

In [None]:
kgen = RNG(jax.random.PRNGKey(10))
for i in range(10):
    print(kgen.next())

In [None]:
#|export

@dataclasses.dataclass
class Frame: # From haiku documentation
  """Tracks what's going on during a call of a transformed function."""
  params: Dict[str, jnp.ndarray]
  is_initialising: bool = False
  key: jax.random.PRNGKey = jax.random.PRNGKey(0)
  
  # Keeps track of how many modules of each clas
  # s have been created so far.
  # Used to assign new modules unique names.
  module_counts: Dict[str, int] = dataclasses.field(
      default_factory=lambda: collections.defaultdict(lambda: 0))

  # Keeps track of the entire path to the current module method call.
  # Module methods, when called, will add themselves to this stack.
  # Used to give each parameter a unique name corresponding to the
  # method scope it is in.
  call_stack: list = dataclasses.field(default_factory=list)

  def create_param_path(self, identifier) -> str:
    """Creates a unique path for this param."""
    return '/'.join(['~'] + self.call_stack + [identifier])

  def create_unique_module_name(self, module_name: str) -> str:
    """Assigns a unique name to the module by appending its number to its name."""
    number = self.module_counts[module_name]
    self.module_counts[module_name] += 1
    return f"{module_name}_{number}"



def current_frame():
  return frame_stack[-1] if len(frame_stack)>0 else []

current_frame()

In [None]:
#|export
class Module:
  def __init__(self):
    self._unique_name = current_frame().create_unique_module_name(
        self.__class__.__name__)

In [None]:
#|export
def module_method(f):
  """A decorator for Module methods."""
  
  def wrapped(self, *args, **kwargs):
    """A version of f that lets the frame know it's being called."""
    # Self is the instance to which this method is attached.
    module_name = self._unique_name
    call_stack = current_frame().call_stack
    call_stack.append(module_name)
    call_stack.append(f.__name__)
    outs = f(self, *args, **kwargs)
    assert call_stack.pop() == f.__name__
    assert call_stack.pop() == module_name
    return outs
  return wrapped

In [None]:
#|export
class Transformed(NamedTuple):
  init: Callable
  apply: Callable


def transform(f) -> Transformed:
  kgen = RNG(random.PRNGKey(42))
  def init_f(*args, **kwargs):
    frame_stack.append(Frame({}, is_initialising=True, key=kgen.next()))
    f(*args, **kwargs)
    frame = frame_stack.pop()
    return frame.params

  def apply_f(params, *args, **kwargs):
    frame_stack.append(Frame(params))
    outs = f(*args, **kwargs)
    frame_stack.pop()
    return outs

  return Transformed(init_f, apply_f)

In [None]:
#|export
def get_param(identifier, shape):
  if current_frame().is_initialising:
    key = current_frame().key
    current_frame().params[identifier] = jax.random.normal(key, shape, dtype=jnp.float32)
  return current_frame().params[identifier]

In [None]:
class Linear(Module):

  def __init__(self, width):
    super().__init__()
    self._width = width
    #height is infered from data
    
  @module_method  
  def __call__(self, x):
    w = get_param('w', shape=(x.shape[-1], self._width))
    b = get_param('b', shape=(self._width,))
    return x @ w + b

In [None]:
#|export
class ReLU(Module):   
  @module_method  
  def __call__(self, x):
    return jnp.maximum(0,x)

In [None]:
class MLP(Module):

  def __init__(self, widths):
    super().__init__()
    self._widths = widths

  @module_method
  def __call__(self, x):
    for w in self._widths:
      w = get_param('w', shape=(x.shape[-1], w))
      b = get_param('b', shape=(w))
      out = Linear(w)(x)
      x = jax.nn.relu(out) # the last layer does not apply ReLU
    print(f'out:{out}')
    return out 

In [None]:
init, forward = transform(lambda x: MLP([512,50,1])(x))

In [None]:
current_frame()

In [None]:
params = init(x_train)
params

In [None]:
out = Linear(10)(x_train)
out

In [None]:
class Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out)]
        
    def __call__(self, x):
        for l in self.layers: x = l(x)
        return x

In [None]:
model = Model(m, nh, 10)
pred = model(x_train)
pred.shape

### Cross entropy loss

First, we will need to compute the softmax of our activations. This is defined by:

$$\hbox{softmax(x)}_{i} = \frac{e^{x_{i}}}{e^{x_{0}} + e^{x_{1}} + \cdots + e^{x_{n-1}}}$$

or more concisely:

$$\hbox{softmax(x)}_{i} = \frac{e^{x_{i}}}{\sum\limits_{0 \leq j \lt n} e^{x_{j}}}$$ 

In practice, we will need the log of the softmax when we calculate the loss.

In [None]:
def log_softmax(x): return (x.exp()/(x.exp().sum(-1,keepdim=True))).log()

In [None]:
log_softmax(pred)

Note that the formula 

$$\log \left ( \frac{a}{b} \right ) = \log(a) - \log(b)$$ 

gives a simplification when we compute the log softmax:

In [None]:
def log_softmax(x): return x - x.exp().sum(-1,keepdim=True).log()

Then, there is a way to compute the log of the sum of exponentials in a more stable way, called the [LogSumExp trick](https://en.wikipedia.org/wiki/LogSumExp). The idea is to use the following formula:

$$\log \left ( \sum_{j=1}^{n} e^{x_{j}} \right ) = \log \left ( e^{a} \sum_{j=1}^{n} e^{x_{j}-a} \right ) = a + \log \left ( \sum_{j=1}^{n} e^{x_{j}-a} \right )$$

where a is the maximum of the $x_{j}$.

In [None]:
def logsumexp(x):
    m = x.max(-1)[0]
    return m + (x-m[:,None]).exp().sum(-1).log()

This way, we will avoid an overflow when taking the exponential of a big activation. In PyTorch, this is already implemented for us. 

In [None]:
def log_softmax(x): return x - x.logsumexp(-1,keepdim=True)

In [None]:
test_close(logsumexp(pred), pred.logsumexp(-1))
sm_pred = log_softmax(pred)
sm_pred

The cross entropy loss for some target $x$ and some prediction $p(x)$ is given by:

$$ -\sum x\, \log p(x) $$

But since our $x$s are 1-hot encoded (actually, they're just the integer indices), this can be rewritten as $-\log(p_{i})$ where i is the index of the desired target.

This can be done using numpy-style [integer array indexing](https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#integer-array-indexing). Note that PyTorch supports all the tricks in the advanced indexing methods discussed in that link.

In [None]:
y_train[:3]

In [None]:
sm_pred[0,5],sm_pred[1,0],sm_pred[2,4]

In [None]:
sm_pred[[0,1,2], y_train[:3]]

In [None]:
def nll(input, target): return -input[range(target.shape[0]), target].mean()

In [None]:
loss = nll(sm_pred, y_train)
loss

Then use PyTorch's implementation.

In [None]:
test_close(F.nll_loss(F.log_softmax(pred, -1), y_train), loss, 1e-3)

In PyTorch, `F.log_softmax` and `F.nll_loss` are combined in one optimized function, `F.cross_entropy`.

In [None]:
test_close(F.cross_entropy(pred, y_train), loss, 1e-3)

## Basic training loop

Basically the training loop repeats over the following steps:
- get the output of the model on a batch of inputs
- compare the output to the labels we have and compute a loss
- calculate the gradients of the loss with respect to every parameter of the model
- update said parameters with those gradients to make them a little bit better

In [None]:
loss_func = F.cross_entropy

In [None]:
bs=50                  # batch size

xb = x_train[0:bs]     # a mini-batch from x
preds = model(xb)      # predictions
preds[0], preds.shape

In [None]:
yb = y_train[0:bs]
yb

In [None]:
loss_func(preds, yb)

In [None]:
preds.argmax(dim=1)

In [None]:
#|export
def accuracy(out, yb): return (out.argmax(dim=1)==yb).float().mean()

In [None]:
accuracy(preds, yb)

In [None]:
lr = 0.5   # learning rate
epochs = 3 # how many epochs to train for

In [None]:
#|export
def report(loss, preds, yb): print(f'{loss:.2f}, {accuracy(preds, yb):.2f}')

In [None]:
xb,yb = x_train[:bs],y_train[:bs]
preds = model(xb)
report(loss_func(preds, yb), preds, yb)

In [None]:
for epoch in range(epochs):
    for i in range(0, n, bs):
        s = slice(i, min(n,i+bs))
        xb,yb = x_train[s],y_train[s]
        preds = model(xb)
        loss = loss_func(preds, yb)
        loss.backward()
        with torch.no_grad():
            for l in model.layers:
                if hasattr(l, 'weight'):
                    l.weight -= l.weight.grad * lr
                    l.bias   -= l.bias.grad   * lr
                    l.weight.grad.zero_()
                    l.bias  .grad.zero_()
    report(loss, preds, yb)

## Using parameters and optim

### Parameters

In [None]:
m1 = nn.Module()
m1.foo = nn.Linear(3,4)
m1

In [None]:
list(m1.named_children())

In [None]:
m1.named_children()

In [None]:
list(m1.parameters())

In [None]:
class MLP(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.l1 = nn.Linear(n_in,nh)
        self.l2 = nn.Linear(nh,n_out)
        self.relu = nn.ReLU()
        
    def forward(self, x): return self.l2(self.relu(self.l1(x)))

In [None]:
model = MLP(m, nh, 10)
model.l1

In [None]:
model

In [None]:
for name,l in model.named_children(): print(f"{name}: {l}")

In [None]:
for p in model.parameters(): print(p.shape)

In [None]:
def fit():
    for epoch in range(epochs):
        for i in range(0, n, bs):
            s = slice(i, min(n,i+bs))
            xb,yb = x_train[s],y_train[s]
            preds = model(xb)
            loss = loss_func(preds, yb)
            loss.backward()
            with torch.no_grad():
                for p in model.parameters(): p -= p.grad * lr
                model.zero_grad()
        report(loss, preds, yb)

In [None]:
fit()

Behind the scenes, PyTorch overrides the `__setattr__` function in `nn.Module` so that the submodules you define are properly registered as parameters of the model.

In [None]:
class MyModule:
    def __init__(self, n_in, nh, n_out):
        self._modules = {}
        self.l1 = nn.Linear(n_in,nh)
        self.l2 = nn.Linear(nh,n_out)

    def __setattr__(self,k,v):
        if not k.startswith("_"): self._modules[k] = v
        super().__setattr__(k,v)

    def __repr__(self): return f'{self._modules}'
    
    def parameters(self):
        for l in self._modules.values(): yield from l.parameters()

In [None]:
mdl = MyModule(m,nh,10)
mdl

In [None]:
for p in mdl.parameters(): print(p.shape)

### Registering modules

In [None]:
from functools import reduce

We can use the original `layers` approach, but we have to register the modules.

In [None]:
layers = [nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10)]

In [None]:
class Model(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = layers
        for i,l in enumerate(self.layers): self.add_module(f'layer_{i}', l)

    def forward(self, x): return reduce(lambda val,layer: layer(val), self.layers, x)

In [None]:
model = Model(layers)
model

In [None]:
model(xb).shape

### nn.ModuleList

`nn.ModuleList` does this for us.

In [None]:
class SequentialModel(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = nn.ModuleList(layers)
        
    def forward(self, x):
        for l in self.layers: x = l(x)
        return x

In [None]:
model = SequentialModel(layers)
model

In [None]:
fit()

### nn.Sequential

`nn.Sequential` is a convenient class which does the same as the above:

In [None]:
model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))

In [None]:
fit()
loss_func(model(xb), yb), accuracy(model(xb), yb)

In [None]:
model

### optim

In [None]:
class Optimizer():
    def __init__(self, params, lr=0.5): self.params,self.lr=list(params),lr

    def step(self):
        with torch.no_grad():
            for p in self.params: p -= p.grad * self.lr

    def zero_grad(self):
        for p in self.params: p.grad.data.zero_()

In [None]:
model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))

In [None]:
opt = Optimizer(model.parameters())

In [None]:
for epoch in range(epochs):
    for i in range(0, n, bs):
        s = slice(i, min(n,i+bs))
        xb,yb = x_train[s],y_train[s]
        preds = model(xb)
        loss = loss_func(preds, yb)
        loss.backward()
        opt.step()
        opt.zero_grad()
    report(loss, preds, yb)

PyTorch already provides this exact functionality in `optim.SGD` (it also handles stuff like momentum, which we'll look at later)

In [None]:
from torch import optim

In [None]:
def get_model():
    model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))
    return model, optim.SGD(model.parameters(), lr=lr)

In [None]:
model,opt = get_model()
loss_func(model(xb), yb)

In [None]:
for epoch in range(epochs):
    for i in range(0, n, bs):
        s = slice(i, min(n,i+bs))
        xb,yb = x_train[s],y_train[s]
        preds = model(xb)
        loss = loss_func(preds, yb)
        loss.backward()
        opt.step()
        opt.zero_grad()
    report(loss, preds, yb)

## Dataset and DataLoader

### Dataset

It's clunky to iterate through minibatches of x and y values separately:

```python
    xb = x_train[s]
    yb = y_train[s]
```

Instead, let's do these two steps together, by introducing a `Dataset` class:

```python
    xb,yb = train_ds[s]
```

In [None]:
#|export
class Dataset():
    def __init__(self, x, y): self.x,self.y = x,y
    def __len__(self): return len(self.x)
    def __getitem__(self, i): return self.x[i],self.y[i]

In [None]:
train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)
assert len(train_ds)==len(x_train)
assert len(valid_ds)==len(x_valid)

In [None]:
xb,yb = train_ds[0:5]
assert xb.shape==(5,28*28)
assert yb.shape==(5,)
xb,yb

In [None]:
model,opt = get_model()

In [None]:
for epoch in range(epochs):
    for i in range(0, n, bs):
        xb,yb = train_ds[i:min(n,i+bs)]
        preds = model(xb)
        loss = loss_func(preds, yb)
        loss.backward()
        opt.step()
        opt.zero_grad()
    report(loss, preds, yb)

### DataLoader

Previously, our loop iterated over batches (xb, yb) like this:

```python
for i in range(0, n, bs):
    xb,yb = train_ds[i:min(n,i+bs)]
    ...
```

Let's make our loop much cleaner, using a data loader:

```python
for xb,yb in train_dl:
    ...
```

In [None]:
class DataLoader():
    def __init__(self, ds, bs): self.ds,self.bs = ds,bs
    def __iter__(self):
        for i in range(0, len(self.ds), self.bs): yield self.ds[i:i+self.bs]

In [None]:
train_dl = DataLoader(train_ds, bs)
valid_dl = DataLoader(valid_ds, bs)

In [None]:
xb,yb = next(iter(valid_dl))
xb.shape

In [None]:
yb

In [None]:
plt.imshow(xb[0].view(28,28))
yb[0]

In [None]:
model,opt = get_model()

In [None]:
def fit():
    for epoch in range(epochs):
        for xb,yb in train_dl:
            pred = model(xb)
            loss = loss_func(pred, yb)
            loss.backward()
            opt.step()
            opt.zero_grad()
        report(loss, preds, yb)

In [None]:
fit()
loss_func(model(xb), yb), accuracy(model(xb), yb)

### Random sampling

We want our training set to be in a random order, and that order should differ each iteration. But the validation set shouldn't be randomized.

In [None]:
import random

In [None]:
class Sampler():
    def __init__(self, ds, shuffle=False): self.n,self.shuffle = len(ds),shuffle
    def __iter__(self):
        res = list(range(self.n))
        if self.shuffle: random.shuffle(res)
        return iter(res)

In [None]:
from itertools import islice

In [None]:
ss = Sampler(train_ds)

In [None]:
it = iter(ss)
for o in range(5): print(next(it))

In [None]:
list(islice(ss, 5))

In [None]:
ss = Sampler(train_ds, shuffle=True)
list(islice(ss, 5))

In [None]:
import fastcore.all as fc

In [None]:
class BatchSampler():
    def __init__(self, sampler, bs, drop_last=False): fc.store_attr()
    def __iter__(self): yield from fc.chunked(iter(self.sampler), self.bs, drop_last=self.drop_last)

In [None]:
batchs = BatchSampler(ss, 4)
list(islice(batchs, 5))

In [None]:
def collate(b):
    xs,ys = zip(*b)
    return torch.stack(xs),torch.stack(ys)

In [None]:
class DataLoader():
    def __init__(self, ds, batchs, collate_fn=collate): fc.store_attr()
    def __iter__(self): yield from (self.collate_fn(self.ds[i] for i in b) for b in self.batchs)

In [None]:
train_samp = BatchSampler(Sampler(train_ds, shuffle=True ), bs)
valid_samp = BatchSampler(Sampler(valid_ds, shuffle=False), bs)

In [None]:
train_dl = DataLoader(train_ds, batchs=train_samp)
valid_dl = DataLoader(valid_ds, batchs=valid_samp)

In [None]:
xb,yb = next(iter(valid_dl))
plt.imshow(xb[0].view(28,28))
yb[0]

In [None]:
xb.shape,yb.shape

In [None]:
model,opt = get_model()

In [None]:
fit()

### Multiprocessing DataLoader

In [None]:
import torch.multiprocessing as mp
from fastcore.basics import store_attr

In [None]:
train_ds[[3,6,8,1]]

In [None]:
train_ds.__getitem__([3,6,8,1])

In [None]:
for o in map(train_ds.__getitem__, ([3,6],[8,1])): print(o)

In [None]:
class DataLoader():
    def __init__(self, ds, batchs, n_workers=1, collate_fn=collate): fc.store_attr()
    def __iter__(self):
        with mp.Pool(self.n_workers) as ex: yield from ex.map(self.ds.__getitem__, iter(self.batchs))

In [None]:
train_dl = DataLoader(train_ds, batchs=train_samp, n_workers=2)
it = iter(train_dl)

In [None]:
xb,yb = next(it)
xb.shape,yb.shape

### PyTorch DataLoader

In [None]:
#|export
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler, BatchSampler

In [None]:
train_samp = BatchSampler(RandomSampler(train_ds),     bs, drop_last=False)
valid_samp = BatchSampler(SequentialSampler(valid_ds), bs, drop_last=False)

In [None]:
train_dl = DataLoader(train_ds, batch_sampler=train_samp, collate_fn=collate)
valid_dl = DataLoader(valid_ds, batch_sampler=valid_samp, collate_fn=collate)

In [None]:
model,opt = get_model()
fit()
loss_func(model(xb), yb), accuracy(model(xb), yb)

PyTorch can auto-generate the BatchSampler for us:

In [None]:
train_dl = DataLoader(train_ds, bs, sampler=RandomSampler(train_ds), collate_fn=collate)
valid_dl = DataLoader(valid_ds, bs, sampler=SequentialSampler(valid_ds), collate_fn=collate)

PyTorch can also generate the Sequential/RandomSamplers too:

In [None]:
train_dl = DataLoader(train_ds, bs, shuffle=True, drop_last=True, num_workers=2)
valid_dl = DataLoader(valid_ds, bs, shuffle=False, num_workers=2)

In [None]:
model,opt = get_model()
fit()

loss_func(model(xb), yb), accuracy(model(xb), yb)

Our dataset actually already knows how to sample a batch of indices all at once:

In [None]:
train_ds[[4,6,7]]

...that means that we can actually skip the batch_sampler and collate_fn entirely:

In [None]:
train_dl = DataLoader(train_ds, sampler=train_samp)
valid_dl = DataLoader(valid_ds, sampler=valid_samp)

In [None]:
xb,yb = next(iter(train_dl))
xb.shape,yb.shape

## Validation

You **always** should also have a [validation set](http://www.fast.ai/2017/11/13/validation-sets/), in order to identify if you are overfitting.

We will calculate and print the validation loss at the end of each epoch.

(Note that we always call `model.train()` before training, and `model.eval()` before inference, because these are used by layers such as `nn.BatchNorm2d` and `nn.Dropout` to ensure appropriate behaviour for these different phases.)

In [None]:
#|export
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        model.train()
        for xb,yb in train_dl:
            loss = loss_func(model(xb), yb)
            loss.backward()
            opt.step()
            opt.zero_grad()

        model.eval()
        with torch.no_grad():
            tot_loss,tot_acc,count = 0.,0.,0
            for xb,yb in valid_dl:
                pred = model(xb)
                n = len(xb)
                count += n
                tot_loss += loss_func(pred,yb).item()*n
                tot_acc  += accuracy (pred,yb).item()*n
        print(epoch, tot_loss/count, tot_acc/count)
    return tot_loss/count, tot_acc/count

In [None]:
#|export
def get_dls(train_ds, valid_ds, bs, **kwargs):
    return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, batch_size=bs*2, **kwargs))

Now, our whole process of obtaining the data loaders and fitting the model can be run in 3 lines of code:

In [None]:
train_dl,valid_dl = get_dls(train_ds, valid_ds, bs)
model,opt = get_model()

In [None]:
%time loss,acc = fit(5, model, loss_func, opt, train_dl, valid_dl)

## Export -

In [None]:
import nbdev; nbdev.nbdev_export()