In [2]:
import sys, os
from glob import glob

sys.path.append('..')

from tqdm import tqdm
from argparse import ArgumentParser
from easydict import EasyDict
from addict import Dict

from toolz import *
from toolz.curried import *
from itertools import islice, product

import numpy as np
import torch
import torchvision
import torch.nn.functional as F

from data.dataLoader import MakeDataLoader

from model.FFTNet import FFTNet
from model.FFTNet import FFTNet_DSNT
from model.UNet import UNet
from model.DeepLab import DeepLab 
from model.TRANS import TRANS 

from utils import GET_OPTIMIZER
from torchmetrics.functional import dice_score 

from visualisation.Visualization2 import visualisation
from scipy.spatial.distance import directed_hausdorff
import skimage

import matplotlib
import matplotlib.pyplot as plt


In [6]:
def parse(task, model, modelName, encoderName, fold):

    parser = ArgumentParser()

    parser.add_argument("--dataPath", type=str, default="../data/datasets")
    parser.add_argument("--logPath", type=str, default="../log")
    parser.add_argument("--ckptPath", type=str, default="../ckpt")
    parser.add_argument("--visPath", type=str, default="../vis")
    
    # training related    
    parser.add_argument("--augType", type=str, default="aug0", help = "aug0 | aug1")
    
    # hyprer parameter
    parser.add_argument("--inputSize", type=int, default=256)    
    parser.add_argument("--batchN", type=int, default=8)
    parser.add_argument("--epochN", type=int, default=500)
    parser.add_argument("--optimizer", type=str, default="Adam")
    parser.add_argument("--lr", type=float, default= 3e-4)
    parser.add_argument("--weight_decay", type=float, default= 0)
    
    #hardware
    parser.add_argument("--cpuN", type=int, default=4)
    parser.add_argument("--gpuN", type=int, default=2, help = "0|1|2|3")

    config = first(parser.parse_known_args())
    
    config.task       = task
    config.model      = model
    config.modelName  = modelName
    config.encoderName = encoderName
    config.fold       = fold
    
    # pick a gpu that has the largest space
    os.environ["CUDA_DEVICE_ORDER"]    = "PCI_BUS_ID"
    os.environ['CUDA_LAUNCH_BLOCKING'] = "3"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpuN)
    
    if config.model == "FFTNet":        
        config.learnSize   = 21
        config.codeSize    = 71        
        net = FFTNet.NETS(config.encoderName, config.inputSize, config.learnSize)

    if config.model == "FFTDSNTNet":        
        config.learnSize   = 21
        config.codeSize    = 71
        config.cov         = 0.1
        net = FFTNet_DSNT.NETS(config.encoderName, config.inputSize, config.learnSize, cov = config.cov)        
        
    if config.model == "UNet":        
        net = UNet.NETS(modelName = config.modelName)
        
    if "DeepLab" in config.model:
        net = DeepLab.NETS(modelName = config.modelName, encoderName = "resnet50", outputStride = 8)
        
    if config.model == "TRANS":
        net = TRANS.NETS()        
        
    return config, net


def evaluate(net, dataloader, vis, config) :
    
    dice = 0
    haus = 0
    net.eval()        
    for (x, y, _) in tqdm(dataloader):

        _, logit = validStep(x.type(torch.float32).cuda(),
                             y.type(torch.float32).cuda(),
                             net)

        logit  = logit.detach().cpu().squeeze()
        target = y.detach().cpu().squeeze()

        if "FFT" in config.model :                
            pred = vis.draw_polygon(vis.get_contour(vis.pad_FFTs(logit))).type(torch.long)
            mask = vis.draw_polygon(vis.get_contour(target)).type(torch.long)

        else :                
            pred = logit.argmax(0).type(torch.long)
            mask = target.type(torch.long)

        metrics = getMetric(pred, mask)

        dice += metrics['dice']
        haus += metrics['haus']

    dice = round(dice / len(dataloader), 3)
    haus = round(haus / len(dataloader), 3)
        
    return dice, haus
        
        
@torch.no_grad()
def validStep(x, y, net):
    
    outputs = net(x)
    
    loss = net.getLoss(outputs, y)
    
    return loss, outputs['logit']
    
def getMetric(pred, mask):
    
    dice_score = lambda pred, mask : torch.sum(pred[mask==1])*2.0 / (torch.sum(pred) + torch.sum(mask))     
    
    dice = np.array(dice_score(pred.view(-1), mask.view(-1)))

    mask_ind = np.nonzero(mask.detach().cpu().numpy())
    mask_ind = np.stack(mask_ind, axis=1)

    pred_ind = np.nonzero(pred.detach().cpu().numpy())
    pred_ind = np.stack(pred_ind, axis=1)

    if len(mask_ind) == 0 or len(pred_ind) == 0:
        if len(mask_ind) == 0 and len(pred_ind) == 0:
            hausdorff = 0
        else:
            hausdorff = 256*np.sqrt(2)
    else:
        hausdorff_m2p, hausdorff_p2m = directed_hausdorff(mask_ind, pred_ind)[0], directed_hausdorff(pred_ind, mask_ind)[0]
        hausdorff = max(hausdorff_m2p, hausdorff_p2m)

    return {'dice' : dice,
            'haus' : hausdorff}

