# Training script for NAMT-10 Hackathon
**Learning outcome:** Train a classification model (one vs. all)
<br> 
Therefore, use the CheXpertDataLoader from the first day and use it for training and evaluating your model.
<br>
<br>

Some challenges, you should keep in mind:
1. What can you do to handle data imbalance? <font color="green">*Oversampling minority class, weighted cross-entropy loss*</font> 
2. What can you do to learn from few samples and how can you prevent overfitting? <font color="green">*Data augmentation, pre-trained models, early stopping*</font> 


<font color="green">*In baseline script only accuracy_score is imported from sklearn.metrics*</font> 

In [34]:
import logging
import os
import random
import sys
import argparse
import json
import pickle
import time
import inspect
from pathlib import Path
import numpy as np
import torch
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.tensorboard import SummaryWriter

from tqdm.notebook import tqdm
import timm
import timm.optim


from sklearn.metrics import (
    accuracy_score,
    balanced_accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    precision_recall_curve,
    roc_curve,
    auc,
    classification_report)

                 
from key2med.data.loader import CheXpertDataLoader, ColorCheXpertDataLoader, StudiesDataLoader
from key2med.models.CBRTiny import CBRTiny
   

In [35]:
# Set logger 
logger = logging.getLogger(__name__)

## Define some basic functions
### Functions for saving / loading pickle objects
During training we want to save the current epoch, training and validation loss, etc. to monitor our model performance

