In [29]:
! pip install -Uqq fastbook

In [30]:
import torch
from fastbook import *

In [31]:
path = untar_data(URLs.MNIST_SAMPLE)

In [32]:
Path.BASE_PATH = path

In [33]:
path.ls()

(#3) [Path('train'),Path('labels.csv'),Path('valid')]

In [34]:
(path/'train').ls()

(#2) [Path('train/7'),Path('train/3')]

In [35]:
threes = [tensor(Image.open(o)) for o in (path/'train'/'3').ls().sorted()]
sevens = [tensor(Image.open(o)) for o in (path/'train'/'7').ls().sorted()]

In [36]:
train_x = torch.cat([torch.stack(threes).float()/255, torch.stack(sevens).float()/255]).view(-1, 28*28)
train_x.shape

torch.Size([12396, 784])

In [37]:
train_y = tensor([1.,0.]*len(threes) + [0., 1.]*len(sevens)).view(-1, 2)
train_y.shape

torch.Size([12396, 2])

In [38]:
def mnist_loss(preds, targ):
    preds = preds.sigmoid()
    return -torch.where(targ==1., preds, 1-preds).log().mean()

In [39]:
def init_params(size, std=1.):
    return (torch.randn(size)*std).requires_grad_()

In [40]:
w = init_params((28*28, 2))
b = init_params(2)

In [41]:
w.shape

torch.Size([784, 2])

In [42]:
def mk_linear(w, b, xb): return xb@w + b

In [43]:
linear = partial(mk_linear, w, b)

In [44]:
preds = linear(train_x[:5])
preds, preds.sigmoid()

(tensor([[ -4.7881,  -2.7658],
         [ -1.6212,   7.5044],
         [  1.9732,  -5.7859],
         [  1.0460,   0.1084],
         [-13.9734,   8.2622]], grad_fn=<AddBackward0>),
 tensor([[8.2593e-03, 5.9202e-02],
         [1.6503e-01, 9.9945e-01],
         [8.7796e-01, 3.0612e-03],
         [7.4001e-01, 5.2706e-01],
         [8.5395e-07, 9.9974e-01]], grad_fn=<SigmoidBackward0>))

In [45]:
train_y[:5]

tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.]])

In [46]:
xb, yb = train_x[:5], train_y[:5]
acts = linear(xb)

mnist_loss(acts, yb), torch.nn.BCEWithLogitsLoss()(acts, yb)

(tensor(3.7583, grad_fn=<NegBackward0>),
 tensor(3.7583, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>))

In [47]:
((acts.sigmoid() > 0.5) == yb.bool()).float().mean()

tensor(0.4000)

In [48]:
def step(xb, yb, model):
    preds = model(xb)
    loss = mnist_loss(preds, yb)
    loss.backward()

In [49]:
step(xb,yb,linear)

w.data -= w.grad * 1e-2
b.data -= b.grad * 1e-2

w.grad.zero_()
b.grad.zero_()

mnist_loss(linear(xb), yb)

tensor(3.6194, grad_fn=<NegBackward0>)

In [50]:
def batch_accuracy(preds, targ):
    preds = preds.sigmoid()
    return ((preds > 0.5) == targ.bool()).float().mean()

In [51]:
batch_accuracy(linear(xb), yb)

tensor(0.5000)

In [52]:
preds = linear(xb)
preds, preds.sigmoid()

(tensor([[ -4.5237,  -2.9634],
         [ -1.3709,   7.2907],
         [  2.1943,  -5.9642],
         [  1.3216,  -0.1262],
         [-13.7049,   8.0391]], grad_fn=<AddBackward0>),
 tensor([[1.0733e-02, 4.9108e-02],
         [2.0248e-01, 9.9932e-01],
         [8.9974e-01, 2.5624e-03],
         [7.8946e-01, 4.6850e-01],
         [1.1170e-06, 9.9968e-01]], grad_fn=<SigmoidBackward0>))

In [53]:
tl = DataLoader(list(zip(train_x, train_y)), batch_size=256)

In [54]:
v_threes = [tensor(Image.open(o)) for o in (path/'valid'/'3').ls().sorted()]
v_sevens = [tensor(Image.open(o)) for o in (path/'valid'/'7').ls().sorted()]

valid_x = torch.cat([torch.stack(v_threes).float()/255, torch.stack(v_sevens).float()/255]).view(-1, 28*28)
valid_y = tensor([1.,0.]*len(v_threes) + [0., 1.]*len(v_sevens)).view(-1, 2)

vl = DataLoader(list(zip(valid_x, valid_y)), batch_size=256)

In [55]:
dls = DataLoaders(tl, vl)

In [56]:
def validate_epoch(model):
  acc = [batch_accuracy(model(xb), yb) for xb,yb in vl]
  loss = [mnist_loss(model(xb), yb) for xb,yb in vl]
  return torch.stack(acc).mean(), torch.stack(loss).mean()

In [57]:
def train(dl, model, params, lr=1e-3, num_epochs=5):
  for i in range(num_epochs):
    for xb, yb in dl:
      step(xb, yb, model)
      for p in params:
        p.data -= p.grad * lr
        p.grad.zero_()
  
    print(validate_epoch(model), end='\n')
  

In [58]:
xb, yb = tl.one_batch()
xb.shape, yb.shape

(torch.Size([256, 784]), torch.Size([256, 2]))

In [59]:
preds = linear(xb)
preds