In [7]:
noisess = \
[
    list(product(["gaussian"],    np.linspace(0,   0.2, num = 5))),
    list(product(["salt_pepper"], np.linspace(0,   0.3, num = 5))),
    list(product(["contrast"],    np.linspace(1.0, 1.5, num = 5))),
    #list(product(["brightness"],  np.linspace(1.0, 1.5, num = 5))),
    #list(product(["blur"],        np.linspace(1,   10,  num = 5, dtype = int))),    
    list(product(["motion"],      np.linspace(1,   25,  num = 5, dtype = int)))]
    
# FOR UNet and DeepLab
################################################################################
baselines = compose(list, product)(["ISIC", "RIM_DISC", "RIM_CUP"],
                                   ["DeepLab_lovasz"],
                                   ["plus"], 
                                   ["ignore"],
                                   ["0"])

# FOR FFTDSNTNet
################################################################################
proposals = compose(list, product)(["ISIC"], #, "RIM_DISC", "RIM_CUP"],
                                   ["FFTDSNTNet"],                                       
                                   ["ignore"],
                                   ["DResNet50"],
                                   ["0"])

#merge the two
###############################################################################
instances = compose(list, concat)([baselines, proposals])

In [8]:
configs  = Dict()
outcomes = Dict()

# loadConfig & model
for task, model, modelName, encoderName, fold in instances:
    
    for noises in noisess:        
        print(task, model, modelName, encoderName, noises[0][0])
        for noise in noises:
            
            path = f'outcomes2/{task}/{model}/{modelName}/{encoderName}/{noise[0]}/{noise[1]}'

            if os.path.exists(path) :

                outcome = np.load(f'{path}/outcome.npy', allow_pickle='TRUE')
                outcomes[task][model][modelName][encoderName][noise[0]][noise[1]] = outcome
                        
            else :
                config, net = parse(task, model, modelName, encoderName, fold)

                if "FFT" in config.model:
                    ckpt = torch.load(f"{config.ckptPath}/{config.task}/{config.model}/{config.encoderName}/{config.fold}/ckpt.pt")
                else:
                    ckpt = torch.load(f"{config.ckptPath}/{config.task}/{config.model}/{config.modelName}/{config.fold}/ckpt.pt")

                net.load_state_dict(ckpt); net.eval(); net.cuda()

                vis = visualisation(21, 71, config.inputSize)

                dataLoader = MakeDataLoader(f"{config.dataPath}/{config.task}/processed",
                                            inputSize = config.inputSize,
                                            task      = config.task,
                                            mode      = "valid",
                                            fold      = int(config.fold),
                                            maskType  = ("fourier", config.codeSize) if "FFT" in config.model else "original",
                                            batchN    = 1,
                                            noiseType = noise,                                        
                                            cpuN      = config.cpuN,
                                            augType   = config.augType)

                if config.model == "FFTDSNTNet" :

                    temp = MakeDataLoader(f"{config.dataPath}/{config.task}/processed",
                                          inputSize = config.inputSize,
                                          task      = config.task,
                                          mode      = "valid",
                                          fold      = int(config.fold),
                                          maskType  = ("fourier", config.codeSize) if "FFT" in config.model else "original",
                                          batchN    = 1,
                                          cpuN      = config.cpuN,
                                          augType   = config.augType)

                    net.initialize_Scale(temp)        

                outcome = evaluate(net, dataLoader, vis, config)
            
                outcomes[task][model][modelName][encoderName][noise[0]][noise[1]] = outcome
                
                if not os.path.exists(path): os.makedirs(path)            
                np.save(f'{path}/outcome.npy', outcome)        
                
            print(outcome)


ISIC DeepLab_lovasz plus ignore gaussian
[ 0.903 18.6  ]
[ 0.849 23.733]
[ 0.81  29.047]
[ 0.777 35.459]
[ 0.749 37.796]
ISIC DeepLab_lovasz plus ignore salt_pepper
[ 0.903 18.6  ]
[ 0.876 20.533]
[ 0.834 25.179]
[ 0.796 30.174]
[ 0.759 34.817]
ISIC DeepLab_lovasz plus ignore contrast
[ 0.903 18.597]
[ 0.903 18.62 ]
[ 0.903 19.006]
[ 0.9   19.391]
[ 0.898 19.752]
ISIC DeepLab_lovasz plus ignore motion
[ 0.903 18.597]
[ 0.903 18.66 ]
[ 0.901 19.438]
[ 0.897 19.892]
[ 0.893 20.525]
RIM_DISC DeepLab_lovasz plus ignore gaussian
[ 0.96  10.637]
[ 0.945 14.304]
[ 0.94  15.366]
[ 0.932 16.426]
[ 0.924 17.515]
RIM_DISC DeepLab_lovasz plus ignore salt_pepper
[ 0.96  10.637]
[ 0.949 14.227]
[ 0.942 15.385]
[ 0.933 16.728]
[ 0.924 17.868]
RIM_DISC DeepLab_lovasz plus ignore contrast
[ 0.96  10.626]
[ 0.96  10.953]
[ 0.96  11.372]
[ 0.96  11.512]
[ 0.96  11.061]
RIM_DISC DeepLab_lovasz plus ignore motion
[ 0.96  10.626]
[ 0.959 11.206]
[ 0.957 11.865]
[ 0.953 13.312]
[ 0.946 13.978]
RIM_CUP DeepLa