# Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
os.chdir('..')
sys.path.append('src')

In [None]:
from PIL import Image
from pathlib import Path
from functools import partial, reduce
from collections import defaultdict
import multiprocessing as mp
from contextlib import contextmanager

import cv2
import numpy as np
from tqdm.auto import tqdm


import utils
import data
import sampler

## Stats

### Train
- input/hm/train/4ef6695ce.tiff 1 [167.48280243 131.99601198 172.1840991 ] [53.6715833  74.18087042 50.01145663]
- input/hm/train/b9a3865fc.tiff 3 [179.12058181 156.16663604 190.57224837] [57.33262864 66.68898569 53.26343919]
- input/hm/train/e79de561c.tiff 3 [168.63503256 140.3692677  179.29400798] [47.20465075 63.66449426 39.38715776]
- input/hm/train/8242609fa.tiff 3 [168.39762381 146.67506496 179.53406931] [68.85119601 75.34678935 66.46572523]
- input/hm/train/cb2d976f4.tiff 3 [156.64766308 137.08461345 163.56018869] [85.22797112 86.17162812 85.83934117]
- input/hm/train/26dc41664.tiff 1 [148.37026934 117.8351192  152.91550503] [74.70058493 81.79527692 73.60989621]
- input/hm/train/b2dc8411c.tiff 3 [145.70774605 130.53835743 153.51077155] [90.19669735 90.07284394 91.13878548]
- input/hm/train/afa5e8098.tiff 3 [144.79319853 121.24794863 154.07814847] [69.64398246 70.21465338 69.66028281]
- input/hm/train/0486052bb.tiff 3 [155.04365695 140.8247772  163.27546087] [86.20611987 87.26473343 86.73144338]
- input/hm/train/1e2425f28.tiff 1 [155.27908076 109.30259728 155.74962968] [60.41321818 74.58395251 58.64478221]
- input/hm/train/c68fe75ea.tiff 1 [177.38737838 146.33150453 185.17719842] [34.14849425 52.73887964 25.65912719]
- input/hm/train/aaa6a05cc.tiff 3 [168.53653974 146.01495027 178.91997248] [70.9594629  80.8427825  67.05318295]
- input/hm/train/54f2eec69.tiff 3 [159.57231947 134.360238   164.28528555] [69.91997402 77.88897022 67.87669295]
- input/hm/train/095bf7a1f.tiff 1 [141.92589491 113.2553824  144.47921706] [77.36870257 81.80458049 77.13849731]
- input/hm/train/2f6ecfcdf.tiff 3 [150.00407259 134.96295758 157.40463786] [89.23076809 89.6362252  89.93819569]


### Test
- input/hm/test/57512b7f1.tiff 1 [151.83879508 129.67205882 155.49909348] [78.47813674 85.79231764 77.58990357]
- input/hm/test/2ec3f1bb9.tiff 3 [176.16665974 150.29656892 187.40280662] [59.38260377 69.46464358 55.26716755]
- input/hm/test/aa05346ff.tiff 1 [166.86605324 137.62296148 174.57705023] [45.55560103 56.21724257 43.40574615]
- input/hm/test/3589adb90.tiff 3 [172.07196767 155.13386244 181.57171249] [72.44817365 78.52393954 70.44612831]
- input/hm/test/d488c759a.tiff 1 [144.37003214 112.24489041 151.02178863] [71.41163807 79.79139525 70.18780297]


# Code

## Narezator

### Objects

In [None]:
@contextmanager
def poolcontext(*args, **kwargs):
    pool = mp.Pool(*args, **kwargs)
    yield pool
    pool.terminate()
    
def mp_func(foo, args, n):
    args_chunks = [args[i:i + n] for i in range(0, len(args), n)]
    with poolcontext(processes=n) as pool:
        pool.map(foo, args_chunks)
    
def mp_foo(foo, args): return foo(*args)

In [None]:
def to_gray(i):return np.mean(i,-1,keepdims=True).repeat(3,-1)

