# Setup data and model

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np

from abyss_deep_learning.utils import config_gpu
config_gpu(gpu_ids=[0], allow_growth=True, log_device_placement=True)

from abyss_deep_learning.datasets.simulated import shapes_gen
from abyss_deep_learning.keras.classification import batching_gen, onehot_gen
from abyss_deep_learning.keras.utils import gen_dump_data, lambda_gen
from keras.backend import clear_session


In [None]:
def dataset_adaptor(gen, expand_dims=False):
    from abyss_deep_learning.utils import instance_to_categorical

    for image, name, instances, cats in gen:
        row = (
            (image.astype(np.float32) - 127.5) / 127.5,
            instance_to_categorical(instances, cats, num_classes=4))
        if expand_dims:
            row  = tuple(np.expand_dims(element, 0) for element in row)
        yield row
        
def unshift_image(image):
    return (image * 127.5 + 127.5).astype(np.uint8)

def example_image(model=None):
    for image, targets in dataset_adaptor(shapes_gen(scale=10, max_shapes=5, nms=0.5, noise=10)):
        print(np.min(image), np.max(image))
        plt.figure()
        plt.subplot(1, 2, 1)
        plt.imshow(unshift_image(image))
        plt.subplot(1, 2, 2)
        if model:
            targets = model.predict_proba(image[np.newaxis, ...])[0]
        print(np.unique(targets.argmax(-1)))
        plt.imshow(targets.argmax(-1))
        break
example_image()

# Test training

In [None]:
from abyss_deep_learning.keras.models import FcnCrfSegmenter
from keras.callbacks import EarlyStopping, ReduceLROnPlateau

batch_size = 1 # MUST BE 1 for FcnCrf

def create_new_model():
    '''Change init_lr if necessary'''
    from keras.utils import get_file
    from keras_applications.vgg16 import WEIGHTS_PATH_NO_TOP
    model = None  # Clear any existing models
    clear_session()
    model = FcnCrfSegmenter(classes=4, crf_iterations=5, init_lr=5e-5)
    weights_path = get_file(
                'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',
                WEIGHTS_PATH_NO_TOP,
                cache_subdir='models',
                file_hash='6d6bbae143d832006294945121d1f1fc')
    model.set_weights(weights_path)
    
    callbacks = [
        ReduceLROnPlateau(patience=5, factor=0.5, cooldown=5, verbose=1),
        EarlyStopping(patience=15, verbose=1, restore_best_weights=True)
    ]
    return model, callbacks

## Fit: batch method

In [None]:
# First create the generators that we will pull the data from.
gen_train = dataset_adaptor(shapes_gen(scale=10, max_shapes=3, nms=0.3, noise=10), expand_dims=False)
gen_val = dataset_adaptor(shapes_gen(scale=10, max_shapes=3, nms=0.5, noise=15), expand_dims=False)
# Dump data from the generators
x_train, y_train = gen_dump_data(gen_train, 50)
validation_data = gen_dump_data(gen_val, 10)
model = None
model, callbacks = create_new_model()
print("Random output loss is", -np.log(1 / model.classes))

In [None]:
model.set_trainable(True)
model.fit(
    x_train, y_train,
    validation_data=validation_data,
    batch_size=batch_size, epochs=100,
    callbacks=callbacks)
example_image(model)

In [None]:
# Should you want to train only parts of the model
model.set_trainable('crf')
model.recompile()
model.set_lr(5e-2) # CRF only
model.fit(
    x_train, y_train,
    validation_data=validation_data,
    batch_size=batch_size, epochs=10,
    callbacks=callbacks)
example_image(model)

In [None]:
del x_train, y_train, validation_data

## Fit: generator method

In [None]:
# First create the generators that we will pull the data from.
# Requires expand_dims=True
gen_train = dataset_adaptor(shapes_gen(scale=10, max_shapes=3, nms=0.3, noise=10), expand_dims=True)
gen_val = dataset_adaptor(shapes_gen(scale=10, max_shapes=3, nms=0.5, noise=15), expand_dims=True)
model = None
model, callbacks = create_new_model()

print("Random output loss is", -np.log(1 / model.classes))
model.fit_generator(
    gen_train,
    validation_data=gen_val,
    steps_per_epoch=50, validation_steps=10,
    epochs=100, use_multiprocessing=True,
    verbose=True)

## Fit: dataset method

In [None]:
from abyss_deep_learning.datasets.misc import CachedGenClassificationDataset
# Make two datasets (image data, classification task) that dumps data from gen_train and gen_val
# and makes it available via the standard abyss Dataset API calls.

dataset_train = CachedGenClassificationDataset(gen_train, n_samples=50)
dataset_val = CachedGenClassificationDataset(gen_val, n_samples=20)
for image, target in dataset_train.generator():
    print(image.shape, image.dtype, target.shape, target.dtype)
    plt.figure()
    plt.subplot(1, 2, 1)
    plt.imshow(unshift_image(image[0]))
    plt.subplot(1, 2, 2)
    plt.imshow(target[0].argmax(-1))
    break

In [None]:
model = None
model, callbacks = create_new_model()
print("Random output loss is", -np.log(1 / model.classes))
model.fit_dataset(
    dataset_train, dataset_val=dataset_val,
    steps_per_epoch=50, validation_steps=10,
    epochs=100, use_multiprocessing=True,
    verbose=True)

# Test serialization

In [None]:
prob1 = model.predict_proba(image)
model.save("/tmp/abcd")
model = FcnCrfSegmenter.load("/tmp/abcd")
prob2 = model.predict_proba(image)
plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(prob1[0, ...].argmax(-1)*255)
plt.subplot(1, 2, 2)
plt.imshow(prob2[0, ...].argmax(-1)*255)

# !rm "/tmp/abcd"
print("Testing serialization: [{}]".format(np.allclose(prob1, prob2)))

In [None]:
model.model_.summary()

In [None]:
layer = model.model_.get_layer('upsample')
layer.weights
# plt.figure()

# plt.imshow(layer.get_weights()[0]