# Imports and Constants Definition

In [None]:
# !/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
from collections import defaultdict
import logging
import matplotlib.pyplot as plt
import os
import shutil
import sys
import threading

%matplotlib inline

# add egosocial to the python path
from os.path import dirname, abspath
sys.path.extend([dirname(abspath('.'))])

import numpy as np
import sklearn
from sklearn.utils import compute_class_weight

import keras
from keras import backend as K
from keras.callbacks import ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.layers import Input, Dense, Dropout
from keras.layers.noise import AlphaDropout
from keras.models import Model
from keras.regularizers import l1, l2
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.applications.imagenet_utils import preprocess_input

import egosocial.config
from egosocial.core.types import relation_to_domain, relation_to_domain_vec
from egosocial.utils.keras.autolosses import AutoMultiLossWrapper
from egosocial.utils.logging import setup_logging
from egosocial.utils.keras.callbacks import PlotLearning
from egosocial.utils.keras.backend import limit_gpu_allocation_tensorflow
from egosocial.utils.filesystem import create_directory, check_directory

# constants
DOMAIN, RELATION = 'domain', 'relation'
END_TO_END, ATTRIBUTES = 'end_to_end', 'attributes'
N_CLS_RELATION, N_CLS_DOMAIN = 16, 5
IMAGE_SHAPE = (256, 256, 3)
INPUT_SHAPE = (227, 227, 3)
SHARED_SEED = 1
STREAM_IDS = [str(idx) for idx in [1, 2]]

RELATION_LABELS = [str(label) for label in range(N_CLS_RELATION)]
DEPRECATED_HOME = '/root'

# Limit GPU memory allocation with Tensorflow

In [None]:
limit_gpu_allocation_tensorflow(0.25)

# Input arguments and fake main

In [None]:
FILE_FORMAT = '{prefix}{dtype}{idx}_{split}_{n_cls}{ext}'

# TODO: move to utils
def get_file(split, idx=1, dtype='body', n_cls=N_CLS_RELATION, 
             ext='.txt', prefix='single_'):
    return FILE_FORMAT.format(split=split, prefix=prefix, dtype=dtype, 
                              idx=idx, n_cls=n_cls, ext=ext)

class Configuration:
    def __init__(self, args):
        self.DATA_TYPE = RELATION
        self.ARCH = 'caffeNet'
        self.LAYER = 'fc7'

        self.CONFIG = '{}_{}_{}'.format(self.LAYER, self.DATA_TYPE, self.ARCH)

        # setup directories
        self.PROJECT_DIR = args.project_dir
        self.BASE_MODELS_DIR = os.path.join(self.PROJECT_DIR,
                                            'models/trained_models')
        self.ATTR_MODELS_DIR = os.path.join(self.BASE_MODELS_DIR,
                                            'attribute_models')
        self.SVM_MODELS_DIR = os.path.join(self.PROJECT_DIR,
                                           'models/svm_models')

        self.SPLITS_DIR = os.path.join(self.PROJECT_DIR,
                                       'datasets/splits/annotator_consistency3')

        self.STATS_MODELS_DIR = os.path.join(self.SVM_MODELS_DIR, 'stats')

        # splits (switch from caffe's split name convention to keras's convention)        
        _train, _val, _test = 'train', 'test', 'eval'        
        self.LABEL_FILES = {split: os.path.join(self.SPLITS_DIR, get_file(split=split))
                            for split in (_train, _val, _test)}

        self.IS_END2END = False

        self.BASE_FEATURES_DIR = os.path.join(self.PROJECT_DIR,
                                              'extracted_features')
        self.FEATURES_DIR = os.path.join(self.BASE_FEATURES_DIR,
                                         'attribute_features',
                                         self.CONFIG)

        self.STORED_FEATURES_DIR = os.path.join(self.FEATURES_DIR,
                                                'all_splits_numpy_format')

        self.PROCESS_FEATURES = args.port_features

        self.EPOCHS = args.epochs
        self.BATCH_SIZE = args.batch_size

        # reuse precomputed model?
        self.REUSE_MODEL = args.reuse_model
        # save model to disk?
        self.SAVE_MODEL = args.save_model
        # save model statistics to disk?
        self.SAVE_STATS = args.save_stats
        
        self.FAKE_DIR = os.path.join(self.PROJECT_DIR, 'datasets', 'fake_dir')
        
