In [None]:
import os

import cv2
import matplotlib.pyplot as plt
import numpy as np
from keras.applications.xception import preprocess_input

from abyss_deep_learning.datasets.coco import ImageClassificationDataset
from abyss_deep_learning.datasets.translators import CocoCaptionTranslator
from abyss_deep_learning.keras.classification import caption_map_gen, onehot_gen
from abyss_deep_learning.keras.models import ImageClassifier
from abyss_deep_learning.keras.utils import lambda_gen, batching_gen

In [None]:
CURR_DIR = os.getcwd()
DATA_DIR = os.path.join(CURR_DIR, 'data')
JSON = 'merged_non_x.json'

In [None]:
ds = ImageClassificationDataset(os.path.join(DATA_DIR, JSON),
                                cached=False,
                                translator=CocoCaptionTranslator())

In [None]:
# Captions:
#     F: Fault
#     RI: Root Intrustion
#     BG: Background

ds_captions = ['F', 'RI', 'BG']

In [None]:
# Relabel captions in data set that are not Forwards (F), Roots (RI) or Background (BG) to Other Faults (OF)

for idx, ann in ds.coco.anns.items():
    if ann['caption'] not in ds_captions:
        ann['caption'] = 'OF'

In [None]:
# Insert remaining caption Other Fault (OF)

ds_captions.append('OF')

In [None]:
caption_map = {caption: idx for idx, caption in enumerate(ds_captions)}
print(caption_map)

In [None]:
split_ratio = 0.8
idx_shuffled = np.random.permutation(ds.data_ids)
idx_split = int(np.floor(split_ratio*len(ds.coco.imgs)))

In [None]:
# Randomly select 80% of data to be in training set
train_ids = idx_shuffled[0:idx_split]
# Randomly select 20% of data to be in validation set
val_ids = idx_shuffled[(idx_split+1):]

In [None]:
train_gen = ds.generator(data_ids=list(train_ids),endless=True, shuffle_ids=True)
val_gen = ds.generator(data_ids=list(val_ids),endless=True, shuffle_ids=True)

# Parameters for model initialization and fitting

In [None]:
num_rows = 299
num_cols = 299

batch_size = 5

steps_per_epoch = len(train_ids) // batch_size
validation_steps = len(val_ids) // batch_size

epochs = 2

init_lr = 1e-5

In [None]:
def func(x, y):
    """ Parameter y indices correspond to caption map
        Remap root and other fault to 0 if background is labelled for the image
    """
    if y[2]:
        y[1:3:2] = 0
    return x, y

def pipeline(gen, caption_map):
    gen1 = caption_map_gen(gen, caption_map)
    gen2 = onehot_gen(gen1, len(caption_map))
    gen3 = lambda_gen(gen2, func)
    gen4 = lambda_gen(gen3, lambda x, y: (preprocess_input(cv2.resize(x, (num_rows, num_cols))), y))
    gen5 = batching_gen(gen4, batch_size)
    return gen5

In [None]:
caption_map

In [None]:
for image, caption in pipeline(train_gen, caption_map):
    print(caption[0])
    plt.imshow(image[0])
    break

In [None]:
model = ImageClassifier(init_lr=init_lr,
                       classes=len(caption_map),
                       loss='binary_crossentropy',
                       output_activation='sigmoid')

In [None]:
model.fit_generator(pipeline(train_gen, caption_map),
                    epochs=epochs,
                    steps_per_epoch=steps_per_epoch,
                    validation_data=pipeline(val_gen, caption_map),
                    validation_steps=validation_steps,
                    )