# Import library

In [30]:
from datetime import datetime
from glob import glob
import os, random, math

from tqdm import tqdm
from easydict import EasyDict
import pandas as pd
import numpy as np
import cv2
import scipy.io as sio

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
from torch.optim.swa_utils import AveragedModel, SWALR

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import matplotlib.pyplot as plt
%matplotlib inline

# Configs

In [2]:
def print_log(inputs: str):
    print(f'LOG >>> {inputs}')


def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    print_log('Complete seed setting!!!')
    
    
def get_device(GPU_NUM: str) -> torch.device:
    if torch.cuda.is_available() > 1:
        output = torch.device('cuda')
    elif torch.cuda.is_available() == 1:
        output = torch.device(f'cuda:{GPU_NUM}')
    else:
        output = torch.device('cpu')

    print_log(f'{output} is checked')
    return output


def get_log_name(**kwargs):
    output = kwargs['time']
    
    for key in kwargs.keys():
        if key == 'time' or key == 'etc': continue
        output += f'_{key}_{kwargs[key]}'
    
    if 'etc' in kwargs.keys():
        if kwargs['etc'] != None:
            output = output + '_' + kwargs['etc']

    print_log(f'Log name: \n\t{output}')
    return output

In [7]:
args = EasyDict({
    #### Path ####
    'BASE_PATH': f'../Datasets/Programmers',
    'CKPT_PATH': './checkpoints',
    'LOG_PATH': './logs',
    
    
    #### Setting ####
    'SEED': 42,
    'GPU_NUM': '0',
    'current_time': datetime.now().strftime('%Y%m%d-%H%M%S'),
    'num_workers': 6,
    'k_folds': 5,
    
    
    #### Training step ####
    'EPOCHS': 5,
    'early_stop': 0,
    'warmup': 0,
    'batch_size': 32,
    'lr': 1e-3,
    'loss': 'MAE',
    'optimizer': 'adam',
    'lr_scheduler': 'CosineAnnealingWarmRestarts',
    
    
    #### ETC ####
    'log_etc': None,
    'is_save': False,
    'use_SAM': False,
})


#### Set Device ####
if torch.cuda.is_available():
    os.environ['CUDA_VISIBLE_DEVBICES'] = args.GPU_NUM
args['device'] = get_device(args.GPU_NUM)
cudnn.benchmark = True
cudnn.fastest = True


#### Set SEED ####
seed_everything(args.SEED)


#### Log setting ####
log_list = {
    'time': args.current_time,
    'batch': args.batch_size,
    'lr': args.lr,
    'loss': args.loss,
    'optimizer': args.optimizer,
    'lr_scheduler': args.lr_scheduler,
    'etc': args.log_etc,
}

args['LOG_NAME'] = get_log_name(**log_list)

LOG >>> cuda:0 is checked
LOG >>> Complete seed setting!!!
LOG >>> Log name: 
	20210708-062959_batch_32_lr_0.001_loss_MAE_optimizer_adam_lr_scheduler_CosineAnnealingWarmRestarts


# Dataset

In [None]:
class Dataset(torch.utils.data.Dataset):
    '''
        dataset = Dataset(args, is_train=True)
    '''
    
    def __init__(self,
                 args,
                 is_train):
        
        self.args = args
        self.path = args.BASE_PATH
        self.is_train = is_train
        
        self.init_data()
        self.init_transform()
        
        
    def __len__(self):
        pass
    
    
    def __getitem__(self, idx):
        
        
        sample = {
            'image': train_image,
            'label': label_image
        }
        
        if self.transforms:
            sample = self.transforms(**sample)
            b
        # targets['image'] = self.transforms(image=targets['image'])['image']
        # targets['label'] = self.transforms(label=targets['label'])['image']
            
        return sample
    
    
    def init_data(self):
        pass
        
        
    def init_transform(self):
        
        if self.is_train:
            self.transforms = A.Compose([
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), p=1.0),
                ToTensorV2(p=1.0)
            ])
        else:
            self.transforms = A.Compose([
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), p=1.0),
                ToTensorV2(p=1.0)
            ])

In [None]:
train_dataset = Dataset(args, is_train=True)
test_dataset = Dataset(args, is_train=False)

test_dataloader = DataLoader(train_dataset,
                             batch_size=1,
                             shuffle=False,
                             pin_memory=True)