def mp_sampler(dst, i_fn, m_fn, a_fn, wh, wh_mask, idxs):
    s = sampler.GdalSampler(i_fn, m_fn, a_fn, wh, wh_mask)
    SCALE = 4
    
    for idx in idxs:
        i,m = s[idx]  
        
        img_dir = dst / 'imgs' / i_fn.with_suffix('').name
        os.makedirs(str(img_dir), exist_ok=True)
        
        mask_dir = dst / 'masks' / i_fn.with_suffix('').name
        os.makedirs(str(mask_dir), exist_ok=True)
        
        orig_name = (str(idx).zfill(6) + '.png')
        img_name = img_dir / orig_name 
        mask_name = mask_dir /orig_name
         
        
        i = i.transpose(1,2,0)
        m = m.transpose(1,2,0)
        
        #i = i.mean(-1, keepdims=True).astype(np.uint8)
        #i = i.repeat(3,-1)
        #print(i.shape, i.dtype, m.shape, m.dtype)
        i = cv2.cvtColor(i, cv2.COLOR_BGR2RGB)
        
        m = 255 * m.repeat(3,-1).astype(np.uint8)
        
        
        i = cv2.resize(i, (wh[0]//SCALE, wh[1]//SCALE), interpolation=cv2.INTER_AREA)
        m = cv2.resize(m, (wh[0]//SCALE, wh[1]//SCALE), interpolation=cv2.INTER_NEAREST)
        
        
        cv2.imwrite(str(img_name), i)
        cv2.imwrite(str(mask_name), m)
    return

In [None]:
imgs_path = Path('input/hm/train')
masks_path = Path('input/bigmasks/')
#p = Path('/home/sokolov/work/webinf/data/kidney/train/')
dst_path = Path('input/CUTS/cuts1024x25')
NUM_PROC = 16
wh = (1024,1024)

In [None]:
filt = partial(utils.filter_ban_str_in_name, bans=['-', '_ell'])
ann_fns = sorted(utils.get_filenames(imgs_path, '*.json', filt))
masks_fns = sorted(utils.get_filenames(masks_path, '*.tiff', filt))
img_fns = sorted([a.with_suffix('.tiff') for a in ann_fns])
#img_fns, ann_fns, masks_fns

In [None]:
t = np.ones((8,8,3), dtype=np.uint8)
t.shape, t.dtype

In [None]:
cv2.resize(t, (2,2), 3)

In [None]:
# for i,m in _s:
#     print(i.shape, m.shape)

In [None]:
#assert  False , 'DO ONCE'
for i_fn, m_fn, a_fn in tqdm(zip(img_fns, masks_fns, ann_fns)):
    const_args = i_fn, m_fn, a_fn, wh, wh
    _s = sampler.GdalSampler(*const_args)
    part_samp = partial(mp_sampler, *(dst_path, *const_args))
    mp_func(part_samp, range(len(_s)), NUM_PROC)
    #break

### Backgrounds

In [None]:
imgs_path = Path('input/hm/train')
masks_path = Path('input/bigmasks/')
dst_path = Path('input/backs020_x25')
#NUM_PROC = 16
wh = (1024,1024)
pct = .2


In [None]:
filt = partial(utils.filter_ban_str_in_name, bans=['-', '_ell'])
ann_fns = sorted(utils.get_filenames(imgs_path, '*.json', filt))
masks_fns = sorted(utils.get_filenames(masks_path, '*.tiff', filt))
img_fns = sorted([a.with_suffix('.tiff') for a in ann_fns])
#img_fns, ann_fns, masks_fns

In [None]:
idx = 0
img_path = img_fns[idx]
mask_path = masks_fns[idx] 
img_anot_struct_path = img_path.parent / (img_path.stem + '-anatomical-structure.json')
recs = utils.jread(str(ann_fns[idx]))

In [None]:

ni = int(len(recs) * pct)
polys = utils.get_cortex_polygons(utils.jread(img_anot_struct_path))
s = sampler.BackgroundSampler(img_path, mask_path, polys, wh, wh, ni)

In [None]:
#assert  False , 'DO ONCE'
for i_fn, m_fn, a_fn in tqdm(zip(img_fns, masks_fns, ann_fns)):
    SCALE = 4
    
    img_anot_struct_path = i_fn.parent / (i_fn.stem + '-anatomical-structure.json')
    recs = utils.jread(str(a_fn))
    ni = int(len(recs) * pct)
    polys = utils.get_cortex_polygons(utils.jread(img_anot_struct_path))
    s = sampler.BackgroundSampler(i_fn, m_fn, polys, wh, wh, ni)
    
    img_dir = dst_path / 'imgs' / i_fn.with_suffix('').name
    os.makedirs(str(img_dir), exist_ok=True)

    mask_dir = dst_path / 'masks' / i_fn.with_suffix('').name
    os.makedirs(str(mask_dir), exist_ok=True)
    #print(i_fn)
    for idx, (i,m) in enumerate(s):
        #print(i.shape, m.shape)
        orig_name = (str(idx).zfill(6) + '.png')
        img_name = img_dir / orig_name 
        mask_name = mask_dir /orig_name
        
        i = i.transpose(1,2,0)
        #i = i.mean(-1, keepdims=True).astype(np.uint8).repeat(3,-1)
        #print(i.shape, i.dtype, m.shape, m.dtype)
        i = cv2.cvtColor(i, cv2.COLOR_BGR2RGB)
        
        m = m.transpose(1,2,0)
        m = 255 * m.repeat(3,-1).astype(np.uint8)
        
        
        i = cv2.resize(i, (wh[0]//SCALE, wh[1]//SCALE), interpolation=cv2.INTER_AREA)
        m = cv2.resize(m, (wh[0]//SCALE, wh[1]//SCALE), interpolation=cv2.INTER_NEAREST)
        
        cv2.imwrite(str(img_name), i)
        cv2.imwrite(str(mask_name), m)

In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

@interact(idx=(0, len(s)),continuous_update=False)
def view(idx): return Image.fromarray(s[idx][0].transpose(1,2,0))

## Datasets

In [None]:
root = Path('input/backs020_x25/')
sd = data.SegmentDataset(root / 'imgs', root / 'masks', mode_train=False)
len(sd)

In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

@interact(idx=(0, len(sd)-1),continuous_update=False)
def view(idx): return sd._view(idx)

In [None]:
root = Path('input/SPLITS/split1024x25/train/')
sd = data.SegmentDataset(root / 'imgs', root / 'masks', mode_train=False)
len(sd)

In [None]:
tds = data.TagSegmentDataset(root / 'imgs', root / 'masks', mode_train=True)
len(tds)

In [None]:
tot = 0
for _,(_,c) in tds:
    tot+=c

In [None]:
tot/len(tds)

## Dataloaders

In [None]:
from config import cfg, cfg_init
from pprint import pprint

from callbacks import  denorm

In [None]:
cfg_init('src/configs/unet_gelb.yaml')
cfg['TRANSFORMERS']['TRAIN']['AUG'] = 'light_scale'

cfg['PARALLEL']['DDP'] = False
cfg['DATA']['TRAIN']['PRELOAD'] = False
cfg['DATA']['TRAIN']['MULTIPLY']["rate"] = 2
#cfg['DATA']['TRAIN']['DATASETS'] = ['train1024x25']

In [None]:
pprint(cfg)

In [None]:
datasets = data.build_datasets(cfg, dataset_types=['TRAIN','VALID'])

In [None]:
datasets = data.build_datasets(cfg)
tds = datasets['TRAIN']
vds = datasets['VALID']
len(tds)

In [None]:
def show_img(tds, idx):
    img, mask = tds[idx]
    print(img.shape, img.dtype)
    img = denorm(img, cfg.TRANSFORMERS.MEAN, cfg.TRANSFORMERS.STD)
    img = img.squeeze().permute(1,2,0).cpu().numpy()
    img = (img * 255.).astype(np.uint8)
    return Image.fromarray(img)

In [None]:
show_img(tds, 0)

In [None]:
import random

In [None]:
from _data import make_datasets_folds

In [None]:
N_FOLDS = 4
datasets_as_folds = make_datasets_folds(cfg, datasets, N_FOLDS, shuffle=False)

In [None]:
datasets_as_folds

In [None]:
for dss in datasets_as_folds:
    dls = data.build_dataloaders(cfg, dss, pin=True, drop_last=False)
    tdl = dls['TRAIN']
    print(tdl, len(tdl))
    for b in tdl:
        pass

In [None]:
tot = 0
for dss in datasets_as_folds:
    for k, v in dss.items():
        print(len(v))
        if k == 'TRAIN':
            tot += len(v)

In [None]:
tot

In [None]:
act_len = len(tds)//2

In [None]:
idx = 47
show_img(tds, idx)

In [None]:
show_img(tds, idx + act_len)

In [None]:
def pimg(img):
    img = img.squeeze().cpu().numpy()
    img = (img * 255.).astype(np.uint8)
    return Image.fromarray(img)

In [None]:
import torch
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from model import load_model, FoldModel
mps = Path('output/2021_Mar_11_23_19_45_PAMBUH').rglob('*.pth')
mps = list(mps)
mps

In [None]:
m1 = load_model(1, mps[-1])

In [None]:
ms = [load_model(1, mp) for mp in mps if '499' in str(mp)]
ms

In [None]:
class FoldModel(torch.nn.Module):
    def __init__(self, models):
        super(FoldModel, self).__init__()
        self.ms = models
        
    def forward(self, x):
        res = torch.stack([m(x) for m in self.ms])
        print(res.shape)
        return res.mean(0)

In [None]:
fold_model = FoldModel(ms)

In [None]:
len(vds)

In [None]:
idx = 120
show_img(vds, idx)

In [None]:
i,m = vds[idx]
i.shape, i.mean(), i.std()

In [None]:
i = i.view(1,*i.shape).repeat(2,1,1,1)
i.shape

In [None]:
with torch.no_grad():
    res = torch.sigmoid(fold_model(i))

In [None]:
plt.hist(res.cpu().numpy().flatten(), bins=50);

In [None]:
res.shape

In [None]:
pimg(res[0]>.7)

In [None]:
pimg(m)

In [None]:
%%timeit -n 10 -r 10
tds[0]

In [None]:
i.shape, i.dtype, i.max(), i.mean(), i.std()

In [None]:
m.shape, m.dtype, m.max()#, m.mean(), m.std()

In [None]:
dls = data.build_dataloaders(cfg, datasets, pin=True, drop_last=False)
tdl = dls['TRAIN']

In [None]:
dls

In [None]:
for xb, yb in dls['VALID2']:
    break

In [None]:
%%timeit -n 2 -r 2
for xb, yb in tdl:
    pass
    #break

In [None]:
xb.shape, xb.dtype, xb.mean(), xb.std()

In [None]:
yb.shape, yb.dtype, yb.max()

In [None]:
import rasterio as rio

In [None]:
path = Path('input/hm/train/')
imgs = list(path.glob('*.tiff'))
imgs

In [None]:
def save_tiff_uint8_3_band(img, path):
    assert img.dtype == np.uint8
    if img.max() <= 1. : print(f"Warning: saving tiff with max value is <= 1, {path}")
    _, h, w = img.shape
    dst = rio.open(path, 'w', driver='GTiff', height=h, width=w, count=3, dtype=np.uint8, interleave='band', compress='deflate')
    dst.write(img)
    dst.close()
    del dst

In [None]:
for img in imgs:
    ds = rio.open(str(img))
    #break
    if ds.count == 1:
        print(f'Single channel: {img}')
        dss = ds.subdatasets
        i = np.zeros((3, *ds.shape), dtype=np.uint8)
        for j, ds in enumerate(dss):
            ds = rio.open(ds)
            i[j]  = ds.read()
        new_name = str(img.parent) + f'/upd/{img.name}'
        save_tiff_uint8_3_band(i, new_name)
    else:
        print(f'3 channels {img}')
        i = ds.read()
        
    print(i.mean((1,2)), i.std((1,2)))
    #break

In [None]:
ds.profile

In [None]:
tds = rio.open(new_name)

In [None]:
tds.shape, tds.count

In [None]:
tds.compression

In [None]:
tds.profile

In [None]:
i.shape

In [None]:
i.mean((1,2))

In [None]:
t.shape, t.dtype, t.mean()

In [None]:
t.shape, t.dtype, t.mean()

In [None]:
t.shape, t.dtype, t.mean()