<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc" style="margin-top: 1em;"><ul class="toc-item"></ul></div>

In [None]:
from skimage import exposure, filters, io
from matplotlib import pyplot as plt
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
import skimage.morphology as sm
import random

%matplotlib inline

In [None]:
test_dir = '/aiml/data/test/'
batch_size = 64
pix = 64

In [None]:
def plots(ims, interp=False, titles=None):
    ims=np.array(ims)
    mn,mx=ims.min(),ims.max()
    f = plt.figure(figsize=(12,24))
    for i in range(len(ims)):
        sp=f.add_subplot(1, len(ims), i+1)
        if not titles is None: sp.set_title(titles[i], fontsize=18)
        plt.imshow(ims[i], interpolation=None if interp else 'none', vmin=mn,vmax=mx)

def plot(im, interp=False):
    f = plt.figure(figsize=(3,6), frameon=True)
    plt.imshow(im, interpolation=None if interp else 'none')

plt.gray()
plt.close()

In [None]:
# %load -r 108-162 /aiml/code/python_code/prepare_data.py
global_config = {'gaussian_train_max':  -1, 
                 'gaussian_train_min':  -1, 
                 'gaussian_test': -1, 
                 'dilation_train_max': -1,
                 'dilation_train_min': -1,
                 'dilation_test': -1, 
                 'threshold_otsu': True, 
                 'rescale_intensity': True,
                 'norm_input': False}

def preprocess_fun(x, gaussian_sigma, dilation_square, threshold_otsu, rescale_intensity, norm_input):
    if norm_input:
        std_px = 63.556923
        mean_px = 222.517471
        x = x - mean_px / std_px
    if threshold_otsu:
        thresh = filters.threshold_otsu(x) #返回一个阈值
        x = (x >= thresh )* 1.0 #根据阈值进行分割
    if gaussian_sigma > 0:
        x = exposure.rescale_intensity(x)
        x = filters.gaussian(x, sigma=gaussian_sigma)
    if dilation_square > 0:
        x = sm.dilation(x.reshape(64, 64), sm.square(dilation_square))
        x = x.reshape(64, 64, 1)
    if rescale_intensity:
        x = exposure.rescale_intensity(x)
    return x

def train_preprocess(x):
    global global_config
    gaussian_train_max = global_config['gaussian_train_max']
    gaussian_train_min = global_config['gaussian_train_min']
    gaussian_test = global_config['gaussian_test']
    dilation_train_max = global_config['dilation_train_max']
    dilation_train_min = global_config['dilation_train_min']
    dilation_test = global_config['dilation_test']
    rescale_intensity = global_config['rescale_intensity']
    threshold_otsu = global_config['threshold_otsu']
    norm_input = global_config['norm_input']

    if gaussian_train_max > 0:
        gaussian_sigma = random.random() * gaussian_train_max
    else:
        gaussian_sigma = gaussian_test

    if dilation_train_max > 0:
        dilation_square = random.randint(dilation_train_min, dilation_train_max)
    else:
        dilation_square = dilation_test

    x = preprocess_fun(x, gaussian_sigma, dilation_square,
                       threshold_otsu, rescale_intensity, norm_input)

    return x

def test_preprocess(x):
    global global_config
    gaussian_test = global_config['gaussian_test']
    dilation_test = global_config['dilation_test']
    threshold_otsu = global_config['threshold_otsu']
    rescale_intensity = global_config['rescale_intensity']
    norm_input = global_config['norm_input']

    x = preprocess_fun(x, gaussian_test, dilation_test,
                       threshold_otsu, rescale_intensity, norm_input)
    return x


In [None]:
test_data = ImageDataGenerator(preprocessing_function=train_preprocess)
test_gen = test_data.flow_from_directory(test_dir, color_mode='grayscale', target_size=(pix, pix), batch_size=batch_size)

In [None]:
images = test_gen.next()[0]
images = images.reshape(batch_size, pix, pix)
n = 0

In [None]:
n += 5
plots(images[n:n+5])

In [None]:
for image in images[n:n+5]:
    print(np.min(image), np.max(image))

膨胀的效果演示

In [None]:
plots([sm.dilation(x.reshape(64, 64), sm.square(2)) for x in images[n:n+5]])

高斯滤波效果演示

In [None]:
plots([filters.gaussian(exposure.rescale_intensity(x), sigma=1) for x in images[n:n+5]])