# Setup data and model

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from abyss_deep_learning.datasets.simulated import alphanum_gen
from abyss_deep_learning.keras.classification import batching_gen, multihot_gen
from abyss_deep_learning.keras.utils import lambda_gen
from keras.backend import clear_session

gen_train = alphanum_gen(list("0123456789"), 1, scale=7,
                         thickness=5, bg=True, noise=50)
gen_val = alphanum_gen(list("0123456789"), 1, scale=7,
                       thickness=10, bg=True, noise=100)


def pipeline(gen):
    return (
        multihot_gen(
            lambda_gen(gen, lambda x, y: ((x - 127.5) , [int(y)])), 10)
    )


for a in pipeline(gen_val):
    print(a[1])
    print(a[0].shape, np.min(a[0]), np.max(a[0]))
    plt.figure()
    plt.imshow(a[0])
    plt.title(str(a[1]))
    break

# Test training

In [None]:
# Instantiate model
batch_size = 10

def create_new_model():
    from abyss_deep_learning.keras.classification import ImageClassifier
    model = None  # Clear any existing models
    clear_session()
    model = ImageClassifier(
        backbone='xception', input_shape=(128, 75, 3), classes=10,
        init_lr=1e-3, init_weights='imagenet',
        trainable=True)
    return model

## Fit: batch method

In [None]:
from abyss_deep_learning.keras.utils import gen_dump_data

x_train, y_train = gen_dump_data(pipeline(gen_train), 100)
validation_data = gen_dump_data(pipeline(gen_train), 20)
# model = create_new_model()

print("Break-even loss is", -np.log(1 / model.classes))
# model.set_weights(None)
# model.set_lr(1e-3)
model.fit(
    x_train, y_train,
    validation_data=validation_data,
    batch_size=batch_size, epochs=4)
# del x_train, y_train, validation_data

## Fit: generator method

In [None]:
model = create_new_model()

print("Break-even loss is", -np.log(1 / model.classes))
model.fit_generator(
    batching_gen(pipeline(gen_train), batch_size=batch_size),
    validation_data=batching_gen(pipeline(gen_val), batch_size=batch_size),
    steps_per_epoch=10, validation_steps=1,
    epochs=4,
    verbose=True)

## Fit: dataset method

In [None]:
# TODO

# Test serialization

In [None]:
prob1 = model.predict_proba(a[0][np.newaxis, ...])
model.save("/tmp/abcd")
model = ImageClassifier.load("/tmp/abcd")
prob2 = model.predict_proba(a[0][np.newaxis, ...])

!rm "/tmp/abcd"
print("Testing serialization: [{}]".format(np.allclose(prob1, prob2)))