In [None]:
# default_exp learn

# learn
> Classes and functions for training and predicting.

In [None]:
#hide
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
#hide
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:85% !important; }</style>"))

In [None]:
#export
from lemonade.setup import * 
from lemonade.metrics import * 
from fastai.imports import *

In [None]:
#hide
from nbdev.showdoc import *

## Helpers 

## `save_to_checkpoint()` -

In [None]:
#export
def save_to_checkpoint(epoch_index, model, optimizer, path):
    '''Save model and optimizer state_dicts to checkpoint'''
    if not os.path.isdir(path): Path(path).mkdir(parents=True, exist_ok=True)
    torch.save({
        'epoch_index':epoch_index,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
        }, f'{path}/checkpoint.tar')
    print(f'Checkpointed to "{path}/checkpoint.tar"')

## `load_from_checkpoint()` -

In [None]:
#export
def load_from_checkpoint(model, path, optimizer=None, for_inference=False):
    '''Load from checkpoint - model, optimizer & epoch_index for training or just model for inference'''
    
    print(f'From "{path}/checkpoint.tar", loading model ...')
    chkpt = torch.load(f'{path}/checkpoint.tar')
    model.load_state_dict(chkpt['model_state_dict'])
    model = model.to(DEVICE)
    
    if for_inference:
        return model
    else:
        print(f'loading optimizer and epoch_index ...')
        optimizer.load_state_dict(chkpt['optimizer_state_dict'])
        return chkpt['epoch_index'], model, optimizer

## `get_loss_fn()` - 

In [None]:
#export
def get_loss_fn(pos_wts):
    '''Return `nn.BCEWithLogitsLoss` with the given positive weights'''
    return nn.BCEWithLogitsLoss(pos_weight=pos_wts).to(DEVICE)

## class `RunHistory` - 

In [None]:
#exports
class RunHistory:
    '''Class to hold training and prediction run histories'''
    def __init__(self, labels):
        self.train = self.valid = self.test = pd.DataFrame(columns=['loss', *labels])
        self.y_train = self.yhat_train = self.y_valid = self.yhat_valid = self.y_test = self.yhat_test = []
        self.prediction_summary = pd.DataFrame()

**Note**
- `y` and `y_hat` are actual (ground truth and predicted) values from 
    - the last epoch of `fit()` for `train` and `valid`, 
    - the last run of `predict()` for `test` 
- `train`, `valid` and `test` are calculated loss and accuracy values at the end of each epoch
    - for `test` there is only a single epoch in each run

## fit & predict

### `BCEWithLogitsLoss` & `torch.sigmoid`

