In [1]:
import os

import tqdm
import lmdb

import numpy as np

from scipy.misc import imread, imsave, imresize

In [2]:
def labeled_img_preprocess_binary_case(image):
    mask = np.where(image < 128, np.zeros_like(image), np.ones_like(image))
    mask = mask.astype(np.int8)
    return mask

In [3]:
def input_img_preprocess(image):
    image = image.astype(np.float32)
    image = image / 255 - 0.5
    return image

In [4]:
def shapes_preprocess(img):
    if img.ndim == 2:
        img = img.reshape(img.shape + (1, ))
    img = img.reshape((1, ) + img.shape) 
    img = img.transpose(0, 3, 1, 2)
    return img

In [5]:
def get_valid_patches(img_shape, patch_size, central_points):
    start = central_points - patch_size / 2
    end = start + patch_size
    
    mask = np.logical_and(start >= 0, end < np.array(img_shape))
    mask = np.all(mask, axis=-1)
    
    return mask


def extract_patches(input_img, labeled_img, n_classes=2, patch_size=64, class_patches_number=100):
    X = []
    Y = []
    
    for label in range(n_classes):
        positions = np.argwhere(labeled_img == label)
        
        accepted_patches_mask = get_valid_patches(labeled_img.shape, patch_size, positions)
        positions = positions[accepted_patches_mask][:class_patches_number]
        np.random.shuffle(positions)
        
        for position in positions:
            start = position - patch_size / 2
            end = start + patch_size
            
            x = shapes_preprocess(
                input_img[start[0]:end[0], start[1]:end[1]]
            )
            y = shapes_preprocess(
                labeled_img[start[0]:end[0], start[1]:end[1]]
            )
            X.append(x)
            Y.append(y)
        
    return X, Y

In [6]:
def get_imgs_from_folder(fpath):
    stack = []
    for fname in sorted(os.listdir(fpath)):
        img = imread(os.path.join(fpath, fname))
        stack.append(img)
    return stack

In [7]:
def get_data(path_X, path_Y, n_classes, patch_size, class_patches_number):
    X = get_imgs_from_folder(path_X)
    Y = get_imgs_from_folder(path_Y)
    
    patches_X, patches_Y = [], []
    for x, y in tqdm.tqdm_notebook(zip(X, Y)):
        x = input_img_preprocess(x)
        y = labeled_img_preprocess_binary_case(y)
        subpatches_X, subpatches_Y = extract_patches(x, y, n_classes, patch_size, class_patches_number)
        patches_X += subpatches_X
        patches_Y += subpatches_Y
        
    X = np.concatenate(patches_X)
    Y = np.concatenate(patches_Y)
        
    Y = Y.reshape(Y.shape[0], -1)
    return X, Y

In [12]:
def create_storage(name, data):
    N = data.shape[0]
    map_size = 4 * data.nbytes
    env = lmdb.open(name, map_size=map_size)

    with env.begin(write=True) as txn:
        for i in range(N):
            str_id = '{:08}'.format(i)
            txn.put(str_id.encode('ascii'), data[i].tobytes())
    return N

In [None]:
def prepare_data(n_classes, patch_size, random_state, class_patches_number):
    np.random.seed(random_state)
    
    path_X = '/home/efim/study/10 semester/course work/all_data/binary_data/ceramics/NLM'
    path_Y = '/home/efim/study/10 semester/course work/all_data/binary_data/ceramics/CAC'
    
    X, Y = get_data(path_X, path_Y, n_classes, patch_size, class_patches_number)
    
    create_storage('input_images.lmdb', X)
    N = create_storage('labeled_images.lmdb', Y)
    
    return N