In [None]:
%matplotlib notebook
import os
import sys
from pprint import pprint

import keras.backend as K
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from abyss_deep_learning.base.datasets import DatasetTaskBase
from abyss_deep_learning.datasets.coco import (CocoDataset, CocoInterface,
                                               ImageDatatype)
from abyss_deep_learning.keras.tensorboard import ImprovedTensorBoard
#augmentation_gen, jaccard_index
from abyss_deep_learning.keras.utils import (initialize_conv_transpose2d,
                                             lambda_gen, tiling_gen, skip_empty_gen)
from abyss_deep_learning.utils import config_gpu, detile
from crfrnn.crfrnn_model import get_crfrnn_model_def
from keras import Model
from keras.applications.resnet50 import preprocess_input
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
from keras.optimizers import Nadam
from keras.utils import to_categorical
from skimage.color import label2rgb
from skimage.io import imread, imsave
from skimage.transform import resize

config_gpu([0])

In [None]:
#### If you don't have the imagenet weights below it will auto download them
if not os.path.exists("~/.keras/models/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5"):
    from keras.applications.vgg16 import VGG16
    vgg = VGG16(include_top=False)
    del vgg

# Setup Data

In [None]:
import random
def _noop(*args):
    return args if len(args) > 1 else args[0]

# Begin notebook

In [None]:
from imgaug.parameters import Deterministic
import imgaug.augmenters as iaa
def augmentation_gen(gen, common_aug, image_aug, enable=True):
    '''
    Data augmentation for segmentation task.
    A common augmentation list is applied to both images and masks and should not contain colour augmentation, 
    and should ensure order=0 is used for all geometric transforms.
    An image augmentation list is then applied to only the image, this should contain no geometric augmentations.
    '''
    if enable:
        common_seq = iaa.Sequential(common_aug)
        image_seq = iaa.Sequential(image_aug)
        for image, target in gen:
            common_seq_det = common_seq.to_deterministic()
            image_c = common_seq_det.augment_image(image)
            masks_c = common_seq_det.augment_image(target)
            yield image_c, masks_c
    else:
        yield from gen

In [None]:
from skimage.transform import resize
from keras.applications.resnet50 import preprocess_input

def preprocess_data(image):
    '''Transform the image before (possibly caching) and input to the network.'''
#     image = resize(image, ARGS['image_dims'], preserve_range=True, mode='constant')
    return preprocess_input(image.astype(ARGS['nn_dtype']), mode='tf')
#     return image.astype(ARGS['nn_dtype'])

def preprocess_targets(image):
    '''Transform the mask before (possibly caching) and input to the network.'''
#     image = resize(image, ARGS['image_dims'][0:2], preserve_range=True, mode='constant')
    return image.astype(ARGS['nn_dtype'])

def postprocess_data(image):
    '''Inverse transform of preprocess_data, used when trying to visualize images out of the dataset.'''
    return ((image + 1) * 127.5).astype(np.uint8)


from abyss_deep_learning.datasets.translators import AnnotationTranslator
class AnnotationMapper(AnnotationTranslator):
        '''Transform COCO JSON annotations in any way you want. This one maps source to dest classes.'''
        def __init__(self, class_map=None):
            self.class_map = class_map
#             self.num_classes = len(class_map)

#         def filter(self, annotation):
#             return (
#                 'segmentation' in annotation 
#                 and annotation['annotation_type'] in ['poly']
#                 and annotation['area'] > 30 ** 2)
        
        def filter(self, annotation):
            return (
                'segmentation' in annotation 
                and annotation['area'] > 30 ** 2)

        def translate(self, annotation):
            output = dict(annotation)
            if self.class_map:
                output['category_id'] = self.class_map[annotation['category_id']]
            return annotation

def pipeline(gen, aug_config=None):
    '''The pipeline to run the dataset generator through.'''
#         from abyss_deep_learning.keras.classification import augmentation_gen
    if not aug_config:
        aug_config = (None, None)
    return \
        skip_empty_gen(
            tiling_gen(
                augmentation_gen(gen, *aug_config, enable=(aug_config[0] is not None))
            , window_size=(500, 500))
        , min_area=1000)
        