def positive_int(value):
    ivalue = int(value)
    if ivalue <= 0:
        raise argparse.ArgumentTypeError(
            "%s is an invalid positive int value" % value)
    return ivalue

def main(*fake_args):
    setup_logging(egosocial.config.LOGGING_CONFIG)

    entry_msg = 'Reproduce experiments in Social Relation Recognition paper.'
    parser = argparse.ArgumentParser(description=entry_msg)

    parser.add_argument('--project_dir', required=True,
                        help='Base directory.')

    parser.add_argument('--port_features', required=False,
                        action='store_true',
                        help='Whether port features from other formats to'
                             'numpy.')

    parser.add_argument('--reuse_model', required=False,
                        action='store_true',
                        help='Use precomputed model if available.')

    parser.add_argument('--save_model', required=False,
                        action='store_true',
                        help='Save model to disk.')

    parser.add_argument('--save_stats', required=False,
                        action='store_true',
                        help='Save statistics to disk.')

    parser.add_argument('--epochs', required=False, type=positive_int,
                        default=100,
                        help='Max number of epochs.')

    parser.add_argument('--batch_size', required=False, type=positive_int,
                        default=32,
                        help='Batch size.')

    # TODO: implement correctly
    args = parser.parse_args(*fake_args)
    # keep configuration
    conf = Configuration(args)
    # check directories
    check_directory(conf.PROJECT_DIR, 'Project')
    check_directory(conf.SPLITS_DIR, 'Splits')
    
    return conf

# Helper functions

In [None]:
def process_line(line):
    image_path, label = line.strip().split(' ')
    image_path = image_path.replace(DEPRECATED_HOME, '~')

    if '~' in image_path:
        image_path = os.path.expanduser(image_path)

    return image_path, label

def get_split_mapping(files_dir, file_callback=get_file):
    # splits (switch from caffe's split name convention to keras's convention)
    _train, _val, _test = 'train', 'test', 'eval'
    
    check_directory(files_dir, 'Splits')
    
    mapping = {}
    for split in (_train, _val, _test):
        for idx in STREAM_IDS:
            key = '{idx}{sep}{split}'.format(split=split, sep=os.path.sep, idx=idx)
            value =  os.path.join(files_dir, file_callback(split=split, idx=idx))
            mapping[key] = value

    return mapping

def create_fake_directory(fake_dir, splits_dir, labels):

    if os.path.isdir(fake_dir):
        shutil.rmtree(fake_dir)
    
    create_directory(fake_dir, 'Fake')
    
    for dtype in ('body', 'face'):
        dtype_dir = os.path.join(fake_dir, dtype)
        create_directory(dtype_dir, 'Fake {}'.format(dtype))
        
        # set dtype to body or face
        file_callback = lambda **kwargs: get_file(dtype=dtype, **kwargs)      

        split_mapping = get_split_mapping(splits_dir, file_callback=file_callback)
        for images_dir, images_file in split_mapping.items():
            fake_images_dir = os.path.join(dtype_dir, images_dir)

            # keras asks a different directory for each class
            for label in labels:
                fake_label_dir = os.path.join(fake_images_dir, str(label))
                create_directory(fake_label_dir, 'Label')

            images_file_path = os.path.join(splits_dir, images_dir, images_file)
            with open(images_file_path) as f:
                for fake_id, line in enumerate(f):
                    image_path, label = process_line(line)
                    # create symlinks for every entry
                    # a name may appear in several entries, so an unique fake name is created
                    fake_name = 'fakelink{}_{}'.format(fake_id, os.path.basename(image_path))
                    fake_image_path = os.path.join(fake_images_dir, label, fake_name)
                    if not os.path.exists(fake_image_path):
                        os.symlink(image_path, fake_image_path)

