# Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
os.chdir('..')
sys.path.append('src')

In [None]:
from PIL import Image
from pathlib import Path
from functools import partial, reduce
from collections import defaultdict
import multiprocessing as mp
from contextlib import contextmanager

import cv2
import numpy as np
from tqdm.auto import tqdm


import utils
import data
import sampler

import matplotlib.pyplot as plt
%matplotlib inline

## Stats

### Train
- input/hm/train/4ef6695ce.tiff 1 [167.48280243 131.99601198 172.1840991 ] [53.6715833  74.18087042 50.01145663]
- input/hm/train/b9a3865fc.tiff 3 [179.12058181 156.16663604 190.57224837] [57.33262864 66.68898569 53.26343919]
- input/hm/train/e79de561c.tiff 3 [168.63503256 140.3692677  179.29400798] [47.20465075 63.66449426 39.38715776]
- input/hm/train/8242609fa.tiff 3 [168.39762381 146.67506496 179.53406931] [68.85119601 75.34678935 66.46572523]
- input/hm/train/cb2d976f4.tiff 3 [156.64766308 137.08461345 163.56018869] [85.22797112 86.17162812 85.83934117]
- input/hm/train/26dc41664.tiff 1 [148.37026934 117.8351192  152.91550503] [74.70058493 81.79527692 73.60989621]
- input/hm/train/b2dc8411c.tiff 3 [145.70774605 130.53835743 153.51077155] [90.19669735 90.07284394 91.13878548]
- input/hm/train/afa5e8098.tiff 3 [144.79319853 121.24794863 154.07814847] [69.64398246 70.21465338 69.66028281]
- input/hm/train/0486052bb.tiff 3 [155.04365695 140.8247772  163.27546087] [86.20611987 87.26473343 86.73144338]
- input/hm/train/1e2425f28.tiff 1 [155.27908076 109.30259728 155.74962968] [60.41321818 74.58395251 58.64478221]
- input/hm/train/c68fe75ea.tiff 1 [177.38737838 146.33150453 185.17719842] [34.14849425 52.73887964 25.65912719]
- input/hm/train/aaa6a05cc.tiff 3 [168.53653974 146.01495027 178.91997248] [70.9594629  80.8427825  67.05318295]
- input/hm/train/54f2eec69.tiff 3 [159.57231947 134.360238   164.28528555] [69.91997402 77.88897022 67.87669295]
- input/hm/train/095bf7a1f.tiff 1 [141.92589491 113.2553824  144.47921706] [77.36870257 81.80458049 77.13849731]
- input/hm/train/2f6ecfcdf.tiff 3 [150.00407259 134.96295758 157.40463786] [89.23076809 89.6362252  89.93819569]


### Test
- input/hm/test/57512b7f1.tiff 1 [151.83879508 129.67205882 155.49909348] [78.47813674 85.79231764 77.58990357]
- input/hm/test/2ec3f1bb9.tiff 3 [176.16665974 150.29656892 187.40280662] [59.38260377 69.46464358 55.26716755]
- input/hm/test/aa05346ff.tiff 1 [166.86605324 137.62296148 174.57705023] [45.55560103 56.21724257 43.40574615]
- input/hm/test/3589adb90.tiff 3 [172.07196767 155.13386244 181.57171249] [72.44817365 78.52393954 70.44612831]
- input/hm/test/d488c759a.tiff 1 [144.37003214 112.24489041 151.02178863] [71.41163807 79.79139525 70.18780297]


# Code

## Narezator

### Objects

In [None]:
@contextmanager
def poolcontext(*args, **kwargs):
    pool = mp.Pool(*args, **kwargs)
    yield pool
    pool.terminate()
    
def mp_func(foo, args, n):
    args_chunks = [args[i:i + n] for i in range(0, len(args), n)]
    with poolcontext(processes=n) as pool:
        pool.map(foo, args_chunks)
    
def mp_foo(foo, args): return foo(*args)

In [None]:
def to_gray(i):return np.mean(i,-1,keepdims=True).repeat(3,-1)

