In [None]:
import keras.backend
import keras.callbacks
import keras.layers
import keras.models
import keras.optimizers

import matplotlib
matplotlib.use('SVG')

import utils.callbacks
import utils.model_builder
import utils.visualize
import utils.data_provider
import utils.metrics

import skimage.io
import sklearn.metrics

import scipy.stats
import pandas as pd

import tensorflow as tf
import numpy as np

import matplotlib.pyplot as plt

import sys

import os
# Uncomment the following line if you don't have a GPU
#os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [None]:
# constants
const_lr = 1e-4

tag = '01'

base_dir = "/data1/image-segmentation/BBBC022/"

# Output dirs
out_dir = base_dir + 'unet/experiments/' + tag + '/out'
tb_log_dir = base_dir + "unet/tensorboard/" + tag + "/"
chkpt_dir = base_dir + "unet/experiments/" + tag + "/checkpoints/"

os.makedirs(out_dir, exist_ok=True)
os.makedirs(tb_log_dir, exist_ok=True)
os.makedirs(chkpt_dir, exist_ok=True)

# Files
chkpt_file = chkpt_dir + "{epoch:04d}.hdf5"
csv_log_file = base_dir + "unet/experiments/" + tag + "/log.csv"

# Input dirs
train_dir_x = base_dir + 'unet/split/training/x/'
train_dir_y = base_dir + 'unet/split/training/y/'
val_dir = base_dir + "unet/split/validation/"

# Learning Settings
rescale_labels = True

epochs = 15

batch_size = 10
# remove steps per epoch 
# remoge validation steps 
# set up a queue 
steps_per_epoch = 500

# make sure these number for to the validation set
val_batch_size = 10
val_steps = int(50 * 4 / val_batch_size)

# generator only params
dim1 = 256
dim2 = 256

bit_depth = 8

In [None]:
# build session running on GPU 1
configuration = tf.ConfigProto()
#configuration.gpu_options.allow_growth = True
configuration.gpu_options.visible_device_list = "2"
session = tf.Session(config = configuration)

# apply session
keras.backend.set_session(session)

train_gen = utils.data_provider.random_sample_generator(
    train_dir_x,
    train_dir_y,
    batch_size,
    bit_depth,
    dim1,
    dim2,
    rescale_labels
)

val_gen = utils.data_provider.single_data_from_images(
    val_dir + 'x/',
     val_dir + 'y/',
     val_batch_size,
     bit_depth,
     dim1,
     dim2,
     rescale_labels
)

In [None]:
# build model
model = utils.model_builder.get_model_3_class(dim1, dim2)
model.summary()

loss = "categorical_crossentropy"
metrics = [keras.metrics.categorical_accuracy, 
           utils.metrics.channel_recall(channel=0, name="background_recall"), 
           utils.metrics.channel_precision(channel=0, name="background_precision"),
           utils.metrics.channel_recall(channel=1, name="interior_recall"), 
           utils.metrics.channel_precision(channel=1, name="interior_precision"),
           utils.metrics.channel_recall(channel=2, name="boundary_recall"), 
           utils.metrics.channel_precision(channel=2, name="boundary_precision"),
          ]

optimizer = keras.optimizers.RMSprop(lr = const_lr)

model.compile(loss=loss, metrics=metrics, optimizer=optimizer)

# CALLBACKS
# save model after each epoch
callback_model_checkpoint = keras.callbacks.ModelCheckpoint(
    filepath=chkpt_file,
    save_weights_only=True,
    save_best_only=False
)
callback_csv = keras.callbacks.CSVLogger(filename=csv_log_file)
# callback_splits_and_merges = utils.callbacks.SplitsAndMergesLoggerBoundary(
#     'images',
#     val_gen,
#     gen_calls = val_steps,
#     log_dir=tb_log_dir
# )

callbacks=[callback_model_checkpoint, callback_csv] #, callback_splits_and_merges]


In [None]:
# TRAIN
statistics = model.fit_generator(
    generator=train_gen,
    steps_per_epoch=steps_per_epoch,
    epochs=epochs,
    validation_data=val_gen,
    validation_steps=val_steps,
    callbacks=callbacks,
    verbose = 1
)

# visualize learning stats
#utils.visualize.visualize_learning_stats_boundary_hard(statistics, out_dir, metrics)

print('Done! :)')