In [None]:
%matplotlib notebook
import os
import sys
from pprint import pprint
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda/lib64:' + os.environ['LD_LIBRARY_PATH']
from abyss_deep_learning.utils import config_gpu

from keras.utils import to_categorical
from skimage.io import imread
from skimage.transform import resize
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import keras.backend as K
from skimage.color import label2rgb
from keras import Model
from keras.optimizers import Nadam
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping
from keras.applications.resnet50 import preprocess_input

from abyss_deep_learning.datasets.base import DatasetTaskBase
from abyss_deep_learning.keras.tensorboard import ImprovedTensorBoard
from abyss_deep_learning.datasets.coco import CocoInterface, CocoDataset, ImageDatatype

#augmentation_gen, jaccard_index
from abyss_deep_learning.keras.utils import initialize_conv_transpose2d, lambda_gen

from crfrnn.crfrnn_model import get_crfrnn_model_def
config_gpu([1])

In [None]:
#### If you don't have the imagenet weights below it will auto download them
if not os.path.exists("~/.keras/models/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5"):
    from keras.applications.vgg16 import VGG16
    vgg = VGG16(include_top=False)
    del vgg

# Setup Data

In [None]:
import random
def _noop(*args):
    return args if len(args) > 1 else args[0]

In [None]:
def _pack_masks(masks, mask_classes, num_classes, dtype=np.uint8):
    '''Pack a list of instance masks into a categorical mask.
    Expects masks to be shape [height, width, num_instances] and mask_classes to be [num_instances].'''
    num_shapes = len(mask_classes)
    shape = masks.shape
    packed = np.zeros(shape[0:2] + (num_classes,), dtype=dtype)
    packed[..., 0] = 1
    for i in range(num_shapes):
        class_id = mask_classes[i]
        mask = masks[..., i]
        packed[..., class_id] |= mask
        packed[..., 0] &= ~mask
    return packed

import imgaug.augmenters as iaa
def augmentation_gen(gen, common_aug, image_aug, enable=True):
    '''
    Data augmentation for segmentation task.
    A common augmentation list is applied to both images and masks and should not contain colour augmentation, 
    and should ensure order=0 is used for all geometric transforms.
    An image augmentation list is then applied to only the image, this should contain no geometric augmentations.
    '''
    if not enable:
        while True:
            yield from gen
    common_seq = iaa.Sequential(common_aug)
    image_seq = iaa.Sequential(image_aug)
    for image, target in gen:
        common_seq_det = common_seq.to_deterministic()
        image_c = common_seq_det.augment_image(image)
#         image_c = image_seq.augment_image(image)
        masks_c = common_seq_det.augment_image(target)
        yield image_c, masks_c
       

In [None]:
from abyss_deep_learning.datasets.translators import AnnotationTranslator
import concurrent.futures

class SemanticSegmentationTask(CocoInterface, DatasetTaskBase):
    def __init__(self, coco, translator=None, **kwargs):
        CocoInterface.__init__(self, coco, **kwargs)
        assert isinstance(translator, (AnnotationTranslator, type(None)))
        self.translator = translator or AnnotationTranslator()
        self.num_classes = len(self.coco.cats) + 1
        self.stats = dict()
        self._targets = dict()

        self._preprocess_targets = kwargs.get('preprocess_targets', _noop)

        if kwargs.get('cached', False):
            with concurrent.futures.ProcessPoolExecutor() as executor:
                for data_id, targets in zip(
                        self.data_ids, executor.map(self.load_targets, self.data_ids)):
                    self._targets[data_id] = targets

        self._calc_class_stats()

    def load_targets(self, data_id, **kwargs):
        assert np.issubdtype(type(data_id), np.integer), "Must pass exactly one ID"
        if data_id in self._targets:
            return self._targets[data_id]
        img = self.coco.loadImgs(ids=[data_id])[0]
        anns = [self.translator.translate(ann) for ann in self.coco.loadAnns(
            self.coco.getAnnIds([data_id])) if self.translator.filter(ann)]
        if anns:
            masks = np.array([self.coco.annToMask(ann) for ann in anns]).transpose((1, 2, 0))
            class_ids = np.array([ann['category_id'] for ann in anns])
            return self._preprocess_targets(
                _pack_masks(masks, class_ids, self.num_classes, dtype=self.dtype_image))
        masks = np.zeros((img['height'], img['width'], self.num_classes), dtype=self.dtype_image)
        masks[..., 0] = 1
        return self._preprocess_targets(masks)
        

    def _calc_class_stats(self):
        from collections import Counter
        if not self.stats:
            self.stats = dict()
            class_count = dict()
            for data_id in self.data_ids:
                target = self.load_targets(data_id).argmax(-1)
                for key, val in Counter(target.ravel().tolist()).items():
                    class_count[key] = class_count.get(key, 0) + val
            
            self.stats['class_weights'] = np.array(
                [class_count.get(key, 0) for key in range(ARGS['num_classes'])], dtype=np.float64)
            self.stats['class_weights'] **= -1.0
            self.stats['class_weights'] /= self.stats['class_weights'].min()
