In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import joblib
from scipy.misc import imread, imsave, imresize
%matplotlib inline

In [2]:
patch_size = 64
random_state_ = 42

n_classes = 4
marking_colors = np.array([[14, 209, 69], [255, 127, 39], [136, 0, 27]]) # n_classes - 1
class_pixels_density = [0.23823693,  0.00462145,  0.75468649,  0.00245512] # n_classes 
class_pixels_to_take = np.array([0.5, 2, 0.25, 1]) / sum([0.5, 2, 0.25, 1]) # n_classess predefined

In [3]:
marking_color = [32, 192, 64]
def ans_preprocess(image):
    mask = np.ones((image.shape[0], image.shape[1]))
    if len(image.shape) == 2:
        return np.zeros_like(image)
    for i in range(3):
        mask = np.logical_and(mask, image[:, :, i] == marking_color[i])
    return np.where(mask, 255, 0)

In [4]:
def grayscale_measure_mask(image):
    return ((image[:, :, 0:1] - image[:, :, 1:2]) ** 2 
            + (image[:, :, 1:2] - image[:, :, 2:3]) ** 2
            + (image[:, :, 2:3] - image[:, :, 1:2]) ** 2)

def ans_preprocess(image):
    masks = [grayscale_measure_mask(image)]
    for i in range(n_classes - 1):
        masks.append(((image - marking_colors[i]) ** 2 ).sum(axis=2, keepdims=True))
        
    masks = np.concatenate(masks, axis=2)
    masks = np.argmin(masks, axis=2).astype(np.uint8)
    return masks 

In [5]:
def image_augmentation(image):
    reflections = [image, 
                   np.flip(image, 0), 
                   np.flip(image, 1), 
                   np.flip(np.flip(image, 1), 0)]
    
    augmentation = []
    for image in reflections:
        for k in range(4):
            augmentation.append(np.rot90(image, k, (0, 1)))
    
    return np.array(augmentation)

In [6]:
def get_valid_patches(img_shape, patch_size, central_points):
    start = central_points - patch_size / 2
    end = start + patch_size
    
    mask = np.logical_and(start >= 0, end < np.array(img_shape))
    mask = np.all(mask, axis=-1)
    
    return mask

def get_patches_proportion(Y_dir_name):
    for fname in listdir(Y_dir_name):
        y = imread(os.path.join(Y_dir_name, fname))
        for label in range(n_classes):
            class_pixels = (y == label).sum()
            max_class_pixels[label] = max(max_class_pixels[label], class_pixels)

def extract_patches(img, answer, patch_size=64, average_patches_number=100):
    answer = answer.reshape(answer.shape[:2])
    
    X = []
    Y = []
    
    H = img.shape[0]
    W = img.shape[1]
       
    for label in range(n_classes):
        pos = np.argwhere(answer == label)
        
        accepted_patches_mask = get_valid_patches(answer.shape, patch_size, pos)
        pos = pos[accepted_patches_mask]
        
        np.random.shuffle(pos)
        
        class_pixels_for_image = (1.0 * pos.shape[0] 
            / (W - 2 * patch_size) 
            / (H - 2 * patch_size) 
            / class_pixels_density[label])

        n_samples = int(class_pixels_to_take[label] * average_patches_number)
        
        for i in range(min(n_samples, len(pos))):
            start = pos[i] - patch_size / 2
            end = start + patch_size
            
            X.append(img[start[0]:end[0], start[1]:end[1]])
            Y.append(answer[start[0]:end[0], start[1]:end[1]])
        
    return np.array(X), np.array(Y)

In [7]:
def patch_preproces(patches):
    patches = patches.astype(np.float32)
    patches = patches / 255 - 0.5
    patches = patches.transpose(0, 1, 2)
    return patches

In [8]:
def get_data_for_epoch(X, Y):
    X = []
    Y = []

    for fname in listdir(X_path):
        x = imread(os.path.join(X_path, fname))
        y = imread(os.path.join(Y_path, fname))
        y = ans_preprocess(y)
        
        new_x, new_y = extract_patches(x, y, patch_sizes)
        X.append(new_X)
        Y.append(new_Y)    
        
    X = np.concatenate(X)
    Y = np.concatenate(Y)

    X = (X.astype(np.float32) / 255 - 0.5).transpose(0, 3, 1, 2)
    Y = Y.reshape(Y.shape[0], -1)
    
    return X, Y

In [9]:
def preproces(patches):
    patches = patches.astype(np.float32)
    patches = patches / 255 - 0.5
    patches = patches.transpose(0, 3, 1, 2)
    return patches

In [20]:
def process_data(X, Y):
    X = X[:, :, :, np.newaxis]
    Y = Y[:, :, :, np.newaxis]
    X = preproces(X)
    Y = Y.transpose(0, 3, 1, 2).reshape(Y.shape[0], -1)
    return X, Y

In [21]:
def get_data_for_epoch(X, Y, image_for_epoch):
    X_patches = []
    Y_patches = []

    idxes = np.random.choice(X.shape[0], image_for_epoch, False)
    
    for i in tqdm(idxes):
        x, y = extract_patches(X[i], Y[i], patch_size)
        X_patches.append(x)
        Y_patches.append(y)    
        
    X_patches = np.concatenate(X_patches)
    Y_patches = np.concatenate(Y_patches)

    X_patches = (X_patches.astype(np.float32) / 255 - 0.5)[:, :, :, np.newaxis].transpose(0, 3, 1, 2)
    Y_patches = Y_patches.reshape(Y_patches.shape[0], -1)
    
    return X_patches, Y_patches