# Model

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(3, 3, 3, padding='same')
    
    def forward(self, x):
        x = self.conv(x)
        return x

# Training

In [None]:
class AverageMeter(object):
    def __init__(self, mode):
        self.mode = mode
        self.reset()
        
    def reset(self):
        self.value = 0
        self.sum = 0
        self.count = 0
        self.avg = 0
    
    def update(self, value, batch_size):
        if self.mode == 'loss':
            self.value = value
            self.sum += value * batch_size
            self.count += batch_size
            self.avg = self.sum / self.count
        elif self.mode == 'acc':
            self.value = value
            self.sum += value
            self.count += batch_size
            self.avg = self.sum / self.count
            
            
class EarlyStopping(object):
    def __init__(self, mode, ckpt_path, filename, is_save=False, early_stop_threshold=10):
        # mode: str, 최소, 최대 모드 설정 ex. min, max
        # ckpt_path: str, 체크포인트 저장 경로
        # early_stop_threshold: int, 얼리스탑 기준, ex. 10
        self.mode = mode
        if not os.path.exists(ckpt_path):
            os.makedirs(ckpt_path)
        self.ckpt_path = ckpt_path + '/' + filename
        self.is_save = is_save
        self.early_stop_threshold = early_stop_threshold
        self.best_model = None
        self.reset()
        
    def reset(self):
        self.early_stop_cnt = 0
        self.best_score = math.inf if self.mode == 'min' else -math.inf
        
    def update(self, val, model):
        if self.mode == 'min':
            if self.best_score > val:
                self.early_stop_cnt = 0
                self.best_score = val
                if self.is_save:
                    torch.save(model.state_dict(), self.ckpt_path)
                self.best_model = model
            else:
                self.early_stop_cnt += 1
        elif self.mode == 'max':
            if self.best_score < val:
                self.early_stop_cnt = 0
                self.best_score = val
                if self.is_save:
                    torch.save(model.state_dict(), self.ckpt_path)
                self.best_model = model
            else:
                self.early_stop_cnt += 1
        else:
            raise Exception('Wrong Input! plz input min or max')
        
        if self.early_stop_cnt > self.early_stop_threshold:
            print_log('Early stopping')
            return True, self.best_model
        else:
            return False, self.best_model

In [None]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

In [None]:
def get_criterion(name):
    if name == 'MAE':
        criterion = torch.nn.L1Loss()
    elif name == 'MSE':
        criterion = torch.nn.MSELoss()
    elif name == 'PSNR':
        criterion = PSNR()
    else:
        raise Exception('Wrong Input! plz input correct name')
    return criterion


def get_optimizer(name, model, args):
    if name == 'sgd':
        optimizer = optim.SGD
        optimizer = SAM(model.parameters(), optimizer, lr=args.lr, momentum=args.momentum)
    elif name == 'adam':
        # optimizer = optim.Adam(model.parameters(), lr=args.lr)
        optimizer = optim.Adam
        optimizer = SAM(model.parameters(), optimizer, lr=args.lr)
    elif name == 'adamw':
        optimizer = optim.AdamW
        optimizer = SAM(model.parameters(), optimizer, lr=args.lr)
    elif name == 'rmsprop':
        optimizer = optim.RMSprop
        optimizer = SAM(model.parameters(), optimizer, lr=args.lr, momentum=args.momentum)
    else:
        raise Exception('Wrong Input! plz input correct name')
    return optimizer


def get_scheduler(name, optimizer, args):
    if name == 'StepLR':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    elif name == 'CosineAnnealingLR':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.EPOCHS//10, eta_min=0)
    elif name == 'CyclicLR':
        try:
            scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=args.lr /10, max_lr=args.lr, step_size_up=args.EPOCHS//5, mode="triangular2")
        except ValueError:
            scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=args.lr/10, max_lr=args.lr, step_size_up=args.EPOCHS//5, cycle_momentum=False, mode="triangular2")
    elif name == 'OneCycleLR':
        try:
            scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, steps_per_epoch=1, epochs=args.EPOCHS)
        except ValueError:
            scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, steps_per_epoch=1, epochs=args.EPOCHS, cycle_momentum=False)
    elif name == 'CosineAnnealingWarmRestarts':
        # scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=1, eta_min=args.lr//10, last_epoch=-1)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-10, last_epoch=-1)
    elif name == 'CosineAnnealingWarmUpRestarts':
        scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0=10, T_mult=2, eta_max=args.lr, T_up=5, gamma=1.0)
    else:
        raise Exception('Wrong Input! plz input correct name')
    return scheduler