#             a = np.array(list(class_weights.values()))
#             self.stats['trivial_accuracy'] = np.mean(a / np.max(a))

    @property
    def class_weights(self):
        '''Returns the class weights that will balance the backprop update over the class distribution.'''
        return self.stats['class_weights']

    def print_class_stats(self):
        '''Prints statistics about the class/image distribution.'''
        self._calc_class_stats()
#         print("{:s} class stats {:s}".format('=' * 8, '=' * 8))
#         print("data count per class:")
#         print(" ", self.stats['images_per_class'])
        print("class weights:")
        print(" ", self.class_weights)
#         print("trivial result accuracy:\n  {:.2f} or {:.2f}".format(
#             self.stats['trivial_accuracy'], 1 - self.stats['trivial_accuracy']))

In [None]:
import itertools
class ImageSemanticSegmentationDataset(CocoDataset, ImageDatatype, SemanticSegmentationTask):
    # TODO: 
    #   *  Class statistics readout
    #   *  Support for computing class weights given current dataset config
    #   *  Support for forcing class balance by selecting IDs evenly
    #   *  Generator data order optimization
    #   *  Support for visualising data sample or prediction with same format
    def __init__(self, json_path, **kwargs):
        CocoDataset.__init__(self, json_path, **kwargs)
        ImageDatatype.__init__(self, self.coco, **kwargs)
        SemanticSegmentationTask.__init__(self, self.coco, **kwargs)
        
    def sample(self, image_id=None, **kwargs):
        if not image_id:
            image_id = random.choice(self.data_ids)
        return (self.load_data(image_id, **kwargs), self.load_targets(image_id, **kwargs))
            
    def generator(self, data_ids=None, shuffle_ids=False, endless=False, **kwargs):
        if not data_ids:
            data_ids = list(self.data_ids)
        if shuffle_ids:
            random.shuffle(data_ids)
        iterator = itertools.cycle if endless else iter
        for data_id in iterator(data_ids):
            yield self.load_data(data_id, **kwargs), self.load_targets(data_id, **kwargs)

# Begin notebook

In [None]:
from skimage.transform import resize
from keras.applications.resnet50 import preprocess_input

def preprocess_data(image):
    '''Transform the image before (possibly caching) and input to the network.'''
    image = resize(image, ARGS['image_dims'], preserve_range=True, mode='constant')
    return preprocess_input(image.astype(ARGS['nn_dtype']), mode='tf')

def preprocess_targets(image):
    '''Transform the mask before (possibly caching) and input to the network.'''
    image = resize(image, ARGS['image_dims'][0:2], preserve_range=True, mode='constant')
    return image.astype(ARGS['nn_dtype'])

def postprocess_data(image):
    '''Inverse transform of preprocess_data, used when trying to visualize images out of the dataset.'''
    return ((image + 1) * 127.5).astype(np.uint8)


from abyss_deep_learning.datasets.translators import AnnotationTranslator
class AnnotationMapper(AnnotationTranslator):
        '''Transform COCO JSON annotations in any way you want. This one maps source to dest classes.'''
        def __init__(self, class_map=None):
            self.class_map = class_map
#             self.num_classes = len(class_map)

        def filter(self, annotation):
            return 'segmentation' in annotation and annotation['annotation_type'] in ['poly']

        def translate(self, annotation):
            output = dict(annotation)
            if self.class_map:
                output['category_id'] = self.class_map[annotation['category_id']]
            return annotation
        
def setup_args():
    
    from bidict import bidict
    from imgaug import augmenters as iaa
    from imgaug.parameters import Normal
    
    

    def pipeline(gen, aug_config=None):
        '''The pipeline to run the dataset generator through.'''
