# Model Training Patterns

## Stochastic Gradient Descent
- On large datasets stochastic gradient descent (SGD) is applied to mini-batches
- This is called stochastic SGD and extensions of SGD e.g. Adam, Adamgrad etc... are the de facto optimiser used in modern-day machine learning frameworks

- SGD requires training to take place iteratively on small batches therefore training happens in a loop
- SGD finds a minimum, but not a closed-form solution, and so we have to detect whether the model convergence has happened
- As a result, the error (called the loss) on the training dataset has to be monitored.
- Overfitting can happen if the model complexity is higher than can be affored by the size and coverage of the dataset
- It is difficult to know if the compexity is too high until we actually train the model on the dataset
- Therefore, evaluation needs to be done within the training loop and error metrics on a witheld split of the training data (validation set).
- Because training and validation datasets have been used i the training loop it is necessary to withhold yet another split of the training dataset called the test set.
- Metrics are reported on the test set

## Design Pattern 11: Useful Overfitting

- Want to intentionally overfit on the training dataset
- Perform training without regularisation, dropout, validation dataset or early stopping

### Problem

- Goal of a ML model is to generalise well to make good predictions on unseen data.
- If the model overfits then the ability to generalise suffers and so do future predictions

- For example, imagine we have system to model the physical enviroment
- The model carries out iterative, numerical calculations to calculate the precise state of the system
- Suppose all observations have a finite number of possibilites e.g. temperature is limited to 60 - 80 degrees celcius in increments of 0.01.
- We can then create a training dataset for the ML system consisting of the complete input space and calculate lavels using the physical model
- Splitting the training dataset would be counterproductive because we would then be expecting the model to learn parts of the input space it will not have seen in the training dataset.

### Solutions

- In the above scenario, there is no "unseen" data that needs to be generalised to, since all possible inputs have been tabulated
- If all possible inputs to a model can be tabulated there is no such thing as overfitting

- Typically, overfitting of the training dataset in this way causes the model to give misguided predictions on new, unseen datapoints
- The difference here is that we know in advance there won't be unseen data

### Why it Works

- If all possible inputs can be tabulated, then an overfit model will make the same predictions as the "true" model if all possible inputs are trained for, so overfitting is not a concern

- Overfitting is useful when:
    - There is no noise, so the labels are accurate for all instances
    - You have the complete dataset as your disposal (you have all the examples there are). In this case, overfitting becomes interpolating the dataset

## Design Pattern 12: Checkpoints

- With checkpoints we store the full state of the model periodically so that we have partially trained models available
- These models can serve as the final model in the case of the final model or as a starting point for continued training

### Problem

- More complex model, more data is needed to train effectively
- More complex models tend to have more tunable parameters
- As model size increases the longer it takes to fit on one batch of data 
- As data increases the number of batches increases
- In terms of computational complexity this is a double whammy

- When training for a long time the chances of machine failure increases.
- If there is a problem we would like to resume from an intermediate point

### Solution

- At the end of every epoch save the model state
- If a machine failure occures we can resume from the saved state and restart
- Make sure the full model state is saved not the just the model
- Once training is complete and exported it is usually only the information required to make a prediciton is saved

- Good to save information about the training loop as well e.g.
    - Learning rate in a learning rate scheduler
    - Batch number
- Saving a full model state so that is can be resumed is called checkpointing
- Model states changes with every batch but there is much overhead to save at every batch so do it at the end of each epoch

##### Example to save model state in pytorch

```
torch.save({
    'epoch': epoch,
    'model_stat_dict': model.state_dict(),
    'optimiser_state_dict': optimiser.state_dict(),
    'loss': loss,
    ...
}, PATH)
```

### Why it Works

- Most ML frameworks can result from a saved checkpoint
- Checkpoints are designed for mainly for resilience, their availability however opens other use cases
- Partially trained models are usually more generalisable that models created later iterations 
    - [See here](https://playground.tensorflow.org/)
    
### Trade-Offs and Alternatives

- Saving checkpoints allows us to implement early stopping and fine-tuning capabilities

#### Early Stopping

- Typically, the longer you train the lower your training loss
- At a certain point the loss on the validation set may stoping decreasing
- If you being to overfit the error on the validation set may increase
- Handy to look at the validation error at the end of every epoch and stop the training process when the validation error is more than that of the previous epoch

#### Checkpoint Selection

- It's not uncommon for the validation error to decrease, increase slightly then decrease again, therefore early stopping of the validation loss increases may not be optimum
- This is because the training initally focuses on common cases, then begings to look at rarer cases [Paper](https://arxiv.org/abs/1912.02292)
- Therefore training should continue for a while longer

#### Regularisation

- Instead of early stopping or checkpoint selection, it can be helpful to add L2 regularisation to your model so that the validation error does not increase.
- Instead, both the training loss and the validation error plateau. We term such a training loop where both training and validation plateau a well-behaved training loop

- Regularisation might be better than early stopping is that regularisation allows you to use the entire dataset to change the weights of the model
- With early stopping you have to decide where to stop on the validation set so some data is wasted in this set

#### Two-Splits

- Recommended to split data into two parts: a training set and evaluation set
- Evaluation set plays the part of the test dataset during training

- A larger training dataset allows for a more complex model and the more accurate the model can get
- Using regularisation rather than early stopping or checkpoint selection allows you to use a larget training dataset
- During the experimentation phase e.g. hyperparm and model architecture exploring, early stopping should be turned off
- This ensures the model has enough capacity to learn the predictive patterns.

- When training a model for prod be prepped for continuous evaluation and model retraining

#### Fine Tuning