def setup_args():
    from bidict import bidict
    from imgaug import augmenters as iaa
    from imgaug.parameters import Normal
    
    class_map = bidict({ # or give a bidict mapping source->dest category_id
        0: 0,
        1: 1,
#         2: 2,
#         3: 3,
#         4: 4,
    })
    
    augmentation_common = \
    iaa.Sequential([ 
        iaa.Fliplr(0.5),
        iaa.Flipud(0.5),
        iaa.Affine(
            scale=(0.8, 1.2),
            translate_percent=(-0.2, 0.2), 
            rotate=(-22.5, 22.5),
            mode='constant', cval=0, order=0
        ),
        
    ])
    augmentation_image = iaa.Sequential([ # Colour aug
        iaa.ChangeColorspace(from_colorspace="RGB", to_colorspace="HSV"),
        iaa.WithChannels(0, iaa.Add(Normal(0, 256 / 6))),
        iaa.WithChannels(1, iaa.Add(Normal(0, 256 / 6))),
        iaa.WithChannels(2, iaa.Add(Normal(0, 256 / 6))),
        iaa.ChangeColorspace(from_colorspace="HSV", to_colorspace="RGB")
    ])
#     augmentation_image = None

    args = {
        'annotation_translator': AnnotationMapper(class_map),
        'augmentation': (augmentation_common, augmentation_image),    # Training augmentation
        'class_map': class_map,             # class_map
        'data': {
            'base_dir': "/data/acfr/ladybird/labelbox/hashed",
            'name': "first",
            'sets': ('train', 'val', 'test')
        },
        'image_dims': (500, 500, 3),    # What to resize images to before CNN
        'nn_dtype': np.float32,         # Pretrained networks are in float32
        'num_classes': len(class_map),
        'use_balanced_set': False,      # Force the use of the largest class-balanced dataset
        'use_cached': False,            # Cache the dataset in memory
        'use_class_weights': True,      # Use class population to weight in the training loss
        'use_parallel': False,          # Use multiple GPUs
        'preprocess_data': preprocess_data,
        'preprocess_targets': preprocess_targets,
        'postprocess_data': postprocess_data,
        'pipeline': pipeline
    }
    
    return args
ARGS = setup_args()

In [None]:
def setup_datasets(args):
    from abyss_deep_learning.datasets.coco import ImageSemanticSegmentationDataset
    
    dataset = dict()
    for set_name in args['data']['sets']:
        path = os.path.join(args['data']['base_dir'], "{:s}/{:s}.json".format(args['data']['name'], set_name))
        dataset[set_name] = ImageSemanticSegmentationDataset(
            path,
            translator=args['annotation_translator'],
            cached=args['use_cached'],
            preprocess_data=args['preprocess_data'],
            preprocess_targets=args['preprocess_targets'])
        print("\n", set_name)
        dataset[set_name].print_class_stats()


    print("\nNumber of classes:", args['num_classes'])
    print("captions:")
    print(args['class_map'])
    return dataset
DATASET = setup_datasets(ARGS)

In [None]:
from abyss_deep_learning.visualize import draw_semantic_seg
from collections import Counter
from skimage.morphology import label

def view_dataset_samples(num_rows=2):
    plt.figure()
    print("Column-wise left to right, bottom row:")
    for i, (name, ds) in enumerate(DATASET.items()):
        print(name, end=' ')
        for j, (image, label) in enumerate(
                ARGS['pipeline'](ds.generator(shuffle_ids=True))):
            plt.subplot(num_rows, 3, 3 * j + i + 1)
            plt.imshow(draw_semantic_seg(label, ARGS['postprocess_data'](image)))
            plt.axis('off')
            if j + 1 == num_rows:
                break
        print('Image: shape: {}, min: {:.1f}, mean: {:.1f}, max: {:.1f}'.format(
            image.shape, image.min(), image.mean(), image.max()))
        print('Label: shape: {}, min: {:.1f}, mean: {:.1f}, max: {:.1f}'.format(
            label.shape, label.min(), label.mean(), label.max()))

view_dataset_samples(num_rows=3)

# Experiment Setup

