#### Installing TF2.0-preview
Training model training is intended to run only on TF2.0+, If you are running a docker container with a previous version, please duplicate the container in order to preserve the old version

In [1]:
#!pip install tf-nightly-gpu-2.0-preview


In [1]:
import tensorflow as tf
import tensorflow.keras as tk
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
import pandas as pd
tf.test.is_gpu_available()


True

In [2]:
tf.__version__

'2.0.0-alpha0'

### Note on COCO_Animals
The train/validation/test data comes from the COCO train2017 dataset, which is contained in /datasets/coco_animals/train. The validation folder containes unused data.

Declaring environment

In [3]:
BASE_MODEL_PATH = 'models/model_base.h5'
MODEL_OUT_FOLDER = 'models/'
ID_TO_LABEL = {16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse'}
LABEL_TO_ID = {'bird': 16, 'cat': 17, 'dog': 18, 'horse': 19}
CHANNEL_ORDER = [0, 16, 17, 18, 19] # Order of channels in output segmentation and corresponding dataset labels
CHANNEL_NAMES = [ID_TO_LABEL[i] if i!=0 else 'other' for i in CHANNEL_ORDER]
ALL_LABELS = list(LABEL_TO_ID.keys())
ds_csv_paths = {dset: {label: 'datasets/coco_animals_{}_{}.csv'.format(dset, label) for label in ALL_LABELS} for dset in ['train', 'validation', 'test']}

In [4]:
def load_dataset(path, size=224, batch_size=32, filter_expr=None):
    def parse_sample(png_path, seg_path, lab_name, lab_value):
        resize = tf.image.resize_image_with_pad if tf.__version__.startswith('1.') else tf.image.resize_with_pad
        png_raw = tf.io.read_file(png_path)
        png = tf.image.decode_png(png_raw, channels=3)
        png = resize(png, size, size)
        png = preprocess_input(tf.cast(png, tf.float32))
        seg_raw = tf.io.read_file(seg_path)
        seg = tf.image.decode_png(seg_raw, channels=1)
        seg = resize(seg, size, size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        segs = []
        for lid in CHANNEL_ORDER:
            # Creating 5 masks out of the index labels
            segs.append(tf.cast(tf.equal(seg, lid), tf.float32))
        seg = tf.concat(segs, axis=-1)
        return png, seg
    dataset = tf.data.experimental.CsvDataset(path, [tf.string, tf.string, tf.string, tf.int32], header=True)
    if filter_expr:
        dataset = dataset.filter(filter_expr)
    dataset = dataset.shuffle(1000)
    dataset = dataset.map(parse_sample)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(1)
    return dataset

In [5]:
def show_one_sample(dataset_label):
    for png, seg in load_dataset(ds_csv_paths['train'][dataset_label]):
        png = png.numpy()[0,...]
        seg = seg.numpy()[0,...]
        break
    plt.figure(figsize=(18, 6))
    plt.subplot(1, seg.shape[-1]+1, 1)
    plt.axis('off')
    io.imshow(png)
    for i in range(seg.shape[-1]):
        plt.subplot(1, seg.shape[-1]+1, i+2)
        plt.axis('off')
        io.imshow(seg[...,i])

In [6]:
def train_model(label, input_model=BASE_MODEL_PATH, output_folder=MODEL_OUT_FOLDER, ):
    
    # Load model
    
    
    
    # Train the model
    return 
    

    

def predict(model, dataset_label):
    for png, seg in load_dataset(ds_csv_paths['validation'][dataset_label], batch_size=1):
        out = model.predict(png)[0,...]
        png = png.numpy()[0,...]
        seg = seg.numpy()[0,...]
        
        
        plt.figure(figsize=(18, 6))
        plt.subplot(2, seg.shape[-1]+1, 1)
        plt.axis('off')
        io.imshow(png)
        for i in range(seg.shape[-1]):
            plt.subplot(2, seg.shape[-1]+1, i+2)
            plt.axis('off')
            plt.title(CHANNEL_NAMES[i])
            io.imshow(seg[...,i])
        for i in range(out.shape[-1]):
            plt.subplot(2, out.shape[-1]+1, seg.shape[-1]+1+i+2)
            plt.axis('off')
            plt.title(CHANNEL_NAMES[i])
            io.imshow(out[...,i])
        yield plt
    

### Train the models

In [None]:
for target_label in ['dog', 'bird', 'horse', 'cat']:
    history = train_model(target_label, epochs=2)
    hist = pd.DataFrame(history.history)
    hist.to_csv('models/history_{}.csv'.format(target_label))
    break

### Evaluate models
Models are evaluated separately for their target label and the others

In [28]:
base_metrics = {lab: evaluate('models/model_base.h5', lab) for lab in ['bird', 'cat', 'dog', 'horse']}

W0625 13:23:56.392818 140413771372288 hdf5_format.py:224] No training configuration found in save file: the model was *not* compiled. Compile it manually.


      8/Unknown - 3s 327ms/step - loss: 0.3315 - accuracy: 0.0000e+00 - binary_accuracy: 0.7999 - false_positives_19: 6725.0000 - false_negatives_19: 12794554.0000 - precision_19: 0.0462 - recall_19: 2.5479e-05

W0625 13:24:09.264626 140413771372288 hdf5_format.py:224] No training configuration found in save file: the model was *not* compiled. Compile it manually.


      8/Unknown - 2s 301ms/step - loss: 0.3300 - accuracy: 0.0000e+00 - binary_accuracy: 0.7999 - false_positives_20: 7352.0000 - false_negatives_20: 11339397.0000 - precision_20: 0.0490 - recall_20: 3.3422e-05

