In [None]:
import json
import math
import os

from keras.backend import clear_session
from keras.applications.xception import preprocess_input
import matplotlib.pyplot as plt
import numpy as np
import cv2

from abyss_deep_learning.datasets.coco import ImageClassificationDataset
from abyss_deep_learning.datasets.translators import CocoCaptionTranslator
from abyss_deep_learning.keras.classification import caption_map_gen, onehot_gen, hamming_loss
from abyss_deep_learning.keras.utils import lambda_gen, batching_gen
from abyss_deep_learning.keras.models import ImageClassifier

In [None]:
DATA_DIR = os.path.join(os.getcwd(), 'data')
TRAIN = "train.json"

In [None]:
ds = ImageClassificationDataset(os.path.join(DATA_DIR, TRAIN),
                                      cached=False,
                                      translator=CocoCaptionTranslator())


In [None]:
# Relabel non 'RI' captions as 'NR' 
# for binary classification of root/non-root

for idx, ann in ds.coco.anns.items():
    if ann['caption'] != 'RI':
        ann['caption'] = 'NR'

In [None]:
image, caption = ds.sample()
print(caption)
plt.imshow(image)

In [None]:
caption_map = {
    'NR': 0,
    'RI': 1,
}

In [None]:
caption_map_r = {value: key for key, value in caption_map.items()}

## Data set splitting

In [None]:
split_ratio = 0.8

In [None]:
idx_shuffled = np.random.permutation(len(ds.coco.imgs))

In [None]:
idx_split = int(np.floor(split_ratio*len(ds.coco.imgs)))

In [None]:
train_ids = idx_shuffled[0:idx_split]
val_ids = idx_shuffled[(idx_split+1):]

In [None]:
train_gen = ds.generator(data_ids=list(train_ids),endless=True, shuffle_ids=True)
val_gen = ds.generator(data_ids=list(val_ids),endless=True, shuffle_ids=True)

## Parameters for model training

In [None]:
num_rows = 299
num_cols = 299

batch_size = 5

steps_per_epoch = len(train_ids) // batch_size
validation_steps = len(val_ids) // batch_size

epochs = 5

init_lr = 1e-5

In [None]:
def pipeline(gen):
    """
        Convert the caption to one hot representation
    """
    return batching_gen(lambda_gen(onehot_gen(caption_map_gen(gen, caption_map), len(caption_map)), 
                                   lambda x, y: (preprocess_input(cv2.resize(x, (num_rows, num_cols))), y)), batch_size)


In [None]:
for image, caption in pipeline(val_gen):
    print(caption)
    break

In [None]:
# Model default uses Xception model
model = ImageClassifier(init_lr=init_lr)

In [None]:
model.fit_generator(pipeline(train_gen),
                    epochs=epochs,
                    steps_per_epoch=steps_per_epoch,
                    validation_data=pipeline(val_gen),
                    validation_steps=validation_steps,
                    )

In [None]:
image, caption = ds.sample()
print(caption_map_r[model.predict(np.expand_dims(cv2.resize(image, (299,299)), 0))[0]])
print(caption.pop())

In [None]:
SWC = "/mnt/rhino/ssd1/processed/industry-data/swc/train_1/cloudfactory/datasets/with-bg/notebook-ready/train-nb.json"

In [None]:
swc_ds = ImageClassificationDataset(SWC,
                                    cached=False,
                                    translator=CocoCaptionTranslator(separator=','))

In [None]:
roots = ['SeRB', 'SeRF', 'SeRJ', 'SeRM']

In [None]:
for idx, ann in swc_ds.coco.anns.items():
    print(idx, ann)

In [None]:
swc_gen = swc_ds.generator()

In [None]:
def swc_pipeline(gen):
    return batching_gen(onehot_gen(
        caption_map_gen(lambda_gen(gen, 
                                      lambda x, y: (preprocess_input(cv2.resize(x, (num_rows, num_cols))), 
                                                    ['RI' if any(label in roots for label in list(y)) else 'NR'])),
                           caption_map,
                          ),
        len(caption_map),
    ),
                        batch_size)

In [None]:
y = {'F', 'SeRB', 'SJ'}
['RI' if any(label in roots for label in list(y)) else 'NR']

In [None]:
swc_ds.stats

In [None]:
steps = len(swc_ds.coco.imgs) // batch_size

In [None]:
swc_ds.print_class_stats()

In [None]:
model.model_.metrics_names

In [None]:
image, caption = next(swc_gen)
preds = model.predict_proba(preprocess_input(np.expand_dims(cv2.resize(image, (num_rows, num_cols)), axis=0)))[0]
labelled = np.array(['RI' if any(label in roots for label in list(caption)) else 'NR'])

In [None]:
for image, caption in swc_gen:
    preds = np.vstack((preds, model.predict_proba(preprocess_input(np.expand_dims(cv2.resize(image, (num_rows, num_cols)), axis=0)))[0]))
    labelled = np.vstack((labelled, np.array(['RI' if any(label in roots for label in list(caption)) else 'NR'])))

In [None]:
for idx, (pred, label) in enumerate(zip(preds, labelled)):
    print(idx, pred, label)

In [None]:
plt.imshow(swc_ds.load_data(472))