In [1]:
import fastbook
from fastai.vision.all import *
from fastbook import *
matplotlib.rc('image', cmap='Greys')

In [76]:
def load_data():
    path = untar_data(URLs.MNIST_SAMPLE)
    Path.BASE_PATH = path
    threes = (path/'train'/'3').ls().sorted()
    sevens = (path/'train'/'7').ls().sorted()
    seven_tensors = [tensor(Image.open(o)) for o in sevens]
    three_tensors = [tensor(Image.open(o)) for o in threes]
    stacked_sevens = torch.stack(seven_tensors).float()/255
    stacked_threes = torch.stack(three_tensors).float()/255
    valid_3_tens = torch.stack([tensor(Image.open(o)) 
                                for o in (path/'valid'/'3').ls()])
    valid_3_tens = valid_3_tens.float()/255
    valid_7_tens = torch.stack([tensor(Image.open(o)) 
                                for o in (path/'valid'/'7').ls()])
    valid_7_tens = valid_7_tens.float()/255
    train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)
    train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)
    dset = list(zip(train_x,train_y))
    valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
    valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
    valid_dset = list(zip(valid_x, valid_y))
    return dset, valid_dset

In [80]:
def create_dataloader():
    dset, valid_dset = load_data()
    dl, valid_dl = DataLoader(dset, batch_size=256), DataLoader(valid_dset, batch_size=256)
    return dl, valid_dl

In [100]:
dl, valid_dl = create_dataloader()
simpleNet = nn.Sequential(nn.Linear(28*28, 30),
                         nn.ReLU(),
                         nn.Linear(30, 15),
                         nn.ReLU(),
                         nn.Linear(15, 5),
                         nn.ReLU(),
                         nn.Linear(5,1))

opt = BasicOptim(simpleNet.parameters(), lr=0.01)

def mnist_loss(predictions, targets):
    predictions = predictions.sigmoid()
    return torch.where(targets==1, 1-predictions, predictions).mean()

def batch_accuracy(xb, yb):
    preds = xb.sigmoid()
    correct = (preds>0.5) == yb
    return correct.float().mean()

In [101]:
learn = Chearner(dl, valid_dl, simpleNet, opt, mnist_loss, batch_accuracy)
learn.train_model(50)

0.4932 0.4932 0.4932 0.4932 0.4932 0.4932 0.4932 0.4932 0.4932 0.4932 0.4932 0.5035 0.7384 0.9025 0.95 0.9565 0.9375 0.9233 0.9326 0.9424 0.956 0.9614 0.9658 0.9643 0.9653 0.9682 0.9692 0.9692 0.9692 0.9696 0.9696 0.9696 0.9696 0.9706 0.9716 0.9721 0.9721 0.9721 0.9721 0.9736 0.9736 0.9736 0.9736 0.9736 0.9736 0.9736 0.9736 0.9745 0.9745 0.975 

In [82]:
class BasicOptim:
    
    def __init__(self, params, lr): self.params, self.lr = list(params), lr
        
    def step(self, *args, **kwargs):
        for p in self.params: p.data -= p.grad.data * self.lr
            
    def zero_grad(self, *args, **kwargs):
        for p in self.params: p.grad = None

In [96]:
class Chearner:
    
    def __init__(self, dl, valid_dl, model, opt, loss_func, metric): 
        self.dl, self.valid_dl = dl, valid_dl
        self.model = model
        self.opt = opt
        self.loss_func = loss_func
        self.metric = metric
    
    def validate_epoch(self):
        accs = [self.metric(self.model(xb), yb) for xb, yb in self.valid_dl]
        return round(torch.stack(accs).mean().item(), 4)
    
    def calc_grad(self, xb, yb):
        preds = self.model(xb)
        loss = self.loss_func(preds, yb)
        loss.backward()
    
    def train_epoch(self, model, opt):
        for xb,yb in self.dl:
            self.calc_grad(xb,yb)
            self.opt.step()
            self.opt.zero_grad()
            
    def train_model(self, epochs):
        for i in range(epochs):
            train_epoch(self.model)
            print(self.validate_epoch(), end= ' ')