In [1]:
import os, sys, json
import shutil
import datetime
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW, SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau

from omegaconf import OmegaConf
from dataloader.cxr_dataloader import CovidXRDataset

from logger.logger import Logger
from trainer.trainer import Trainer
from utils.util import *

from model.modelloader import COVID_PVTv2, COVID_ViT, load_checkpoint

from trainer.trainer import Trainer

from torch.utils.tensorboard import SummaryWriter

cwd = os.getcwd()

# Load config
config_file = 'config/trainer_config.yml'
train_config = OmegaConf.load((os.path.join(cwd, config_file)))['trainer']
seeding(train_config)

ImportError: cannot import name 'save_model' from 'utils.util' (C:\Users\ASUS\COVID_19\utils\util.py)

# Ultility functions

In [None]:
def get_dataset(config):
    train_params = {'batch_size': config.dataloader.train.batch_size,
                    'shuffle': True,
                    'num_workers': config.dataloader.train.num_workers,
                    'pin_memory': True}
    
    val_params = {'batch_size': config.dataloader.val.batch_size,
                  'shuffle': True,
                  'num_workers': config.dataloader.val.num_workers,
                  'pin_memory': True}

    test_params = {'batch_size': config.dataloader.test.batch_size,
                   'shuffle': False,
                   'num_workers': config.dataloader.test.num_workers}
    
    print('Name of the dataset:', config.dataset.name)
    print('Collected from the description of these github: https://github.com/lindawangg/COVID-Net'\
          , end='\n{}\n'.format('-'*50))
    
    
    # Data loader and Generator 
    train_loader = CovidXRDataset(config = config, mode='train')
    val_loader = CovidXRDataset(config = config, mode='val')
    # test_loader = CovidXRDataset(config = config, mode='test')
    class_dict = train_loader.class_dict

    training_generator = DataLoader(train_loader, **train_params)
    val_generator = DataLoader(val_loader, **val_params)
    # test_generator = DataLoader(test_loader, **test_params)
    
    print('Load data complete!')
    
    return training_generator, val_generator, test_generator, class_dict

In [None]:
def get_model(name):
     if  name == 'ViT':
        return COVID_ViT()
     if  name == 'PVT_V2':
        return COVID_PVTv2()

In [None]:
def select_scheduler_optimizer(model, config):
    opt = config['optimizer']['type']
    lr = config['optimizer']['lr']
    dec = config['optimizer']['weight_decay']
    optimizer = None
    if (opt == 'AdamW'):
        print("Use optimizer Adam with lr: ", lr)
        optimizer = AdamW(model.parameters(), lr=lr, weight_decay=dec)
    elif (opt == 'SGD'):
        print("Use optimizer SGD with lr: ", lr)
        optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)
        
    scheduler = ReduceLROnPlateau(optimizer, factor=config['scheduler']['scheduler_factor'],
                                      patience=config['scheduler']['scheduler_patience'],
                                      min_lr=config['scheduler']['scheduler_min_lr'],
                                      verbose=config['scheduler']['scheduler_verbose'])
    return optimizer, scheduler

# Train process

In [None]:
def engine(config):
    now = datetime.datetime.now()
    dt_string = now.strftime("%d_%m_%Y_%H.%M.%S")
    model_name = input("Choose model: {model_ViT}, {model_PVT_V2} ? : ")
    assert model_name in ['model_ViT', 'model_PVT_V2'], "You must decleare the model as in description!"
    print('-'*50)
        
    # Dataset
    train_generator, val_generator, _, class_dict = get_dataset(config)
    
    # Model / Model loader
    model = get_model(config[model_name].name)
    
    #Optimizer
    optimizer, scheduler = select_scheduler_optimizer(model, config[model_name])
    
    # Load model from checkpoint if config load = True
    if config.load:
        print('----- LOADING CHECKPOINTS -----')
        get_checkpoints(config[model_name].name)
        checkpoint_name = input("Choose one of these checkpoints above: ")
        cpkt_fol_name = os.path.join(config.cwd, f'checkpoints/model_{config[model_name].name}/{checkpoint_name}')
        
        checkpoint_dirmodel = f'{cpkt_fol_name}/model_best_checkpoint.pth'
        model, optimizer, scheduler = load_checkpoint(checkpoint_dirmodel, model, optimizer, scheduler)      
    else:
         # Create new checkpoint
        log.info(f"Checkpoint Folder {cpkt_fol_name} ")
        cpkt_fol_name = os.path.join(config.cwd, f'checkpoints/model_{config[model_name].name}/date_{dt_string}')
    
    # Or create new model
    print('----- CREATING NEW MODEL -----')
    model = torch.nn.DataParallel(model).to(device)
    
    # Logger
    logname = str('LOG_' + config[model_name].name)
    log = Logger(path=cpkt_fol_name, name=logname).get_logger()
    
    log.info(f"date and time = {dt_string}")
    log.info(f'pyTorch VERSION:{torch.__version__}', )
    log.info(f'CUDA VERSION:{torch.version.cuda}')
    
    # Writer
    writer = SummaryWriter('./runs/' + f'model_{config[model_name].name}/date_{dt_string}')

    # Device
    device = torch.device("cuda:0" if (torch.cuda.is_available() and config.cuda) else "cpu")

    log.info(f'CUDNN VERSION:{torch.backends.cudnn.version()}')
    log.info(f'Number CUDA Devices: {torch.cuda.device_count()}') 
    log.info(f'device: {device}')
    
    
    # Trainer
    trainer = Trainer(config=config, model=model, optimizer=optimizer,
                      data_loader=train_generator, logger=log,
                      valid_data_loader=val_generator, class_dict=class_dict,
                      lr_scheduler=scheduler,
                      checkpoint_dir=cpkt_fol_name)
    trainer.train()

In [None]:
engine(train_config)