In [None]:
%matplotlib notebook
import os
import sys
from pprint import pprint
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda/lib64:' + os.environ['LD_LIBRARY_PATH']
from abyss_deep_learning.utils import config_gpu

from keras.utils import to_categorical
from skimage.io import imread
from skimage.transform import resize
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import keras.backend as K
from skimage.color import label2rgb
from keras import Model
from keras.optimizers import Nadam
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, TensorBoard, EarlyStopping
from keras.applications.resnet50 import preprocess_input

from abyss_deep_learning.keras.segmentation import SegmentationDataset, augmentation_gen, jaccard_index
from abyss_deep_learning.keras.utils import initialize_conv_transpose2d, lambda_gen

from crfrnn.crfrnn_model import get_crfrnn_model_def

config_gpu([0])

In [None]:
#### If you don't have the imagenet weights below it will auto download them
if not os.path.exists("~/.keras/models/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5"):
    from keras.applications.vgg16 import VGG16
    vgg = VGG16(include_top=False)
    del vgg

# Setup Data

In [None]:
# Globals
parallel = False # Use multiple GPUs
num_classes = 6
image_dims = (500, 500, 3) # Can't be changed for crfrnn
use_class_weights = False
use_balanced_set = False
use_cached = True
aug_config = None

In [None]:
database_dir = "/data/abyss"
dataset_name = "swc_pipeline"
'/data/ab'
dataset_files = {
    'train': os.path.join(database_dir, "{:s}/swc_pipeline-coco2_train.json".format(dataset_name)),
    'val': os.path.join(database_dir, "{:s}/swc_pipeline-coco2_val.json".format(dataset_name)),
    'test': os.path.join(database_dir, "{:s}/swc_pipeline-coco2_val.json".format(dataset_name))
}
dataset_name = dataset_name.replace("/", "-")

print("Number of classes:", num_classes)

# Application Methods

In [None]:
def preprocess(image, target):
    image = resize(image, image_dims, preserve_range=True, mode='constant')
    target = resize(target, image_dims[0:2], preserve_range=True, mode='constant', order=0)
    return preprocess_input(image, mode='tf'), target

def postprocess(image, target):
    print(image.shape, target.shape)
    return ((image + 1) * 127.5).astype(np.uint8)
        
def pipeline(gen, samples_per_image=1, augment=False):
    return (
            lambda_gen(
                augmentation_gen(
#                         fill_mask_gen(
                                gen
#                         , min_size=3750)
                , aug_config=aug_config, enable=augment)
            , func=preprocess)
    )

In [None]:
import concurrent.futures

dataset = {
    'names': list(dataset_files.keys()),
    'classes': [], # MUST FILL IN
    'class_weights': {name: None for name in dataset_files.keys()},
    'ids' : {},
    'gens': {},
    'data': {},
    'coco': {}
}
print("Combinations of labels present")

def annotation_filter(ann):
    if 'segmentation' not in ann:
        return False
    if ('annotation_type' in ann) and (ann['annotation_type'] == 'point'):
        return False
    return True

for name, path in dataset_files.items():
    coco = SegmentationDataset(path, image_dims)
    coco.annotation_filter = annotation_filter
    ids = list(balanced_set(coco)) if use_balanced_set else coco.image_ids
#     if name in ['test']:
#         ids = ids[::4] #TODO: REMOVE AFTER TESTING
#     if name in ['train']:
#         ids = ids[::2] #TODO: REMOVE AFTER TESTING
    gen = pipeline(
        coco.generator(imgIds=ids, shuffle_ids=True),
        coco.num_classes, aug_config if name == 'train' else None)
    print("{:s}: {:d} images".format(name, len(ids)))
    dataset['coco'][name] = coco
    dataset['ids'][name] = ids
    dataset['gens'][name] = gen
    
    if use_cached:
        expected_size = 4 * len(ids) * np.product(image_dims)
        print("Caching {:s} set will take {:.1f} GB".format(name, expected_size / 1024 ** 3))
        dataset['data'][name] = [
            np.zeros((len(ids),) + image_dims, dtype=np.float32),
            np.zeros((len(ids),) + image_dims[0:2] + (coco.num_classes,), dtype=np.float32)]
        
#         dataset['data'][name] = gen_dump_data(gen, len(ids))
        def procedure(a):
            idx, img_id = a
            image, caption = preprocess(coco.load_image(img_id), coco.load_segmentation(img_id))
