In [933]:
! pip install -Uqq fastbook

In [934]:
import torch
from fastbook import *

In [935]:
path = untar_data(URLs.MNIST_SAMPLE)

In [936]:
Path.BASE_PATH = path

In [937]:
path.ls()

(#3) [Path('valid'),Path('labels.csv'),Path('train')]

In [938]:
(path/'train').ls()

(#2) [Path('train/7'),Path('train/3')]

In [939]:
threes = [tensor(Image.open(o)) for o in (path/'train'/'3').ls().sorted()]
sevens = [tensor(Image.open(o)) for o in (path/'train'/'7').ls().sorted()]

In [940]:
train_x = torch.cat([torch.stack(threes).float()/255, torch.stack(sevens).float()/255]).view(-1, 28*28)
train_x.shape

torch.Size([12396, 784])

In [941]:
train_y = tensor([1.,0.]*len(threes) + [0., 1.]*len(sevens)).view(-1, 2)
train_y.shape

torch.Size([12396, 2])

In [942]:
def mnist_loss(preds, targ):
    preds = preds.sigmoid()
    return torch.where(targ==1., 1-preds, preds).mean()

In [943]:
def init_params(size, std=1.):
    return (torch.randn(size)*std).requires_grad_()

In [944]:
w = init_params((28*28, 2))
b = init_params(2)

In [945]:
w.shape

torch.Size([784, 2])

In [946]:
def mk_linear(w, b, xb): return xb@w + b

In [947]:
linear = partial(mk_linear, w, b)

In [948]:
preds = linear(train_x[:5])
preds, preds.sigmoid()

(tensor([[  4.7158,  -2.9719],
         [  8.8701, -11.1608],
         [  4.7474, -12.1990],
         [  1.2906, -11.5354],
         [  1.4465,   0.4475]], grad_fn=<AddBackward0>),
 tensor([[9.9113e-01, 4.8711e-02],
         [9.9986e-01, 1.4220e-05],
         [9.9140e-01, 5.0355e-06],
         [7.8426e-01, 9.7777e-06],
         [8.0946e-01, 6.1003e-01]], grad_fn=<SigmoidBackward0>))

In [949]:
train_y[:5]

tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.]])

In [950]:
xb, yb = train_x[:5], train_y[:5]
acts = linear(xb)

mnist_loss(acts, yb), torch.nn.BCEWithLogitsLoss()(acts, yb)

(tensor(0.1083, grad_fn=<MeanBackward0>),
 tensor(0.1464, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>))

In [951]:
((acts.sigmoid() > 0.5) == yb.bool()).float().mean()

tensor(0.9000)

In [952]:
def step(xb, yb, model):
    preds = model(xb)
    loss = mnist_loss(preds, yb)
    loss.backward()

In [953]:
step(xb,yb,linear)

w.data -= w.grad * 1e-2
b.data -= b.grad * 1e-2

w.grad.zero_()
b.grad.zero_()

mnist_loss(linear(xb), yb)

tensor(0.1064, grad_fn=<MeanBackward0>)

In [954]:
def batch_accuracy(preds, targ):
    preds = preds.sigmoid()
    return ((preds > 0.5) == targ.bool()).float().mean()

In [955]:
batch_accuracy(linear(xb), yb)

tensor(0.9000)

In [956]:
preds = linear(xb)
preds, preds.sigmoid()

(tensor([[  4.7431,  -2.9946],
         [  8.8958, -11.1803],
         [  4.7724, -12.2171],
         [  1.3254, -11.5588],
         [  1.4782,   0.4184]], grad_fn=<AddBackward0>),
 tensor([[9.9136e-01, 4.7669e-02],
         [9.9986e-01, 1.3946e-05],
         [9.9161e-01, 4.9452e-06],
         [7.9008e-01, 9.5514e-06],
         [8.1430e-01, 6.0309e-01]], grad_fn=<SigmoidBackward0>))

In [957]:
tl = DataLoader(list(zip(train_x, train_y)), batch_size=256)

In [958]:
v_threes = [tensor(Image.open(o)) for o in (path/'valid'/'3').ls().sorted()]
v_sevens = [tensor(Image.open(o)) for o in (path/'valid'/'7').ls().sorted()]

valid_x = torch.cat([torch.stack(v_threes).float()/255, torch.stack(v_sevens).float()/255]).view(-1, 28*28)
valid_y = tensor([1.,0.]*len(v_threes) + [0., 1.]*len(v_sevens)).view(-1, 2)

vl = DataLoader(list(zip(valid_x, valid_y)), batch_size=256)