#         from abyss_deep_learning.keras.classification import augmentation_gen
        if not aug_config:
            aug_config = (None, None)
        return (
            augmentation_gen(gen, *aug_config, enable=(aug_config[0] is not None))
        )

    
        
    class_map = bidict({ # or give a bidict mapping source->dest category_id
        0: 0,
        1: 1,
        2: 2,
        3: 3,
    })
    
    augmentation_common = iaa.Sequential([ 
        iaa.Fliplr(0.5),
        iaa.Flipud(0.5),
        iaa.Affine(
            scale=(0.8, 1.2),
            translate_percent=(-0.2, 0.2), 
            rotate=(-22.5, 22.5),
            mode='constant', cval=0, order=0
        ),
        
    ])
    augmentation_image = iaa.Sequential([ # Colour aug
        iaa.ChangeColorspace(from_colorspace="RGB", to_colorspace="HSV"),
        iaa.WithChannels(0, iaa.Add(Normal(0, 256 / 6))),
        iaa.WithChannels(1, iaa.Add(Normal(0, 256 / 6))),
        iaa.WithChannels(2, iaa.Add(Normal(0, 256 / 6))),
        iaa.ChangeColorspace(from_colorspace="HSV", to_colorspace="RGB")
    ])

    args = {
        'annotation_translator': AnnotationMapper(class_map),
        'augmentation': (augmentation_common, augmentation_image),    # Training augmentation
        'class_map': class_map,             # class_map
        'caption_type': ['single', 'multi'][1], # Caption type can be either "single" or "multi".
                                                # This sets up various other parameters in the system.
        'data': {
            'base_dir': "/data/abyss/anadarko/label-sets",
            'name': "first",
            'sets': ('train', 'val', 'test')
        },
        'image_dims': (500, 500, 3),    # What to resize images to before CNN
        'nn_dtype': np.float32,         # Pretrained networks are in float32
        'num_classes': len(class_map),
        'use_balanced_set': False,      # Force the use of the largest class-balanced dataset
        'use_cached': True,            # Cache the dataset in memory
        'use_class_weights': True,      # Use class population to weight in the training loss
        'use_parallel': False,          # Use multiple GPUs
        'preprocess_data': preprocess_data,
        'preprocess_targets': preprocess_targets,
        'postprocess_data': postprocess_data,
        'pipeline': pipeline
    }
    
    return args
ARGS = setup_args()

In [None]:
def setup_datasets(args):
    from abyss_deep_learning.datasets.coco import ImageClassificationDataset
    
    dataset = dict()
    for set_name in args['data']['sets']:
        path = os.path.join(args['data']['base_dir'], "{:s}/{:s}.json".format(args['data']['name'], set_name))
        dataset[set_name] = ImageSemanticSegmentationDataset(
            path,
            translator=args['annotation_translator'],
            cached=args['use_cached'],
            preprocess_data=args['preprocess_data'],
            preprocess_targets=args['preprocess_targets'])
        print("\n", set_name)
        dataset[set_name].print_class_stats()


    print("\nNumber of classes:", args['num_classes'])
    print("captions:")
    print(args['class_map'])
    return dataset
DATASET = setup_datasets(ARGS)

In [None]:
def view_dataset_samples(num_rows=2):
    plt.figure()
    print("Column-wise left to right, bottom row:")
    for i, (name, ds) in enumerate(DATASET.items()):
        print(name, end=' ')
        for j, (image, label) in enumerate(ARGS['pipeline'](ds.generator(shuffle_ids=True))):
            plt.subplot(num_rows, 3, 3 * j + i + 1)
            plt.imshow(
                label2rgb(label.argmax(-1), ARGS['postprocess_data'](image), bg_label=0))
#             plt.title(', '.join([str(ARGS['class_map'].inv[int(cap_id)]) for cap_id in np.argwhere(label)]))
            plt.axis('off')
            if j + 1 == num_rows:
                break
        print('shape: {}, label: {}, min: {:.1f}, mean: {:.1f}, max: {:.1f}'.format(
            image.shape, label.shape, image.min(), image.mean(), image.max()))

view_dataset_samples(num_rows=2)

# Experiment Setup

