# Training script for NAMT-10 Hackathon
**Learning outcome:** Train a classification model (one vs. all)
<br> 
Therefore, use the CheXpertDataLoader from the first day and use it for training and evaluating your model.
<br>
<br>

Some challenges, you should keep in mind:
1. What can you do to handle data imbalance? 
2. What can you do to learn from few samples and how can you prevent overfitting? 


In [4]:
import logging
import os
import random
import sys
import argparse
import json
import pickle
import time
import inspect
from pathlib import Path
import numpy as np
import torch
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.tensorboard import SummaryWriter

from tqdm.notebook import tqdm
import timm
import timm.optim


from sklearn.metrics import (
    accuracy_score)
                         
from key2med.data.loader import CheXpertDataLoader, ColorCheXpertDataLoader, StudiesDataLoader
from key2med.models.CBRTiny import CBRTiny


In [3]:
# Set logger 
logger = logging.getLogger(__name__)

## Define some basic functions
### Functions for saving / loading pickle objects
During training we want to save the current epoch, training and validation loss, etc. to monitor our model performance

In [4]:
def save_obj_pkl(path, obj):
    with open( path, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj_pkl(path ):
    with open( path, 'rb') as f:
        return pickle.load(f)

### Functions for evaluating your model
You can either write your own evaluation metris and save it as a metric_dict or use pre-defined metrics from sklearn
<br>
**Which metrics are suitable to evaluate your model performance?**

In [5]:
# Define evaluation metrics you want to use    
def get_metric_dict(preds, y_true, y_pred):
    metric_dict['Acc'] = accuracy_score(y_true, y_pred)
    return metric_dict

# Define a evaluation function 
def eval_model(args, model, dataloader, dataset='valid'):
    
    model.eval()
    
    preds = []
    y_preds = []
    y_true = []
    for batch in dataloader:
        inputs, targets = batch
        inputs = inputs.to(args['device'])
        targets = targets.squeeze(dim=1).detach().cpu().numpy()
        y_true += list(targets)
        cur_preds = torch.nn.functional.softmax(model(inputs), dim=-1).detach().cpu().numpy()
        preds += list(cur_preds)
        y_preds += list( (cur_preds[:,1] > 0.5).astype(int))
        
    preds, y_preds, y_true =  np.asarray(preds), np.asarray(y_preds), np.asarray(y_true)
    metric_dict = get_metric_dict(preds, y_true, y_preds)
    with open(args['output_dir']+f'results_{dataset}_{args["class_positive"]}.json', 'w', encoding='utf-8') as file:
        json.dump(metric_dict, file, indent=2)
    with open(args['output_dir']+f'results_{dataset}_{args["class_positive"]}.pkl', 'wb') as file:
        pickle.dump([metric_dict, y_true, y_preds, preds], file)

## Define training settings 
For example define data path, batch size, number of epochs, etc.
<br>
Also specify here the class you are working with (Edema, Atelectasis, Cardiomegaly, Consolidation, Pleural Effusion)

## Set default settings

In [15]:
args = {'seed': 42, # Set seed number if you want to reproduce results, else set it to None
        'data_dir': '/home/admin_ml/NAMT/data/CheXpert-v1.0-small', # path to Chexpert data
        'class_positive': 'Edema', # Set class name you are working on for one vs. all classification
        'channel_in': 3, # Number of input channels (3 because of color for pre-trained imagenet)
        'model_to_load_dir': None, # model path, if you want to continue training
        'output_dir': None,
        'num_epochs': 10, # number of epochs for training the model
        'max_steps': -1, # Total number of training steps. Low number for debugging, negative number for no limit.
        'do_train': True, 
        'max_dataloader_size': None, # Set to low number for debugging. default = None (no limit) 
        'view': 'Frontal', # For DataLoader, do you want to load Frontal, Lateral or both views?
        'batch_size': 24,
        'num_workers': 4, # For DataLoader
        'lr': 1e-3, # initial learning rate 
        'wd': 1e-6, # weight decay 
        'do_eval': True, # set to True if validation and test data should be evaluated after training
        'eval_steps': 500, # Number of batches/steps to eval. Default 500.
        'no_cuda': False,
       }

args['basePath'] = os.path.dirname(os.path.realpath(globals()['_dh'][0]))+os.sep

if args['output_dir'] is None:
    args['output_dir'] = f'{args["basePath"]}training_output{os.sep}'
else:
    if args['output_dir'][-1] != os.sep: args['output_dir'] += os.sep

if args['model_to_load_dir'] is not None:
    if args['model_to_load_dir'][-1] != os.sep: args['model_to_load_dir'] += os.sep

# set device (cuda if available or cpu)
args['device'] = torch.device('cuda:0' if (torch.cuda.is_available() and not args['no_cuda']) else 'cpu')   
 
args

{'seed': 42,
 'data_dir': '/home/admin_ml/NAMT/data/CheXpert-v1.0-small',
 'class_positive': 'Edema',
 'channel_in': 3,
 'model_to_load_dir': None,
 'output_dir': '/media/admin_ml/D0-P1/2022_Namt/NAMT-10-master/training_output/',
 'num_epochs': 10,
 'max_steps': -1,
 'do_train': True,
 'max_dataloader_size': None,
 'view': 'Frontal',
 'batch_size': 24,
 'num_workers': 4,
 'lr': 0.001,
 'wd': 1e-06,
 'do_eval': True,
 'eval_steps': 500,
 'no_cuda': False,
 'basePath': '/media/admin_ml/D0-P1/2022_Namt/NAMT-10-master/',
 'device': device(type='cuda', index=0)}

Setup logging

In [None]:
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
    filename=args['output_dir']+'train.log',
    filemode='w',)