In [36]:
def save_obj_pkl(path, obj):
    with open( path, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj_pkl(path ):
    with open( path, 'rb') as f:
        return pickle.load(f)

### Functions for evaluating your model
You can either write your own evaluation metris and save it as a metric_dict or use pre-defined metrics from sklearn
<br>
**Which metrics are suitable to evaluate your model performance?**
<br>
<font color="green">*In baseline script, only accuracy is used*</font> 

In [37]:
# Define evaluation metrics you want to use    
def get_metric_dict(preds, y_true, y_pred):
 
    metric_dict = {'Num_samples':len(y_true)}
    metric_dict['Num'] = int(y_true.sum())
    metric_dict['Acc'] = accuracy_score(y_true, y_pred)
    metric_dict['bAcc'] = balanced_accuracy_score(y_true, y_pred)
    metric_dict['Precision'] =  precision_score(y_true=y_true, y_pred=y_pred, zero_division = 0)
    metric_dict['Recall'] =  recall_score(y_true=y_true, y_pred=y_pred, zero_division = 0)
    metric_dict['F1'] =  f1_score(y_true=y_true, y_pred=y_pred, zero_division = 0)
    # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html
    #Note that in binary classification, recall of the positive class is also known as “sensitivity”; recall of the negative class is “specificity”.
    metric_dict['Sensitivity'] =  recall_score(y_true=y_true, y_pred=y_pred, zero_division = 0)
    metric_dict['Specificity'] =  recall_score(y_true=~(y_true>0), y_pred=~(y_pred>0), zero_division = 0)
    fpr, tpr, _ = roc_curve(y_true, preds[:,1])
    metric_dict['AUC'] = auc(fpr, tpr)

    return metric_dict

# Define a evaluation function 
def eval_model(args, model, dataloader, dataset='valid'):
    
    model.eval()
    
    preds = []
    y_preds = []
    y_true = []
    for batch in dataloader:
        inputs, targets = batch
        inputs = inputs.to(args['device'])
        targets = targets.squeeze(dim=1).detach().cpu().numpy()
        y_true += list(targets)
        cur_preds = torch.nn.functional.softmax(model(inputs), dim=-1).detach().cpu().numpy()
        preds += list(cur_preds)
        y_preds += list( (cur_preds[:,1] > 0.5).astype(int))
        
    preds, y_preds, y_true =  np.asarray(preds), np.asarray(y_preds), np.asarray(y_true)
    metric_dict = get_metric_dict(preds, y_true, y_preds)
    with open(args['output_dir']+f'results_{dataset}_{args["class_positive"]}.json', 'w', encoding='utf-8') as file:
        json.dump(metric_dict, file, indent=2)
    with open(args['output_dir']+f'results_{dataset}_{args["class_positive"]}.pkl', 'wb') as file:
        pickle.dump([metric_dict, y_true, y_preds, preds], file)

## Set default training settings
For example define data path, batch size, number of epochs, etc.
<br>
Also specify here the class you are working with (Edema, Atelectasis, Cardiomegaly, Consolidation, Pleural Effusion)
<br>
<font color="green">*In baseline script do_weight_loss_even, do_upsample, do_early_stopping and early_stopping_patience don't exist*</font> 

In [38]:
args = {'seed': 42, # Set seed number if you want to reproduce results, else set it to None
        'data_dir': '/home/admin_ml/NAMT/data/CheXpert-v1.0-small', # path to Chexpert data
        'class_positive': 'Edema', # Set class name you are working on for one vs. all classification
        'channel_in': 3, # Number of input channels (3 because of color for pre-trained imagenet)
        'freeze': True, # freeze model weights (if you use pre-trained model)
        'model_to_load_dir': None, # model path, if you want to continue training
        'output_dir': None,
        'num_epochs': 10, # number of epochs for training the model
        'max_steps': 100, # Total number of training steps. Low number for debugging, negative number for no limit.
        'do_train': True, 
        'max_dataloader_size': None, # Set to low number for debugging. default = None (no limit) 
        'view': 'Frontal', # For DataLoader, do you want to load Frontal, Lateral or both views?
        'batch_size': 24,
        'num_workers': 4, # For DataLoader
        'lr': 1e-3, # initial learning rate 
        'wd': 1e-6, # weight decay 
        'do_eval': True, # set to True if validation and test data should be evaluated after training
        'eval_steps': 500, # Number of batches/steps to eval. Default 500.
        'do_weight_loss_even': True, # Set to true if you want to use weighted loss function
        'do_upsample': True,
        'no_cuda': False,
        'do_early_stopping': False, # Set to true, if you want to use early stopping 
        'early_stopping_patience': 10 # Stop training if after x steps validation loss has not increased 
       }


args['basePath'] = os.path.dirname(os.path.realpath(globals()['_dh'][0]))+os.sep

if args['output_dir'] is None:
    args['output_dir'] = f'{args["basePath"]}training_output{os.sep}'
else:
    if args['output_dir'][-1] != os.sep: args['output_dir'] += os.sep

if args['model_to_load_dir'] is not None:
    if args['model_to_load_dir'][-1] != os.sep: args['model_to_load_dir'] += os.sep

# set device (cuda if available or cpu)
args['device'] = torch.device('cuda:0' if (torch.cuda.is_available() and not args['no_cuda']) else 'cpu')   
 
args

{'seed': 42,
 'data_dir': '/home/admin_ml/NAMT/data/CheXpert-v1.0-small',
 'class_positive': 'Edema',
 'channel_in': 3,
 'freeze': True,
 'model_to_load_dir': None,
 'output_dir': '/home/admin_ml/NAMT/training_output/',
 'num_epochs': 10,
 'max_steps': 100,
 'do_train': True,
 'max_dataloader_size': None,
 'view': 'Frontal',
 'batch_size': 24,
 'num_workers': 4,
 'lr': 0.001,
 'wd': 1e-06,
 'do_eval': True,
 'eval_steps': 500,
 'do_weight_loss_even': True,
 'do_upsample': True,
 'no_cuda': False,
 'do_early_stopping': False,
 'early_stopping_patience': 10,
 'basePath': '/home/admin_ml/NAMT/',
 'device': device(type='cuda', index=0)}

Setup logging

In [39]:
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
    filename=args['output_dir']+'train.log',
    filemode='w',)

Set seed if args['seed'] is not None, to reproduce your results later on

In [40]:
g = None # Generator for seed
worker_init_fn = None
if args['seed'] is not None:
    logger.info(f'Applying seed {args["seed"]}')
    #https://pytorch.org/docs/stable/notes/randomness.html
    torch.manual_seed(args['seed'])
    random.seed(args['seed'])
    np.random.seed(args['seed'])
    g = torch.Generator()
    g.manual_seed(args['seed'])

2022-02-14 13:47:23 Qilab-003 __main__[545493] INFO Applying seed 42