# TODO: move to utils
class threadsafe_iter(object):
    """Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    """
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def __next__(self):
        with self.lock:
            return next(self.it)

# TODO: move to utils
def threadsafe_generator(f):
    """A decorator that takes a generator function and makes it thread-safe.
    """
    def g(*a, **kw):
        return threadsafe_iter(f(*a, **kw))
    return g

# TODO: move to utils
@threadsafe_generator
def fuse_inputs_generator(generators, inputs, outputs_callback=None): 
    assert len(generators) == len(inputs)
    
    # assume single output (shared by all inputs)
    if outputs_callback is None: # indices 0: first input, 1: label data
        outputs_callback = lambda batch, _inputs: batch[0][1]
    
    while True:
        data_batch = [next(gen) for gen in generators]
        # multiple inputs
        X = {input_name : data_batch[idx][0] for idx, input_name in enumerate(inputs)}        

        yield X, outputs_callback(data_batch, inputs)

# TODO: move to utils
def flow_from_dirs(input_directories, **kwargs):
    gen_args = kwargs.pop('gen_args', None)
    gen_args = gen_args if gen_args else {}

    flow_gens = []
    for directory in input_directories:
        check_directory(directory)
        datagen = ImageDataGenerator(**gen_args)
        flow_gen = datagen.flow_from_directory(
            directory, 
            **kwargs
        )

        flow_gens.append(flow_gen)
        
    return flow_gens

# TODO: move to utils
def get_relation_domain(data_batch, inputs):
    # relation output (every attribute has the same output)
    y_rel = data_batch[0][1] # indices 0: first input, 1: label data
    # inverse of function to_categorical
    plain_y_rel = np.argmax(y_rel, axis=1)
    # final output is domain, relation
    y_dom = to_categorical(relation_to_domain_vec(plain_y_rel), N_CLS_DOMAIN)
    
    return dict(domain=y_dom, relation=y_rel)

def create_data_split_generators(directory, 
                                 train_gen_args=None, 
                                 test_val_gen_args=None,
                                 **kwargs):
    check_directory(directory)
    
    inputs = ['body/1', 'body/2', 'face/1', 'face/2']
    def input_dirs(split):
        return [os.path.join(directory, input_name, split) 
                for input_name in inputs]

    # splits (switch from caffe's split name convention to keras's convention)    
    _train, _val, _test = 'train', 'test', 'eval'
        
    split_gen_args = [ (_train, train_gen_args)
                     , (_val, test_val_gen_args)
                     , (_test, test_val_gen_args)]

    split_generators = []
    for split, gen_args in split_gen_args:
        gens = flow_from_dirs(
            input_dirs(split), 
            gen_args=gen_args,
            **kwargs
        )
        
        fused_gen = fuse_inputs_generator(
            gens, inputs, 
            outputs_callback=get_relation_domain
        )
        
        split_generators.append(fused_gen)
        
    return split_generators

# Main class