In [None]:
class Experiment(object):
    def __init__(self):
        self.batch_size = 1
        self.model_parallel = None
        self.has_crf = False
    
    def get_train_model(self):
        return self.model_parallel or self.model
        
    def plot_test(self, output_fn=np.argmax):
        gen = DATASET['test'].generator(shuffle_ids=True)
        for i, (rgb, target) in enumerate(gen):
            print("rgb.shape", rgb.shape)
            print("rgb min/max", np.min(rgb), np.max(rgb))
            print("target.shape", target.shape)
            print("target min/max", np.min(target), np.max(target))

            rgb8 = ((rgb + 1) / 2 * 255).astype(np.uint8)
            Y_pred = self.model.predict(rgb[np.newaxis, ...])
            plt.figure()
            plt.subplot(2, 2, 1)
            plt.imshow(rgb8)
            plt.subplot(2, 2, 2)
            plt.imshow(label2rgb(target.argmax(-1), rgb8, bg_label=0))
            plt.subplot(2, 2, (3, 4))
            plt.imshow(label2rgb(Y_pred[0].argmax(-1), rgb8, bg_label=0))
#             plt.title("CRF IoU={:.2f}".format(
#                 jaccard_index(output_fn(target, axis=-1), np.argmax(Y_pred[0], axis=-1))))
            plt.tight_layout()
            break

    def init_crf(self):
        crf_params = ['crfrnn/spatial_ker_weights:0',
                  'crfrnn/bilateral_ker_weights:0',
                  'crfrnn/compatibility_matrix:0']
        n = ARGS['num_classes']
        self.model.get_layer(name='crfrnn').set_weights([
            np.eye(n), np.eye(n), 1 - np.eye(n)
        ])
        
    def create_model(self, num_classes, image_dims, upsample='new', num_iterations=0):
        '''
        crf can be one of 'new', 'load', 'train' or 'none'
        upsample can be one of 'new', 'load' or 'train
        '''
        self.has_crf = num_iterations > 0
        self.model = None
        self.model_parallel = None
        K.clear_session()
        print("Making model with {:d} classes and {} input shape".format(num_classes, str(image_dims)))
        self.model = get_crfrnn_model_def(
            num_classes=num_classes, input_shape=image_dims,
            num_iterations=num_iterations, with_crf=self.has_crf)
        self.model.summary()
        if upsample in ['bilinear', 'train']:
            print("initializing conv tranpose kernels")
            initialize_conv_transpose2d(
                self.model,
                ['score2', 'score4', 'upsample'],
                trainable=(upsample == 'train'))
        if self.has_crf:
            print("initializing CRF")
            self.init_crf()
            
    def load_model(self, model_path):
        if model_path: 
            self.model.load_weights(model_path, by_name=True)
        if self.has_crf:
            self.init_crf()

    def compile_model(self, train_layers=None, parallel=False):
        from abyss_deep_learning.keras.metrics import mpca_factory, auc_factory
        weights = np.ones((1, 1, 1, 1))
        
        if train_layers:
            for layer in self.model.layers:
                layer.trainable = layer.name in train_layers
        if parallel:
            from keras.utils import multi_gpu_model
            self.model_parallel = multi_gpu_model(self.model, gpus=2)#, cpu_merge=True, cpu_relocation=False)
        self.get_train_model().compile(
            optimizer='nadam',
            loss='categorical_crossentropy',
            metrics=[
                'accuracy',
#                  auc_factory("PR", weights),
#                  auc_factory("ROC", weights)
            ])

    def train(self, epochs, initial_epoch=0, val_data=None):
        from abyss_deep_learning.keras.utils import batching_gen
        steps_per_epoch = len(DATASET['train'].data_ids) // self.batch_size
        steps_per_epoch_val = val_data[0].shape[0] // self.batch_size // 10
        print("Steps per epoch:", steps_per_epoch)
        print("Steps per steps_per_epoch_val:", steps_per_epoch_val)

        train_gen = ARGS['pipeline'](
            DATASET['train'].generator(shuffle_ids=True, endless=True),
            aug_config=ARGS['augmentation'])
        common = {
            "class_weight": DATASET['train'].class_weights if ARGS['use_class_weights'] else None,
            "callbacks": self.callbacks,
            "epochs": epochs,
            "verbose": 1,
            "initial_epoch": initial_epoch,
        }

        self.history = self.get_train_model().fit_generator(
            batching_gen(train_gen, batch_size=self.batch_size),
            validation_data=val_data,
            steps_per_epoch=steps_per_epoch,
            validation_steps=steps_per_epoch_val,
            workers=4,
            use_multiprocessing=True,
            **common)
        return self.history

