In [None]:
import os
import sys
os.chdir('..')
sys.path.append('src')

In [None]:
import os
from datetime import datetime

from pathlib import Path
import shutil
import numpy as np

import utils
from collections import defaultdict

In [None]:
def get_images(root):
    imgs = list((root/'imgs').glob('*'))
    for img in imgs:
        filenames = list(img.glob('*'))
        yield filenames
        
def get_maskname_for_img(img_name):
    im_root = img_name.parent.parent
    mask_name  = img_name.parent.parent.parent / 'masks' / img_name.relative_to(im_root)
    return mask_name

def create_split(filenames, pct=.05):
    n = int(len(filenames) * pct)
    split = np.random.choice(filenames, n, replace=False).tolist()
    main_part = [f for f in filenames if f not in split]
    return main_part,  split

def create_split_from_polys(filenames, split_names):
    a,b = [], []
    for f in filenames:
        x = b if f.name in split_names else a
        x.append(f)
    return a,b

def select_samples_from_polys(name):
    data = utils.jread(Path(f'input/split_jsons/{name}.json'))
    val_poly = utils.json_record_to_poly(data[0])[0]
    
    glo_json = Path(f'input/hm/train/{name}.json')
    val_names = []
    
    data = utils.jread(glo_json)
    cnt = 0
    for i,d in enumerate(data):
        p = utils.json_record_to_poly(d)[0]
        if val_poly.contains(p.centroid) and cnt < 20:
            cnt += 1
            val_names.append(str(i).zfill(6) + '.png')
    print(cnt)
    return val_names

def copy_split(split, root, dst_path):
    p = dst_path / split.relative_to(root)
    os.makedirs(str(p.parent), exist_ok=True)
    shutil.copy(str(split), str(p))
    
def create_save_splits(root, dst_path, split_pct=None):
    '''
        takes root folder path with 2 folders inside: imgs, masks.
        for each subfolder in imgs, masks , i.e. 1e2425f28:
            splits images in subfolder in two groups randomly by split_pct:
            split_pct = 0.05
            len(p1) == .95 * len(p)
            len(p2) == .05 * len(p)
        and saves them into dst_path WITH TIMESTAMP 
        p1 is train folder, p2 is val folder
    '''
    for img_cuts in get_images(root):
        print(img_cuts[0].parent.name)
        
        if split_pct is not None:
            print('splitting randomly by percent')
            split_imgs_1, split_imgs_2 = create_split(img_cuts, pct=val_pct)
        else:
            print('splitting by predefined polygons in input/split_jsons')
            split_names = select_samples_from_polys(img_cuts[0].parent.name)
            print('selected:', split_names)
            split_imgs_1, split_imgs_2 = create_split_from_polys(img_cuts, split_names)
            
            
        print(len(split_imgs_1), len(split_imgs_2))

        for i in split_imgs_1:
            m = get_maskname_for_img(i)
            copy_split(i, root, dst_path/'train')
            copy_split(m, root, dst_path/'train')

        for i in split_imgs_2:
            m = get_maskname_for_img(i)
            copy_split(i, root, dst_path/'val')
            copy_split(m, root, dst_path/'val')


In [None]:
root = Path('input/CUTS/cuts_B_1536x33/')
timestamp = '{:%Y_%b_%d_%H_%M_%S}'.format(datetime.now())
dst_path = root.parent.parent / f'SPLITS/split2048x25_{timestamp}/'
#val_pct = 0.05

In [None]:
#create_save_splits(root, dst_path, val_pct)
create_save_splits(root, dst_path)

# Split scores

In [None]:
import pickle

with open('input/scores_color.pkl', 'rb') as f:
    scores = pickle.load(f)

In [None]:
def split_scores(ps, scores):
    split = {}
    for p in ps:
        idxs = []
        for i in p.glob('*.png'):
            idxs.append(int(i.stem))
        split[p.stem] = scores[p.stem][idxs]
    return split

In [None]:
ps = list(Path('input/SPLITS/split2048x25/val/imgs/').glob('*'))
val_split = split_scores(ps, scores)
with open('scores__color_val.pkl', 'wb') as f:
    pickle.dump(val_split, f)