In [None]:
class Experiment(object):
    def __init__(self):
        self.batch_size = 1
        self.model_parallel = None
        self.has_crf = False
    
    def get_train_model(self):
        return self.model_parallel or self.model
        
    def plot_test(self, output_fn=np.argmax):
        gen = DATASET['test'].generator(shuffle_ids=True)
        for i, (rgb, target) in enumerate(gen):
            print("rgb.shape", rgb.shape)
            print("rgb min/max", np.min(rgb), np.max(rgb))
            print("target.shape", target.shape)
            print("target min/max", np.min(target), np.max(target))

            rgb8 = ((rgb + 1) / 2 * 255).astype(np.uint8)
            Y_pred = self.model.predict(rgb[np.newaxis, ...])
            plt.figure()
            plt.subplot(2, 2, 1)
            plt.imshow(rgb8)
            plt.subplot(2, 2, 2)
            plt.imshow(label2rgb(target.argmax(-1), rgb8, bg_label=0))
            plt.subplot(2, 2, 3)
            plt.imshow(label2rgb(Y_pred[0].argmax(-1), rgb8, bg_label=0))
#             plt.title("CRF IoU={:.2f}".format(
#                 jaccard_index(output_fn(target, axis=-1), np.argmax(Y_pred[0], axis=-1))))
            plt.tight_layout()
            break

    def init_crf(self):
        crf_params = ['crfrnn/spatial_ker_weights:0',
                  'crfrnn/bilateral_ker_weights:0',
                  'crfrnn/compatibility_matrix:0']
        self.model.get_layer(name='crfrnn').set_weights([
            np.eye(2), np.eye(2), 1 - np.eye(2)
        ])
        
    def create_model(self, num_classes, image_dims, upsample='new', num_iterations=0):
        '''
        crf can be one of 'new', 'load', 'train' or 'none'
        upsample can be one of 'new', 'load' or 'train
        '''
        self.has_crf = num_iterations > 0
        self.model = None
        self.model_parallel = None
        K.clear_session()
        print("Making model with {:d} classes and {} input shape".format(num_classes, str(image_dims)))
        self.model = get_crfrnn_model_def(
            num_classes=num_classes, input_shape=image_dims,
            num_iterations=num_iterations, with_crf=self.has_crf)
        if upsample in ['bilinear', 'train']:
            initialize_conv_transpose2d(
                self.model,
                ['score2', 'score4', 'upsample'],
                trainable=True)
        if self.has_crf:
            self.init_crf()
            
    def load_model(self, model_path):
        if model_path: 
            self.model.load_weights(model_path, by_name=True)
        if self.has_crf:
            self.init_crf()

    def compile_model(self, train_layers=None, parallel=False):
        from abyss_deep_learning.keras.metrics import mpca_factory, auc_factory
        weights = np.ones((1, 1, 1, 1))
        
        if train_layers:
            for layer in self.model.layers:
                layer.trainable = layer.name in train_layers
        if parallel:
            from keras.utils import multi_gpu_model
            self.model_parallel = multi_gpu_model(self.model, gpus=2)#, cpu_merge=True, cpu_relocation=False)
        self.get_train_model().compile(
            optimizer='nadam',
            loss='categorical_crossentropy',
            metrics=[
                'accuracy',
#                  auc_factory("PR", weights),
#                  auc_factory("ROC", weights)
            ])

    def train(self, epochs, initial_epoch=0, val_data=None):
        from abyss_deep_learning.keras.utils import batching_gen
        
        steps_per_epoch = len(DATASET['train'].data_ids) // self.batch_size
        steps_per_epoch_val = VAL_DATA[0].shape[0] // self.batch_size 
        print("Steps per epoch:", steps_per_epoch)
        print("Steps per steps_per_epoch_val:", steps_per_epoch_val)

        train_gen = ARGS['pipeline'](DATASET['train'].generator(
            shuffle_ids=True, endless=True), aug_config=ARGS['augmentation'])
        common = {
            "class_weight": DATASET['train'].class_weights,
            "callbacks": self.callbacks,
            "epochs": epochs,
            "verbose": 1,
            "initial_epoch": initial_epoch,
        }

        self.history = self.get_train_model().fit_generator(
            batching_gen(train_gen, batch_size=self.batch_size),
#             validation_data=batching_gen(
#                 data_source('val'), batch_size=exp.batch_size),
            validation_data=VAL_DATA,
            steps_per_epoch=steps_per_epoch,
            validation_steps=steps_per_epoch_val,
            workers=10,
            **common)
        return self.history

# Training

## Load val data

In [None]:
def dump_dataset(dataset, num_data, aug_config=None):
    data = np.empty((num_data,) + tuple(ARGS['image_dims']), dtype=ARGS['nn_dtype'])
    targets = np.empty((num_data,) + tuple(ARGS['image_dims'][0:2] + (ARGS['num_classes'],)), dtype=ARGS['nn_dtype'])
    for i, (datum, target) in enumerate(ARGS['pipeline'](dataset.generator(), aug_config)):
        data[i], targets[i] = datum, target
        if i + 1 == num_data:
            break
    return data, targets

