# Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
os.chdir('..')
sys.path.append('src')

In [None]:
from PIL import Image
from pathlib import Path
from functools import partial, reduce
from collections import defaultdict
import multiprocessing as mp
from contextlib import contextmanager

import cv2
import numpy as np
from tqdm.auto import tqdm


import utils
import data
import sampler

# Code

## Narezator

### Objects

In [None]:
@contextmanager
def poolcontext(*args, **kwargs):
    pool = mp.Pool(*args, **kwargs)
    yield pool
    pool.terminate()
    
def mp_func(foo, args, n):
    args_chunks = [args[i:i + n] for i in range(0, len(args), n)]
    with poolcontext(processes=n) as pool:
        pool.map(foo, args_chunks)
    
def mp_foo(foo, args): return foo(*args)

In [None]:
def to_gray(i):return np.mean(i,-1,keepdims=True).repeat(3,-1)

def mp_sampler(dst, i_fn, m_fn, a_fn, wh, wh_mask, idxs):
    s = sampler.GdalSampler(i_fn, m_fn, a_fn, wh, wh_mask)
    SCALE = 4
    TO_GRAY = True
    
    for idx in idxs:
        i,m = s[idx]  
        
        img_dir = dst / 'imgs' / i_fn.with_suffix('').name
        os.makedirs(str(img_dir), exist_ok=True)
        
        mask_dir = dst / 'masks' / i_fn.with_suffix('').name
        os.makedirs(str(mask_dir), exist_ok=True)
        
        orig_name = (str(idx).zfill(6) + '.png')
        img_name = img_dir / orig_name 
        mask_name = mask_dir /orig_name
        
        #print(mask_name, m.shape, m.dtype)
        
        i = i.transpose(1,2,0)
        m = 255 * np.expand_dims(m,-1).repeat(3,-1).astype(np.uint8)
        
        i = cv2.resize(cv2.cvtColor(i, cv2.COLOR_BGR2RGB), (wh[0]//SCALE, wh[1]//SCALE))
        m = cv2.resize(m, (wh[0]//SCALE, wh[1]//SCALE))
        
        if TO_GRAY: i = to_gray(i).astype(np.uint8)
        
        cv2.imwrite(str(img_name), i)
        cv2.imwrite(str(mask_name), m)
    return

In [None]:
imgs_path = Path('input/hm/train')
masks_path = Path('input/bigmasks/')
#p = Path('/home/sokolov/work/webinf/data/kidney/train/')
dst_path = Path('input/CUTS/cuts1024x25_gray')
NUM_PROC = 16
wh = (1024,1024)

In [None]:
filt = partial(utils.filter_ban_str_in_name, bans=['-', '_ell'])
ann_fns = sorted(utils.get_filenames(imgs_path, '*.json', filt))
masks_fns = sorted(utils.get_filenames(masks_path, '*.tiff', filt))
img_fns = sorted([a.with_suffix('.tiff') for a in ann_fns])
img_fns, ann_fns, masks_fns

In [None]:
# for i,m in _s:
#     print(i.shape, m.shape)

In [None]:
#assert  False , 'DO ONCE'
for i_fn, m_fn, a_fn in tqdm(zip(img_fns, masks_fns, ann_fns)):
    const_args = i_fn, m_fn, a_fn, wh, wh
    _s = sampler.GdalSampler(*const_args)
    part_samp = partial(mp_sampler, *(dst_path, *const_args))
    mp_func(part_samp, range(len(_s)), NUM_PROC)
    #break

### Backgrounds

In [None]:
imgs_path = Path('input/hm/train')
masks_path = Path('input/bigmasks/')
dst_path = Path('input/backs030_x25_gray')
#NUM_PROC = 16
wh = (1024,1024)
pct = .3


In [None]:
filt = partial(utils.filter_ban_str_in_name, bans=['-', '_ell'])
ann_fns = sorted(utils.get_filenames(imgs_path, '*.json', filt))
masks_fns = sorted(utils.get_filenames(masks_path, '*.tiff', filt))
img_fns = sorted([a.with_suffix('.tiff') for a in ann_fns])
img_fns, ann_fns, masks_fns

In [None]:
{i.stem:0 for i in img_fns}

In [None]:
idx = 0
img_path = img_fns[idx]
mask_path = masks_fns[idx] 
img_anot_struct_path = img_path.parent / (img_path.stem + '-anatomical-structure.json')
recs = utils.jread(str(ann_fns[idx]))

In [None]:

ni = int(len(recs) * pct)
polys = utils.get_cortex_polygons(utils.jread(img_anot_struct_path))
s = sampler.BackgroundSampler(img_path, mask_path, polys, wh, wh, ni)

In [None]:
#assert  False , 'DO ONCE'
for i_fn, m_fn, a_fn in tqdm(zip(img_fns, masks_fns, ann_fns)):
    SCALE = 4
    TO_GRAY = True
    
    img_anot_struct_path = i_fn.parent / (i_fn.stem + '-anatomical-structure.json')
    recs = utils.jread(str(a_fn))
    ni = int(len(recs) * pct)
    polys = utils.get_cortex_polygons(utils.jread(img_anot_struct_path))
    s = sampler.BackgroundSampler(i_fn, m_fn, polys, wh, wh, ni)
    
    img_dir = dst_path / 'imgs' / i_fn.with_suffix('').name
    os.makedirs(str(img_dir), exist_ok=True)

    mask_dir = dst_path / 'masks' / i_fn.with_suffix('').name
    os.makedirs(str(mask_dir), exist_ok=True)
    #print(i_fn)
    for idx, (i,m) in enumerate(s):
        #print(i.shape, m.shape)
        orig_name = (str(idx).zfill(6) + '.png')
        img_name = img_dir / orig_name 
        mask_name = mask_dir /orig_name
        
        i = i.transpose(1,2,0)
        m = 255 * np.expand_dims(m,-1).repeat(3,-1).astype(np.uint8)
        
        i = cv2.resize(cv2.cvtColor(i, cv2.COLOR_BGR2RGB), (wh[0]//SCALE, wh[1]//SCALE))
        m = cv2.resize(m, (wh[0]//SCALE, wh[1]//SCALE))
        
        if TO_GRAY: i = to_gray(i).astype(np.uint8)
        
        cv2.imwrite(str(img_name), i)
        cv2.imwrite(str(mask_name), m)

In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

@interact(idx=(0, len(s)),continuous_update=False)
def view(idx): return Image.fromarray(s[idx][0].transpose(1,2,0))

## Datasets

In [None]:
root = Path('input/backs030/')
sd = data.SegmentDataset(root / 'imgs', root / 'masks', mode_train=False)
len(sd)

In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

@interact(idx=(0, len(sd)),continuous_update=False)
def view(idx): return sd._view(idx)

In [None]:
sd[0]

In [None]:
sd

In [None]:
root = Path('input/SPLITS/split1024x25/train/')
sd = data.SegmentDataset(root / 'imgs', root / 'masks', mode_train=False)
len(sd)

In [None]:
tds = data.TagSegmentDataset(root / 'imgs', root / 'masks', mode_train=True)
len(tds)

In [None]:
tot = 0
for _,(_,c) in tds:
    tot+=c

In [None]:
tot/len(tds)

## Dataloaders

In [None]:
from config import cfg, cfg_init
from pprint import pprint

from callbacks import  denorm

In [None]:
cfg_init('src/configs/unet_gelb.yaml')
cfg['TRANSFORMERS']['TRAIN']['AUG'] = 'light_scale'

cfg['PARALLEL']['DDP'] = False
cfg['DATA']['TRAIN']['PRELOAD'] = False
cfg['DATA']['TRAIN']['MULTIPLY']["rate"] = 2
#cfg['DATA']['TRAIN']['DATASETS'] = ['train1024x25']

In [None]:
pprint(cfg)

In [None]:
datasets = data.build_datasets(cfg, dataset_types=['TRAIN','VALID'])

In [None]:
datasets = data.build_datasets(cfg)
tds = datasets['TRAIN']
vds = datasets['VALID']
len(tds)

In [None]:
def show_img(tds, idx):
    img, mask = tds[idx]
    print(img.shape, img.dtype)
    img = denorm(img, cfg.TRANSFORMERS.MEAN, cfg.TRANSFORMERS.STD)
    img = img.squeeze().permute(1,2,0).cpu().numpy()
    img = (img * 255.).astype(np.uint8)
    return Image.fromarray(img)

In [None]:
show_img(tds, 0)

In [None]:
import random

In [None]:
from _data import make_datasets_folds

In [None]:
N_FOLDS = 4
datasets_as_folds = make_datasets_folds(cfg, datasets, N_FOLDS, shuffle=False)

In [None]:
datasets_as_folds

In [None]:
for dss in datasets_as_folds:
    dls = data.build_dataloaders(cfg, dss, pin=True, drop_last=False)
    tdl = dls['TRAIN']
    print(tdl, len(tdl))
    for b in tdl:
        pass

In [None]:
tot = 0
for dss in datasets_as_folds:
    for k, v in dss.items():
        print(len(v))
        if k == 'TRAIN':
            tot += len(v)

In [None]:
tot

In [None]:
act_len = len(tds)//2

In [None]:
idx = 47
show_img(tds, idx)

In [None]:
show_img(tds, idx + act_len)

In [None]:
def pimg(img):
    img = img.squeeze().cpu().numpy()
    img = (img * 255.).astype(np.uint8)
    return Image.fromarray(img)

In [None]:
import torch
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from model import load_model, FoldModel
mps = Path('output/2021_Feb_20_18_03_52_PAMBUH/models/').glob('*.pth')
mps = list(mps)
mps

In [None]:
m1 = load_model(1, mps[-1])

In [None]:
ms = [load_model(1, mp) for mp in mps]

In [None]:
class FoldModel(torch.nn.Module):
    def __init__(self, models):
        super(FoldModel, self).__init__()
        self.ms = models
        
    def forward(self, x):
        res = torch.stack([m(x) for m in self.ms])
        print(res.shape)
        return res.mean(0)

In [None]:
fold_model = FoldModel(ms)

In [None]:
idx = 22
show_img(vds, idx)

In [None]:
i,m = vds[idx]
i.shape, i.mean(), i.std()

In [None]:
i = i.view(1,*i.shape).repeat(2,1,1,1)
i.shape

In [None]:
with torch.no_grad():
    res = torch.sigmoid(fold_model(i))

In [None]:
plt.hist(res.cpu().numpy().flatten(), bins=50);

In [None]:
res.shape

In [None]:
pimg(res[0]>.7)

In [None]:
pimg(m)

In [None]:
150/255

In [None]:
%%timeit -n 10 -r 10
tds[0]

In [None]:
i.shape, i.dtype, i.max(), i.mean(), i.std()

In [None]:
m.shape, m.dtype, m.max()#, m.mean(), m.std()

In [None]:
dls = data.build_dataloaders(cfg, datasets, pin=True, drop_last=False)
tdl = dls['TRAIN']

In [None]:
dls

In [None]:
for xb, yb in dls['VALID2']:
    break

In [None]:
%%timeit -n 2 -r 2
for xb, yb in tdl:
    pass
    #break

In [None]:
xb.shape, xb.dtype, xb.mean(), xb.std()

In [None]:
yb.shape, yb.dtype, yb.max()