<a href="https://colab.research.google.com/github/cluePrints/fastai-v3-notes/blob/master/fastai3_part2_04_callbacks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:

%load_ext autoreload
%autoreload 2

%matplotlib inline

In [0]:
from fastai import datasets
import torch
import gzip, pickle
MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'

def get_data():
  path = datasets.download_data(MNIST_URL, ext='.gz')
  with gzip.open(path) as f:
    ((train_X, train_Y), (valid_X, valid_Y), (test_X, test_Y)) = pickle.load(f, encoding='latin')

  train_X = torch.tensor(train_X)
  train_Y = torch.tensor(train_Y)
  valid_X = torch.tensor(valid_X)
  valid_Y = torch.tensor(valid_Y)
  return train_X, train_Y, valid_X, valid_Y

import operator

def test(a,b,cmp,cname=None):
    if cname is None: cname=cmp.__name__
    assert cmp(a,b),f"{cname}:\n{a}\n{b}"

def test_eq(a,b): test(a,b,operator.eq,'==')
def near(a,b): return torch.allclose(a, b, rtol=1e-3, atol=1e-5)
def test_near(a,b): test(a,b,near)
  
def accuracy(actual, expected):
  idx = actual.argmax(dim=1)
  return (idx == expected).float().mean()

In [0]:
train_X, train_Y, valid_X, valid_Y = get_data()

In [0]:
batch_size = 64
from torch.utils.data import TensorDataset, DataLoader
train_ds = TensorDataset(train_X, train_Y)
valid_ds = TensorDataset(valid_X, valid_Y)
train_dl = DataLoader(train_ds, batch_size = batch_size, shuffle = True)
valid_dl = DataLoader(valid_ds, batch_size = batch_size, shuffle = False)

In [5]:
len(valid_dl.dataset)

10000

In [0]:
class Databunch():
  def __init__(self, train_dl, valid_dl):
    self.train_dl = train_dl
    self.valid_dl = valid_dl
  
  @property
  def train_ds(self):
    return self.train_dl.dataset

  @property
  def valid_ds(self):
    return self.valid_dl.dataset

databunch = Databunch(train_dl, valid_dl)
test_eq(len(databunch.train_ds), len(train_ds))

In [7]:
import torch.nn.functional as F
from torch import nn, optim

n_inputs = 784
n_hidden = 50
n_classes = len(train_Y.unique())
lr = 0.5

model = nn.Sequential(
   nn.Linear(n_inputs, n_hidden),
   nn.ReLU(),
   nn.Linear(n_hidden, n_classes)
)

class Recorder():
  pass

class Learner():
  def __init__(self, model, databunch, loss_func, opt_class):
    self.model = model
    self.databunch = databunch
    self.loss_func = loss_func
    self.opt_class = opt_class
    
  def fit(self, n_epochs, lr):
    self.recorder = Recorder()
    opt = self.opt_class(self.model.parameters(), lr)

    for epoch in range(n_epochs):
      model.train()
      for (batch_x, batch_y) in self.databunch.train_dl:
        preds = self.model(batch_x)
        loss = self.loss_func(preds, batch_y)
        loss.backward()

        opt.step()
        opt.zero_grad()

      model.eval()
      accuracy_exp_avg = 0
      loss_exp_avg = 0
      for idx, (v_batch_x, v_batch_y) in enumerate(self.databunch.valid_dl):
        with torch.no_grad():
          v_preds = model(v_batch_x)
          loss = self.loss_func(v_preds, v_batch_y)
          avg_coeff = 0.5
          accuracy_exp_avg = accuracy_exp_avg * avg_coeff + accuracy(v_preds, v_batch_y) * (1 - avg_coeff)
          loss_exp_avg     = loss_exp_avg * avg_coeff + loss * (1 - avg_coeff)

      print(f"Epoch {epoch}/{n_epochs}. Validation metrics. Loss: {loss_exp_avg:.3f}, accuracy: {accuracy_exp_avg:.3f}")

    t_preds = model(train_X)
    v_preds = model(valid_X)
    t_acc = accuracy(t_preds, train_Y)
    v_acc = accuracy(v_preds, valid_Y)
    t_loss = self.loss_func(t_preds, train_Y)
    v_loss = self.loss_func(v_preds, valid_Y)
    print(f"Accuracy train: {t_acc:.3f} (loss: {t_loss:.3f}), validation: {v_acc:.3f} (loss: {v_loss:.3f})")
    self.recorder.accuracy = v_acc

learner = Learner(model, databunch, F.cross_entropy, optim.SGD)
learner.fit(5, 0.5)

assert learner.recorder.accuracy > 0.9

