In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
os.chdir('..')
sys.path.append('src')

In [None]:
from config import cfg, cfg_init
from pprint import pprint

from callbacks import  denorm

In [None]:
from PIL import Image
from pathlib import Path
from functools import partial, reduce
from collections import defaultdict
import multiprocessing as mp
from contextlib import contextmanager

import cv2
import numpy as np
from tqdm.auto import tqdm

import utils
import data
import model as nn_model

import matplotlib.pyplot as plt
%matplotlib inline

# Local

In [None]:
cfg_init('src/configs/unet_gelb.yaml')
cfg['TRANSFORMERS']['TRAIN']['AUG'] = 'light_scale'

cfg['PARALLEL']['DDP'] = False
cfg['DATA']['TRAIN']['PRELOAD'] = False
cfg['DATA']['TRAIN']['MULTIPLY']["rate"] = 2
#cfg['DATA']['TRAIN']['DATASETS'] = ['train1024x25']
#cfg['VALID']['BATCH_SIZE'] = 4

In [None]:
datasets = data.build_datasets(cfg, dataset_types=['TRAIN','VALID'])
dls = data.build_dataloaders(cfg, datasets)

In [None]:
root = 'output/2021_Apr_02_10_43_49_PAMBUH/models/'
model = nn_model.load_model(cfg, root).cuda()

# THreshold

In [None]:
from loss import dice_loss

In [None]:
thrs1 = np.logspace(0, 1, num=10)/100
thrs2 = 1-np.logspace(0, 1, num=10)/100
thrs2 = thrs2[::-1]
thrs = np.concatenate([thrs1, np.arange(.2,.9, .025), thrs2])
#thrs

In [None]:
dices = []
cnt = 0
for x,y in dls['VALID']:
    #x = x.view(1, *x.shape).cuda()
    with torch.no_grad():
        pred = model(x.cuda()).sigmoid().cpu()
    #pred = (pred > .5).float()
    preds = [pred > thr for thr in thrs]
    #dice = dice_loss(pred, y)
    dice = [dice_loss(p, y) for p in preds]
    dices.append(dice)
    
dices = np.array(dices)
dices_mean = dices.mean(0)
dices_mean.max(), thrs[np.argmax(dices_mean)]

In [None]:
plt.plot(dices_mean)

# Infer

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
from run_inference import start_inf

In [None]:
def get_split_by_idx(idx):
    imgs = Path('input/hm/train').glob('*.tiff')
    splits = [
        ['0486052bb', 'e79de561c'],
        ['cb2d976f4', 'c68fe75ea'],
        ['2f6ecfcdf', 'afa5e8098'],
        ['1e2425f28', '8242609fa']]
    return [i for i in imgs if i.stem in splits[idx]]

In [None]:
def process_split(idx, model_folder):
    gpus = [0]
    num_processes = len(gpus)
    threshold = 0
    save_predicts=False
    use_tta=True
    to_rle=False
    img_names = [Path('input/hm/train/e79de561c.tiff')]#get_split_by_idx(idx)
    
    result = start_inf(model_folder, img_names, gpus, threshold, num_processes, save_predicts, use_tta, to_rle)
    return dict(result)

In [None]:
model_folder = Path('output/ffpe_splits/e7/')

In [None]:
res_masks = process_split(3, model_folder)

In [None]:
for k, v in res_masks.items():
    utils.save_tiff_uint8_single_band(v, f'output/ffpe_splits/e7/predicts/cv/masks/{k.name}', bits=8)

# Dice

In [None]:
from loss import dice_loss
import rasterio as rio
import torch

In [None]:
def calc_all_dices(res_masks, thrs):
    dices = {}
    for k, mask in res_masks.items():
        gt = rio.open(f'input/masks/bigmasks/{k.name}').read()
        mask = mask.astype(np.float16)
        mask/=255.

        mask = torch.from_numpy(mask)
        gt = torch.from_numpy(gt)

        ds = []
        for thr in tqdm(thrs):
            th_mask = (mask > thr)
            dice = dice_loss(gt>0, th_mask)
            ds.append(dice)
            print(k.stem, round(thr, 3), round(dice, 4))

        ds = np.array(ds)
        dices[k.stem] = ds
        print(k.stem, ds.max(), thrs[np.argmax(ds)])
        
    return dices

def calc_common_dice(dices, thrs):
    best_thrs = []
    for k, v in dices.items():
        thr = thrs[np.argmax(v)]
        best_thrs.append(thr)
    return np.mean(best_thrs)

def get_thrs():
    return np.arange(.2,.9, .025)

In [None]:
thrs = get_thrs()
dices = calc_all_dices(res_masks, thrs)

In [None]:
calc_common_dice(dices, thrs)

In [None]:
0486052bb 0.9485, 0.65
e79de561c 0.9395, 0.575

cb2d976f4 0.9590 0.55
c68fe75ea 0.9069 0.65

2f6ecfcdf 0.9663 0.525
afa5e8098 0.9363 0.65 

1e2425f28 0.9558 0.475
8242609fa 0.965 0.5

In [None]:
(0.948515772819 + 0.939515) / 2

In [None]:
(0.6499 + 0.575 ) / 2

In [None]:
0: 0.6125
1: 0.6
2: 0.5

# Merge

In [None]:
for k, v in res_masks.items():
    mask_name1 = f'input/masks/bigmasks/{k.name}'
    mask_name2 = f'output/ffpe_splits/e7/predicts/cv/masks/{k.name}'
    merge_name = f'output/ffpe_splits/e7/predicts/cv/combined/{k.name}'

    utils.tiff_merge_mask(k, mask_name1, merge_name, mask_name2)

In [None]:
output/ffpe_splits/