# Training script for NAMT-10 Hackathon
**Learning outcome:** Train a classification model (one vs. all)
<br> 
Therefore, use the CheXpertDataLoader from the first day and use it for training and evaluating your model.
<br>
<br>

Some challenges, you should keep in mind:
1. What can you do to handle data imbalance? 
2. What can you do to learn from few samples and how can you prevent overfitting? 


In [4]:
import logging
import os
import random
import sys
import argparse
import json
import pickle
import time
import inspect
from pathlib import Path
import numpy as np
import torch
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.tensorboard import SummaryWriter

from tqdm.notebook import tqdm
import timm
import timm.optim


from sklearn.metrics import (
    accuracy_score)

# import DataLoader
from ... import ... 
# import model, e.g.
from key2med.models.CBRTiny import CBRTiny


In [3]:
# Set logger 
logger = logging.getLogger(__name__)

## Define some basic functions
### Functions for saving / loading pickle objects
During training we want to save the current epoch, training and validation loss, etc. to monitor our model performance

In [4]:
def save_obj_pkl(path, obj):
   ...

def load_obj_pkl(path ):
   ...

### Functions for evaluating your model
You can either write your own evaluation metris and save it as a metric_dict or use pre-defined metrics from sklearn
<br>
**Which metrics are suitable to evaluate your model performance?**

In [5]:
# Define evaluation function 
def eval_model(args, model, dataloader, dataset='valid'):
    
    model.eval()
    
    preds = []
    y_preds = []
    y_true = []
    for batch in dataloader:
        inputs, targets = batch
        ...
        y_true += ...
        preds += ...
        y_preds += ...
        
    metric_dict = {}
    metric_dict['Acc'] = accuracy_score(y_true,_y_pred)
    ... # Add more metrics here

## Define training settings 
For example define data path, batch size, number of epochs, etc.
<br>
Also specify here the class you are working with (Edema, Atelectasis, Cardiomegaly, Consolidation, Pleural Effusion)

## Set default settings
You can define settings here, which you will need later on for reading the dataset and training your model, like number of epochs, learning rate, data directory and the class you are working on

In [2]:
args = {'data_dir': '/data/MEDICAL/datasets/CheXpert_small/CheXpert-v1.0-small', # path to Chexpert data
        'class_positive': 'Edema', # Set class name you are working on for one vs. all classification
        'num_epochs': 10, # number of epochs for training the model
        'lr': 1e-3, # initial learning rate 
        # ...
       }
args

{'data_dir': '/data/MEDICAL/datasets/CheXpert_small/CheXpert-v1.0-small',
 'class_positive': 'Edema',
 'num_epochs': 10,
 'lr': 0.001}

Setup logging

In [None]:
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
    filename=args['output_dir']+'train.log',
    filemode='w',)

## Reading in the data
Use your dataloader from Hackathon#1 

In [None]:
dataloader = CheXpertDataLoader(
        data_path=args['data_dir'], 
        splits="train_valid_test",
        ...
    )

## Define the model you want to use for training
You can use here your own implemented model or the CBRTiny model by Raghu et al. (https://arxiv.org/pdf/1902.07208.pdf), which is already implemented (from key2med.models.CBRTiny import CBRTiny)

In [None]:
model = CBRTiny(num_classes=2, channel_in=args['channel_in']).to(args['device'])

Another option is to use the timm library, where many well-known models are already implemented.
<br>
You can find the timm documentation here: https://fastai.github.io/timmdocs/
<br>
You can call up for example all existing pretrained efficient models in this way:

In [None]:
avail_pretrained_models = timm.list_models('eff*',pretrained=True)
# List number of all found models and the first five models
len(avail_pretrained_models), avail_pretrained_models[:5]

We use here for example the efficientnetb0 model, pretrained on ImageNet (Set num_classes to 2 for one vs. all classification):

In [None]:
model = timm.create_model('efficientnet_b0', num_classes=2, in_chans=args['channel_in'], pretrained=True).to(args['device'])

If you have already startet the training, you can load the last checkpoint in this way:

In [13]:
if args['model_to_load_dir'] is not None:
    checkpoint = torch.load(osp.join(args['model_to_load_dir'], 'best_model.pth'))
    model.load_state_dict(checkpoint['model_state_dict'])

## Define optimizer and learning rate scheduler
With timm library you can also use many pred-defined optimizers.
<br>
List all available optimizers:

In [None]:
[cls_name for cls_name, cls_obj in inspect.getmembers(timm.optim) if inspect.isclass(cls_obj) if cls_name !='Lookahead']

Here, we use for example the *AdamP* Optimizer

In [None]:
optimizer = timm.optim.create_optimizer_v2(model,
                                           opt='AdamP',
                                           lr=args['lr'],
                                           weight_decay=args['wd'])

Now, you should define a learning rate scheduler. You can for example use a learning rate scheduler from pytorch (for more infos https://pytorch.org/docs/stable/optim.html) or define your own function:

In [None]:
def get_lr(...):
    ...

# or:
scheduler = torch.optim.lr_scheduler...

## Use tensorboard for monitoring your model performance
To do this, you need to set up a *SummaryWriter* that stores the current epoch count, current learning rate, training and validation loss, and current time.
<br>
Also, we store everything in the dictionary *writer_dict*, so you can create your own plots at the end.

In [None]:
outputPath = '...'
writer = SummaryWriter(outputPath+os.sep+'runs')
writer_dict = {
                'epochs': [],#np.zeros(howOftenValid*howOftenRepeat,dtype=int),
                'lr': [], #np.zeros(howOftenValid*howOftenRepeat),
                'loss_train': [],#np.zeros(howOftenValid*howOftenRepeat),
                'loss_valid': [],#np.zeros(howOftenValid*howOftenRepeat),
                'walltime': [],#np.zeros(howOftenValid*howOftenRepeat)
                }

## Define loss function
You can for example use the cross entropy loss for a classification task

In [None]:
# Define cross entropy loss
loss_function = torch.nn.CrossEntropyLoss()

## Write training loop

In [None]:
model.train()
steps=0
logger.info('Start Training')
for epoch in tqdm(range(args['num_epochs'])):

    writer_dict['epochs'].append(epoch)
    writer.add_scalar('utils/epochs', epoch, steps) # for tensorboard

    for batch in tqdm(dataloader.train, leave=False):
        steps += 1
        steps_since_last_eval +=1

        inputs, targets = batch
        ...  
        outputs = model(inputs)
        loss = loss_function(outputs, targets)

        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        # append to writer (for Tensorflow) and writer_dict for monitoring your model performance
        lr = optimizer.param_groups[0]['lr']
        writer_dict['lr'].append(lr)
        writer.add_scalar('utils/lr', lr, steps)
        ...
        
        # evaluate your model on validation set
        if dataloader.validate is not None:
            model.eval()
            mean_loss = 0
            
            for batch in dataloader.validate:
               ...
           
            writer_dict['loss_valid'].append(mean_loss)
            writer.add_scalar('loss/valid', mean_loss, steps) # for tensorboard

     

# Save your model:
torch.save(...)

writer.close()
logger.info(f'End of training')

## Evaluate your model
If you want to evaluate your model on the validation and test sets, use the eval_model function defined at the beginning of this notebook

In [None]:
if dataloader.validate is not None:
    logger.info('Start evaluation valid')
    eval_model(args, model, dataloader.validate, dataset='valid')

if dataloader.test is not None:
    logger.info('Start evaluation test')
    eval_model(args, model, dataloader.test, dataset='test')