# Training

## Load val data

In [None]:
def dump_dataset(dataset, max_data=None):
    inputs, targets = [], []
    for i, (datum, target) in enumerate(
            ARGS['pipeline'](dataset.generator(endless=False), None)):
        if max_data is not None and i >= max_data:
            break
        inputs.append(datum[None, ...])
        targets.append(target[None, ...])
    return np.concatenate(inputs), np.concatenate(targets)
VAL_DATA = dump_dataset(DATASET['val'], max_data=None)

## First train only Conv2D

In [None]:
# [
#     [
#         ("scratch", 1.4e-2): (9.00, 9.00),
#         ("scratch", 1.6e-3): (0.55, 0.57),
#         ("scratch", 6.4e-4): (0.56, 0.54),
#         ("scratch", 1.5e-6): (1.11, 1.38),
#     ],
#     [
#         ("scratch prelu", 1.0e-3): (7.67, 9.37),
#         ("scratch prelu", 1.6e-3): (5.50, 5.9),
#         ("scratch prelu", 3.5e-4): (0.59, 0.57),
#         ("scratch prelu", 1.6e-3): (0.59, 0.55),
#         ("scratch prelu", 3.1e-4): (0.61, 0.59),
        
#     ]
# ]

In [None]:
def set_prelu_value(value):
    for layer in exp.model.layers:
        if 'prelu' in layer.name:
            weights = layer.get_weights()
            layer.set_weights([value * np.ones_like(weights[0])])

In [None]:
imagenet_weights = None#"/home/docker/.keras/models/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5"
logdir = os.path.join("/data/log/fcn-crfrnn/labelbox-seg/tile-scratch/crf_5-ups_train-aug-{:04d}".format(np.random.randint(0, 9999)))
!mkdir -p "$logdir/models"
best_path = os.path.join(logdir, "models/best.{epoch:03d}-{val_loss:.4f}.h5")
print(logdir)

In [None]:
exp = None
K.clear_session()

In [None]:
from abyss_deep_learning.keras.tensorboard import kernel_sparsity, avg_update_ratio

exp = None
K.clear_session()
exp = Experiment()
exp.batch_size = 1
exp.create_model(
    ARGS['num_classes'], ARGS['image_dims'],
    upsample='train', # bilnear breaks the network - 
    num_iterations=5)
if imagenet_weights:
    exp.load_model(imagenet_weights)
    set_prelu_value(1e-5) ############ Note if scratch training don't do this

exclude = []#['score2', 'score4', 'upsample', 'crfrnn']
layers = [layer.name for layer in exp.model.layers]
# train_layers = layers[layers.index('fc6'):]
# train_layers = exclude
train_layers = list(set(layers) - set(exclude))#[layer for layer in layers if '_prelu' in layer]
print("Training layers:")
print(train_layers)
predictions_kernel = exp.model.get_layer(name='fc6').trainable_weights[0] # Used in a scalar callback

exp.compile_model(train_layers=train_layers, parallel=ARGS['use_parallel'])
exp.callbacks = [
        ModelCheckpoint(
            best_path, monitor='val_loss', mode='min', verbose=1,
            save_best_only=True, save_weights_only=True, period=1),
        TensorBoard(log_dir=logdir, write_grads=False, write_graph=False, write_images=False),
#         ImprovedTensorBoard(
#             log_dir=logdir,
#             scalars={
#                 'learning_rate': exp.get_train_model().optimizer.lr,
#                 'feature_sparsity': kernel_sparsity(exp.get_train_model()),
#                 'prediction_UW_ratio': avg_update_ratio(exp.get_train_model(), predictions_kernel)
#             },
# #             groups={
# #                 'performance': {
# #                     'loss': ['loss', 'val_loss'],
# #                     'accuracy': [r'.*acc.*'],
# # #                     'Mean Per-Class Average Accuracy': [r'.*mpca.*'],
# # #                     'Mean Avg Precision': [r'.*PR.*'],
# # #                     'ROC AUC': [r'.*ROC.*']
# #                 }
# #             },
#             pr_curve=False,
#             num_classes=VAL_DATA[1].shape[1],
#             histogram_freq=None,
#             batch_size=exp.batch_size,
#             write_grads=False,
#         ),
        ReduceLROnPlateau(
            monitor='val_loss', mode='min', factor=0.2, patience=5, cooldown=5, verbose=1),
        EarlyStopping(
            monitor='val_loss', mode='min',
            min_delta=0.0, patience=30, verbose=1, restore_best_weights=True)
]