tensor([[-4.5237e+00, -2.9634e+00],
        [-1.3709e+00,  7.2907e+00],
        [ 2.1943e+00, -5.9642e+00],
        [ 1.3216e+00, -1.2618e-01],
        [-1.3705e+01,  8.0391e+00],
        [-2.8022e+01, -1.9884e+00],
        [ 7.3283e+00, -3.5304e-02],
        [-3.7637e+00, -1.5320e+01],
        [-5.6923e+00, -1.7013e+00],
        [ 1.0457e+01,  4.9793e-01],
        [ 3.9436e+00, -1.8608e+00],
        [ 4.8244e+00,  2.9545e+00],
        [-3.1911e+00, -1.0898e+00],
        [ 3.7475e+00,  4.9070e+00],
        [ 6.6870e-01,  1.6085e+00],
        [ 1.7065e+00,  9.2504e+00],
        [ 8.9764e+00, -9.7835e-01],
        [-5.0925e+00, -2.2423e+00],
        [-8.9287e-01,  5.5390e+00],
        [ 5.7465e+00,  3.0316e+00],
        [-6.6608e+00,  4.6812e-01],
        [ 4.3504e+00, -2.7852e+00],
        [-7.3924e+00, -1.3504e+00],
        [ 4.7840e+00, -1.8894e-01],
        [ 5.8777e-01,  3.9847e+00],
        [ 9.5865e+00, -8.3451e+00],
        [-4.9981e+00, -3.8148e+00],
        [ 8.0121e+00, -5.111

In [60]:
preds = preds.sigmoid()
preds

tensor([[1.0733e-02, 4.9109e-02],
        [2.0248e-01, 9.9932e-01],
        [8.9974e-01, 2.5624e-03],
        [7.8946e-01, 4.6850e-01],
        [1.1170e-06, 9.9968e-01],
        [6.7662e-13, 1.2043e-01],
        [9.9934e-01, 4.9117e-01],
        [2.2672e-02, 2.2212e-07],
        [3.3604e-03, 1.5429e-01],
        [9.9997e-01, 6.2197e-01],
        [9.8099e-01, 1.3461e-01],
        [9.9203e-01, 9.5048e-01],
        [3.9502e-02, 2.5166e-01],
        [9.7697e-01, 9.9266e-01],
        [6.6121e-01, 8.3320e-01],
        [8.4638e-01, 9.9990e-01],
        [9.9987e-01, 2.7322e-01],
        [6.1049e-03, 9.6020e-02],
        [2.9052e-01, 9.9608e-01],
        [9.9682e-01, 9.5398e-01],
        [1.2785e-03, 6.1494e-01],
        [9.8726e-01, 5.8127e-02],
        [6.1557e-04, 2.0580e-01],
        [9.9171e-01, 4.5291e-01],
        [6.4285e-01, 9.8174e-01],
        [9.9993e-01, 2.3751e-04],
        [6.7053e-03, 2.1567e-02],
        [9.9967e-01, 3.7493e-01],
        [9.9830e-01, 3.3156e-03],
        [1.674

In [61]:
w1 = init_params((28*28, 2))
b1 = init_params(2)
linear1 = partial(mk_linear, w1, b1)
params = w1,b1

In [62]:
train(tl, linear1, params, num_epochs=5, lr=1.)

(tensor(0.5000), tensor(nan, grad_fn=<MeanBackward0>))
(tensor(0.5000), tensor(nan, grad_fn=<MeanBackward0>))
(tensor(0.5000), tensor(nan, grad_fn=<MeanBackward0>))
(tensor(0.5000), tensor(nan, grad_fn=<MeanBackward0>))
(tensor(0.5000), tensor(nan, grad_fn=<MeanBackward0>))


In [63]:
nums = DataBlock(
  blocks=(ImageBlock, MultiCategoryBlock),
  get_items=get_image_files,
  get_y=parent_label,
  splitter=RandomSplitter(seed=42),
)

dls = nums.dataloaders(path/'train')

In [66]:
x,y = dls.one_batch()
x[0], y[0]

(TensorImage([[[0., 0., 0.,  ..., 0., 0., 0.],
               [0., 0., 0.,  ..., 0., 0., 0.],
               [0., 0., 0.,  ..., 0., 0., 0.],
               ...,
               [0., 0., 0.,  ..., 0., 0., 0.],
               [0., 0., 0.,  ..., 0., 0., 0.],
               [0., 0., 0.,  ..., 0., 0., 0.]],
 
              [[0., 0., 0.,  ..., 0., 0., 0.],
               [0., 0., 0.,  ..., 0., 0., 0.],
               [0., 0., 0.,  ..., 0., 0., 0.],
               ...,
               [0., 0., 0.,  ..., 0., 0., 0.],
               [0., 0., 0.,  ..., 0., 0., 0.],
               [0., 0., 0.,  ..., 0., 0., 0.]],
 
              [[0., 0., 0.,  ..., 0., 0., 0.],
               [0., 0., 0.,  ..., 0., 0., 0.],
               [0., 0., 0.,  ..., 0., 0., 0.],
               ...,
               [0., 0., 0.,  ..., 0., 0., 0.],
               [0., 0., 0.,  ..., 0., 0., 0.],
               [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0'),
 TensorMultiCategory([0., 1.], device='cuda:0'))

In [67]:
learn = vision_learner(dls, resnet18, metrics=partial(accuracy_multi, thresh=.5))
learn.fine_tune(3, base_lr=3e-3, freeze_epochs=4)

epoch,train_loss,valid_loss,accuracy_multi,time
0,0.044026,0.014599,0.995764,00:05
1,0.014144,0.007773,0.997176,00:04
2,0.00401,0.004452,0.99879,00:04