- Using `BCEWithLogitsLoss` because its [more numerically stable than using a plain Sigmoid followed by a BCELoss](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html)
- Also accomodates multi-label classification & class-imbalanced datasets due to the use of `pos_weights`
- But this means that the model does not do a final sigmoid at the output layer, since thats done by the loss function
- So need to do a `torch.sigmoid` on the `yhat`s before using them to calculate accuracy
- 2 good discussions on this topic
    - [How to interpret the probability of classes in binary classification?](https://discuss.pytorch.org/t/how-to-interpret-the-probability-of-classes-in-binary-classification/45709)
    - [BCEWithLogitsLoss and model accuracy calculation](https://discuss.pytorch.org/t/bcewithlogitsloss-and-model-accuracy-calculation/59293/2)

### `train()` -

In [None]:
#export
def train(model, train_dl, train_loss_fn, optimizer, lazy=True):
    '''Train model using train dataset'''
    yhat_train = y_train = Tensor([])
    train_loss = 0.
    model.train()

    for xb, yb in train_dl:
        if lazy: xb, yb = [x.to_gpu(non_block=True) for x in xb], yb.to(DEVICE, non_blocking=True)
        y_hat  = model(xb)
        loss   = train_loss_fn(y_hat, yb)

        train_loss += loss.item()
        yhat_train = torch.cat((yhat_train, torch.sigmoid( y_hat.cpu().detach() ) ))
        y_train    = torch.cat((y_train, yb.cpu().detach()))

        loss.backward()
        optimizer.step()     
        model.zero_grad(set_to_none=True)
        
    return train_loss, yhat_train, y_train, model

### `evaluate()` -

In [None]:
#export
def evaluate(model, eval_dl, eval_loss_fn, lazy=True):
    '''Evaluate model - used for validation (while training) and prediction'''
    yhat_eval = y_eval = Tensor([])
    eval_loss = 0.
    model.eval()
    
    with torch.no_grad():                                  
        for xb, yb in eval_dl:  
            if lazy: xb, yb = [x.to_gpu(non_block=True) for x in xb], yb.to(DEVICE, non_blocking=True)
            y_hat  = model(xb)                             

            eval_loss += (eval_loss_fn(y_hat, yb)).item()
            yhat_eval = torch.cat((yhat_eval, torch.sigmoid( y_hat.cpu().detach() ) ))
            y_eval    = torch.cat((y_eval, yb.cpu().detach()))    
        
    return eval_loss, yhat_eval, y_eval

### `fit()` -

In [None]:
#export
def fit(epochs, history, model, train_loss_fn, valid_loss_fn, optimizer, accuracy_fn, 
        train_dl, valid_dl, lazy=True, to_chkpt_path=None, from_chkpt_path=None, verbosity=0.75):
    '''Fit model and return results in `history`'''
    
    if from_chkpt_path: 
        last_epoch, model, optimizer = load_from_checkpoint(model, from_chkpt_path, optimizer, for_inference=False)
        start_epoch = last_epoch+1
    else:
        start_epoch = 0
    end_epoch = start_epoch+(epochs-1)
    print_epochs = np.linspace(start_epoch, end_epoch, int(epochs*verbosity), endpoint=True, dtype=int)
    train_history, valid_history = [], []
    
    print('{:>5} {:>16} {:^20} {:>25} {:^20}'.format('epoch |', 'train loss |', 'train aurocs', 'valid loss |', 'valid aurocs'))
    print('{:-^100}'.format('-'))
    
    for epoch in range (start_epoch, start_epoch+epochs):
        
        train_loss, yhat_train, y_train, model = train(model, train_dl, train_loss_fn, optimizer, lazy)
        valid_loss, yhat_valid, y_valid = evaluate(model, valid_dl, valid_loss_fn, lazy)                
        
        train_loss,   valid_loss   = train_loss/len(train_dl), valid_loss/len(valid_dl)
        train_aurocs, valid_aurocs = accuracy_fn(y_train, yhat_train), accuracy_fn(y_valid, yhat_valid)
        train_history.append([train_loss, *train_aurocs])
        valid_history.append([valid_loss, *valid_aurocs])
        
        if epoch in print_epochs:
            row  = f'{epoch:>5} |'
            row += f'{train_loss:>15.3f} | '
            row += f'{[f"{a:.3f}" for a in train_aurocs[:4]]}'
            row += f'{valid_loss:>19.3f} | '
            row += f'{[f"{a:.3f}" for a in valid_aurocs[:4]]}'
            print(re.sub("',*", "", row))
            
    if to_chkpt_path: save_to_checkpoint(end_epoch, model, optimizer, to_chkpt_path)

    h = history
    h.y_train, h.yhat_train, h.y_valid, h.yhat_valid = y_train, yhat_train, y_valid, yhat_valid
    h.train = h.train.append(pd.DataFrame(train_history, columns=h.train.columns), ignore_index=True)
    h.valid = h.valid.append(pd.DataFrame(valid_history, columns=h.valid.columns), ignore_index=True)

    return h

### `predict()` -

In [None]:
#export
def predict(history, model, test_loss_fn, accuracy_fn, test_dl, chkpt_path, lazy=True):
    '''Predict and return results in `history`'''
    model = load_from_checkpoint(model, chkpt_path, for_inference=True)
    test_loss, yhat_test, y_test = evaluate(model, test_dl, test_loss_fn, lazy) 
            
    test_loss = test_loss/len(test_dl)
    test_aurocs = accuracy_fn(y_test, yhat_test)
    print(f'test loss = {test_loss}')
    print(f'test aurocs = {test_aurocs}')
    
    h = history
    h.test = pd.DataFrame([[test_loss, *test_aurocs]], columns=h.test.columns)
    h.y_test, h.yhat_test = y_test, yhat_test
    
    return h

## Plotting

### `plot_loss()` -

In [None]:
#export
def plot_loss(history_df, title='Loss', axis=None):
    '''Plot loss'''
    if axis == None:
        fig = plt.figure(figsize=(8,5))
        axis = fig.add_axes([0,0,1,1])

    axis.plot(range(len(history_df)), history_df['loss'], label=title)

    axis.set_title(title)
    axis.set_xlabel('epochs')
    axis.set_ylabel('loss')
    axis.legend(loc=0)

### `plot_losses()` -

In [None]:
#export
def plot_losses(train_history, valid_history):
    '''Plot multiple losses (train and valid) side by side'''
    fig, axes = plt.subplots(1,2, figsize=(15,5))
    plt.tight_layout()
    
    plot_loss(train_history, title='Train Loss', axis=axes[0])
    plot_loss(valid_history, title='Valid Loss', axis=axes[1])

### `plot_aurocs()` -

In [None]:
#export
def plot_aurocs(history_df, title='AUROC Scores', axis=None):
    '''Plot AUROC scores'''
    if axis == None:
        fig = plt.figure(figsize=(8,5))
        axis = fig.add_axes([0,0,1,1])
    for lbl in history_df.columns[1:]:
        axis.plot(range(len(history_df)), history_df[lbl], label=f'{lbl} (final: {history_df[lbl].iat[-1]:.3f})')
    axis.set_title(title)
    axis.set_xlabel('epochs')
    axis.set_ylabel('auroc scores')
    axis.legend(loc=0)

### `plot_train_valid_aurocs()` -

In [None]:
#export
def plot_train_valid_aurocs(train_history, valid_history):
    '''Plot train and valid AUROC scores side by side'''
    fig, axes = plt.subplots(1,2, figsize=(15,5))
    plt.tight_layout()
    plot_aurocs(train_history, title='Train - AUROC Scores', axis=axes[0])
    plot_aurocs(valid_history, title='Valid - AUROC Scores', axis=axes[1])

## Summarize

### `plot_fit_results()` - 

In [None]:
#export
def plot_fit_results(history, labels):
    '''All plots after fit - ROC curves, losses and AUROCs'''
    h = history
    train_rocs, valid_rocs = MultiLabelROC(h.y_train, h.yhat_train, labels), MultiLabelROC(h.y_valid, h.yhat_valid, labels)
    plot_train_valid_rocs(train_rocs.ROCs, valid_rocs.ROCs, labels, multilabel=True)
    plot_losses(h.train, h.valid)
    plot_train_valid_aurocs(h.train, h.valid)

### `summarize_prediction()` -

In [None]:
#export
def summarize_prediction(history, labels, plot=True):
    '''Summarize after prediction - plot ROC curves, calculate auroc, optimal threshold and 95% CI for AUROC and return results in `history`'''
    h = history
    test_rocs = MultiLabelROC(h.y_test, h.yhat_test, labels)
    if plot: plot_rocs(test_rocs.ROCs, labels, title='Test ROC curves', multilabel=True)
    print('\nPrediction Summary ...')    
    col_names = ['auroc_score', 'optimal_threshold', 'auroc_95_ci']
    rows = []
    for i, label in enumerate(labels):
        row = [test_rocs.ROCs[label].auroc, test_rocs.ROCs[label].optimal_thresh(), auroc_ci(h.y_test[:,i], h.yhat_test[:,i])]
        rows.append(row)    
    history.prediction_summary = pd.DataFrame(rows, index=labels, columns=col_names)        
    print(history.prediction_summary)
    return history

## `count_parameters()` -

In [None]:
#export
def count_parameters(model, printout=False):
    '''Returns number of parameters in model'''
    total         = sum(p.numel() for p in model.parameters())
    trainable     = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad==False)
    assert total == (trainable+non_trainable)
    print(f'total: {total:,}, trainable: {trainable:,}, non_trainable: {non_trainable:,}')
    return total, trainable, non_trainable

## Export -

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_setup.ipynb.
Converted 01_preprocessing_clean.ipynb.
Converted 02_preprocessing_vocab.ipynb.
Converted 03_preprocessing_transform.ipynb.
Converted 04_data.ipynb.
Converted 05_metrics.ipynb.
Converted 06_learn.ipynb.
Converted 07_models.ipynb.
Converted 08_experiment.ipynb.
Converted 99_quick_walkthru.ipynb.
Converted 99_running_exps.ipynb.
Converted index.ipynb.