#             caption = set_to_multihot({caption_map[i] for i in caption if i in captions}, coco.num_classes)
            return idx, img_id, image, caption
        
        with concurrent.futures.ProcessPoolExecutor() as executor:
            for idx, img_id, image, segm in executor.map(procedure, enumerate(ids)):
                dataset['data'][name][0][idx] = image
                dataset['data'][name][1][idx] = np.array(segm)
        print("Dataset {:s} has {:d} classes".format(name, dataset['data'][name][1].shape[-1]))
    

def data_source(name):
    if dataset['data'] and name in dataset['data']:
        return dataset['data'][name]
    return dataset['gens'][name]

def data_sample(name, size=None):
    source = data_source(name)
    if isinstance(source, list):
        if dataset['data'][name]:
            idx = np.random.choice(np.arange(dataset['data'][name][0].shape[0]), size=size)
            images = dataset['data'][name][0][idx]
            labels = dataset['data'][name][1][idx]
#             if size == None:
#                 images, labels = images[0], labels[0]
            return images, labels
    for image, label in source:
        return image, label
# print(np.unique(
#     [i 
#      for image in coco_train.imgs.values()
#     for i in coco_train.load_caption(image['id'])]
# ))

dataset['classes'] = sorted([cat['id'] for cat in dataset['coco']['train'].cats.values()])

In [None]:
from skimage.color import label2rgb
plt.figure()
num_rows = 3
print("Left to right: ground truth samples from ", end='')
for j in range(num_rows):
    for i, name in enumerate(dataset['names']):
        plt.subplot(num_rows, 3, 3 * j + i + 1)
        image, target = data_sample(name, None)
#         print(image.dtype, image.shape, target.shape)
        plt.imshow(label2rgb(target.argmax(axis=2), postprocess(image, target), alpha=0.2))
#         plt.title(', '.join([caption_map_r[int(cap_id)] for cap_id in np.argwhere(label)]))
        print(name, end=', ')
        plt.axis('off')


In [None]:
# This cell intentionally left blank due to display bug above.

In [None]:
from collections import Counter
from sklearn.utils.class_weight import compute_class_weight

###################
def label_encoding(y, mode):
    if mode == 'multihot':
        return np.argwhere(y)[:, 1]
    else:
        raise NotImplementedError()
    #   return np.sum(y * (2 ** np.arange(y.shape[1])[::-1]), axis=1)

def noop(args, **kwargs):
    return args
    

if use_class_weights:
    for name, coco in dataset['coco'].items():
        print("{:s} {:s} class stats {:s}".format('=' * 8, name, '=' * 8))
        labels = [coco.load_caption(image['id'])
                  for image in coco.imgs.values() if image['id'] in dataset['ids'][name]]
        y = [l for fields in labels for l in fields]
        print(np.unique(y), dataset['classes'])
        count = np.array(list(dict(sorted(Counter(y).items(), key=lambda x: x[0])).values()))
        spread = {i: float(v.round(2)) for i, v in enumerate(count / np.sum(count))}
        print(sum([i not in dataset['classes'] for i in y]))
        class_weights = compute_class_weight('balanced', dataset['classes'], y)
        class_weights = {i: float(np.round(v, 3)) for i, v in enumerate(class_weights)}
        dataset['class_weights'][name] = class_weights
        a = np.array(list(dataset['class_weights'][name].values()))
        
        print("class weights:".format(name))
        print(" ", class_weights)
        trivial = np.mean(a / np.max(a))
        print("trivial result accuracy:\n  {:.2f} or {:.2f}".format(trivial, 1-trivial))
        print("class cover fractions:\n  ", spread )


In [None]:
from abyss_deep_learning.keras.utils import batching_gen, gen_dump_data

num_val_data = len(dataset['ids']['val'])
if use_cached:
    print("use_cached")
    test_data = dataset['data']['test']
    val_data = dataset['data']['val']
else:
    val_data = gen_dump_data(dataset['gens']['val'], num_val_data)
    test_data = gen_dump_data(dataset['gens']['test'], len(dataset['ids']['test']))