In [None]:
class SocialClassifier:
    
    def __init__(self, data_dir):
        self._data_dir = data_dir

        self._log = logging.getLogger(self.__class__.__name__)
        
        # initialize when model is configured
        self._model_wrapper = None
        self.model = None
        
        self._train_gen, self._val_gen, self._test_gen = None, None, None
    
    def setup_datagen(self, **kwargs):
        self._log.debug("Creating data generators.")        

        train_gen_args = dict(
            rescale=1./255,
            width_shift_range=0.2,
            height_shift_range=0.2,
            horizontal_flip=True,
            preprocessing_function=preprocess_input,
        )
        test_val_gen_args = dict(
            rescale=1./255,
            preprocessing_function=preprocess_input,
        )
        
        split_gens = create_data_split_generators(
            self._data_dir, train_gen_args, test_val_gen_args,
            **kwargs
        )

        self._train_gen, self._val_gen, self._test_gen = split_gens
        
    def set_model(self, model, optimizer='adam', 
                  loss='categorical_crossentropy', 
                  loss_weights='auto',
                  **kwargs):
        assert model
        self._log.debug("Initializing model.")
        
        # wrapper allows to train the loss weights
        self._model_wrapper = AutoMultiLossWrapper(model)
        self._model_wrapper.compile(optimizer, loss, 
                                    loss_weights=loss_weights, 
                                    **kwargs)

        self.model = self._model_wrapper.model
        self._log.info(self.model.summary())
        
    def fit(self, steps_per_epoch, validation_steps, **kwargs):
        assert self.model
        assert self._train_gen
        assert self._val_gen
        self._log.debug("Training model from scratch.")        

        # train the model on the new data for a few epochs
        return self.model.fit_generator(
            self._train_gen,
            steps_per_epoch=steps_per_epoch,
            validation_data=self._val_gen,
            validation_steps=validation_data,
            **kwargs
        )
        
    def evaluate(steps, **kwargs):
        assert self.model
        assert self._test_gen
        self._log.debug("Evaluating model in test data.")
        
        return self.model.evaluate_generator(
            self._test_gen,
            steps=steps,
            **kwargs
        )

# Model definitions

# Fake call to main to process inputs arguments

In [None]:
args = [
    "--project_dir", "/home/shared/Documents/final_proj",
    "--epochs", "30",
    "--batch_size", "32",
]

conf = main(args)

# Prepate the data

In [None]:
create_fake_directory(conf.FAKE_DIR, conf.SPLITS_DIR, RELATION_LABELS)

In [None]:
helper = SocialClassifier(conf.FAKE_DIR)

In [None]:
batch_size = conf.BATCH_SIZE

helper.setup_datagen(
    batch_size=batch_size,
    seed=SHARED_SEED,
    classes=RELATION_LABELS,
    target_size=INPUT_SHAPE[0:2],
    follow_links=True, # fake directory uses simlinks
)

# Initialize model

In [None]:
def create_model(input_shape, target_size):
    return None

learning_rate = 0.0001
helper.set_model(
    create_model(INPUT_SHAPE, IMAGE_SHAPE),
    optimizer=keras.optimizers.Adam(learning_rate, decay=1e-6),
    metrics=['accuracy'],
)

# Training

In [None]:
# size of each split, output from the data generator
n_train, n_validation, n_test = 13729, 709, 5106

epochs = conf.EPOCHS

checkpoint_path = os.path.join(egosocial.config.MODELS_CACHE_DIR, 'training',
                               'finetuned_weights.{epoch:02d}-{val_loss:.2f}.h5')
checkpointer = ModelCheckpoint( 
    filepath=checkpoint_path, monitor='val_loss',
    save_best_only=True, period=1,
)

lr_handler = ReduceLROnPlateau(
    monitor='val_loss', factor=0.1, patience=10, min_lr=0.00001
)

# TODO: implement csv logs
#csv_logger = CSVLogger(filename)

# if plot is enabled, set verbose=0
plot_metrics = PlotLearning(update_step=1)

callbacks = [
#    checkpointer,
#    lr_handler, 
    plot_metrics
]

hist = helper.fit(
    steps_per_epoch=np.ceil(1.0 * n_train / batch_size),    
    validation_steps=np.ceil(1.0 * n_validation / batch_size),
    epochs=epochs,
    callbacks=callbacks,
    workers=4,
    max_queue_size=5,
    verbose=0,     
)

# Evaluation

In [None]:
scores = helper.evaluate( 
    steps=np.ceil(1.0 * n_test / batch_size),
    workers=4,
    max_queue_size=5,
)

for score, metric_name in zip(scores, helper.model.metrics_names):
    helper._log.info("{} : {0.4f}".format(metric_name, score))