Set seed if args['seed'] is not None, to reproduce your results later on

In [11]:
g = None # Generator for seed
worker_init_fn = None
if args['seed'] is not None:
    logger.info(f'Applying seed {args["seed"]}')
    #https://pytorch.org/docs/stable/notes/randomness.html
    torch.manual_seed(args['seed'])
    random.seed(args['seed'])
    np.random.seed(args['seed'])
    g = torch.Generator()
    g.manual_seed(args['seed'])

## Reading in the data
Use the ColorChexpetDataLoader from key2med.data.loader

In [None]:
dataloader = ColorCheXpertDataLoader( #CheXpertDataLoader #for 1 ch
        data_path=args['data_dir'],
        batch_size=args['batch_size'],
        img_resize=224,
        splits="train_valid_test",
        channels=args['channel_in'],
        do_random_transform=True,
        use_cache=False,
        in_memory=True,
        max_size= args['max_dataloader_size'],
        plot_stats=False,
        n_workers=args['num_workers'],
        frontal_lateral_values = [args['view']], 
        label_filter = [args['class_positive']],
        uncertain_to_one = [args['class_positive']],
        uncertain_to_zero = [],
    )

## Define the model you want to use for training
You can use here your own implemented model or the CBRTiny model by Raghu et al. (https://arxiv.org/pdf/1902.07208.pdf), which is already implemented (from key2med.models.CBRTiny import CBRTiny)

In [12]:
model = CBRTiny(num_classes=2, channel_in=args['channel_in']).to(args['device'])

NameError: name 'CBRTiny' is not defined

Another option is to use the timm library, where many well-known models are already implemented.
<br>
You can find the timm documentation here: https://fastai.github.io/timmdocs/
<br>
You can call up for example all existing pretrained efficient models in this way:

In [None]:
avail_pretrained_models = timm.list_models('eff*',pretrained=True)
# List number of all found models and the first five models
len(avail_pretrained_models), avail_pretrained_models[:5]

We use here for example the efficientnetb0 model, pretrained on ImageNet (Set num_classes to 2 for one vs. all classification):

In [None]:
model = timm.create_model('efficientnet_b0', num_classes=2, in_chans=args['channel_in'], pretrained=True).to(args['device'])

If you have already startet the training, you can load the last checkpoint in this way:

In [13]:
if args['model_to_load_dir'] is not None:
    checkpoint = torch.load(osp.join(args['model_to_load_dir'], 'best_model.pth'))
    model.load_state_dict(checkpoint['model_state_dict'])

## Define optimizer and learning rate scheduler
With timm library you can also use many pred-defined optimizers.
<br>
List all available optimizers:

In [None]:
[cls_name for cls_name, cls_obj in inspect.getmembers(timm.optim) if inspect.isclass(cls_obj) if cls_name !='Lookahead']

Here, we use for example the *AdamP* Optimizer

In [None]:
optimizer = timm.optim.create_optimizer_v2(model,
                                           optimizer_name='AdamP',
                                           learning_rate=args['lr'],
                                           weight_decay=args['wd'])

Now we define a learning rate scheduler using pytorch (for more infos https://pytorch.org/docs/stable/optim.html)
<br>
We can use for example the One Cylce Learning Rate Scheduler:

In [None]:
num_steps = args['num_epochs']*len(dataloader.train) if args['max_steps']<0 else args['max_steps'] #for Debug !
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                max_lr=args['lr'],
                                                total_steps=num_steps)  

## Use tensorboard for monitoring your model performance
To do this, you need to set up a *SummaryWriter* that stores the current epoch count, current learning rate, training and validation loss, and current time.
<br>
Also, we store everything in the dictionary *writer_dict*, so you can create your own plots at the end.

In [None]:
Path(args['output_dir']).mkdir(parents=True, exist_ok=True)  

writer = SummaryWriter(args['output_dir']+os.sep+'runs')
writer_dict = {
                'epochs': [],#np.zeros(howOftenValid*howOftenRepeat,dtype=int),
                'lr': [], #np.zeros(howOftenValid*howOftenRepeat),
                'loss_train': [],#np.zeros(howOftenValid*howOftenRepeat),
                'loss_valid': [],#np.zeros(howOftenValid*howOftenRepeat),
                'walltime': [],#np.zeros(howOftenValid*howOftenRepeat)
                }

## Define loss function
You can for example use the cross entropy loss for a classification task

In [None]:
# Define cross entropy loss
loss_function = torch.nn.CrossEntropyLoss()

## Write training loop

In [None]:
if args['do_train']:
    model.train()
    steps=0
    steps_since_last_eval=0
    logger.info('Start Training')
    for epoch in tqdm(range(args['num_epochs'])):

        writer_dict['epochs'].append(epoch)
        writer.add_scalar('utils/epochs', epoch, steps) # for tensorboard

        for batch in tqdm(dataloader.train, leave=False):
            steps += 1
            steps_since_last_eval +=1

            if steps > num_steps: break

            inputs, targets = batch
            inputs = inputs.to(args['device'])
            targets = targets.squeeze(dim=1).long().to(args['device'])
            outputs = model(inputs)
            loss = loss_function(outputs, targets)

            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            writer_dict['walltime'].append( time.time() )
            lr = optimizer.param_groups[0]['lr']
            writer_dict['lr'].append(lr)
            writer.add_scalar('utils/lr', lr, steps)
            loss = loss.detach().cpu().numpy()
            writer_dict['loss_train'].append(loss)
            writer.add_scalar('loss/train', loss, steps)

            if steps_since_last_eval >= args['eval_steps']:
                steps_since_last_eval = 0
                if dataloader.validate is not None:
                    model.eval()
                    mean_loss = 0
                    for batch in dataloader.validate:
                        inputs, targets = batch
                        inputs = inputs.to(args['device'])
                        targets = targets.squeeze(dim=1).long().to(args['device'])
                        outputs = model(inputs)
                        mean_loss += loss_function(outputs, targets).detach().cpu().numpy()

                    mean_loss /= len(dataloader.validate)
                    writer_dict['loss_valid'].append(mean_loss)
                    writer.add_scalar('loss/valid', mean_loss, steps) # for tensorboard
                    
            if steps >= num_steps: break

    with open( args['output_dir']+'last_loss.txt', 'w') as file:
        print( f'loss {mean_loss}\n step {steps}', file=file )

    torch.save({
                'step': steps,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
               }, args['output_dir']+'last_model.pth')

    save_obj_pkl( args['output_dir']+'tensorboard_writer.pkl', writer_dict )       
    writer.close()
    logger.info(f'End of training')

## Evaluate your model
If you want to evualate your model on the validation and test sets, you have to set args['do_eval'] to True.
<br>
Then the evaluation function will be called after the training is completed. 

In [None]:
if args['do_eval']:
    if dataloader.validate is not None:
        logger.info('Start evaluation valid')
        eval_model(args, model, dataloader.validate, dataset='valid')

    if dataloader.test is not None:
        logger.info('Start evaluation test')
        eval_model(args, model, dataloader.test, dataset='test')