## Reading in the data
Use the ColorChexpetDataLoader from key2med.data.loader
<br>
<font color="green">*In baseline scirpt use_upsampling and upsample_labels don't exist*</font> 

In [41]:
dataloader = ColorCheXpertDataLoader( #CheXpertDataLoader #for 1 ch
        data_path=args['data_dir'],
        batch_size=args['batch_size'],
        img_resize=224,
        splits="train_valid_test",
        channels=args['channel_in'],
        do_random_transform=True,
        use_cache=True, #False
        in_memory=False, #True
        max_size= args['max_dataloader_size'],
        use_upsampling = args['do_upsample'],
        upsample_labels = [args['class_positive']],
        plot_stats=False,
        n_workers=args['num_workers'],
        frontal_lateral_values = [args['view']], 
        label_filter = [args['class_positive']],
        uncertain_to_one = [args['class_positive']],
        uncertain_to_zero = [],
    )

2022-02-14 13:47:24 Qilab-003 key2med.data.datasets[545493] INFO Found labels in /home/admin_ml/NAMT/data/CheXpert-v1.0-small/train.csv: ['Edema']
Reading label csv file /home/admin_ml/NAMT/data/CheXpert-v1.0-small/train.csv: 223414it [00:05, 42463.49it/s]
2022-02-14 13:47:31 Qilab-003 key2med.data.datasets[545493] INFO Reading data from cache file /home/admin_ml/NAMT/data/CheXpert-v1.0-small/cache/53d9231c52e3f5e3e18101b1f7ab219745cb9d358f16ce0398d4c8d5082761c1
2022-02-14 13:47:31 Qilab-003 key2med.data.datasets[545493] INFO Found data of size (3, 224, 224) in file /home/admin_ml/NAMT/data/CheXpert-v1.0-small/cache/53d9231c52e3f5e3e18101b1f7ab219745cb9d358f16ce0398d4c8d5082761c1
2022-02-14 13:47:31 Qilab-003 key2med.data.datasets[545493] INFO Found labels in /home/admin_ml/NAMT/data/CheXpert-v1.0-small/train.csv: ['Edema']
Reading label csv file /home/admin_ml/NAMT/data/CheXpert-v1.0-small/train.csv: 223414it [00:04, 48511.23it/s]
2022-02-14 13:47:36 Qilab-003 key2med.data.datasets[54

216,683 samples for training
19,040 samples for validation
202 samples for testing


## Define the model you want to use for training
You can use here your own implemented model or the CBRTiny model by Raghu et al. (https://arxiv.org/pdf/1902.07208.pdf), which is already implemented (from key2med.models.CBRTiny import CBRTiny)

In [42]:
model = CBRTiny(num_classes=2, channel_in=args['channel_in']).to(args['device'])

Another option is to use the timm library, where many well-known models are already implemented.
<br>
You can find the timm documentation here: https://fastai.github.io/timmdocs/
<br>
You can call up for example all existing pretrained efficient models in this way:

In [43]:
avail_pretrained_models = timm.list_models('eff*',pretrained=True)
# List number of all found models and the first five models
len(avail_pretrained_models), avail_pretrained_models[:5]

(15,
 ['efficientnet_b0',
  'efficientnet_b1',
  'efficientnet_b1_pruned',
  'efficientnet_b2',
  'efficientnet_b2_pruned'])

We use here for example the efficientnetb0 model, pretrained on ImageNet (Set num_classes to 2 for one vs. all classification):

In [44]:
model = timm.create_model('efficientnet_b0', num_classes=2, in_chans=args['channel_in'], pretrained=True).to(args['device'])

2022-02-14 13:47:36 Qilab-003 timm.models.helpers[545493] INFO Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth)


If you have already startet the training, you can load the last checkpoint in this way:

In [45]:
if args['model_to_load_dir'] is not None:
    checkpoint = torch.load(osp.join(args['model_to_load_dir'], 'best_model.pth'))
    model.load_state_dict(checkpoint['model_state_dict'])

<font color="green">**What can you do to prevent overfitting?**
<br>
If you use a pre-trained model, you can for example only train the classification layer and freeze all other model weights</font>

In [46]:
if args['freeze']:
    logger.info(f'Freezing model')
    args['output_dir'] += 'frozen'+os.sep
    for name, param in model.named_parameters():
        if 'classifier' not in name:    param.requires_grad = False
else:
    args['output_dir'] += 'unfrozen'+os.sep

2022-02-14 13:47:36 Qilab-003 __main__[545493] INFO Freezing model


## Define optimizer and learning rate scheduler
With timm library you can also use many pred-defined optimizers.
<br>
List all available optimizers:

In [47]:
[cls_name for cls_name, cls_obj in inspect.getmembers(timm.optim) if inspect.isclass(cls_obj) if cls_name !='Lookahead']

['AdaBelief',
 'Adafactor',
 'Adahessian',
 'AdamP',
 'AdamW',
 'Nadam',
 'NovoGrad',
 'NvNovoGrad',
 'RAdam',
 'RMSpropTF',
 'SGDP']

Here, we use for example the *AdamP* Optimizer

In [48]:
optimizer = timm.optim.create_optimizer_v2(model,
                                           optimizer_name='AdamP',
                                           learning_rate=args['lr'],
                                           weight_decay=args['wd'])

Now we define a learning rate scheduler using pytorch (for more infos https://pytorch.org/docs/stable/optim.html)
<br>
We can use for example the One Cylce Learning Rate Scheduler:

In [49]:
num_steps = args['num_epochs']*len(dataloader.train) if args['max_steps']<0 else args['max_steps'] #for Debug !
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                max_lr=args['lr'],
                                                total_steps=num_steps)  

## Use tensorboard for monitoring your model performance
To do this, you need to set up a *SummaryWriter* that stores the current epoch count, current learning rate, training and validation loss, and current time.
<br>
Also, we store everything in the dictionary *writer_dict*, so you can create your own plots at the end.

In [50]:
Path(args['output_dir']).mkdir(parents=True, exist_ok=True)  

writer = SummaryWriter(args['output_dir']+os.sep+'runs')
writer_dict = {
                'epochs': [],#np.zeros(howOftenValid*howOftenRepeat,dtype=int),
                'lr': [], #np.zeros(howOftenValid*howOftenRepeat),
                'loss_train': [],#np.zeros(howOftenValid*howOftenRepeat),
                'loss_valid': [],#np.zeros(howOftenValid*howOftenRepeat),
                'walltime': [],#np.zeros(howOftenValid*howOftenRepeat)
                }

## Define loss function
You can for example use the cross entropy loss for a classification task
<br>
<br>
<font color="green">**What can you do here for handling class imbalance?**
<br>
For handling class imbalance, you can for example use a weighted cross entropy loss.
<br>
You can finde more details on the implementation here: https://discuss.pytorch.org/t/weights-in-weighted-loss-nn-crossentropyloss/69514/2</font> 


In [51]:
# Calculate weights for weighted cross entropy loss
args['loss_weights'] = [1, 1]
if args['do_weight_loss_even']:
     #https://discuss.pytorch.org/t/weights-in-weighted-loss-nn-crossentropyloss/69514/2
    num_total = len(dataloader.train.dataset.all_labels())
    num_pos = sum(dataloader.train.dataset.all_labels())
    num_neg = num_total-num_pos
    args['loss_weights'] = [1 - (x / num_total) for x in [ num_neg, num_pos ] ]

args['loss_weights'] = torch.tensor(args['loss_weights']).float().cuda()
logger.info(f'LOSS WEIGHTS: {args["loss_weights"]}')

# Define cross entropy loss
loss_function = torch.nn.CrossEntropyLoss(weight=args['loss_weights'])

2022-02-14 13:47:43 Qilab-003 __main__[545493] INFO LOSS WEIGHTS: tensor([0.4616, 0.5384], device='cuda:0')


## Write training loop

<font color="green">**What can you do here to prevent overfitting?**
<br>
For example, you can use early stopping or only save the best model found during validation and use it for inference.</font> 

In [52]:
best_loss = np.inf # for find best model on validation set
eval_steps_since_last_better_model = 0 # for find best model on validation set 
if args['do_train']:

    model.train()
    steps=0
    steps_since_last_eval=0
    logger.info('Start Training')
    for epoch in tqdm(range(args['num_epochs'])):

        writer_dict['epochs'].append(epoch)
        writer.add_scalar('utils/epochs', epoch, steps) # for tensorboard

        for batch in tqdm(dataloader.train, leave=False):
            steps += 1
            steps_since_last_eval +=1

            if steps > num_steps: break

            inputs, targets = batch
            inputs = inputs.to(args['device'])
            targets = targets.squeeze(dim=1).long().to(args['device'])
            outputs = model(inputs)
            loss = loss_function(outputs, targets)

            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            writer_dict['walltime'].append( time.time() )
            lr = optimizer.param_groups[0]['lr']
            writer_dict['lr'].append(lr)
            writer.add_scalar('utils/lr', lr, steps)
            loss = loss.detach().cpu().numpy()
            writer_dict['loss_train'].append(loss)
            writer.add_scalar('loss/train', loss, steps)

            if steps_since_last_eval >= args['eval_steps']:
                steps_since_last_eval = 0
                if dataloader.validate is not None:
                    model.eval()
                    mean_loss = 0
                    for batch in dataloader.validate:
                        inputs, targets = batch
                        inputs = inputs.to(args['device'])
                        targets = targets.squeeze(dim=1).long().to(args['device'])
                        outputs = model(inputs)
                        mean_loss += loss_function(outputs, targets).detach().cpu().numpy()

                    mean_loss /= len(dataloader.validate)
                    writer_dict['loss_valid'].append(mean_loss)
                    writer.add_scalar('loss/valid', mean_loss, steps) # for tensorboard
                    
                    # Save best model
                    model.train()
                    if mean_loss < best_loss:

                        best_loss = mean_loss
                        eval_steps_since_last_better_model = 0

                        with open( args['output_dir']+'best_loss.txt', 'w') as file:
                            print( f'loss {best_loss}\n step {steps}', file=file )

                        torch.save({
                                    'step': steps,
                                    'model_state_dict': model.state_dict(),
                                    'optimizer_state_dict': optimizer.state_dict()
                                   }, args['output_dir']+'best_model.pth')
                    else:
                        # validation loss is higher than the best loss find in previous iterations 
                        # count steps for early stopping
                        eval_steps_since_last_better_model += 1 
                        
                if args['do_early_stopping']:
                    if eval_steps_since_last_better_model >= args['early_stopping_patience']: break

            if steps >= num_steps: break

    with open( args['output_dir']+'last_loss.txt', 'w') as file:
        print( f'loss {mean_loss}\n step {steps}', file=file )

    torch.save({
                'step': steps,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
               }, args['output_dir']+'last_model.pth')

    save_obj_pkl( args['output_dir']+'tensorboard_writer.pkl', writer_dict )       
    writer.close()
    logger.info(f'End of training')

2022-02-14 13:47:43 Qilab-003 __main__[545493] INFO Start Training


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/9029 [00:00<?, ?it/s]

  0%|          | 0/9029 [00:00<?, ?it/s]

  0%|          | 0/9029 [00:00<?, ?it/s]

  0%|          | 0/9029 [00:00<?, ?it/s]

  0%|          | 0/9029 [00:00<?, ?it/s]

  0%|          | 0/9029 [00:00<?, ?it/s]

  0%|          | 0/9029 [00:00<?, ?it/s]

  0%|          | 0/9029 [00:00<?, ?it/s]

  0%|          | 0/9029 [00:00<?, ?it/s]

  0%|          | 0/9029 [00:00<?, ?it/s]

2022-02-14 13:47:58 Qilab-003 __main__[545493] INFO End of training


## Evaluate your model
If you only want to evualate your model on the validation and test set, you have to set args['do_eval'] to True.
<br>
Then the evaluation function will be called after the training is completed. 

In [53]:
if args['do_eval']:
    if dataloader.validate is not None:
        logger.info('Start evaluation valid')
        eval_model(args, model, dataloader.validate, dataset='valid')

    if dataloader.test is not None:
        logger.info('Start evaluation test')
        eval_model(args, model, dataloader.test, dataset='test')


2022-02-14 13:47:58 Qilab-003 __main__[545493] INFO Start evaluation valid
2022-02-14 13:48:16 Qilab-003 __main__[545493] INFO Start evaluation test