Epoch 0/5. Validation metrics. Loss: 0.089, accuracy: 0.989
Epoch 1/5. Validation metrics. Loss: 0.277, accuracy: 0.906
Epoch 2/5. Validation metrics. Loss: 0.300, accuracy: 0.921
Epoch 3/5. Validation metrics. Loss: 0.096, accuracy: 0.986
Epoch 4/5. Validation metrics. Loss: 0.135, accuracy: 0.984
Accuracy train: 0.965 (loss: 0.112), validation: 0.956 (loss: 0.150)


In [8]:
class Callback():
  def before_train(self): return True
  def before_validation(self): return True
  
  
class CallbackHandler():
  def __init__(self):
    self.callbacks = []
  
  def before_train(self):
    for callback in self.callbacks: callback.before_train()

  def before_validation(self):
    for callback in self.callbacks: callback.before_validation()
      
  def init_model_aware(self, model):
    for callback in self.callbacks:
      if hasattr(callback, 'model'): callback.model = model
  
class TrainEvalCallback(Callback):
  def __init__(self):
    self.model = None

  def before_train(self):
    self.model.train()
    
  def before_validation(self):
    self.model.eval()

class Learner():
  def __init__(self, model, databunch, loss_func, opt_class):
    self.model = model
    self.databunch = databunch
    self.loss_func = loss_func
    self.opt_class = opt_class
    self.callback_handler = CallbackHandler()
    self.callback_handler.callbacks.append(TrainEvalCallback())
    self.callback_handler.init_model_aware(model)
    
  def fit(self, n_epochs, lr):
    self.recorder = Recorder()
    opt = self.opt_class(self.model.parameters(), lr)

    for epoch in range(n_epochs):
      self.callback_handler.before_train()
      for (batch_x, batch_y) in self.databunch.train_dl:
        preds = self.model(batch_x)
        loss = self.loss_func(preds, batch_y)
        loss.backward()

        opt.step()
        opt.zero_grad()

      accuracy_exp_avg = 0
      loss_exp_avg = 0
      self.callback_handler.before_validation()
      for idx, (v_batch_x, v_batch_y) in enumerate(self.databunch.valid_dl):
        with torch.no_grad():
          v_preds = model(v_batch_x)
          loss = self.loss_func(v_preds, v_batch_y)
          avg_coeff = 0.5
          accuracy_exp_avg = accuracy_exp_avg * avg_coeff + accuracy(v_preds, v_batch_y) * (1 - avg_coeff)
          loss_exp_avg     = loss_exp_avg * avg_coeff + loss * (1 - avg_coeff)

      print(f"Epoch {epoch}/{n_epochs}. Validation metrics. Loss: {loss_exp_avg:.3f}, accuracy: {accuracy_exp_avg:.3f}")

    t_preds = model(train_X)
    v_preds = model(valid_X)
    t_acc = accuracy(t_preds, train_Y)
    v_acc = accuracy(v_preds, valid_Y)
    t_loss = self.loss_func(t_preds, train_Y)
    v_loss = self.loss_func(v_preds, valid_Y)
    print(f"Accuracy train: {t_acc:.3f} (loss: {t_loss:.3f}), validation: {v_acc:.3f} (loss: {v_loss:.3f})")
    self.recorder.accuracy = v_acc

learner = Learner(model, databunch, F.cross_entropy, optim.SGD)
learner.fit(5, 0.5)

assert learner.recorder.accuracy > 0.9

Epoch 0/5. Validation metrics. Loss: 0.074, accuracy: 0.990
Epoch 1/5. Validation metrics. Loss: 0.081, accuracy: 0.993
Epoch 2/5. Validation metrics. Loss: 0.084, accuracy: 0.990
Epoch 3/5. Validation metrics. Loss: 0.079, accuracy: 0.994
Epoch 4/5. Validation metrics. Loss: 0.223, accuracy: 0.955
Accuracy train: 0.919 (loss: 0.378), validation: 0.911 (loss: 0.430)


In [9]:
model = nn.Sequential(
   nn.Linear(n_inputs, n_hidden),
   nn.ReLU(),
   nn.Linear(n_hidden, n_classes)
)

class Callback():
  def before_train(self): return True
  def before_validation(self): return True
  def after_forward(self, x, y): return True
  def before_epoch(self, epoch, n_epoch): return True
  def after_epoch(self): return True
  def after_loss(self, loss): return True
  
from collections import defaultdict
class TrainingState():
  def __init__(self):
    self.training = True
    self.recorder = defaultdict(lambda: 0)
  
class CallbackHandler():
  def __init__(self):
    self.callbacks = []
  
  def before_train(self):
    for callback in self.callbacks: callback.before_train()

  def before_validation(self):
    for callback in self.callbacks: callback.before_validation()
      
  def after_forward(self, x, y):
    for callback in self.callbacks: callback.after_forward(x, y)

  def before_epoch(self, epoch, n_epoch):
    for callback in self.callbacks: callback.before_epoch(epoch, n_epoch)
      
  def after_epoch(self):
    for callback in self.callbacks: callback.after_epoch()
      
  def after_loss(self, loss_value):
    for callback in self.callbacks: callback.after_loss(loss_value)
      
  def init_model_aware(self, model):
    for callback in self.callbacks:
      if hasattr(callback, 'model'): callback.model = model

  def init_state_aware(self, state):
    for callback in self.callbacks:
      if hasattr(callback, 'state'): callback.state = state
        
  def init_loss_func_aware(self, opt):
    for callback in self.callbacks:
      if hasattr(callback, 'loss_func'): callback.loss_func = loss_func
  