In [None]:
kfold = KFold(n_splits=args.k_folds, shuffle=True, random_state=args.SEED)

for fold, (train_idx, val_idx) in enumerate(kfold.split(train_dataset)):
    
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)
    
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  pin_memory=True,
                                  sampler=train_subsampler)
    val_dataloader = DataLoader(train_dataset,
                                batch_size=args.batch_size,
                                pin_memory=True,
                                sampler=val_subsampler)
    dataloaders = {'train': train_dataloader, 'val': val_dataloader}
    
    break_point = False
    
    model = Net()
    model.to(args.device)

    criterion = get_criterion(args.loss)
    optimizer = get_optimizer(args.optimizer, model, args)
    scheduler = get_scheduler(args.lr_scheduler, optimizer, args)
    
    ### Logging setting ###
    print(f'LOG >>> \n\tBatch size: {args.batch_size}\n\tLR: {args.lr}\n\tOptimizer: {args.optimizer}\n\tScheduler: {args.lr_scheduler}\n')

    writer = SummaryWriter(f'{args.LOG_PATH}/{args.LOG_NAME}')


    ### Early stopping setting ###
    early_stopping = EarlyStopping(mode='min',
                                   ckpt_path=f'{args.CKPT_PATH}/{args.LOG_NAME}_fold{fold+1}/',
                                   filename=f'best_model.pth',
                                   is_save=args.is_save,
                                   early_stop_threshold=args.early_stop)

    for epoch in range(args.EPOCHS):

        ######################################
        ###          Add Metrics           ###
        train_loss = AverageMeter(mode='loss')
        val_loss = AverageMeter(mode='loss')
        ######################################

        for phase in ['train', 'val']:
            if phase == 'train': model.train()
            else: model.eval()

            with tqdm(dataloaders[phase], total=dataloaders[phase].__len__(), unit='batch') as train_bar:

                for sample in train_bar:

                    ### Logging ###
                    train_bar.set_description(f'{phase} Epoch {epoch}')

                    images = sample['image'].to(args.device, dtype=torch.float)
                    labels = sample['label'].to(args.device, dtype=torch.float)

                    batch_size = images.shape[0]
                    optimizer.zero_grad(set_to_none=True)

                    with torch.set_grad_enabled(phase == 'train'):

                        outputs = model(images)
                        loss = criterion(outputs, labels)

                        if phase == 'train':
                            loss.backward()
                            
                            if agrs.use_SAM:
                                ### SAM ###
                                def closure():
                                    loss = criterion(model(inputs), labels)
                                    loss.backward()
                                    return loss
                                
                                    ####################################
                                    '''
                                    SAM: Noisy label 해결 기법
                                    Sharpness-Aware Minimization for Efficiently Improving Generalization
                                    '''
                                    optimizer.step(closure)
                                    ####################################
                            else:
                                optimizer.step()
                                
                            train_loss.update(loss_.detach().item(), batch_size)

                        else:
                            val_loss.update(loss.detach().item(), batch_size)

                    ### Logging ###
                    if phase == 'train':
                        train_bar.set_postfix(train_loss=train_loss.avg)
                    else:
                        train_bar.set_postfix(val_loss=val_loss.avg)

            if phase == 'val' and epoch >= args.warmup:
                break_point, best_model = early_stopping.update(val_loss.avg, model)

        ### Tensorboard ###
        writer.add_scalar('Loss/train', train_loss.avg, epoch)
        writer.add_scalar('Loss/val', val_loss.avg, epoch)

        if scheduler is not None: scheduler.step()

        if break_point: break

# Evaluation

In [None]:
ckpt_path = []

In [None]:
total_results = []

for fold in range(5):
    
    fold_results = []
    
    for data in test_dataloader:
        data = data.to(args.device).astype('float')

        model = torch.load(ckpt_path[fold])
        model.eval()
        
        output = model(data)
        output = output.detach().cpu().numpy()
        
        fold_results.append(output)
        
    total_results.append(fold_results)
    
total_results = np.array(total_results)