In [None]:
class Experiment(object):
    def __init__(self):
        self.batch_size = 2
        self.model_parallel = None
        self.has_crf = False
    
    def get_train_model(self):
        return self.model_parallel or self.model
        
    def plot_test(self, gen, output_fn=np.argmax):
        for i, (rgb, target) in enumerate(gen):
            print("rgb.shape", rgb.shape)
            print("rgb min/max", np.min(rgb), np.max(rgb))
            print("target.shape", target.shape)
            print("target min/max", np.min(target), np.max(target))

            rgb8 = ((rgb + 1) / 2 * 255).astype(np.uint8)
            Y_pred = self.model.predict(rgb[np.newaxis, ...])
            plt.figure()
            plt.subplot(2, 2, 1)
            plt.imshow(rgb8)
            plt.subplot(2, 2, 2)
            plt.imshow(label2rgb(target[ ..., 1], rgb8, bg_label=-1))
            plt.subplot(2, 2, 3)
            plt.imshow(Y_pred[0, ..., 0])
            plt.subplot(2, 2, 4)
            plt.imshow(label2rgb(np.argmax(Y_pred[0], axis=-1), rgb8, bg_label=0))
            plt.title("CRF IoU={:.2f}".format(
                jaccard_index(output_fn(target, axis=-1), np.argmax(Y_pred[0], axis=-1))))
            plt.tight_layout()
            break

    def init_crf(self):
        crf_params = ['crfrnn/spatial_ker_weights:0',
                  'crfrnn/bilateral_ker_weights:0',
                  'crfrnn/compatibility_matrix:0']
        self.model.set_weights([
            np.eye(2), np.eye(2), 1 - np.eye(2)
        ])
        
    def create_model(self, num_classes, image_dims, upsample='new', num_iterations=0):
        '''
        crf can be one of 'new', 'load', 'train' or 'none'
        upsample can be one of 'new', 'load' or 'train
        '''
        self.has_crf = num_iterations > 0
        self.model = None
        self.model_parallel = None
        K.clear_session()
        print("Making model with {:d} classes and {} input shape".format(num_classes, str(image_dims)))
        self.model = get_crfrnn_model_def(
            num_classes=(num_classes + 1), input_shape=image_dims,
            num_iterations=num_iterations, with_crf=self.has_crf)
        if upsample in ['bilinear', 'train']:
            initialize_conv_transpose2d(
                self.model,
                ['score2', 'score4', 'upsample'],
                trainable=True)
        if self.has_crf:
            init_crf(self.model.get_layer(name='crfrnn'))
            
    def load_model(self, model_path):
        if model_path: 
            self.model.load_weights(model_path, by_name=True)
        if self.has_crf:
            init_crf(self.model.get_layer(name='crfrnn'))

    def compile_model(self, train_layers=None, parallel=False):
        if train_layers:
            for layer in self.model.layers:
                layer.trainable = layer.name in train_layers
        if parallel:
            from keras.utils import multi_gpu_model
            self.model_parallel = multi_gpu_model(self.model, gpus=2)#, cpu_merge=True, cpu_relocation=False)
        self.get_train_model().compile(optimizer='nadam',
            loss='categorical_crossentropy',
            metrics=['accuracy'])

    def train(self, epochs, initial_epoch=0, val_data=None):
        steps_per_epoch = len(dataset['ids']['train']) // self.batch_size
        steps_per_epoch_val = 100 // self.batch_size #len(dataset['ids']['val']) // batch_size
        print("Steps per epoch:", steps_per_epoch)
        print("Steps per steps_per_epoch_val:", steps_per_epoch_val)

        source = data_source('train')
        common = {
            "class_weight": dataset['class_weights']['train'],
            "callbacks": self.callbacks,
            "epochs": epochs,
            "verbose": 1,
            "initial_epoch": initial_epoch,
#             "steps_per_epoch": steps_per_epoch
        }
        if isinstance(source, list): # Cached data
            print("Training cached data")
            self.history = self.get_train_model().fit(
                x=source[0], y=source[1],
                batch_size=self.batch_size,
                validation_data=tuple(data_source('val')),
                shuffle=True,
                **common)
        elif val_data:
            # Generator training with static val data
            self.history = self.get_train_model().fit_generator(
                batching_gen(source, batch_size=exp.batch_size),
                steps_per_epoch=steps_per_epoch,
                validation_steps=steps_per_epoch_val,
                validation_data=val_data,
                workers=10,
                **common)
        else:
            self.history = self.get_train_model().fit_generator(
                batching_gen(source, batch_size=exp.batch_size),
                validation_data=batching_gen(
                    data_source('val'), batch_size=exp.batch_size),
                steps_per_epoch=steps_per_epoch,
                validation_steps=steps_per_epoch_val,
                workers=10,
                **common)
        return self.history