In [None]:
# exp.model.load_weights("/data/log/fcn-crfrnn/oceaneering/tile-imagenet/crf_5-ups_train-aug-PR-7417/best.h5")

## LR Search

In [None]:
from abyss_deep_learning.keras.utils import LRSearch, batching_gen
train_gen = ARGS['pipeline'](DATASET['train'].generator(
            shuffle_ids=True, endless=True), aug_config=ARGS['augmentation'])
search = LRSearch(
    exp.model,
    x=batching_gen(train_gen, batch_size=1), batch_size=1)


In [None]:
def plot(self):
    x, y = list(self.results.keys()), list(self.results.values())
    plt.figure()
    plt.semilogx(x, y, '.')
    
search.fit(n_lrs=10, n_epochs=4, lr_power_range=(-7, -2), steps_per_epoch=10)
plot(search)

In [None]:
del search
# exp.model.save_weights(os.path.join(logdir, "best.h5"))


## Training

In [None]:
print("training now...")
exp.model.save_weights('/data/tmp/blah_weights.h5')
try:
    K.set_value(exp.get_train_model().optimizer.lr, 2e-5)
    exp.train(400,
              val_data=VAL_DATA,
#               batching_gen(ARGS['pipeline'](
#             DATASET['val'].generator(shuffle_ids=True, endless=True),
#             aug_config=None), batch_size=exp.batch_size),
              initial_epoch=15)
except KeyboardInterrupt:
    pass
except:
    raise
# Save the weights and epoch for next training step
# initial_epoch = exp.callbacks[3].stopped_epoch + 1 if exp.callbacks[3].stopped_epoch else 16
# initial_lr = K.eval(exp.get_train_model().optimizer.lr)
exp.model.save_weights('/data/tmp/blah_weights2.h5')
# raise RuntimeError("Stop Run All")

In [None]:
try:
    initial_epoch = 0
    for lr in [1e-5, 1e-6, 1e-7]:
        K.set_value(exp.get_train_model().optimizer.lr, lr*100)
        exp.train(200, val_data=VAL_DATA, initial_epoch=initial_epoch)
        initial_epoch = exp.callbacks[-1].stopped_epoch + 1
        del exp.callbacks[-1]
        exp.callbacks.append(EarlyStopping(
            monitor='val_loss', mode='min',
            min_delta=0.0, patience=15, verbose=1, restore_best_weights=True))
except KeyboardInterrupt:
    pass
except:
    raise


In [None]:
#Prelu from scratch: 5e-5 for 200, LR plateau 10/5
#Prelu from scratch: 5e-5 for 200, LR plateau 10/5
#Prelu from imagenet: 1e-5 for 200, LR plateau 10/5
raise Exception("Stop Run All Cells")
# exp.model.get_layer(name='fc6').trainable_weights

In [None]:
for layer in exp.model.layers:
    if '_prelu' in layer.name:
        print(layer.weights)
        weight = layer.weights[0]
        value = weight.eval(session=K.get_session())
        plt.figure()
        plt.hist(value.ravel())

## Predict

In [None]:
%matplotlib notebook

def plot_test(model, window_size=(500, 500)):
    from abyss_deep_learning.utils import tile_gen, detile
#     from abyss_deep_learning.visualize import label2rgb
    from skimage.color import label2rgb
    images, targets = [], []
    
#     for i in range(1):
#         sample = DATASET['train'].sample()
#         images.append(sample[0][np.newaxis, ...])
#         targets.append(sample[1][np.newaxis, ...])
#     images = np.concatenate(images)
#     targets = np.concatenate(targets)
    
    
    for j, (image, target) in enumerate(
            DATASET['test'].generator(data_ids=None)):
        s = (1,) + image.shape
        output = np.zeros((s[0] * s[1], s[2] * 3, 3))#, dtype=np.uint8)
        tiles_target = [
            model.predict(
                tile[np.newaxis, ...])[0]
            for tile in tile_gen(image, window_size)]
        image = detile([tile for tile in tile_gen(image, window_size)], window_size, image.shape)
        prediction = detile(tiles_target, window_size, target.shape)
        image = postprocess_data(image) / 255
        print(image.shape, target.shape, prediction.shape)
        i = 0
        plt.figure()