def mp_sampler(dst, i_fn, m_fn, a_fn, wh, b_fn, idxs):
    domains = {  '4ef6695ce': 1,
                 'b9a3865fc': 0,
                 'e79de561c': 1,
                 '8242609fa': 0,
                 'cb2d976f4': 0,
                 '26dc41664': 1,
                 'b2dc8411c': 0,
                 'afa5e8098': 1,
                 '0486052bb': 0,
                 '1e2425f28': 1,
                 'c68fe75ea': 1,
                 'aaa6a05cc': 0,
                 '54f2eec69': 1,
                 '095bf7a1f': 1,
                 '2f6ecfcdf': 0}
    _wh, _wh_mask = wh, wh
#     dr = .5 / .65
#     if not domains[i_fn.stem]:
#         _wh = int(_wh[0] * dr), int(_wh[1] * dr)
#         _wh_mask  = _wh#*= .65 / .5
        
    #s = sampler.GdalSampler(i_fn, m_fn, a_fn, _wh)
    s = sampler.GdalSampler(i_fn, m_fn, a_fn, _wh, b_fn)
    SCALE = 4
    
    for idx in idxs:
        i,m,b = s[idx]  
        #i,m = s[idx]  
        
        #print(idx, i.shape, m.shape)
        orig_name = (str(idx).zfill(6) + '.png')
        
        border_dir = dst / 'borders' / i_fn.with_suffix('').name
        os.makedirs(str(border_dir), exist_ok=True)
        border_name = border_dir / orig_name
        b = b.transpose(1,2,0)
        b = 255 * b.repeat(3,-1).astype(np.uint8)
        b = cv2.resize(b, (wh[0]//SCALE, wh[1]//SCALE), interpolation=cv2.INTER_NEAREST)
        cv2.imwrite(str(border_name), b)
        
        
        img_dir = dst / 'imgs' / i_fn.with_suffix('').name
        os.makedirs(str(img_dir), exist_ok=True)
        
        mask_dir = dst / 'masks' / i_fn.with_suffix('').name
        os.makedirs(str(mask_dir), exist_ok=True)
        
        img_name = img_dir / orig_name 
        mask_name = mask_dir /orig_name
         
        i = i.transpose(1,2,0)
        m = m.transpose(1,2,0)
        
        #i = i.mean(-1, keepdims=True).astype(np.uint8)
        #i = i.repeat(3,-1)
        #print(i.shape, i.dtype, m.shape, m.dtype)
        i = cv2.cvtColor(i, cv2.COLOR_BGR2RGB)
        
        m = 255 * m.repeat(3,-1).astype(np.uint8)
        
        i = cv2.resize(i, (wh[0]//SCALE, wh[1]//SCALE), interpolation=cv2.INTER_AREA)
        m = cv2.resize(m, (wh[0]//SCALE, wh[1]//SCALE), interpolation=cv2.INTER_NEAREST)
    
        cv2.imwrite(str(img_name), i)
        cv2.imwrite(str(mask_name), m)
    return

In [None]:
imgs_path = Path('input/hm/train')
masks_path = Path('input/masks/bigmasks/')
borders_path = Path('input/masks/borders/')
dst_path = Path('input/CUTS/cuts_B_1536x25')

# imgs_path = Path('input/scleros_glomi/')
# masks_path = Path('input/scleros_glomi/scle_masks/')
# dst_path = Path('input/scleros_glomi/scle_cuts_1024/')

NUM_PROC = 16
wh = (1024 + 512, 1024 + 512)
#wh = (1024, 1024)

In [None]:
filt = partial(utils.filter_ban_str_in_name, bans=['-', '_ell', '_sc'])
ann_fns = sorted(utils.get_filenames(imgs_path, '*.json', filt))
masks_fns = sorted(utils.get_filenames(masks_path, '*.tiff', filt))
borders_fns = sorted(utils.get_filenames(borders_path, '*.tiff', filt))
img_fns = sorted([a.with_suffix('.tiff') for a in ann_fns])
#img_fns, ann_fns, masks_fns, borders_fns

In [None]:
#assert  False , 'DO ONCE'
for i_fn, m_fn, a_fn, b_fn in tqdm(zip(img_fns, masks_fns, ann_fns, borders_fns)):
    const_args = i_fn, m_fn, a_fn, wh
    _s = sampler.GdalSampler(*const_args)
    part_samp = partial(mp_sampler, *(dst_path, *const_args, b_fn))
    mp_func(part_samp, range(len(_s)), NUM_PROC)
    #break

### Backgrounds

In [None]:
imgs_path = Path('input/hm/train')
masks_path = Path('input/masks/bigmasks/')
borders_path = Path('input/masks/borders/')
dst_path = Path('input/backs/backs_x25_cortex_b')
#NUM_PROC = 16
poly_type = 'Cortex'
#poly_type = 'Medulla'
#poly_type = None
NN = None#100
wh = (1024, 1024)
pct = .2


In [None]:
filt = partial(utils.filter_ban_str_in_name, bans=['-', '_ell'])
ann_fns = sorted(utils.get_filenames(imgs_path, '*.json', filt))
masks_fns = sorted(utils.get_filenames(masks_path, '*.tiff', filt))
borders_fns = sorted(utils.get_filenames(borders_path, '*.tiff', filt))
img_fns = sorted([a.with_suffix('.tiff') for a in ann_fns])
#img_fns, ann_fns, masks_fns

In [None]:
idx = 0
img_path = img_fns[idx]
mask_path = masks_fns[idx] 
img_anot_struct_path = img_path.parent / (img_path.stem + '-anatomical-structure.json')
recs = utils.jread(str(ann_fns[idx]))

In [None]:
#assert  False , 'DO ONCE'
for i_fn, m_fn, a_fn, b_fn in tqdm(zip(img_fns, masks_fns, ann_fns, borders_fns)):
    SCALE = 4
    
    img_anot_struct_path = i_fn.parent / (i_fn.stem + '-anatomical-structure.json')
    recs = utils.jread(str(a_fn))
    ni = NN if NN is not None else int(len(recs) * pct)
    
    if poly_type is not None:
        polys = utils.get_polygons_by_type(utils.jread(img_anot_struct_path), poly_type)
    else: polys = None
    
    s = sampler.BackgroundSampler(i_fn, m_fn, polys, wh, ni, border_path=b_fn)
    
    img_dir = dst_path / 'imgs' / i_fn.with_suffix('').name
    os.makedirs(str(img_dir), exist_ok=True)

    mask_dir = dst_path / 'masks' / i_fn.with_suffix('').name
    os.makedirs(str(mask_dir), exist_ok=True)
    
    border_dir = dst_path / 'borders' / i_fn.with_suffix('').name
    os.makedirs(str(border_dir), exist_ok=True)
        
    #for idx, (i,m) in enumerate(s):
    for idx, (i,m,b) in enumerate(s):
    
        #print(i.shape, m.shape, i.mean())

        if (i.mean() < 10 or i.mean() > 245) and (np.random.random() > .2):
            continue
            
        orig_name = (str(idx).zfill(6) + '.png')
        

        border_name = border_dir / orig_name
        b = b.transpose(1,2,0)
        b = 255 * b.repeat(3,-1).astype(np.uint8)
        b = cv2.resize(b, (wh[0]//SCALE, wh[1]//SCALE), interpolation=cv2.INTER_NEAREST)
        cv2.imwrite(str(border_name), b)
        
        
        
        img_name = img_dir / orig_name 
        mask_name = mask_dir /orig_name
        
        i = i.transpose(1,2,0)
        #i = i.mean(-1, keepdims=True).astype(np.uint8).repeat(3,-1)
        #print(i.shape, i.dtype, m.shape, m.dtype)
        i = cv2.cvtColor(i, cv2.COLOR_BGR2RGB)
        
        m = m.transpose(1,2,0)
        m = 255 * m.repeat(3,-1).astype(np.uint8)
        
        
        i = cv2.resize(i, (wh[0]//SCALE, wh[1]//SCALE), interpolation=cv2.INTER_AREA)
        m = cv2.resize(m, (wh[0]//SCALE, wh[1]//SCALE), interpolation=cv2.INTER_NEAREST)
        
        cv2.imwrite(str(img_name), i)
        cv2.imwrite(str(mask_name), m)

In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

@interact(idx=(0, len(s)),continuous_update=False)
def view(idx): return Image.fromarray(s[idx][0].transpose(1,2,0))

### Plygons

In [None]:
imgs_path = Path('input/hm/test/')
dst_path = Path('input/ssl_cortex')
#NUM_PROC = 16
poly_type = 'Cortex'
#poly_type = 'Medulla'
#poly_type = None
NN = 100
wh = (1024, 1024)

In [None]:
filt = partial(utils.filter_ban_str_in_name, bans=['-', '_ell'])
img_fns = sorted(utils.get_filenames(imgs_path, '*.tiff', filt))
img_fns

In [None]:
#assert  False , 'DO ONCE'
for i_fn, m_fn, a_fn in tqdm(zip(img_fns, masks_fns, ann_fns)):
    SCALE = 4
    
    img_anot_struct_path = i_fn.parent / (i_fn.stem + '-anatomical-structure.json')
    recs = utils.jread(str(a_fn))
    ni = NN if NN is not None else int(len(recs) * pct)
    
    if poly_type is not None:
        polys = utils.get_polygons_by_type(utils.jread(img_anot_struct_path), poly_type)
    else:
        polys = None
    s = sampler.PolySampler(i_fn, polys, wh, ni)
    
    img_dir = dst_path / 'imgs' / i_fn.with_suffix('').name
    os.makedirs(str(img_dir), exist_ok=True)

    for idx, i in enumerate(s):
        #print(i.shape, m.shape, i.mean())

        if (i.mean() < 10 or i.mean() > 245) and (np.random.random() > .2):
            continue
            
        orig_name = (str(idx).zfill(6) + '.png')
        img_name = img_dir / orig_name 
        
        i = i.transpose(1,2,0)
        i = cv2.cvtColor(i, cv2.COLOR_BGR2RGB)
        i = cv2.resize(i, (wh[0]//SCALE, wh[1]//SCALE), interpolation=cv2.INTER_AREA)
        cv2.imwrite(str(img_name), i)

## Extra dataset

In [None]:
l = Path('input/extra_data/external_datasetA_orig/masks/').rglob('*.png')
l = list(l)

In [None]:
import shutil

In [None]:
dst = Path('input/extra_data/external_datasetA_/')

In [None]:
# cnt = 0
# for m in l:
#     s = np.array(Image.open(m)).sum() / 255
#     if s > 100:
#         shutil.copy(str(m), str(dst / 'masks/b' / m.name))
#         img_name = m.parent.parent.parent / 'imgs/a/' / m.name
        
#         shutil.copy(str(img_name), str(dst / 'imgs/b' / m.name))
#         #break
#     #print(i.sum()/255)
#     #break

In [None]:
# for m in l:
#     i = Image.open(m)
#     i = np.expand_dims(np.array(i), 2).repeat(3, 2)
#     i = Image.fromarray((255*i).astype(np.uint8))
#     p = m.parent.parent / 'b'/ m.name
#     i.save(p)
#     #print(i.size)
#     #break

In [None]:
#root = Path('input/CUTS/cuts2048x25/')
root = Path('input/extra_data/external_datasetA_/')
sd = data.SegmentDataset(root / 'imgs', root / 'masks', mode_train=False)
len(sd)

In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

@interact(idx=(0, len(sd)-1),continuous_update=False)
def view(idx): return sd._view(idx)

## Scleros

In [None]:
root = Path('input/scleros_glomi/')
anns = list(root.glob('*.json'))
anns

In [None]:
import rasterio as rio

In [None]:
idx=0
n = anns[idx].stem
img_name = f'input/hm/test/{n}.tiff'
img = rio.open(img_name)
mask_arr = np.zeros(img.shape, dtype=np.uint8)

In [None]:
j = utils.jread(anns[idx])
polys = [utils.json_record_to_poly(r)[0] for r in j]

In [None]:
for poly in polys:
    poly_pts = poly.exterior.xy
    poly_pts = np.expand_dims(np.array(poly_pts).astype(np.int32).T,0)
    cv2.fillPoly(mask_arr, poly_pts, 255)

In [None]:
dst_path =  Path(f'input/scleros_glomi/scle_masks/{n}_sc.tiff')
utils.save_tiff_uint8_single_band(mask_arr, dst_path)

In [None]:
# dst_path =  Path(f'input/scleros_glomi/scle_masks/b9a3865fc.tiff')
# img_name = f'input/hm/train/b9a3865fc.tiff'

In [None]:

# merge_name = dst_path.parent / (dst_path.stem + '_merge.tiff')
# utils.tiff_merge_mask(img_name, dst_path, merge_name)

In [None]:
dst_path

In [None]:
name = 'd488c759a'

In [None]:
m1 = rio.open(f'input/bigmasks/{name}.tiff').read()

In [None]:
m2 = rio.open(f'input/scleros_glomi/scle_masks/{name}_sc.tiff').read()

In [None]:
m1.max(), m1.shape,  m1.dtype, m2.max(), m2.shape, m2.dtype 

In [None]:
m = m1 + m2
m.shape, m.max()

In [None]:
utils.save_tiff_uint8_single_band((m[0]*255).astype(np.uint8), f'input/scleros_glomi/scle_masks/{name}.tiff')

## Datasets

In [None]:
root = Path('input/CUTS/cuts_B_1536x25/')
sd = data.BorderSegmentDataset(root, mode_train=False)
len(sd)

In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

@interact(idx=(0, len(sd)-1),continuous_update=False)
def view(idx): return sd._view(idx, border=True)

In [None]:
root = Path('input/SPLITS/split1024x25/train/')
sd = data.SegmentDataset(root / 'imgs', root / 'masks', mode_train=False)
len(sd)

In [None]:
tds = data.TagSegmentDataset(root / 'imgs', root / 'masks', mode_train=True)
len(tds)

In [None]:
tot = 0
for _,(_,c) in tds:
    tot+=c

In [None]:
tot/len(tds)

## Weighted Datasets

In [None]:
import scipy
import json
import shapely
import pickle
import rasterio as rio

In [None]:
def gpoly(poly):
    x,y = poly.exterior.xy
    x, y = np.array(x), np.array(y)
    return x, y

def get_poly_raster(img_name, p):
    x1,y1,x2,y2 = p.bounds
    ds = rio.open(str(img_name))
    rast = ds.read(window=((y1,y2),(x1,x2)))
    return rast.transpose(1,2,0)

def get_poly_mask(p):
    p = (p - p.centroid)
    x,y = gpoly(p)
    x-=x.mean()
    y-=y.mean()
    x+=abs(x.min())
    y+=abs(y.min())
    p = shapely.geometry.Polygon(np.array([x,y]).T)
    x1,y1,x2,y2 = p.bounds
    buf = np.zeros((int(y2-y1),int(x2-x1), 3)).astype(np.uint8)
    
    pts = np.array(p.exterior.xy).T.astype(np.int32)
    cv2.fillPoly(buf, np.expand_dims(pts,0), (255,255,255));
    
    return buf

def gen_stats(img_name, polys):
    for p in tqdm(polys):
        mask = get_poly_mask(p)
        rast = get_poly_raster(img_name, p)
        gpixels = rast[mask>0]
        area, color = p.area**.5, gpixels.mean()
        yield area, color

def get_poly_scores(img_name, polys):
    all_stats = np.array(list(gen_stats(img_name, polys)))
    areas, colors = all_stats[:,0], all_stats[:,1]
    
    scores = []
    for a, c in all_stats:
        p1 = abs(scipy.stats.percentileofscore(areas, a) / 100 - .5)
        p2 = abs(scipy.stats.percentileofscore(colors, c) / 100 - .5)
        p1 *= 2
        p2 *= 2
        score = p2 #p1 * p2
        scores.append(score)
    return np.array(scores)

def read_polys(data):
    polys = []
    for d in data:
        cd = d['geometry']['coordinates'][0]
        try:
            poly = shapely.geometry.Polygon(cd)
        except Exception as e:
            print(e, d)       
        polys.append(poly)
    return polys

def get_polys(img_name):
    js = img_name.with_suffix('.json')   
    with open(js, 'r') as f:
        polys = read_polys(json.load(f))
    return polys

class WeightedDataset:
    def __init__(self, dataset, scores, replacement=True):
        assert len(dataset) == len(scores), (len(dataset), len(scores))
        scores = ( 1+ scores) ** 2
        self.scores = scores / scores.sum()
        self.dataset = dataset
        self.replacement = replacement
        self.idxs = list(range(len(self.dataset)))
        #self.sampler = WeightedRandomSampler(self.scores, len(self.scores), replacement=replacement)
    
    def __getitem__(self, _):
        num_samples = 1
        idx = np.random.choice(self.idxs, num_samples, self.replacement, self.scores)
        #idxs = list(self.sampler)
        print(idx)
        return self.dataset[idx[0]]
        
    def __len__(self): return len(self.dataset)

In [None]:
imgs = list(Path('input/hm/train/').glob('*.tiff'))
res = {}
for img_name in imgs:
    polys = get_polys(img_name)
    scores = get_poly_scores(img_name, polys)
    res[img_name.stem] = scores

In [None]:
with open('scores_color.pkl', 'wb') as f:
    pickle.dump(res, f)

In [None]:
def IA(axs):
    for i in range(len(axs)):
        for j in range(len(axs[0])):
            yield axs[i][j]

In [None]:
f, axs = plt.subplots(4,4, sharex=True, sharey=True, figsize=(16,16))
ia = IA(axs)

for k,v in res.items():
    ax = next(ia)
    ax.hist(v, bins=50);

In [None]:
ds = data.ImageDataset('input/CUTS/cuts2048x25/imgs/afa5e8098/', "*.png")

In [None]:
wds = WeightedDataset(dataset=ds, scores=scores)

In [None]:
wds[0]

## SSL Dataset

In [None]:

import _data
import albumentations as albu

In [None]:
datasets = data.build_datasets(cfg, dataset_types=['TRAIN','VALID', 'SSL'])
sds = datasets['SSL']
len(sds)

In [None]:
a,b = sds[1]

In [None]:
da = denorm(a, mean=cfg.TRANSFORMERS.MEAN, std=cfg.TRANSFORMERS.STD)
da = da.numpy().squeeze().transpose(1,2,0) * 255
da = da.astype(np.uint8)
db = denorm(b, mean=cfg.TRANSFORMERS.MEAN, std=cfg.TRANSFORMERS.STD)
db = db.numpy().squeeze().transpose(1,2,0) * 255
db = db.astype(np.uint8)

a.shape, da.shape

In [None]:
Image.fromarray(da)

In [None]:
Image.fromarray(db)

In [None]:
import torch

In [None]:
dl = torch.utils.data.DataLoader(sds)

In [None]:
i = iter(dl)

In [None]:
a,b = next(i)

In [None]:
a.shape

In [None]:
mask = (a > 0.5) + (a < 0)
mask.shape

In [None]:
a[mask]

In [None]:
ds = data.ImageDataset('input/backs_x25_medula/imgs/', '*/*.png')
ds.process_item = _data.expander

In [None]:
a = augs.get_aug('light_scale', cfg.TRANSFORMERS)

In [None]:
class TransformSSLDataset:
    def __init__(self, dataset, transforms, is_masked=False):
        self.dataset = dataset
        self.transforms = albu.Compose([]) if transforms is None else transforms
        self.is_masked = is_masked
    
    def __getitem__(self, idx):
        i = self.dataset.__getitem__(idx)
        return i, self.transforms(image=i, mask=None)['image']
    
    def __len__(self): return len(self.dataset)
    
class SSLDataset:
    def __init__(self, dataset, crop_size):
        #TODO scale?
        self.d4 = albu.Compose([albu.RandomCrop(*crop_size), albu.Flip(), albu.RandomRotate90()])
        self.dataset = dataset
    
    def __getitem__(self, idx):
        i = self.dataset[idx]
        i = self.d4(image=i)['image']
        return i

In [None]:
crop_size = (256,256)
tds = SSLDataset(ds, crop_size)

In [None]:
a = tds[0]
Image.fromarray(a).resize((512,512))

In [None]:
Image.fromarray(b).resize((512,512))

## Dataloaders

In [None]:
from config import cfg, cfg_init
from pprint import pprint

from callbacks import  denorm

In [None]:
cfg_init('src/configs/unet_gelb.yaml')
cfg['TRANSFORMERS']['TRAIN']['AUG'] = 'light_scale'

cfg['PARALLEL']['DDP'] = False
cfg['DATA']['TRAIN']['PRELOAD'] = False
cfg['DATA']['TRAIN']['MULTIPLY']["rate"] = 2
#cfg['DATA']['TRAIN']['DATASETS'] = ['train1024x25']
#cfg['VALID']['BATCH_SIZE'] = 4

In [None]:
pprint(cfg)

In [None]:
datasets = data.build_datasets(cfg, dataset_types=['TRAIN','VALID'])
tds = datasets['TRAIN']
vds = datasets['VALID']
len(tds)

In [None]:
dls = data.build_dataloaders(cfg, datasets)

In [None]:
a,b,c = tds[0]

In [None]:
plt.imshow(b.squeeze())

In [None]:
plt.imshow(c.squeeze())

In [None]:
def show_img(tds, idx):
    img, mask = tds[idx]
    print(img.shape, img.dtype)
    img = denorm(img, cfg.TRANSFORMERS.MEAN, cfg.TRANSFORMERS.STD)
    img = img.squeeze().permute(1,2,0).cpu().numpy()
    img = (img * 255.).astype(np.uint8)
    return Image.fromarray(img)

In [None]:
show_img(tds, 0)

In [None]:
import random
from _data import make_datasets_folds

In [None]:
N_FOLDS = 4
datasets_as_folds = make_datasets_folds(cfg, datasets, N_FOLDS, shuffle=False)

In [None]:
datasets_as_folds

In [None]:
for dss in datasets_as_folds:
    dls = data.build_dataloaders(cfg, dss, pin=True, drop_last=False)
    tdl = dls['TRAIN']
    print(tdl, len(tdl))
    for b in tdl:
        pass

In [None]:
tot = 0
for dss in datasets_as_folds:
    for k, v in dss.items():
        print(len(v))
        if k == 'TRAIN':
            tot += len(v)

# MODEL TESTS

In [None]:
def pimg(img):
    img = img.squeeze().cpu().numpy()
    img = (img * 255.).astype(np.uint8)
    return Image.fromarray(img)

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch


In [None]:
from model import load_model, FoldModel
mps = Path('output/2021_Apr_02_10_43_49_PAMBUH/').rglob('*.pth')
mps = sorted(list(mps))
mps

In [None]:
import segmentation_models_pytorch as smp
from collections import OrderedDict
from loss import dice_loss

In [None]:
m1 = smp.MAnet(encoder_name='se_resnet50')#MAnet(encoder_name='timm-res2net50_26w_4s')
#m1 = smp.MAnet()
#m1 = smp.MAnet(encoder_name='timm-efficientnet-b4')

In [None]:
state_dict = torch.load(str(mps[-1]))['model_state']
# new_state_dict = OrderedDict()
# for k, v in state_dict.items():
#     if k.startswith('module'):
#         k = k.lstrip('module')[1:]
#         new_state_dict[k] = v


In [None]:
m1.load_state_dict(state_dict)

In [None]:
m1 = m1.cuda()
m1 = m1.eval()

In [None]:
thrs1 = np.logspace(0, 1, num=10)/100

thrs2 = 1-np.logspace(0, 1, num=10)/100
thrs2 = thrs2[::-1]

thrs = np.concatenate([thrs1, np.arange(.2,.9, .05), thrs2])
thrs

In [None]:
import ttach
m2 = ttach.SegmentationTTAWrapper(m1, ttach.aliases.d4_transform())

In [None]:
x,y = tds[0]
x = x#[:,:256,:256]
x = x.view(1, *x.shape).cuda()
p = m2(x)
p.sigmoid().max()

In [None]:
t = p[0,0].detach().cpu().numpy()
t.shape

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(t)

In [None]:
plt.hist(t.ravel(), bins=50);

# DICE CV

In [None]:
dices = []
cnt = 0
for x,y in dls['VALID']:
    #x = x.view(1, *x.shape).cuda()
    with torch.no_grad():
        pred = m1(x.cuda()).sigmoid().cpu()
    #pred = (pred > .5).float()
    preds = [pred > thr for thr in thrs]
    #dice = dice_loss(pred, y)
    dice = [dice_loss(p, y) for p in preds]
    dices.append(dice)
dices = np.array(dices)    

In [None]:
dices_mean = dices.mean(0)
dices_mean

In [None]:
dices_mean.max()

In [None]:
plt.plot(dices_mean)

In [None]:
thrs[np.argmax(dices_mean)]

In [None]:
dices_mean

# Edge

In [None]:
import torch
from loss import EdgeLoss

import matplotlib.pyplot  as plt
%matplotlib inline

In [None]:
datasets

In [None]:
dls = data.build_dataloaders(cfg, datasets)

In [None]:
dl = dls['VALID']
idl = iter(dl)

In [None]:
xb,yb = next(idl)
xb.shape, yb.shape

In [None]:
pb = m1(xb.cuda()).sigmoid()
yb = yb.cuda()

In [None]:
yb = yb[:3,:,:64,:64]
pb = pb[:3,:,:64,:64]
pb.shape, pb.max()

# RESAVE

In [None]:
import rasterio as rio

In [None]:
path = Path('input/hm/train/')
imgs = list(path.glob('*.tiff'))
imgs

In [None]:
def save_tiff_uint8_3_band(img, path):
    assert img.dtype == np.uint8
    if img.max() <= 1. : print(f"Warning: saving tiff with max value is <= 1, {path}")
    _, h, w = img.shape
    dst = rio.open(path, 'w', driver='GTiff', height=h, width=w, count=3, dtype=np.uint8, interleave='band', compress='deflate')
    dst.write(img)
    dst.close()
    del dst

In [None]:
for img in imgs:
    ds = rio.open(str(img))
    break
    if ds.count == 1:
        print(f'Single channel: {img}')
        dss = ds.subdatasets
        i = np.zeros((3, *ds.shape), dtype=np.uint8)
        for j, ds in enumerate(dss):
            ds = rio.open(ds)
            i[j]  = ds.read()
        new_name = str(img.parent) + f'/upd/{img.name}'
        save_tiff_uint8_3_band(i, new_name)
    else:
        print(f'3 channels {img}')
        i = ds.read()
        
    print(i.mean((1,2)), i.std((1,2)))
    #break

In [None]:
tds = rio.open('input/hm/test/aa05346ff.tiff')

In [None]:
path = Path('input/hm/train/')
imgs = list(path.glob('*.tiff'))
imgs

In [None]:
b9a3865fc .65
e79de561c .5
8242609fa .65
cb2d976f4 .65
b2dc8411c .65
afa5e8098 .5
0486052bb .65
aaa6a05cc .65
2f6ecfcdf .65

In [None]:
for i in imgs:
    ds = rio.open(str(i))
    #print(i.name, ds.tags())