In [84]:
import torch
from fastbook import *
import matplotlib

matplotlib.rc('image', cmap='Greys')

In [85]:
path = untar_data(URLs.MNIST_SAMPLE)

In [86]:
Path.BASE_PATH = path

In [87]:
path.ls()

(#3) [Path('valid'),Path('labels.csv'),Path('train')]

In [88]:
(path/'train').ls()

(#2) [Path('train/7'),Path('train/3')]

In [89]:
threes = (path/'train'/'3').ls().sorted()
sevens = (path/'train'/'7').ls().sorted()

In [90]:
threes_tensor = [tensor(Image.open(o), device='mps') for o in threes]
sevens_tensor = [tensor(Image.open(o), device='mps') for o in sevens]

In [91]:
img_3 = threes_tensor[0]
df = pd.DataFrame(img_3[4:15, 4:22])
df.style.set_properties(**{'font-size': '6pt'}).background_gradient('Greys')

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,0,0,0,0,0,0,0,42,118,219,166,118,118,6,0,0,0,0
1,0,0,0,0,0,0,103,242,254,254,254,254,254,66,0,0,0,0
2,0,0,0,0,0,0,18,232,254,254,254,254,254,238,70,0,0,0
3,0,0,0,0,0,0,0,104,244,254,224,254,254,254,141,0,0,0
4,0,0,0,0,0,0,0,0,207,254,210,254,254,254,34,0,0,0
5,0,0,0,0,0,0,0,0,84,206,254,254,254,254,41,0,0,0
6,0,0,0,0,0,0,0,0,0,24,209,254,254,254,171,0,0,0
7,0,0,0,0,0,0,0,0,91,137,253,254,254,254,112,0,0,0
8,0,0,0,0,0,0,40,214,250,254,254,254,254,254,34,0,0,0
9,0,0,0,0,0,0,81,247,254,254,254,254,254,254,146,0,0,0


In [92]:
stacked_threes = torch.stack(threes_tensor).float()/255
stacked_sevens = torch.stack(sevens_tensor).float()/255
stacked_threes.shape

torch.Size([6131, 28, 28])

In [93]:
train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)
train_x.shape

torch.Size([12396, 784])

In [94]:
train_y = tensor([1]*len(threes) + [0]*len(sevens), device='mps').unsqueeze(1)
train_y.shape

torch.Size([12396, 1])

In [95]:
dset = list(zip(train_x, train_y))
x, y = dset[0]
x.shape, y.shape

(torch.Size([784]), torch.Size([1]))

In [96]:
validation_threes = torch.stack(
  [tensor(Image.open(o), device='mps') for o in (path/'valid'/'3').ls()]
).float()/255

validation_sevens = torch.stack(
  [tensor(Image.open(o), device='mps') for o in (path/'valid'/'7').ls()]
).float()/255

validation_x = torch.cat([validation_threes, validation_sevens]).view(-1, 28*28)
validation_y = tensor([1]*len(validation_threes) 
                            + [0]*len(validation_sevens), device='mps').unsqueeze(1)

validation_dset = list(zip(validation_x, validation_y))

In [97]:
def init_params(size, std=1.):
  return (torch.randn(size, device='mps')*std).requires_grad_()

In [98]:
dset_loader = DataLoader(dset, batch_size=256)
validation_dset_loader = DataLoader(validation_dset, batch_size=256)
xb, yb = first(dset_loader)
xb.shape, yb.shape

(torch.Size([256, 784]), torch.Size([256, 1]))

In [99]:
def mnist_loss(preds, acts):
  preds = preds.sigmoid()
  return torch.where(acts == 1., 1-preds, preds).mean()

In [100]:
w = init_params((28*28, 1))
b = init_params(1)

In [101]:
def linear(xb): return xb@w + b

In [102]:
preds = linear(train_x[:5])
preds

tensor([[-1.9430],
        [-1.7535],
        [ 4.2589],
        [ 3.8169],
        [-0.5155]], device='mps:0', grad_fn=<AddBackward0>)

In [103]:
mnist_loss(preds, train_y[:5])

tensor(0.4777, device='mps:0', grad_fn=<MeanBackward0>)

In [104]:
def step(xb, yb, model):
  preds = model(xb)
  loss = mnist_loss(preds, yb)
  loss.backward()

In [105]:
def batch_accuracy(xb, yb):
  preds = xb.sigmoid()
  correct = (preds>0.5) == yb
  return correct.float().mean()

In [106]:
step(train_x, train_y, linear)

w.data -= w.grad*1
b.data -= b.grad*1

w.grad.zero_()
b.grad.zero_()

tensor([0.], device='mps:0')

In [107]:
batch_accuracy(linear(train_x), train_y)

tensor(0.4789, device='mps:0')

In [108]:
def validate_epoch(model):
  accs = [batch_accuracy(model(xb), yb) for xb,yb in validation_dset_loader]
  loss = [mnist_loss(model(xb), yb) for xb,yb in validation_dset_loader]
  return torch.stack(accs).mean(), torch.stack(loss).mean()

In [109]:
validate_epoch(linear)

(tensor(0.4672, device='mps:0'),
 tensor(0.5350, device='mps:0', grad_fn=<MeanBackward0>))

In [110]:
def train1(model, params, lr=1., num_epochs=5):
  for i in range(num_epochs):
    for xb, yb in dset_loader:
      step(xb, yb, model)
      for p in params:
        p.data -= p.grad*lr
        p.grad.zero_()
    
    print(validate_epoch(model), end='\n')

In [111]:
def train2(model, params, lr=1., num_epochs=5):
  for i in range(num_epochs):
    step(train_x, train_y, model)
    for p in params:
      p.data -= p.grad*lr
      p.grad.zero_()
    
    print(validate_epoch(model), end='\n')

In [112]:
def train3(model, params, lr=1., num_epochs=5):
  for xb, yb in dset_loader:
    for i in range(num_epochs):
      step(xb, yb, model)
      for p in params:
        p.data -= p.grad*lr
        p.grad.zero_()
        
    print(batch_accuracy(model(xb), yb), end='\n')

In [113]:
def train4(model, lr=1., num_epochs=5):
  for i in range(num_epochs):
    for xb, yb in dset_loader:
      step(xb, yb, model)
      for p in model.parameters():
        p.data -= p.grad*lr
        p.grad.zero_()
    
    print(validate_epoch(model), end='\n')

In [114]:
params = w,b
train1(linear, params, num_epochs=5, lr=1.)

(tensor(0.6547, device='mps:0'), tensor(0.3500, device='mps:0', grad_fn=<MeanBackward0>))
(tensor(0.8432, device='mps:0'), tensor(0.1645, device='mps:0', grad_fn=<MeanBackward0>))
(tensor(0.9003, device='mps:0'), tensor(0.1030, device='mps:0', grad_fn=<MeanBackward0>))
(tensor(0.9286, device='mps:0'), tensor(0.0752, device='mps:0', grad_fn=<MeanBackward0>))
(tensor(0.9433, device='mps:0'), tensor(0.0610, device='mps:0', grad_fn=<MeanBackward0>))


In [115]:
batch_accuracy(linear(validation_x), validation_y)

tensor(0.9431, device='mps:0')

In [116]:
linear_model = torch.nn.Linear(28*28, 1, device='mps')

In [117]:
batch_accuracy(linear_model(train_x), train_y)

tensor(0.5799, device='mps:0')

In [118]:
train4(linear_model, num_epochs=5, lr=.1)

(tensor(0.5649, device='mps:0'), tensor(0.3626, device='mps:0', grad_fn=<MeanBackward0>))
(tensor(0.8755, device='mps:0'), tensor(0.1798, device='mps:0', grad_fn=<MeanBackward0>))
(tensor(0.9360, device='mps:0'), tensor(0.1071, device='mps:0', grad_fn=<MeanBackward0>))
(tensor(0.9550, device='mps:0'), tensor(0.0802, device='mps:0', grad_fn=<MeanBackward0>))
(tensor(0.9633, device='mps:0'), tensor(0.0669, device='mps:0', grad_fn=<MeanBackward0>))


In [40]:
dls = DataLoaders(dset_loader, validation_dset_loader)

In [41]:
learn = Learner(dls, torch.nn.Linear(28*28, 1), loss_func=mnist_loss,
                opt_func=SGD, metrics=batch_accuracy)

learn.fit(5, 1.)

epoch,train_loss,valid_loss,batch_accuracy,time
0,0.636928,0.500492,0.495584,00:03
1,0.276627,0.331183,0.658979,00:03
2,0.111369,0.155409,0.86212,00:03
3,0.054769,0.09831,0.916094,00:03
4,0.033432,0.073837,0.935231,00:03


In [119]:
neural_net = torch.nn.Sequential(
  torch.nn.Linear(28*28, 30),
  torch.nn.ReLU(),
  torch.nn.Linear(30, 1),
).to(device='mps')

Sequential(
  (0): Linear(in_features=784, out_features=30, bias=True)
  (1): ReLU()
  (2): Linear(in_features=30, out_features=1, bias=True)
)

In [120]:
learn = Learner(dls, neural_net, opt_func=SGD, loss_func=mnist_loss,
                metrics=batch_accuracy)

In [121]:
learn.fit(5, 1.)

epoch,train_loss,valid_loss,batch_accuracy,time
0,0.166287,0.489857,0.504416,00:04
1,0.083623,0.251175,0.755643,00:03
2,0.042393,0.112369,0.895486,00:03
3,0.02675,0.077265,0.92787,00:03
4,0.019674,0.061569,0.941119,00:03


In [122]:
class Optim:
  def __init__(self, params, lr):
    self.params, self.lr = list(params), lr
    
  def step(self, *args, **kwargs):
    for p in self.params:
      p.data -= p.grad.data*self.lr
  
  def zero(self, *args, **kwargs):
    for p in self.params:
      p.grad.zero_()

In [127]:
def train(dls, model, opt, num_epochs=5):
  for i in range(num_epochs):
    for xb,yb in dls:
      step(xb, yb, model)
      opt.step()
      opt.zero()
      
    print(validate_epoch(model), end='\n')

In [130]:
neural_net = torch.nn.Sequential(
  torch.nn.Linear(28*28, 100),
  torch.nn.ReLU(),
  torch.nn.Linear(100, 1)
).to(device='mps')
opt = Optim(neural_net.parameters(), 1e-2)

train(dset_loader, neural_net, opt, num_epochs=5)

(tensor(0.9378, device='mps:0'), tensor(0.4626, device='mps:0', grad_fn=<MeanBackward0>))
(tensor(0.9452, device='mps:0'), tensor(0.4167, device='mps:0', grad_fn=<MeanBackward0>))
(tensor(0.9433, device='mps:0'), tensor(0.3560, device='mps:0', grad_fn=<MeanBackward0>))
(tensor(0.9462, device='mps:0'), tensor(0.2873, device='mps:0', grad_fn=<MeanBackward0>))
(tensor(0.9511, device='mps:0'), tensor(0.2247, device='mps:0', grad_fn=<MeanBackward0>))
