In [None]:
import json
import math
import os

from keras.backend import clear_session
from keras.applications.xception import preprocess_input
import matplotlib.pyplot as plt
import numpy as np
import cv2

from abyss_deep_learning.datasets.coco import ImageClassificationDataset
from abyss_deep_learning.datasets.translators import CocoCaptionTranslator
from abyss_deep_learning.keras.classification import caption_map_gen, onehot_gen, hamming_loss
from abyss_deep_learning.keras.utils import lambda_gen, batching_gen
from abyss_deep_learning.keras.models import ImageClassifier

In [None]:
# Current working directory of binary classification (/home/users/khu/src/abyss/deep-learning/jupyter/kent-experiments/exp_001)
DATA_DIR = os.path.join(os.getcwd(), 'data')
TRAIN = "train.json"

In [None]:
ds = ImageClassificationDataset(os.path.join(DATA_DIR, TRAIN),
                                      cached=False,
                                      translator=CocoCaptionTranslator())

In [None]:
# Relabel non 'RI' captions as 'NR' 
# for binary classification of root/non-root

for idx, ann in ds.coco.anns.items():
    if ann['caption'] != 'RI':
        ann['caption'] = 'NR'

In [None]:
# Sample image and caption
image, caption = ds.sample()
print(caption)
plt.imshow(image)

In [None]:
caption_map = {
    'NR': 0,
    'RI': 1,
}

In [None]:
caption_map_r = {value: key for key, value in caption_map.items()}

## Data set splitting

In [None]:
split_ratio = 0.8

In [None]:
idx_shuffled = np.random.permutation(len(ds.coco.imgs))

In [None]:
idx_split = int(np.floor(split_ratio*len(ds.coco.imgs)))

In [None]:
# Randomly select 80% of data to be in training set
train_ids = idx_shuffled[0:idx_split]
# Randomly select 20% of data to be in validation set
val_ids = idx_shuffled[(idx_split+1):]

In [None]:
train_gen = ds.generator(data_ids=list(train_ids),endless=True, shuffle_ids=True)
val_gen = ds.generator(data_ids=list(val_ids),endless=True, shuffle_ids=True)

## Parameters for model training

In [None]:
num_rows = 299
num_cols = 299

batch_size = 5

steps_per_epoch = len(train_ids) // batch_size
validation_steps = len(val_ids) // batch_size

epochs = 5

init_lr = 1e-5

In [None]:
def pipeline(gen):
    """
        caption_map_gen:
            Maps ['NR', 'RI'] to [0, 1]
        onehot_gen:
            Create a one hot vector from input of either 0 or 1
        lambda_gen:
            Preprocess image to be in compatible format for Xception
        batching_gen:
            Generate batches of 5 images to pass in fit_generator
    """
    return batching_gen(lambda_gen(onehot_gen(caption_map_gen(gen, caption_map), len(caption_map)), 
                                   lambda x, y: (preprocess_input(cv2.resize(x, (num_rows, num_cols))), y)), batch_size)


In [None]:
# Model default uses Xception model
model = ImageClassifier(init_lr=init_lr)

In [None]:
model.fit_generator(pipeline(train_gen),
                    epochs=epochs,
                    steps_per_epoch=steps_per_epoch,
                    validation_data=pipeline(val_gen),
                    validation_steps=validation_steps,
                    )

In [None]:
image, caption = ds.sample()
print(caption_map_r[model.predict(np.expand_dims(cv2.resize(image, (299,299)), 0))[0]])
print(caption.pop())

# Data preparation for model testing

In [None]:
SWC = "/mnt/rhino/ssd1/processed/industry-data/swc/train_1/cloudfactory/datasets/with-bg/notebook-ready/train-nb.json"

In [None]:
swc_ds = ImageClassificationDataset(SWC,
                                    cached=False,
                                    translator=CocoCaptionTranslator(separator=','))

In [None]:
# Equialent SWC labels for a root
roots = ['SeRB', 'SeRF', 'SeRJ', 'SeRM']

In [None]:
swc_gen = swc_ds.generator()

In [None]:
def swc_pipeline(gen):
    gen1 = lambda_gen(gen, lambda x, y: (preprocess_input(cv2.resize(x, (num_rows, num_cols))),
                                         ['RI' if any(label in roots for label in list(y)) else 'NR']))
    map_gen = caption_map_gen(gen1, caption_map)
    gen2 = onehot_gen(map_gen, len(caption_map))
    return batching_gen(gen2, batch_size)

In [None]:
def process_image(image)
    return preprocess_input(np.expand_dims(cv2.resize(image, (num_rows, num_cols)), axis=0))

## Iterate through SWC generator and calculate probabilities of either class root/non-root

In [None]:
image, caption = next(swc_gen)
preds = model.predict_proba(process_image(image))[0]
labelled = np.array(['RI' if any(label in roots for label in list(caption)) else 'NR'])

In [None]:
for image, caption in swc_gen:
    pred = model.predict_proba(process_image(image))[0]
    preds = np.vstack((preds, pred))
    label = np.array(['RI' if any(label in roots for label in list(caption)) else 'NR'])
    labelled = np.vstack((labelled, label))

In [None]:
for idx, (pred, label) in enumerate(zip(preds, labelled)):
    print(idx, pred, label)