# Homework 3
### Marco Sicklinger, April 2021

### Prerequisites: modules & functions definitions


In [13]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

## Assignment 1
The custom loss function for the L1 norm regularization can be defined as follows.

In [12]:
def loss_ReL1(model, decay = .1, loss_ft = None):
    if loss_ft is None:
        return None
    
    L1_norms = [par.norm(1).item() for name, par in model.named_parameters() if 'weight' in name]
    return loss_ft + decay*sum(L1_norms)

It's usage can be found in `02-sgd-training.ipynb`.  
A word of caution: if the user who implements the `model` names every layer with labels containing the word "weight", the function ends up adding also the biases' norms to the regularization term.

## Assignment 2

Early stopping can be implemented as a function object as follows

In [15]:
class early_stopping:
    """
    Early stopping procedure
    """

    def __init__(self, wordy = False, tolerance = .5, saving_path = 'early_stopping.pt', messenger = print):
        """
        Members
        -------
        _wordy: bool
                boolean for printing (True) or not (False) updates
        _path:  str
                string indicating where to save best model
        _messenger: function
                    function used to print class messages
        _tolerance: Number
                    maximum allowed change in generalized loss 
        _current:   Number
                    variable for saving current values of the loss function
        _min:   Number
                variable for saving the best loss computed up to current epoch
        _generalized_loss:  Number
                            variable for saving the change in loss function from previous to current step
        early_stop: bool
                    boolean set to `True` if threshold is reached, `False` otherwise

        Parameters
        ----------
        wordy:  bool
                boolean passed as `True` if class messages are requested, `False` otherwise
        saving_path:    str
                        string indicating a path where to save best model
        messenger:  function
                    function used to print class messages 
        """
        
        self._wordy = wordy
        self._path = saving_path
        self._messenger = messenger
        self._tolerance = threshold
        self._current = None
        self._min = np.Inf
        self._generalized_loss = 0
        self.early_stop = False

    def __call__(self, loss, model):
        """
        Parameters
        ----------
        loss:   Number
                computed loss function at current epoch
        model:  class[torch.nn.Module]
                model used for learning
        """

        # initialize current loss to the passed computed loss
        # and save minimum found (trivial in this case)
        if self._current is None:
            self._current = loss
            self._min = loss
            self.save_model(loss, model)
        # if loss decreases update minimum and save.
        elif loss < self._current:
            self._current = loss
            if self._wordy:
                self._messenger('early_stopping class message: good step - generalized loss {:.3f}'.format(self.generalized_loss(loss, self._min)))
            self._min = loss
            self.save_model(loss, model)
        # if tolerance is reached, set member `early_stop` to `True`
        # so to allow the user to stop the procedure
        elif loss >= self._current:
            gen_loss = self.generalized_loss(loss, self._min)
            if self._wordy:
                self._messenger('early_stopping class message: bad step - generalized loss {:.3f}'.format(gen_loss))
            if gen_loss > self._tolerance:
                self.early_stop = True

    def save_model(self, loss, model):
        """
        Function for saving the model

        Parameters
        ----------
        loss:   Number
                computed loss function at current epoch
        model:  class[torch.nn.Module]
                model used for learning
        """

        torch.save(model.state_dict(), self._path)

    def generalized_loss(self, loss, Min):
        """
        Function for computing the generalized loss

        Parameters
        ----------
        loss:   Number
                computed loss function
        Min:    Number
                minimum loss function computed up to current epoch
        """
        return loss/Min-1