In [None]:
ps = list(Path('input/SPLITS/split2048x25/train/imgs/').glob('*'))
val_split = split_scores(ps, scores)
with open('scores_color_train.pkl', 'wb') as f:
    pickle.dump(val_split, f)

# Split images

In [None]:
from functools import partial
import shutil

In [None]:
root = Path('input/CUTS/cuts_FP_1536x33/imgs')
list(root.glob('*'))

In [None]:
split_stems = [
        ['0486052bb', 'e79de561c'],
        ['2f6ecfcdf', 'afa5e8098'],
        ['1e2425f28', '8242609fa'],
        ['cb2d976f4', 'c68fe75ea'],
        ]

In [None]:
root = Path('input/backs/grid_x50/')
timestamp = '{:%Y_%b_%d_%H_%M_%S}'.format(datetime.now())
dst = Path(f'input/backs/{timestamp}')

filt = partial(utils.filter_ban_str_in_name, bans=['-', '_ell', '_sc'])
masks_fns = sorted(utils.get_filenames(root / 'masks', '*', filt))
#borders_fns = sorted(utils.get_filenames(root / 'borders', '*', filt))
img_fns = sorted([m.parent.parent/'imgs'/m.name for m in masks_fns])
#img_fns, masks_fns

In [None]:
for split in split_stems:
    name = split[0][0] + split[1][0]
    path = dst / name
    train_path, val_path = path/'train', path/'val'
    os.makedirs(str(path))
    os.makedirs(str(train_path))
    os.makedirs(str(val_path))
    
    for imgs, masks in zip(img_fns, masks_fns):
    #for imgs, masks, borders in zip(img_fns, masks_fns, borders_fns):
    
        if imgs.stem in split:
            dst_path = val_path
        else:
            dst_path = train_path
            
        
        imgs_dst = dst_path / 'imgs' /  imgs.stem
        masks_dst = dst_path / 'masks' /  imgs.stem
        #borders_dst = dst_path / 'borders' /  imgs.stem
        #os.makedirs(str(imgs_dst))
        #os.makedirs(str(masks_dst))
        
        shutil.copytree(imgs, imgs_dst)
        shutil.copytree(masks, masks_dst)

        #shutil.copytree(borders, borders_dst)
        
        #break
    
    #break

# TEST SPLIT

In [None]:
split_stems = [
['4ef6695ce', 'b9a3865fc'],
['e79de561c', '8242609fa'],
['26dc41664', 'cb2d976f4'],
['afa5e8098', 'b2dc8411c'],
['1e2425f28', '0486052bb'],
['c68fe75ea', 'aaa6a05cc'],
['54f2eec69', '2f6ecfcdf'],
['095bf7a1f', '2f6ecfcdf']
]

In [None]:
root = Path('input/CUTS/cuts_B_1536x33/')
timestamp = '{:%Y_%b_%d_%H_%M_%S}'.format(datetime.now())

filt = partial(utils.filter_ban_str_in_name, bans=['-', '_ell', '_sc'])
masks_fns = sorted(utils.get_filenames(root / 'masks', '*', filt))
borders_fns = sorted(utils.get_filenames(root / 'borders', '*', filt))
img_fns = sorted([m.parent.parent/'imgs'/m.name for m in masks_fns])
#img_fns, masks_fns

In [None]:

for split in split_stems:
    path = Path(f'input/SPLITS/SMTH/{split[0][:2]}')
    train_path, test_path = path/'sub_data', path/'test'
    os.makedirs(str(path))
    os.makedirs(str(train_path))
    os.makedirs(str(test_path))
    
    #for imgs, masks in zip(img_fns, masks_fns):
    for imgs, masks, borders in zip(img_fns, masks_fns, borders_fns):
    

        if imgs.stem in [split[0]]:
            dst_path = test_path
        else:
            dst_path = train_path
            
        
        imgs_dst = dst_path / 'imgs' /  imgs.stem
        masks_dst = dst_path / 'masks' /  imgs.stem
        borders_dst = dst_path / 'borders' /  imgs.stem
        #os.makedirs(str(imgs_dst))
        #os.makedirs(str(masks_dst))
        
        shutil.copytree(imgs, imgs_dst)
        shutil.copytree(masks, masks_dst)

        shutil.copytree(borders, borders_dst)
        
        #break
    
    #break