In [None]:
from fastai.vision.all import *

In [None]:
class my_dataloader:
    
    def __init__(self, train_dataset=None, valid_dataset=None, batch_size=256, shuffle=True):
        self.tds = train_dataset
        self.vds = valid_dataset
        self.bs = batch_size
        self.shuffle = shuffle

    def get_dl_dims(self):
        x, y = self.tds[0]
        return len(x) 
    
    # create training model (personal or resent18 for example)
    def create_linear_model(self):
        return nn.Sequential(
            nn.Linear(self.get_dl_dims(), 30),
            nn.ReLU(),
            nn.Linear(30,1)
        )
    
    def get_train_dl(self):
        return DataLoader(self.tds, self.bs, self.shuffle)
    
    def get_valid_dl(self):
        return DataLoader(self.vds, self.bs, self.shuffle)
    
    def get_dataloaders(self):
        return (self.get_train_dl(), self.get_valid_dl())
    

In [None]:
class BasicOptim:
    def __init__(self,params,lr): 
        self.params, self.lr = list(params), lr

    def step(self, *args, **kwargs):
        for p in self.params: p.data -= p.grad.data * self.lr

    def zero_grad(self, *args, **kwargs):
        for p in self.params: p.grad = None

In [None]:
class my_learner:
    
    def __init__(self, dls=None, training_model=None, optimization_function=None, 
                 loss_function=None, lr=1e-4, metrics=None):
        # Initialize all variables
        self.train_dl, self.valid_dl = dls
        self.model = training_model
        self.opt_func = optimization_function
        self.loss_func = loss_function 
        self.lr = lr
        self.metrics = self.batch_accuracy if metrics == 'accuracy' else metrics
        
        self.output_cols = ['epoch', 'train_loss', 'valid_loss', 'metrics']
        self.output_rows = []
        
        self._epoch = 0
        self._tloss = 0.
        self._vloss = 0.
        self._met = 0.
        
    def add_data(self):
        self.output_rows.append([self._epoch, self._tloss, self._vloss, self._met])
        
    def show_output(self):
        d = pd.DataFrame(columns=self.output_cols, data=self.output_rows)
        return d.style.hide_index()
        
    # do single training loop
    def train_epoch(self, model):
        for xb, yb in self.train_dl:
            self.calc_gradient(xb, yb, model)
            self.opt.step()
            self.opt.zero_grad()
        for xb, yb in self.valid_dl:
            self._log_validation_loss(xb, yb, model)
            
    # do actual training 
    def train_model(self, model, epochs):
        m = model if model else self.model
        if epochs > 0:
            self.opt = self.opt_func(self.model.parameters(), self.lr)
            for i in range(epochs):
                self._epoch = i
                self.train_epoch(m)
                self._met = self.validate_epoch(m)
                self.add_data()
    
    def validate_epoch(self, model):
        accs = [self.metrics(model(xb), yb) for xb, yb in self.valid_dl]
        self._met = round(torch.stack(accs).mean().item(), 4)
        return self._met
    
    def _log_validation_loss(self, xb, yb, model):
        preds = model(xb)
        loss = self.loss_func(preds, yb)
        self._vloss = round(loss.item(), 4)
        return self._vloss
    
    def calc_gradient(self, xb, yb, model):
        preds = model(xb)
        loss = self.loss_func(preds, yb)
        self._tloss = round(loss.item(), 4)
        loss.backward()
    
    def batch_accuracy(self, xb, yb):
        preds = xb.sigmoid()
        correct = (preds>0.5) == yb
        return correct.float().mean()
    
    def fit(self, epochs, learning_rate=0):
        if (learning_rate > 0):
            self.lr = learning_rate
        self.train_model(self.model, epochs)
        

In [None]:
# datasets setup
path = untar_data(URLs.MNIST_SAMPLE)
Path.BASE_PATH = path

In [None]:
# training dataset tensors
threes = (path/'train'/'3').ls().sorted()
sevens = (path/'train'/'7').ls().sorted()

seven_tensors = [tensor(Image.open(o)) for o in sevens]
three_tensors = [tensor(Image.open(o)) for o in threes]

d1, d2 = seven_tensors[0].shape # get dimension sizes
dsqr = d1 * d2 # dimention size

stacked_sevens = torch.stack(seven_tensors).float()/255
stacked_threes = torch.stack(three_tensors).float()/255

train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, dsqr)
train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)

train_ds = list(zip(train_x, train_y))

In [None]:
# validation dataset setup
valid_7_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'valid'/'7').ls()])
valid_7_tens = valid_7_tens.float()/255
valid_3_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'valid'/'3').ls()])
valid_3_tens = valid_3_tens.float()/255

valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, dsqr)
valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)

valid_ds = list(zip(valid_x, valid_y))

In [None]:
# Load data from datasets
dls = my_dataloader(train_ds, valid_ds)
dataloaders = dls.get_dataloaders()

In [None]:
# get linear training model
training_model = dls.create_linear_model()

In [None]:
# create the loss function
def mnist_loss(predictions, targets):
    predictions = predictions.sigmoid()
    return torch.where(targets == 1, 1 - predictions, predictions).mean()

In [None]:
# create learner
learn = my_learner(dataloaders, training_model, optimization_function=SGD, 
                   loss_function=mnist_loss, lr=0.01, metrics='accuracy')

In [None]:
# run the learner
learn.fit(10)

In [None]:
learn.show_output()