In [116]:
path = untar_data(URLs.MNIST_SAMPLE)
threes = (path/'train/3').ls().sorted()
sevens = (path/'train/7').ls().sorted()
sevens_tensors = [tensor(Image.open(o)) for o in sevens]
three_tensors = [tensor(Image.open(o)) for o in threes]

stacked_sevens = torch.stack(sevens_tensors).float()/255
stacked_threes = torch.stack(three_tensors).float()/255

valid_3_tns = torch.stack([tensor(Image.open(o)) for o in (path/'valid/3').ls()])
valid_3_tns = valid_3_tns.float()/255

valid_7_tns = torch.stack([tensor(Image.open(o)) for o in (path/'valid/7').ls()])
valid_7_tns = valid_7_tns.float()/255

train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28 * 28)

train_y = tensor([1] * len(threes) + [0]* len(sevens)).unsqueeze(1)
dset = list(zip(train_x, train_y))

valid_x = torch.cat([valid_3_tns, valid_7_tns]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_3_tns) + [0]* len(valid_7_tns)).unsqueeze(1)
valid_dset = list(zip(valid_x, valid_y))

dl = DataLoader(dset, batch_size=256)
valid_dl = DataLoader(valid_dset, batch_size=256)

dls = DataLoaders(dl, valid_dl)

In [117]:
def mnist_loss(predictions, targets):
    predictions = predictions.sigmoid()
    return torch.where(targets==1, 1-predictions, predictions).float().mean()

def batch_accuracy(xb, yb):
    preds = xb.sigmoid()
    correct = (preds > 0.5) == yb
    return correct.float().mean()

In [118]:
from fastai.vision.all import *

In [317]:
class Model:
    def __init__(self, dls, arch, opt_func, loss_func, metrics):
        self.train, self.valid = dls
        self.arch = arch
        self.opt_func = opt_func
        self.parameters = self.arch.parameters()
        self.opt_func = opt_func(self.arch.parameters(), 0.001)
        self.loss_func = loss_func
        self.metrics = metrics
        
        self.current_loss = None
    
    def arch_(self, data):
        return self.arch(data)
    
    def metrics_(self, data, y):
        return self.metrics(self.arch_(data), y)
    
    def loss(self, x, y):
        return self.loss_func(self.arch_(x), y)
    
    def calc_grad(self, x, y):
        preds = self.arch_(x)
        loss = self.loss_func(preds, y)
        loss.backward()
        self.current_loss = loss

    def train_epoch(self):
        for x, y in self.train:
            self.calc_grad(x, y)
            self.opt_func.step()
            self.opt_func.zero_grad()
            
    def validate_epoch(self):
        accs = [self.metrics_(x, y) for x, y in self.valid]
        return round(torch.stack(accs).mean().item(), 4)
    
    def fit(self, epochs):
        print(f"epoch  |  train_loss  |  metrics")
        for i in range(epochs):
            self.train_epoch()
            print(f"{i}    {round(self.current_loss.item(), 2)}    {self.validate_epoch()}")

    def predict(self, x):
        pred = round(self.arch_(x).item(),3)
        return 1 if pred > .5 else 0

In [328]:
m = Model(dls, nn.Linear(28*28, 1), Adam, mnist_loss, batch_accuracy)
m.fit(10)

epoch  |  train_loss  |  metrics
0    0.68    0.4932
1    0.24    0.9646
2    0.12    0.9662
3    0.09    0.9662
4    0.07    0.9667
5    0.06    0.9681
6    0.05    0.9691
7    0.04    0.9711
8    0.04    0.972
9    0.03    0.9725


<function fastai.optimizer.Adam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-05, wd=0.01, decouple_wd=True)>

In [320]:
m.predict(valid_x[-1])

0

In [321]:
valid_y[-1]

tensor([0])

In [221]:
train, valid = dls

In [227]:
for x, y in train:
    print(x)
    break

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [195]:
m = Model(dls, nn.Linear(28*28, 1), SGD, mnist_loss, batch_accuracy)

In [196]:
m.arch_(train_x[0])

tensor([0.0408], grad_fn=<AddBackward0>)

In [197]:
m.metrics_(train_x, train_y)

tensor(0.7007)

In [198]:
m.loss(train_x[0], train_y[0])

tensor(0.4898, grad_fn=<MeanBackward0>)

In [204]:
w, b = m.arch.parameters()
print(b.grad)

tensor([-0.4998])


In [207]:
m.calc_grad(train_x[0], train_y[0])

In [208]:
x, y = dls

In [209]:
w, b = m.arch.parameters()
print(b.grad)

tensor([-0.7497])
