# Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
os.chdir('..')
sys.path.append('src')

In [None]:
from PIL import Image
from pathlib import Path
from functools import partial, reduce
from collections import defaultdict
import multiprocessing as mp
from contextlib import contextmanager

import cv2
import numpy as np
from tqdm.auto import tqdm


import utils
from sampler import GdalSampler

# Code

## Narezator

In [None]:
@contextmanager
def poolcontext(*args, **kwargs):
    pool = mp.Pool(*args, **kwargs)
    yield pool
    pool.terminate()
    
def mp_func(foo, args, n):
    args_chunks = [args[i:i + n] for i in range(0, len(args), n)]
    with poolcontext(processes=n) as pool:
        pool.map(foo, args_chunks)
    
def mp_foo(foo, args): return foo(*args)

In [None]:
def mp_sampler(dst, i_fn, a_fn, wh, wh_mask, idxs):
    s = GdalSampler(i_fn, a_fn, wh, wh_mask)
    for idx in idxs:
        i,m = s[idx]  
        
        img_dir = dst / 'imgs' / i_fn.with_suffix('').name
        os.makedirs(str(img_dir), exist_ok=True)
        
        mask_dir = dst / 'masks' / i_fn.with_suffix('').name
        os.makedirs(str(mask_dir), exist_ok=True)
        
        orig_name = (str(idx) + '.png')
        img_name = img_dir / orig_name 
        mask_name = mask_dir /orig_name
        
        cv2.imwrite(str(img_name), i.transpose(1,2,0))
        cv2.imwrite(str(mask_name), np.expand_dims(m,-1).repeat(3,-1))

In [None]:
p = Path('input/hm/train')
#p = Path('/home/sokolov/work/webinf/data/kidney/train/')
dst_path = Path('input/train')
NUM_PROC = 16
wh = (2048, 2048)

In [None]:
filt = partial(utils.filter_ban_str_in_name, bans=['-', '_ell'])
ann_fns = utils.get_filenames(p, '*.json', filt)
img_fns = [a.with_suffix('.tiff') for a in ann_fns]
img_fns[0], ann_fns[0]

In [None]:
#assert  False , 'DO ONCE'
# for i_fn,a_fn in tqdm(zip(img_fns, ann_fns)):
#     const_args = i_fn, a_fn, wh, wh
#     _s = GdalSampler(*const_args)
#     part_samp = partial(mp_sampler, *(dst_path, *const_args))
#     mp_func(part_samp, range(len(_s)), NUM_PROC)

## Datasets

### General ones

In [None]:
class Dataset:
    def __init__(self, root, pattern):
        self.root = Path(root)
        self.pattern = pattern
        self.files = sorted(list(self.root.glob(self.pattern)))
        self._is_empty('There is no matching files!')
        
    def apply_filter(self, filter_fn):
        self.files = filter_fn(self.files)
        self._is_empty()

    def _is_empty(self, msg='There is no item in dataset!'): assert len(self.files) > 0
    def __len__(self): return len(self.files)
    def __getitem__(self, idx): return self.process_item(self.load_item(idx))
    def load_item(self, idx): raise NotImplementedError
    def process_item(self, item): return item
    
class ImageDataset(Dataset):
    def load_item(self, idx):
        img_path = self.files[idx]
        img = Image.open(str(img_path))
        return img
    
class PairDataset:
    def __init__(self, ds1, ds2):
        self.ds1, self.ds2 = ds1, ds2
        self.check_len()
    
    def __len__(self): return len(self.ds1)
    def check_len(self): assert len(self.ds1) == len(self.ds2)
    
    def __getitem__(self, idx):
        return self.ds1.__getitem__(idx), self.ds2.__getitem__(idx) 

class ConcatDataset:
    """
    To avoid recursive calls (like in torchvision variant)
    """
    def __init__(self, dss):
        self.length = 0
        self.ds_map = {}
        for i, ds in enumerate(dss):
            for j in range(len(ds)):
                self.ds_map[j+self.length] = i, self.length
            self.length += len(ds)
        self.dss = dss
    
    def load_item(self, idx):
        if idx >= self.__len__(): raise StopIteration
        ds_idx, local_idx = self.ds_map[idx]
        return self.dss[ds_idx].__getitem__(idx - local_idx)
    
    def _is_empty(self, msg='There is no item in dataset!'): assert len(self.files) > 0
    def __len__(self): return self.length
    def __getitem__(self, idx): return self.load_item(idx)

def expander(x):
    x = np.array(x)
    return x if len(x.shape) == 3 else np.repeat(np.expand_dims(x, axis=-1), 3, -1)

### Pam specific

In [None]:
class SegmentDataset:
    def __init__(self, imgs_path, masks_path, mode_train=True):
        self.img_folders = utils.get_filenames(imgs_path, '*', lambda x: False)
        self.masks_folders = utils.get_filenames(masks_path, '*', lambda x: False)
        self.mode_train = mode_train
        
        dss = []
        for imgf, maskf in zip(self.img_folders, self.masks_folders):
            ids = ImageDataset(imgf, '*.png')
            mds = ImageDataset(maskf, '*.png')
            if self.mode_train:
                ids.process_item = expander
                mds.process_item = expander
            dss.append(PairDataset(ids, mds))
        
        self.dataset = ConcatDataset(dss)
    
    def __len__(self): return len(self.dataset)
    def __getitem__(self, idx): return self.dataset[idx]
    def _view(self, idx):
        pair = self.__getitem__(idx)
        return Image.blend(*pair,.5)
    
def build_datasets(mode_train=True):
    root = Path('input/train/1024')
    sd = SegmentDataset(root / 'imgs', root / 'masks', mode_train=mode_train)
    return {'TRAIN':sd}

In [None]:
root = Path('input/train/1024')
sd = SegmentDataset(root / 'imgs', root / 'masks', mode_train=False)
len(sd)

In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

@interact(idx=(0, len(sd)),continuous_update=False)
def view(idx): return sd._view(idx)

## Dataloaders

In [None]:
def create_dataloader(dataset, sampler, shuffle, batch_size, num_workers, drop_last, pin):
    dl = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin,
        drop_last=drop_last,
        collate_fn=None,
        sampler=sampler,
    )
    return dl

In [None]:
def build_dataloaders(datasets, samplers=None, batch_sizes=None, num_workers=1, drop_last=False, pin=False):
    dls = {}
    for kind, dataset in datasets.items():
        sampler = samplers[kind]    
        shuffle = kind == 'TRAIN' if sampler is None else False
        batch_size = batch_sizes[kind] if batch_sizes[kind] is not None else 1
        dls[kind] = create_dataloader(dataset, sampler, shuffle, batch_size, num_workers, drop_last, pin)
            
    return dls

In [None]:
import torch
from torch.utils.data import DataLoader

In [None]:
datasets = build_datasets()
dls = build_dataloaders(datasets,samplers={'TRAIN':None}, batch_sizes={'TRAIN':32}, num_workers=4, pin=True, drop_last=False)

In [None]:
tdl = dls['TRAIN']
for b in tdl:
    break

In [None]:
b[0].shape, b[0].dtype, b[1].shape, b[1].dtype