#         output[i * s[1] : (i + 1) * s[1], 0:s[2], :] = image
#         output[i * s[1] : (i + 1) * s[1], s[2]:2*s[2], :] = label2rgb(target.argmax(-1), image, bg_label=0)
#         output[i * s[1] : (i + 1) * s[1], s[2]*2:3*s[2], :] = label2rgb(prediction.argmax(-1), image, bg_label=0)
        output = label2rgb(prediction.argmax(-1), image, bg_label=0)
        plt.imshow(output)
        if j == 2: 
            break
    return output

# plt.figure(figsize=(6, 10))
plt.imshow(plot_test(exp.model))
# plt.tight_layout()
# plt.title("{:s}{:^50}{:s}".format('rgb', 'ground truth', 'predicted'))
# plt.gca().xaxis.set_visible(False)
# plt.gca().yaxis.set_visible(False)
# # exp.plot_test()

In [None]:
from sklearn.metrics import average_precision_score
y_pred = exp.model.predict(x=test_data[0], batch_size=exp.batch_size)
ap_micro = average_precision_score(
    test_data[1][..., 1:].reshape((-1, test_data[1].shape[-1] - 1)),
    y_pred[..., 1:].reshape((-1, test_data[1].shape[-1] - 1)),
    average='micro')
print("Micro average precision of FG classes is {:.4f}".format(ap_micro))

# Inspect Trained Parameters

## FCN Upsample Weights

In [None]:
# Look at upsampling weights
# exp.model.load_weights("/data/log/fcn-crfrnn/oceaneering/tile-imagenet-crf0/2323/models/best.019-0.2982.h5", by_name=True)
layer_names = ['score2', 'score4', 'upsample']
for name in layer_names:
    layer = exp.model.get_layer(name=name)
    v = layer.get_weights()[0]
    print(v.shape, v.mean(), v.std(), v.min(), v.max())
    plt.figure()
    plt.subplot(2, 2, 1)
    plt.imshow(v[:, :, 0, 0])
    plt.subplot(2, 2, 2)
    plt.imshow(v[:, :, 0, 1])
    plt.subplot(2, 2, 3)
    plt.imshow(v[:, :, 1, 0])
    plt.subplot(2, 2, 4)
    plt.imshow(v[:, :, 1, 1])

## CRF Parameters

In [None]:
exp.get_train_model().get_layer(name='crfrnn').get_weights()

# Overlay on images and save

In [None]:
from keras.models import model_from_json
# model_path = "/data/log/fcn-crfrnn/anadarko/prelu/2635/models/best.071-0.2609.h5" # First good anadarko
model_path = "/data/log/fcn-crfrnn/oceaneering/tile-imagenet/crf_5-ups_train-aug-PR-7417/best.h5" # First good anadarko
model = get_crfrnn_model_def(
    num_classes=ARGS['num_classes'],
    input_shape=ARGS['image_dims'],
    num_iterations=5, with_crf=True)
model.load_weights(model_path)

In [None]:
from glob import glob

def predict_image(image, model, window_size=(500, 500)):
    from abyss_deep_learning.utils import tile_gen, detile
    tiles = [
        model.predict(
            preprocess_data(tile[np.newaxis, ...]))[0]
        for tile in tile_gen(image, window_size)]
    return detile(tiles, window_size, image.shape[:2] + (ARGS['num_classes'],))

def save_condition(path):
    image_dir = os.path.dirname(path)
    filename = os.path.basename(path)
    return (
        '_pred' not in filename
        and '_mask' not in filename
        and filename.lower().endswith(('.png', '.jpg', '.jpeg'))
    )

def save_overlay(path):
    image_dir = os.path.dirname(path)
    filename = os.path.basename(path)
    output_path = os.path.join(image_dir, '.'.join(
        filename.split(".")[:-1]) + '_pred.jpg')
    output_path_mask = os.path.join(image_dir, '.'.join(
        filename.split(".")[:-1]) + '_mask.jpg')
    print(path)

    image_full = imread(path)
    if image_full.ndim == 1:  # Wierd imread bug
        image_full = image_full[0]
    pred = predict_image(image_full, model)
    pred_upscaled = resize(
        pred, image_full.shape[0:2], order=0, preserve_range=True)
    mask = pred_upscaled.argmax(-1).astype(np.uint16)
    pred_rgb = label2rgb(mask, image_full, bg_label=0)
    imsave(output_path, pred_rgb)
    imsave(output_path_mask, mask)

