In [29]:
#export
import torch
from torch import optim
from nbs.dl2.exp.nb_02 import getMnistData, assertNearZero
from nbs.dl2.exp.nb_03 import Dataset, createDataLoaders, accuracy
from nbs.dl2.exp.nb_04 import DataBunch
from nbs.dl2.exp.nb_05 import aggregateSchedulers, createCosineSchedulers, cosineScheduler
from nbs.dl2.exp.nb_06 import normalizeVectors, createBetterConvolutionModel
from nbs.dl2.exp.nb_07D import *

In [30]:
%matplotlib inline


In [31]:
xTraining, yTraining, xValidation, yValidation = getMnistData()
xTrainingNormalized, xValidationNormalized = \
    normalizeVectors(xTraining, xValidation)

In [32]:
assertNearZero(xTrainingNormalized.mean())
assertNearZero(xValidationNormalized.mean())
assertNearZero(1 - xTrainingNormalized.std())
assertNearZero(1 - xValidationNormalized.std())


In [33]:
layerSizes = [8, 16, 32, 64, 64]
numberOfClasses = 10
hiddenLayerSize = 75
batchSize = 64

In [60]:
trainingDataSet, validationDataSet = Dataset(xTrainingNormalized[:10000], yTraining[:10000]), Dataset(xValidationNormalized[:10000], yValidation[:10000])
trainingDataLoader, validationDataLoader = createDataLoaders(trainingDataSet, validationDataSet, batchSize)
imageDataBunch = DataBunch(trainingDataLoader, validationDataLoader, numberOfClasses)

In [57]:
phases = [0.3, 0.7]
weightsScheduler = aggregateSchedulers(phases, createCosineSchedulers(0.3, 0.6, 0.2)) 
biasScheduler = aggregateSchedulers(phases, createCosineSchedulers(0.9, 1.8, 0.6))



In [36]:
class ProcessCancellationException(Exception): pass

In [38]:
class TrainingSubscriber(StatisticsSubscriber, HookedSubscriber):

    def __init__(self,
                 lossFunction=torch.nn.functional.cross_entropy,
                 schedulingFunctions=[cosineScheduler(1e-1, 1e-6), cosineScheduler(1e-1, 1e-6)], ):
        super().__init__(name="Training")
        self._optimizer = None
        self._schedulingFunctions = schedulingFunctions
        self._lossFunction = lossFunction

    def preModelTeach(self, model, epochs):
        super().preModelTeach(model, epochs)
        self._optimizer = optim.SGD(model.parameters(), self._schedulingFunctions[0](0))
        self._totalEpochs = epochs

    def postBatchEvaluation(self, predictions, valdationData):
        super().postBatchEvaluation(predictions, valdationData)
        calculatedLoss = self._lossFunction(predictions, valdationData)
        self._teachModel(calculatedLoss)
        self.postBatchLossConsumption(calculatedLoss)

    def _teachModel(self, loss):
        loss.backward()
        self._optimizer.step()
        self._optimizer.zero_grad()

    def preBatchEvaluation(self):
        super().preBatchEvaluation()
        self._annealLearningRate()

    def _annealLearningRate(self):
        for parameterGroup, schedulingFunction in zip(self._optimizer.param_groups, self._schedulingFunctions):
            parameterGroup['lr'] = schedulingFunction(self._currentEpoch / self._totalEpochs)

In [39]:
class TeacherOptimized:
    def __init__(self,
                 dataBunch,
                 trainingSubscriber: TrainingSubscriber,
                 validationSubscriber: ValidationSubscriber):
        self._dataBunch = dataBunch
        self._trainingSubscriber = trainingSubscriber
        self._validationSubscriber = validationSubscriber

    def teachModel(self, model, numberOfEpochs):
        self._notifiyPreTeach(model, numberOfEpochs)
        for epoch in range(numberOfEpochs):
            self._trainModel(model,
                             epoch)
            self._validateModel(model,
                                epoch)
        self._notifiyPostTaught()

    def _notifiyPreTeach(self, model, epochs):
        self._trainingSubscriber.preModelTeach(model, epochs)
        self._validationSubscriber.preModelTeach(model, epochs)

    def _notifiyPostTaught(self):
        self._trainingSubscriber.postModelTeach()
        self._validationSubscriber.postModelTeach()

    def _trainModel(self, model, epoch):
        self._processData(model,
                          self._dataBunch.trainingDataSet,
                          epoch,
                          self._trainingSubscriber)

    def _validateModel(self, model, epoch):
        with torch.no_grad():
            self._processData(model,
                              self._dataBunch.validationDataSet,
                              epoch,
                              self._validationSubscriber)

    def _processData(self,
                     model,
                     dataLoader,
                     epoch,
                     processingSubscriber: Subscriber):
        processingSubscriber.preEpoch(epoch, dataLoader)
        try:
            for _xDataBatch, _yDataBatch in dataLoader:
                processingSubscriber.preBatchEvaluation()
                _predictions = model(_xDataBatch)
                processingSubscriber.postBatchEvaluation(_predictions, _yDataBatch)
        except ProcessCancellationException: pass
        finally:
            processingSubscriber.postEpoch(epoch)


In [40]:
validationSubscriber = ValidationSubscriber()

In [41]:
trainingSubscriber = TrainingSubscriber()

In [42]:
teacher = TeacherEnhanced(imageDataBunch, 
                          trainingSubscriber,
                          validationSubscriber
                         )


In [61]:
convolutionalModelSR1 = createBetterConvolutionModel(numberOfClasses, layerSizes)


In [62]:
accuracy(convolutionalModelSR1(validationDataSet.xVector), validationDataSet.yVector)

tensor(0.1064)

In [63]:
teacher.teachModel(convolutionalModelSR1, 3)

Epoch #0 Training: Loss 2.3020851612091064 Accuracy 0.10712499916553497
Epoch #0 Validation: Loss 0.0 Accuracy 0.1197916641831398
Epoch #1 Training: Loss 2.0847008228302 Accuracy 0.3006249964237213
Epoch #1 Validation: Loss 0.0 Accuracy 0.55629962682724
Epoch #2 Training: Loss 0.8825256824493408 Accuracy 0.7307500243186951
Epoch #2 Validation: Loss 0.0 Accuracy 0.8175843358039856