class TrainEvalCallback(Callback):
  def __init__(self):
    self.model = None
    self.state = None

  def before_train(self):
    self.state.training = True
    self.model.train()
    
  def before_validation(self):
    self.state.training = False
    self.model.eval()
    
class MetricsCallback(Callback):
  def __init__(self):
    self.state = None
    self.metrics = {'accuracy': accuracy}
    
  def after_forward(self, x, y):
    for k, metric_func in self.metrics.items():
      metric_value = metric_func(x, y)
      self._record(k, metric_value)

  def before_epoch(self, epoch, n_epoch):
    self.epoch = epoch
    self.n_epochs = n_epoch
      
  def after_loss(self, loss):
    self._record('loss', loss.item())
      
  def after_epoch(self):
    loss_exp_avg = self.state.recorder['val_loss']
    accuracy_exp_avg = self.state.recorder['val_accuracy']
    print(f"Epoch {self.epoch}/{self.n_epochs}. Validation metrics. Loss: {loss_exp_avg:.3f}, accuracy: {accuracy_exp_avg:.3f}")
    
  def _record(self, k, metric_value):
    avg_coeff = 0.5
    name = self._prefix() + k
    self.state.recorder[name] = self.state.recorder[name] * (1 - avg_coeff) + metric_value * avg_coeff
    
  def _prefix(self):
    return 'trn_' if self.state.training else 'val_'


class Learner():
  def __init__(self, model, databunch, loss_func, opt_class):
    self.model = model
    self.databunch = databunch
    self.loss_func = loss_func
    self.opt_class = opt_class
    self.state = TrainingState()
    self.callback_handler = CallbackHandler()
    self.callback_handler.callbacks.append(TrainEvalCallback())
    self.callback_handler.callbacks.append(MetricsCallback())
    self.callback_handler.init_model_aware(self.model)
    self.callback_handler.init_state_aware(self.state)
    self.callback_handler.init_loss_func_aware(self.loss_func)
    
  def fit(self, n_epochs, lr):
    self.opt = self.opt_class(self.model.parameters(), lr)
    
    for epoch in range(n_epochs):
      self.callback_handler.before_epoch(epoch, n_epochs)
      self.single_epoch(epoch, n_epochs)
      self.callback_handler.after_epoch()
      
    self.calculate_final_metrics()

  def single_epoch(self, epoch, n_epochs):
    self.callback_handler.before_train()
    self.run_batches(self.databunch.train_dl)

    self.callback_handler.before_validation()
    self.run_batches(self.databunch.valid_dl)

  def run_batches(self, dataloader):
    for (batch_x, batch_y) in dataloader:
      preds = self.model(batch_x)
      self.callback_handler.after_forward(preds, batch_y)
      loss = self.loss_func(preds, batch_y)
      self.callback_handler.after_loss(loss)
      loss.backward()

      self.opt.step()
      self.opt.zero_grad()
        
  def calculate_final_metrics(self):
    t_preds = model(train_X)
    v_preds = model(valid_X)
    t_acc = accuracy(t_preds, train_Y)
    v_acc = accuracy(v_preds, valid_Y)
    t_loss = self.loss_func(t_preds, train_Y)
    v_loss = self.loss_func(v_preds, valid_Y)
    print(f"Accuracy train: {t_acc:.3f} (loss: {t_loss:.3f}), validation: {v_acc:.3f} (loss: {v_loss:.3f})")

learner = Learner(model, databunch, F.cross_entropy, optim.SGD)
learner.fit(5, 0.5)

assert learner.state.recorder['trn_accuracy'] > 0.9
assert learner.state.recorder['val_accuracy'] > 0.9

Epoch 0/5. Validation metrics. Loss: 0.101, accuracy: 0.987
Epoch 1/5. Validation metrics. Loss: 0.075, accuracy: 0.984
Epoch 2/5. Validation metrics. Loss: 0.073, accuracy: 0.988
Epoch 3/5. Validation metrics. Loss: 0.073, accuracy: 0.990
Epoch 4/5. Validation metrics. Loss: 0.070, accuracy: 0.996
Accuracy train: 0.977 (loss: 0.069), validation: 0.985 (loss: 0.048)


In [0]:
# DataBunch, @property
# Learner
# CAllbackHandler
# TrainEvalCallback
# Runner
# AvgStatsCallback
#--