# Training

## First train only Conv2D

In [None]:

imagenet_weights = "~/.keras/models/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5"
logdir = os.path.join("/data/log/fcn-crfrnn/test-swc/{:04d}".format(np.random.randint(0, 9999)))
!mkdir -p "$logdir/models"
best_path = os.path.join(logdir, "models/best.{epoch:03d}-{val_loss:.4f}.h5")

In [None]:
exp = None
K.clear_session()
exp = Experiment()
exp.create_model(
    len(dataset['classes']), image_dims,
    upsample='train',
    num_iterations=0)
exp.batch_size = 4
exp.load_model(imagenet_weights)

layers = [layer.name for layer in exp.model.layers]
train_layers = layers[layers.index('fc6'):]

exp.compile_model(train_layers=train_layers, parallel=parallel)
exp.callbacks = [
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, cooldown=5, verbose=1),
        ModelCheckpoint(
            best_path, monitor='val_loss', verbose=1,
            save_best_only=True, save_weights_only=True, mode='auto', period=1),
        TensorBoard(
            log_dir=logdir, histogram_freq=0, batch_size=8,
            write_graph=False, write_grads=False, write_images=False),
        EarlyStopping(
            monitor='val_loss', min_delta=0.0, patience=10, verbose=1, mode='auto')
]
# Test to see if LR causes loss explosion
K.set_value(exp.get_train_model().optimizer.lr, 1e-4)
try:
    exp.train(200, val_data=val_data)
except KeyboardInterrupt:
    pass
except:
    raise
# Save the weights and epoch for next training step
initial_epoch = exp.callbacks[3].stopped_epoch + 1 if exp.callbacks[3].stopped_epoch else 16
initial_lr = K.eval(exp.get_train_model().optimizer.lr)
exp.model.save_weights('/data/tmp/blah_weights.h5')
# raise RuntimeError("Stop Run All")

## Save weights, reload and train all layers

In [None]:
K.clear_session()
exp = Experiment()
exp.create_model(
    len(dataset['classes']), image_dims,
    upsample='train',
    num_iterations=0)
exp.batch_size = 4
exp.model.load_weights('/data/tmp/blah_weights.h5')

layers = [layer.name for layer in exp.model.layers]
train_layers = None #layers[layers.index('block2_conv1'):]
exp.compile_model(train_layers=train_layers, parallel=parallel)

exp.callbacks = [
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10, cooldown=10, verbose=1),
        ModelCheckpoint(
            best_path, monitor='val_loss', verbose=1,
            save_best_only=True, save_weights_only=True, mode='auto', period=1),
        TensorBoard(
            log_dir=logdir, histogram_freq=5, batch_size=8,
            write_graph=False, write_grads=False, write_images=False),
        EarlyStopping(
            monitor='val_loss', min_delta=0.0, patience=20, verbose=1, mode='auto')
]

K.set_value(exp.get_train_model().optimizer.lr, initial_lr)
try:
    exp.train(200, val_data=val_data, initial_epoch=initial_epoch)
except KeyboardInterrupt:
    pass

In [None]:
exp.plot_test(dataset['gens']['test'])

In [None]:
from sklearn.metrics import average_precision_score
y_pred = exp.model.predict(x=test_data[0], batch_size=exp.batch_size)
ap_micro = average_precision_score(
    test_data[1][..., 1:].reshape((-1, test_data[1].shape[-1] - 1)),
    y_pred[..., 1:].reshape((-1, test_data[1].shape[-1] - 1)),
    average='micro')
print("Micro average precision of FG classes is {:.4f}".format(ap_micro))

# Inspect Trained Parameters

## FCN Upsample Weights

In [None]:
# Look at upsampling weights
layer_names = ['score2', 'score4', 'upsample']
for name in layer_names:
    layer = exp.model.get_layer(name=name)
    v = layer.get_weights()[0]
    print(v.shape)
    plt.figure()
    plt.subplot(2, 2, 1)
    plt.imshow(v[:, :, 0, 0])
    plt.subplot(2, 2, 2)
    plt.imshow(v[:, :, 0, 1])
    plt.subplot(2, 2, 3)
    plt.imshow(v[:, :, 1, 0])
    plt.subplot(2, 2, 4)
    plt.imshow(v[:, :, 1, 1])

## CRF Parameters

In [None]:
exp.get_train_model().get_layer(name='crfrnn').get_weights()