VAL_DATA = dump_dataset(DATASET['val'], num_data=len(DATASET['val'].data_ids), aug_config=None)

## First train only Conv2D

In [None]:
imagenet_weights = "/home/docker/.keras/models/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5"
logdir = os.path.join("/data/log/fcn-crfrnn/anadarko/first/{:04d}".format(np.random.randint(0, 9999)))
!mkdir -p "$logdir/models"
best_path = os.path.join(logdir, "models/best.{epoch:03d}-{val_loss:.4f}.h5")
print(logdir)

In [None]:
g = K.tf.get_default_graph()
g.

In [None]:
from abyss_deep_learning.keras.tensorboard import kernel_sparsity, avg_update_ratio

exp = None
K.clear_session()
exp = Experiment()
exp.batch_size = 2
exp.create_model(
    ARGS['num_classes'], ARGS['image_dims'],
    upsample='train',
    num_iterations=0)
exp.load_model(imagenet_weights)

layers = [layer.name for layer in exp.model.layers]
train_layers = layers[layers.index('fc6'):]
predictions_kernel = exp.model.get_layer(name='upsample').trainable_weights[0] # Used in a scalar callback

exp.compile_model(train_layers=train_layers, parallel=ARGS['use_parallel'])
exp.callbacks = [
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, cooldown=5, verbose=1),
        ModelCheckpoint(
            best_path, monitor='val_loss', verbose=1,
            save_best_only=True, save_weights_only=True, mode='auto', period=1),
        ImprovedTensorBoard(
            log_dir=logdir,
            histogram_freq=5, batch_size=exp.batch_size,
            scalars={
                'learning_rate': exp.get_train_model().optimizer.lr,
                'feature_sparsity': kernel_sparsity(exp.get_train_model()),
                'prediction_UW_ratio': avg_update_ratio(exp.get_train_model(), predictions_kernel)
            },
            groups={
                'performance': {
                    'loss': ['loss', 'val_loss'],
                    'accuracy': [r'.*accuracy.*'],
                    'Mean Per-Class Average Accuracy': [r'.*mpca.*'],
                    'Mean Avg Precision': [r'.*PR.*'],
                    'ROC AUC': [r'.*ROC.*']
                }
            },
            pr_curve=False,
            num_classes=VAL_DATA[1].shape[1],
#             write_graph=True,
#             write_grads=True,
#             write_images=False,
#             embeddings_freq=10,
#             embeddings_layer_names=['predictions', 'features'],
#             embeddings_metadata=(logdir + "/data_labels.tsv"),
#             embeddings_data=VAL_DATA[0],
        ),
        EarlyStopping(
            monitor='val_loss', min_delta=0.0, patience=10, verbose=1, mode='auto')
]
# Test to see if LR causes loss explosion
K.set_value(exp.get_train_model().optimizer.lr, 5e-4)
try:
    exp.train(200, val_data=VAL_DATA)
except KeyboardInterrupt:
    pass
except:
    raise
# Save the weights and epoch for next training step
initial_epoch = exp.callbacks[3].stopped_epoch + 1 if exp.callbacks[3].stopped_epoch else 16
initial_lr = K.eval(exp.get_train_model().optimizer.lr)
exp.model.save_weights('/data/tmp/blah_weights.h5')
# raise RuntimeError("Stop Run All")

## Save weights, reload and train all layers

In [None]:
exp = None
K.clear_session()
exp = Experiment()
exp.create_model(
    ARGS['num_classes'], ARGS['image_dims'],
    upsample='train',
    num_iterations=5)
exp.batch_size = 1
exp.load_model('/data/tmp/blah_weights.h5')

layers = [layer.name for layer in exp.model.layers]
train_layers = layers[layers.index('fc6'):]
predictions_kernel = exp.model.get_layer(name='upsample').trainable_weights[0] # Used in a scalar callback

