In [3]:
%matplotlib inline

In [None]:
class TeacherWithHooks():

    def __init__(self,
                 lossFunction=Functional.cross_entropy,
                 accuracyFunction=accuracy,
                 schedulingFunctions=[cosineScheduler(1e-1, 1e-6), cosineScheduler(1e-1, 1e-6)]
                 ):
        self.lossFunction = lossFunction
        self.accuracyFunction = accuracyFunction
        self.schedulingFunctions = schedulingFunctions
        self._initStorage()

    def _initStorage(self):
        self.learningRates, self.losses = [], []
        self.registeredHooks = []

    def teachModel(self, cnnModel, dataBunch, numberOfEpochs):
        self.optimizer = optim.SGD(cnnModel.parameters(), self.schedulingFunctions[0](0))
        self.numberOfEpochs = numberOfEpochs
        self.modelsMeans = [[] for _ in cnnModel.layers]
        self.convolutionModel = cnnModel
        self.modelsStandardDeviations = [[] for _ in cnnModel.layers]
        self._beginTraining()
        for epoch in range(numberOfEpochs):
            self.epoch = epoch
            trainingLoss, trainingAccuracy = self._trainModel(cnnModel, dataBunch.trainingDataLoader)
            print("Epoch #{} Training: Loss {} Accuracy {}".format(epoch, trainingLoss, trainingAccuracy))

            validationLoss, validationAccuracy = self._validateModel(cnnModel, dataBunch.validationDataLoader)
            print("Epoch #{} Validation: Loss {} Accuracy {}".format(epoch, validationLoss, validationAccuracy))
            print("")

    def plotLearningRates(self):
        plotter.plot(self.learningRates)

    def plotLosses(self):
        plotter.plot(self.losses)

    def plotMeans(self):
        for layerOutputMeans in self.modelsMeans: plotter.plot(layerOutputMeans[:200])
        plotter.legend(range(len(self.modelsMeans)))

    def plotStandardDeviations(self):
        for layerOutputSD in self.modelsStandardDeviations: plotter.plot(layerOutputSD[:200])
        plotter.legend(range(len(self.modelsStandardDeviations)))

    def _beginTraining(self):
        self._initStorage()

    def _addStats(self, index, model, inputParameters, outputParameters):
        self.modelsMeans[index].append(outputParameters.data.mean())
        self.modelsStandardDeviations[index].append(outputParameters.data.std())

    def _registerHooks(self):
        for index, modelLayer in enumerate(self.convolutionModel.layers):
            self.registeredHooks.append(modelLayer.register_forward_hook(partial(self._addStats, index)))

    def _unregisterHooks(self):
        for registeredHooks in self.registeredHooks: registeredHooks.remove()
        self.registeredHooks = []

    def _anealLearningRate(self):
        for parameterGroup, schedulingFunction in zip(self.optimizer.param_groups, self.schedulingFunctions):
            scheduledLearningRate = schedulingFunction(self.epoch / self.numberOfEpochs)
            parameterGroup['lr'] = scheduledLearningRate

    def _trainModel(self, cnnModel, trainingDataSet):
        def _teachModel(loss):
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

            # capture metrics
            self.learningRates.append(self.optimizer.param_groups[-1]['lr'])
            self.losses.append(loss.detach().cpu())

        returnItems = self._proccessDataSet(cnnModel,
                                            trainingDataSet,
                                            preEpoch=self._registerHooks,
                                            preEvaluation=self._anealLearningRate,
                                            postEpoch=self._unregisterHooks,
                                            postEvaluation=_teachModel
                                            )
        return returnItems

    def _validateModel(self, cnnModel, validationDataSet):
        with torch.no_grad():
            returnItems = self._proccessDataSet(cnnModel, validationDataSet)
        return returnItems

    def _proccessDataSet(self,
                         cnnModel,
                         dataLoader,
                         preEpoch=lambda: None,
                         preEvaluation=lambda: None,
                         postEvaluation=lambda loss: None,
                         postEpoch=lambda: None
                         ):
        accumulatedLoss, accumulatedAccuracy = 0., 0.
        preEpoch()
        for _xDataSet, _yDataSet in dataLoader:
            preEvaluation()
            _predictions = cnnModel(_xDataSet)
            loss = self.lossFunction(_predictions, _yDataSet)
            postEvaluation(loss)
            accumulatedLoss += loss
            accumulatedAccuracy += self.accuracyFunction(_predictions, _yDataSet)
        postEpoch()
        numberOfBatches = len(dataLoader)
        return accumulatedLoss / numberOfBatches, accumulatedAccuracy / numberOfBatches