In [84]:
#export
from export.nb_00 import *
from torch import nn, optim, tensor, Tensor, hub
import torch.nn.functional as F
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
import gzip, pickle, torch

from functools import partial
from typing import Any, Collection, Callable, NewType, List, Union, TypeVar, Optional, Generator, Iterable

In [8]:
#export

mnist_path=Path(r'd:\datasets\data\mnist.pkl.gz')
# mnist_path.ls()

In [27]:
#export
def get_mnist():
    with gzip.open(mnist_path, 'rb') as f:
        ((x_train, y_train),(x_valid, y_valid), _)=pickle.load(f, encoding='latin-1')
    x_train, y_train, x_valid, y_valid= map(tensor, (x_train, y_train, x_valid, y_valid))
    return (x_train, y_train), (x_valid, y_valid)

The most important argument of DataLoader constructor is dataset, which indicates a dataset object to load data from. PyTorch supports two different types of datasets:

    * `map style`
    * `iter style`

In [33]:
class Dataset():
    def __init__(self, x, y):
        self.x, self.y=x, y
        
    def __getitem__(self, i): return self.x[i], self.y[i]
    def __len__(self): return len(self.x)

In [34]:
(x_train, y_train), (x_valid, y_valid)= get_mnist()
train_ds, valid_ds= Dataset(x_train, y_train), Dataset(x_valid, y_valid)

In [36]:
bs=128
epochs=3

In [37]:
def collate(b):
    xs,ys = zip(*b)
    return torch.stack(xs),torch.stack(ys)

train_dl = DataLoader(train_ds, bs, sampler=RandomSampler(train_ds), collate_fn=collate)
valid_dl = DataLoader(valid_ds, bs, sampler=SequentialSampler(valid_ds), collate_fn=collate)

In [43]:
Rank0Tensor = NewType('OneEltTensor', Tensor)
LossFunction = Callable[[Tensor, Tensor], Rank0Tensor]
Model = nn.Module

In [45]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        # Handle batchnorm / dropout
        model.train()
#         print(model.training)
        for xb,yb in train_dl:
            loss = loss_func(model(xb), yb)
            loss.backward()
            opt.step()
            opt.zero_grad()

        model.eval()
#         print(model.training)
        with torch.no_grad():
            tot_loss,tot_acc = 0.,0.
            for xb,yb in valid_dl:
                pred = model(xb)
                tot_loss += loss_func(pred, yb)
                tot_acc  += accuracy (pred,yb)
        nv = len(valid_dl)
        print(epoch, tot_loss/nv, tot_acc/nv)
    return tot_loss/nv, tot_acc/nv

In [71]:
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func=func
        
    def forward(self, x): return self.func(x)

In [72]:
class SequentialModel(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = nn.ModuleList(layers)
        
    def __call__(self, x):
        for l in self.layers: x = l(x)
        return x

In [75]:
layers = [nn.Linear(784,256), nn.ReLU(), nn.Linear(256,10)]

In [81]:
model=SequentialModel(layers)
loss_fn=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
def accuracy(out, yb): return (torch.argmax(out, dim=1)==yb).float().mean()

In [82]:
fit(3, model, loss_fn, optimizer, train_dl, valid_dl)

0 tensor(0.2572) tensor(0.9266)
1 tensor(0.2263) tensor(0.9382)
2 tensor(0.1983) tensor(0.9456)


(tensor(0.1983), tensor(0.9456))