def foreach_file(image_glob, func, condition=None):
    for path in list(glob(image_glob)):
        if condition and condition(path):
            func(path)


image_glob = "/data/abyss/oceaneering/data/*.JPG"
foreach_file(image_glob, save_overlay, condition=save_condition)

In [None]:
"""Module for visualizing various machine learning outputs.

Attributes
----------
COLOR_DICT : dict(str - > tuple(float, float, float))
    dict mapping color strings to RGB values in the range [0, 1].
DEFAULT_COLORS : list of str
    Default keys in COLOR_DICT to use.
"""
import numpy as np

from skimage._shared.utils import warn
from skimage.color import rgb_colors, rgb2gray, gray2rgb
from skimage.color.colorlabel import _rgb_vector, _match_label_with_color
from skimage.segmentation import mark_boundaries
from skimage.util import img_as_float

from abyss_deep_learning.utils import instance_to_categorical

__all__ = ['COLOR_DICT', 'label2rgb', 'DEFAULT_COLORS']


COLOR_DICT = {
    k: v for k, v in rgb_colors.__dict__.items()
    if isinstance(v, tuple)}

DEFAULT_COLORS = (
    'red', 'blue', 'yellow', 'magenta', 'green',
    'indigo', 'darkorange', 'cyan', 'pink', 'yellowgreen')



def label2rgb(
        label, image=None, colors=None, alpha=0.3,
        gray_bg=False, contours='thick',
        bg_label=-1, bg_color=(0, 0, 0), image_alpha=1, kind='overlay'):
    """Return an RGB image where color-coded labels are painted over the image, and optionally contours are painted.
    Source: https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorlabel.py

    Parameters
    ----------
    label : array, shape (M, N)
        Integer array of labels with the same shape as `image`.
    image : array, shape (M, N, 3), optional
        Image used as underlay for labels. If the input is an RGB image, it's
        converted to grayscale before coloring.
    colors : list, optional
        List of colors. If the number of labels exceeds the number of colors,
        then the colors are cycled.
    alpha : float [0, 1], optional
        Opacity of colorized labels. Ignored if image is `None`.
    gray_bg : bool, optional
        Set the background image to grayscale when mode='overlay'.
    contours : str, optional
        Description
    bg_label : int, optional
        Label that's treated as the background.
    bg_color : str or array, optional
        Background color. Must be a name in `COLOR_DICT` or RGB float values
        between [0, 1].
    image_alpha : float [0, 1], optional
        Opacity of the image.
    kind : string, one of {'overlay', 'avg'}
        The kind of color image desired. 'overlay' cycles over defined colors
        and overlays the colored labels over the original image. 'avg' replaces
        each labeled segment with its average color, for a stained-class or
        pastel painting appearance.
    contours, : string in {‘thick’, ‘inner’, ‘outer’, ‘subpixel’}, optional
        The mode for finding and drawing class spatial boundaries, use None to not draw contours.

    Returns
    -------
    result : array of float, shape (M, N, 3)
        The result of blending a cycling colormap (`colors`) for each distinct
        value in `label` with the image, at a certain alpha value.
    """
    if kind == 'overlay':
        return _label2rgb_overlay(label, image, colors, alpha, bg_label,
                                  bg_color, image_alpha, gray_bg=gray_bg, contours=contours)
    return _label2rgb_avg(label, image, bg_label, bg_color)

