In [65]:
#export
from export.nb_00 import *
from torch import nn, optim, tensor, Tensor, hub
import torch.nn.functional as F
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
import gzip, pickle, torch
from tqdm.notebook import tqdm, trange

In [66]:
#export

mnist_path=Path(r'd:\datasets\data\mnist.pkl.gz')
# mnist_path.ls()

In [67]:
#export
def get_mnist():
    with gzip.open(mnist_path, 'rb') as f:
        ((x_train, y_train),(x_valid, y_valid), _)=pickle.load(f, encoding='latin-1')
    x_train, y_train, x_valid, y_valid= map(tensor, (x_train, y_train, x_valid, y_valid))
    return (x_train, y_train), (x_valid, y_valid)

The most important argument of DataLoader constructor is dataset, which indicates a dataset object to load data from. PyTorch supports two different types of datasets:

    * `map style`
    * `iter style`

In [68]:
#export
class Dataset():
    def __init__(self, x, y):
        self.x, self.y=x, y
        
    def __getitem__(self, i): return self.x[i], self.y[i]
    def __len__(self): return len(self.x)

In [69]:
(x_train, y_train), (x_valid, y_valid)= get_mnist()
train_ds, valid_ds= Dataset(x_train, y_train), Dataset(x_valid, y_valid)

In [70]:
bs=128
epochs=3

In [71]:
#export
def collate(b):
    xs,ys = zip(*b)
    return torch.stack(xs),torch.stack(ys)

In [72]:
train_dl = DataLoader(train_ds, bs, sampler=RandomSampler(train_ds), collate_fn=collate)
valid_dl = DataLoader(valid_ds, bs, sampler=SequentialSampler(valid_ds), collate_fn=collate)

In [85]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in trange(epochs, desc='epochs', smoothing=1):
        # Handle batchnorm / dropout
        model.train()
#         print(model.training)
        for xb,yb in tqdm(train_dl,leave=False, desc='train', smoothing=0.5,):
            loss = loss_func(model(xb), yb)
            loss.backward()
            opt.step()
            opt.zero_grad()

        model.eval()
#         print(model.training)
        with torch.no_grad():
            tot_loss,tot_acc = 0.,0.
            for xb,yb in tqdm(valid_dl, leave=False, desc='valid', smoothing=0.5):
                pred = model(xb)
                tot_loss += loss_func(pred, yb)
                tot_acc  += accuracy (pred,yb)
        nv = len(valid_dl)
        tqdm.write(f'{epoch}, {tot_loss/nv}, {tot_acc/nv}')
    return tot_loss/nv, tot_acc/nv

In [86]:
#export
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func=func
        
    def forward(self, x): return self.func(x)

In [87]:
class SequentialModel(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = nn.ModuleList(layers)
        
    def __call__(self, x):
        for l in self.layers: x = l(x)
        return x

In [88]:
layers = [nn.Linear(784,256), nn.ReLU(), nn.Linear(256,10)]

In [89]:
model=SequentialModel(layers)
loss_fn=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
def accuracy(out, yb): return (torch.argmax(out, dim=1)==yb).float().mean()

In [90]:
fit(3, model, loss_fn, optimizer, train_dl, valid_dl)

HBox(children=(FloatProgress(value=0.0, description='epochs', max=3.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='train', max=391.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='valid', max=79.0, style=ProgressStyle(description_width='…

0, 0.3124648630619049, 0.9126780033111572


HBox(children=(FloatProgress(value=0.0, description='train', max=391.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='valid', max=79.0, style=ProgressStyle(description_width='…

1, 0.2591705918312073, 0.9274129867553711


HBox(children=(FloatProgress(value=0.0, description='train', max=391.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='valid', max=79.0, style=ProgressStyle(description_width='…

2, 0.22022294998168945, 0.9387856125831604



(tensor(0.2202), tensor(0.9388))