exp.compile_model(train_layers=train_layers, parallel=ARGS['use_parallel'])
exp.callbacks = [
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, cooldown=5, verbose=1),
        ModelCheckpoint(
            best_path, monitor='val_loss', verbose=1,
            save_best_only=True, save_weights_only=True, mode='auto', period=1),
        ImprovedTensorBoard(
            log_dir=logdir,
            histogram_freq=5, batch_size=exp.batch_size,
            scalars={
                'learning_rate': exp.get_train_model().optimizer.lr,
                'feature_sparsity': kernel_sparsity(exp.get_train_model()),
                'prediction_UW_ratio': avg_update_ratio(exp.get_train_model(), predictions_kernel)
            },
            groups={
                'performance': {
                    'loss': ['loss', 'val_loss'],
                    'accuracy': ['acc', 'val_acc'],
                    'Mean Per-Class Average Accuracy': [r'.*mpca.*'],
                    'Mean Avg Precision': [r'.*PR.*'],
                    'ROC AUC': [r'.*ROC.*']
                }
            },
            pr_curve=False,
            num_classes=VAL_DATA[1].shape[1],
#             write_graph=True,
#             write_grads=True,
#             write_images=False,
#             embeddings_freq=10,
#             embeddings_layer_names=['predictions', 'features'],
#             embeddings_metadata=(logdir + "/data_labels.tsv"),
#             embeddings_data=VAL_DATA[0],
        ),
        EarlyStopping(
            monitor='val_loss', min_delta=0.0, patience=10, verbose=1, mode='auto')
]
# Test to see if LR causes loss explosion
exp.model.layers[-1].trainable = False
K.set_value(exp.get_train_model().optimizer.lr, 5e-4)
try:
    exp.train(200, val_data=VAL_DATA, initial_epoch=2)
except KeyboardInterrupt:
    pass
except:
    raise
# Save the weights and epoch for next training step
initial_epoch = exp.callbacks[3].stopped_epoch + 1 if exp.callbacks[3].stopped_epoch else 16
initial_lr = K.eval(exp.get_train_model().optimizer.lr)
exp.model.save_weights('/data/tmp/blah_weights2.h5')
# raise RuntimeError("Stop Run All")

In [None]:
a = exp.model.layers[-1]
a, a.get_output_shape_at(0)

In [None]:
img, label = DATASET['test'].sample()
print(img.shape, label.shape)
y_pred = exp.model.predict(img[np.newaxis, ...])
print(y_pred.shape)

In [None]:
def plot_test(self, output_fn=np.argmax):
        gen = DATASET['test'].generator(shuffle_ids=True)
        for i, (rgb, target) in enumerate(gen):
            print("rgb.shape", rgb.shape)
            print("rgb min/max", np.min(rgb), np.max(rgb))
            print("target.shape", target.shape)
            print("target min/max", np.min(target), np.max(target))

            rgb8 = ((rgb + 1) / 2 * 255).astype(np.uint8)
            Y_pred = self.model.predict(rgb[np.newaxis, ...])
            plt.figure()
            plt.subplot(2, 2, 1)
            plt.title("RGB")
            plt.imshow(rgb8)
            plt.subplot(2, 2, 2)
            plt.title("Ground Truth")
            plt.imshow(label2rgb(target.argmax(-1), rgb8, bg_label=0))
            plt.subplot(2, 2, (3, 4))
            plt.imshow(label2rgb(np.argmax(Y_pred[0], axis=-1), rgb8, bg_label=0))
            plt.title("Predicted")
#             plt.title("CRF IoU={:.2f}".format(
#                 jaccard_index(output_fn(target, axis=-1), np.argmax(Y_pred[0], axis=-1))))
            plt.tight_layout()
            break
plot_test(exp)
# exp.plot_test()

In [None]:
from sklearn.metrics import average_precision_score
y_pred = exp.model.predict(x=test_data[0], batch_size=exp.batch_size)
ap_micro = average_precision_score(
    test_data[1][..., 1:].reshape((-1, test_data[1].shape[-1] - 1)),
    y_pred[..., 1:].reshape((-1, test_data[1].shape[-1] - 1)),
    average='micro')
print("Micro average precision of FG classes is {:.4f}".format(ap_micro))

# Inspect Trained Parameters

## FCN Upsample Weights

In [None]:
# Look at upsampling weights
layer_names = ['score2', 'score4', 'upsample']
for name in layer_names:
    layer = exp.model.get_layer(name=name)
    v = layer.get_weights()[0]
    print(v.shape)
    plt.figure()
    plt.subplot(2, 2, 1)
    plt.imshow(v[:, :, 0, 0])
    plt.subplot(2, 2, 2)
    plt.imshow(v[:, :, 0, 1])
    plt.subplot(2, 2, 3)
    plt.imshow(v[:, :, 1, 0])
    plt.subplot(2, 2, 4)
    plt.imshow(v[:, :, 1, 1])

## CRF Parameters

In [None]:
exp.get_train_model().get_layer(name='crfrnn').get_weights()