# Setup data and model

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np

from abyss_deep_learning.utils import config_gpu
config_gpu(gpu_ids=[0], allow_growth=True, log_device_placement=True)

from abyss_deep_learning.keras.classification import batching_gen, onehot_gen
from abyss_deep_learning.keras.utils import gen_dump_data, lambda_gen
from abyss_deep_learning.datasets.coco import ImageSemanticSegmentationDataset
from abyss_deep_learning.datasets.translators import AnnotationTranslator
from keras.backend import clear_session


In [None]:
class AnadarkoTrialTranslator(AnnotationTranslator):
    def filter(self, annotation):
        return 'category_id' in annotation and annotation['category_id'] == 1
    def translate(self, annotation):
        return annotation
    
dataset = ImageSemanticSegmentationDataset(
    "/data/abyss/anadarko/test-run/Test Anadarko/all.json",
    image_dir="/data/abyss/anadarko/test-run/Test Anadarko",
    translator=AnadarkoTrialTranslator(),
    num_classes=2,
    cached=True
)

In [None]:
def unshift_image(image):
    return (image * 127.5 + 127.5).astype(np.uint8)

def example_image(model=None):
    image, targets = dataset.sample()
    print(image.shape, image.dtype, np.min(image), np.max(image))
    plt.figure()
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.subplot(1, 2, 2)
    if model:
        targets = model.predict_proba(image[np.newaxis, ...])[0]
    print(np.unique(targets.argmax(-1)))
    plt.imshow(targets.argmax(-1))
example_image()

# Test training

In [None]:
from abyss_deep_learning.keras.models import FcnCrfSegmenter
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from abyss_deep_learning.keras.utils import tiling_gen, batching_gen

batch_size = 1 # MUST BE 1 for FcnCrf

def create_new_model():
    '''Change init_lr if necessary'''
    from keras.utils import get_file
    from keras_applications.vgg16 import WEIGHTS_PATH_NO_TOP
    model = None  # Clear any existing models
    clear_session()
    model = FcnCrfSegmenter(
        classes=dataset.num_classes, crf_iterations=5, init_lr=5e-5)
    weights_path = get_file(
                'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',
                WEIGHTS_PATH_NO_TOP,
                cache_subdir='models',
                file_hash='6d6bbae143d832006294945121d1f1fc')
    model.set_weights(weights_path)
    
    callbacks = [
        ReduceLROnPlateau(patience=3, factor=0.5, cooldown=3, verbose=1),
        EarlyStopping(patience=10, restore_best_weights=True, verbose=1)
    ]
    return model, callbacks

## Fit: generator method

In [None]:
for image, target in dataset.generator(shuffle_ids=True):
    print(image.shape, image.dtype, target.shape, target.dtype)
    plt.figure()
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.subplot(1, 2, 2)
    plt.imshow(target.argmax(-1))
    break

In [None]:
model = None
model, callbacks = create_new_model()
print("Random output loss is", -np.log(1 / model.classes))
model.fit_generator(
    batching_gen(tiling_gen(dataset.generator(endless=True), (500, 500)), batch_size=1),
    steps_per_epoch=50,
    epochs=100, use_multiprocessing=False,
    verbose=True)

# Test serialization

In [None]:
prob1 = model.predict_proba(image[np.newaxis, :500, :500, :])
model.save("/tmp/abcd")
model = FcnCrfSegmenter.load("/tmp/abcd")
prob2 = model.predict_proba(image[np.newaxis, :500, :500, :])
plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(prob1[0, ...].argmax(-1)*255)
plt.subplot(1, 2, 2)
plt.imshow(prob2[0, ...].argmax(-1)*255)

!rm "/tmp/abcd"
print("Testing serialization: [{}]".format(np.allclose(prob1, prob2)))