In [959]:
dls = DataLoaders(tl, vl)

In [960]:
def validate_epoch(model):
  acc = [batch_accuracy(model(xb), yb) for xb,yb in vl]
  loss = [mnist_loss(model(xb), yb) for xb,yb in vl]
  return torch.stack(acc).mean(), torch.stack(loss).mean()

In [961]:
def train(dl, model, params, lr=1e-3, num_epochs=5):
  for i in range(num_epochs):
    for xb, yb in dl:
      step(xb, yb, model)
      for p in params:
        p.data -= p.grad * lr
        p.grad.zero_()
  
    print(validate_epoch(model), end='\n')
  

In [962]:
xb, yb = tl.one_batch()
xb.shape, yb.shape

(torch.Size([256, 784]), torch.Size([256, 2]))

In [963]:
preds = linear(xb)
preds

tensor([[ 4.7431e+00, -2.9946e+00],
        [ 8.8958e+00, -1.1180e+01],
        [ 4.7724e+00, -1.2217e+01],
        [ 1.3254e+00, -1.1559e+01],
        [ 1.4782e+00,  4.1838e-01],
        [-2.8312e+00,  5.8213e+00],
        [ 4.0131e+00, -1.0081e+01],
        [ 3.9838e+00, -3.9072e+00],
        [ 5.9184e+00, -6.1768e-01],
        [ 5.9953e+00, -1.0417e+01],
        [ 7.7429e+00, -4.2494e+00],
        [ 1.0468e+01, -6.0428e+00],
        [ 3.5956e+00, -8.0650e-01],
        [ 1.7528e+01, -3.2700e+00],
        [ 2.4495e+00, -8.7076e+00],
        [ 7.9191e+00,  2.1297e+00],
        [ 8.8135e+00, -1.0012e+01],
        [ 1.1741e+01, -8.0418e+00],
        [ 5.0707e+00,  3.4889e+00],
        [ 1.1734e+01,  6.7275e-01],
        [ 7.1007e+00, -6.6241e+00],
        [ 4.8031e+00,  3.3805e+00],
        [ 1.0117e+01, -2.0969e+01],
        [ 4.2099e+00, -3.3816e+00],
        [ 1.0809e+01, -8.8754e+00],
        [ 2.3257e+00, -4.2809e+00],
        [ 1.7353e+00, -7.1853e+00],
        [ 1.2514e+01,  3.603

In [964]:
preds = preds.sigmoid()
preds

tensor([[9.9136e-01, 4.7669e-02],
        [9.9986e-01, 1.3946e-05],
        [9.9161e-01, 4.9452e-06],
        [7.9008e-01, 9.5514e-06],
        [8.1430e-01, 6.0309e-01],
        [5.5662e-02, 9.9705e-01],
        [9.8224e-01, 4.1867e-05],
        [9.8173e-01, 1.9700e-02],
        [9.9732e-01, 3.5031e-01],
        [9.9752e-01, 2.9911e-05],
        [9.9957e-01, 1.4072e-02],
        [9.9997e-01, 2.3692e-03],
        [9.7329e-01, 3.0864e-01],
        [1.0000e+00, 3.6614e-02],
        [9.2052e-01, 1.6529e-04],
        [9.9964e-01, 8.9376e-01],
        [9.9985e-01, 4.4862e-05],
        [9.9999e-01, 3.2164e-04],
        [9.9376e-01, 9.7037e-01],
        [9.9999e-01, 6.6212e-01],
        [9.9918e-01, 1.3262e-03],
        [9.9186e-01, 9.6709e-01],
        [9.9996e-01, 7.8218e-10],
        [9.8537e-01, 3.2876e-02],
        [9.9998e-01, 1.3977e-04],
        [9.1098e-01, 1.3642e-02],
        [8.5008e-01, 7.5705e-04],
        [1.0000e+00, 9.7350e-01],
        [8.4781e-01, 4.8919e-06],
        [9.999

In [1017]:
w1 = init_params((28*28, 2))
b1 = init_params(2)
linear1 = partial(mk_linear, w1, b1)
params = w1,b1

In [1018]:
train(tl, linear1, params, num_epochs=5, lr=1.)

(tensor(0.7580), tensor(0.2435, grad_fn=<MeanBackward0>))
(tensor(0.8529), tensor(0.1523, grad_fn=<MeanBackward0>))
(tensor(0.9088), tensor(0.0984, grad_fn=<MeanBackward0>))
(tensor(0.9301), tensor(0.0751, grad_fn=<MeanBackward0>))
(tensor(0.9416), tensor(0.0632, grad_fn=<MeanBackward0>))