W0625 13:24:20.741027 140413771372288 hdf5_format.py:224] No training configuration found in save file: the model was *not* compiled. Compile it manually.


      8/Unknown - 2s 292ms/step - loss: 0.3323 - accuracy: 0.0000e+00 - binary_accuracy: 0.7999 - false_positives_21: 6303.0000 - false_negatives_21: 11339734.0000 - precision_21: 0.0066 - recall_21: 3.7038e-06

W0625 13:24:32.311767 140413771372288 hdf5_format.py:224] No training configuration found in save file: the model was *not* compiled. Compile it manually.


      8/Unknown - 3s 326ms/step - loss: 0.3312 - accuracy: 0.0000e+00 - binary_accuracy: 0.7999 - false_positives_22: 6289.0000 - false_negatives_22: 12844764.0000 - precision_22: 0.0444 - recall_22: 2.2732e-05

In [7]:
evaluate('models/model_cat-500-0.98.h5', 'cat')

      8/Unknown - 3s 366ms/step - loss: 0.0209 - accuracy: 0.5440 - binary_accuracy: 0.9783 - false_positives_1: 615126.0000 - false_negatives_1: 615191.0000 - precision_1: 0.9458 - recall_1: 0.9457

([0.014624733167390028,
  0.5413325,
  0.98538685,
  522381.0,
  522471.0,
  0.96347004,
  0.96346396],
 [0.02085717290174216,
  0.5439969,
  0.97830087,
  615126.0,
  615191.0,
  0.9457547,
  0.9457493])

In [8]:
evaluate('models/model_dog-500-0.98.h5', 'dog')

      8/Unknown - 3s 330ms/step - loss: 0.0274 - accuracy: 0.5105 - binary_accuracy: 0.9699 - false_positives: 854592.0000 - false_negatives: 854598.0000 - precision: 0.9246 - recall: 0.9246

([0.012869831485052904,
  0.5234587,
  0.98713106,
  463300.0,
  463302.0,
  0.9678275,
  0.9678274],
 [0.027432613307610154,
  0.5105202,
  0.96985495,
  854592.0,
  854598.0,
  0.9246376,
  0.92463714])

In [7]:
evaluate('models/model_bird-500-0.99.h5', 'bird')

      8/Unknown - 4s 439ms/step - loss: 0.0159 - accuracy: 0.6226 - binary_accuracy: 0.9841 - false_positives_1: 507788.0000 - false_negatives_1: 507823.0000 - precision_1: 0.9603 - recall_1: 0.9603

([0.010003979583936078,
  0.6380507,
  0.9899777,
  250178.0,
  250185.0,
  0.97494465,
  0.974944],
 [0.01588962972164154,
  0.62264305,
  0.9841247,
  507788.0,
  507823.0,
  0.9603131,
  0.96031046])

In [8]:
evaluate('models/model_horse-500-0.98.h5', 'horse')

      8/Unknown - 3s 381ms/step - loss: 0.0250 - accuracy: 0.5023 - binary_accuracy: 0.9750 - false_positives: 803281.0000 - false_negatives: 803285.0000 - precision: 0.9375 - recall: 0.9375

([0.016505081206560135,
  0.488655,
  0.98318064,
  411413.0,
  411414.0,
  0.9579518,
  0.9579517],
 [0.025014091515913606,
  0.502264,
  0.9749854,
  803281.0,
  803285.0,
  0.93746376,
  0.93746346])