# Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
os.chdir('..')
sys.path.append('src')

In [None]:
from PIL import Image
from pathlib import Path
from functools import partial, reduce
from collections import defaultdict
import multiprocessing as mp
from contextlib import contextmanager

import cv2
import numpy as np
from tqdm.auto import tqdm


import utils
import data
from sampler import GdalSampler

# Code

## Narezator

In [None]:
@contextmanager
def poolcontext(*args, **kwargs):
    pool = mp.Pool(*args, **kwargs)
    yield pool
    pool.terminate()
    
def mp_func(foo, args, n):
    args_chunks = [args[i:i + n] for i in range(0, len(args), n)]
    with poolcontext(processes=n) as pool:
        pool.map(foo, args_chunks)
    
def mp_foo(foo, args): return foo(*args)

In [None]:
def mp_sampler(dst, i_fn, m_fn, a_fn, wh, wh_mask, idxs):
    s = GdalSampler(i_fn, m_fn, a_fn, wh, wh_mask)
    for idx in idxs:
        i,m = s[idx]  
        
        img_dir = dst / 'imgs' / i_fn.with_suffix('').name
        os.makedirs(str(img_dir), exist_ok=True)
        
        mask_dir = dst / 'masks' / i_fn.with_suffix('').name
        os.makedirs(str(mask_dir), exist_ok=True)
        
        orig_name = (str(idx).zfill(6) + '.png')
        img_name = img_dir / orig_name 
        mask_name = mask_dir /orig_name
        
        #print(mask_name, m.shape, m.dtype)
        
        i = i.transpose(1,2,0)
        m = 255 * np.expand_dims(m,-1).repeat(3,-1).astype(np.uint8)
        
        i = cv2.resize(cv2.cvtColor(i, cv2.COLOR_BGR2RGB), (wh[0]//2, wh[1]//2))
        m = cv2.resize(m, (wh[0]//2, wh[1]//2))
        
        cv2.imwrite(str(img_name), i)
        cv2.imwrite(str(mask_name), m)
    return

In [None]:
imgs_path = Path('input/hm/train')
masks_path = Path('input/bigmasks/')
#p = Path('/home/sokolov/work/webinf/data/kidney/train/')
dst_path = Path('input/cuts1024x05')
NUM_PROC = 16
wh = (1024,1024)

In [None]:
filt = partial(utils.filter_ban_str_in_name, bans=['-', '_ell'])
ann_fns = sorted(utils.get_filenames(imgs_path, '*.json', filt))
masks_fns = sorted(utils.get_filenames(masks_path, '*.tiff', filt))
img_fns = sorted([a.with_suffix('.tiff') for a in ann_fns])
img_fns, ann_fns, masks_fns

In [None]:
# for i,m in _s:
#     print(i.shape, m.shape)

In [None]:
#assert  False , 'DO ONCE'
for i_fn, m_fn, a_fn in tqdm(zip(img_fns, masks_fns, ann_fns)):
    const_args = i_fn, m_fn, a_fn, wh, wh
    _s = GdalSampler(*const_args)
    part_samp = partial(mp_sampler, *(dst_path, *const_args))
    mp_func(part_samp, range(len(_s)), NUM_PROC)
    #break

## Datasets

In [None]:
root = Path('input/cuts1024x05/')
sd = data.SegmentDataset(root / 'imgs', root / 'masks', mode_train=False)
len(sd)

In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

@interact(idx=(0, len(sd)),continuous_update=False)
def view(idx): return sd._view(idx)

## Dataloaders

In [None]:
from config import cfg, cfg_init
from pprint import pprint

In [None]:
pprint(cfg)

In [None]:
cfg_init('src/configs/unet.yaml')
cfg['TRANSFORMERS']['TRAIN']['AUG'] = 'light'
cfg['PARALLEL']['DDP'] = False
cfg['DATA']['TRAIN']['PRELOAD'] = False

In [None]:
datasets = data.build_datasets(cfg)
tds = datasets['TRAIN']
vds = datasets['VALID']
len(tds)

In [None]:
%%timeit -n 10 -r 10
tds[0]

In [None]:
i.shape, i.dtype, i.max(), i.mean(), i.std()

In [None]:
m.shape, m.dtype, m.max()#, m.mean(), m.std()

In [None]:
dls = data.build_dataloaders(cfg, datasets, pin=True, drop_last=False)
tdl = dls['TRAIN']

In [None]:
%%timeit -n 2 -r 2
for xb, yb in tdl:
    pass
    #break

In [None]:
xb.shape, xb.dtype, xb.mean(), xb.std()

In [None]:
yb.shape, yb.dtype, yb.max()