def _label2rgb_overlay(label, image=None, colors=None, alpha=0.3,
                       bg_label=-1, bg_color=None, image_alpha=1, gray_bg=False, contours=None):
    """Return an RGB image where color-coded labels are painted over the image.

    Parameters
    ----------
    label : array, shape (M, N)
        Integer array of labels with the same shape as `image`.
    image : array, shape (M, N, 3), optional
        Image used as underlay for labels. If the input is an RGB image, it's
        converted to grayscale before coloring.
    colors : list, optional
        List of colors. If the number of labels exceeds the number of colors,
        then the colors are cycled.
    alpha : float [0, 1], optional
        Opacity of colorized labels. Ignored if image is `None`.
    bg_label : int, optional
        Label that's treated as the background.
    bg_color : str or array, optional
        Background color. Must be a name in `COLOR_DICT` or RGB float values
        between [0, 1].
    image_alpha : float [0, 1], optional
        Opacity of the image.
    gray_bg : bool, optional
        Set the background image to grayscale when mode='overlay'.
    contours : None, optional
        Description
    contours, : string in {‘thick’, ‘inner’, ‘outer’, ‘subpixel’}, optional
        The mode for finding and drawing class spatial boundaries, use None to not draw contours.

    Returns
    -------
    result : array of float, shape (M, N, 3)
        The result of blending a cycling colormap (`colors`) for each distinct
        value in `label` with the image, at a certain alpha value.

    Raises
    ------
    ValueError
        When image and label are not the same shape.
    """
    if colors is None:
        colors = DEFAULT_COLORS
    colors = [_rgb_vector(c) for c in colors]

    if image is None:
        image = np.zeros(label.shape + (3,), dtype=np.float64)
        # Opacity doesn't make sense if no image exists.
        alpha = 1
    else:
        if not image.shape[:2] == label.shape:
            raise ValueError("`image` and `label` must be the same shape")

        if image.min() < 0:
            warn("Negative intensities in `image` are not supported")
        if gray_bg:
            image = img_as_float(rgb2gray(image))
            image = gray2rgb(image) * image_alpha + (1 - image_alpha)
        else:
            image = img_as_float(image)

    # Ensure that all labels are non-negative so we can index into
    # `label_to_color` correctly.
    offset = min(label.min(), bg_label)
    if offset != 0:
        label = label - offset  # Make sure you don't modify the input array.
        bg_label -= offset

    new_type = np.min_scalar_type(int(label.max()))
    if new_type == np.bool:
        new_type = np.uint8
    label = label.astype(new_type)

    mapped_labels_flat, color_cycle = _match_label_with_color(label, colors,
                                                              bg_label, bg_color)

    if len(mapped_labels_flat) == 0:
        return image

    dense_labels = range(max(mapped_labels_flat) + 1)

    label_to_color = np.array([c for i, c in zip(dense_labels, color_cycle)])

    mapped_labels = label
    mapped_labels.flat = mapped_labels_flat
    if gray_bg:
        result = label_to_color[mapped_labels] * alpha + image * (1 - alpha)
    else:
        result = label_to_color[mapped_labels] * alpha + image * (1 - alpha)

    # Remove background label if its color was not specified.
    remove_background = 0 in mapped_labels_flat and bg_color is None
    if remove_background:
        result[label == bg_label] = image[label == bg_label]

    if contours:
        for label_idx in range(label_to_color.shape[0]):
            result = mark_boundaries(
                result, label == label_idx, color=label_to_color[label_idx], mode=contours)

    return result

In [None]:
def save_dataset_overlay(coco_path, dataset):
    import json
    from pycocotools.coco import COCO
    from abyss_deep_learning.utils import ann_rle_encode
    from abyss_deep_learning.visualize import label2rgb
    from skimage.exposure import rescale_intensity
    
    coco = COCO(coco_path)
    for img, (image, target) in zip(coco.imgs.values(), dataset.generator(endless=False, shuffle=False)):
        path = img['path']
        image = imread(path, plugin='imread')
        print(image.shape, image.dtype, target.shape)
        image_dir = os.path.dirname(path)
        filename = os.path.basename(path)
        output_path = os.path.join(image_dir, '.'.join(
            filename.split(".")[:-1]) + '_gt.jpg')
        imsave(
            output_path,
            (255 * label2rgb(target.argmax(-1), image, bg_label=0)).astype(np.uint8))
        
"/data/log/fcn-crfrnn/oceaneering/tile-imagenet-crf0/ups_new-ups_only-reinit_ups2363/models/best.015-0.2838.h5"
save_dataset_overlay("/data/abyss/oceaneering/annotations/separation/train.json", DATASET['train'])

In [None]:
del model

In [None]:
# from abyss_deep_learning.utils import image_streamer
# for path, seq_no, image_full in image_streamer()