# Step 02
# Training a U-Net model

In [1]:
%matplotlib inline

import sys
import os

import numpy as np
import skimage.io

import tensorflow as tf

import keras.backend
import keras.callbacks
import keras.layers
import keras.models
import keras.optimizers

import sys
__file__ = '012-prerpocessing.ipynb'
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import utils.model_builder
import utils.data_provider
import utils.metrics
import utils.objectives
import utils.dirtools

import random
# Uncomment the following line if you don't have a GPU
# os.environ['CUDA_VISIBLE_DEVICES'] = ''

Using TensorFlow backend.


In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"

In [None]:
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

# Configuration

In [4]:
from config import config_vars

In [5]:
# operate on which dataset? write it here
config_vars["root_directory"] = '../CELL/seg_samples_KZ_images/'

experiment_name = '00'
config_vars = utils.dirtools.setup_experiment(config_vars, experiment_name)

In [5]:
config_vars

{'root_directory': 'FISH/',
 'max_training_images': 0,
 'create_split_files': True,
 'training_fraction': 0.8,
 'validation_fraction': 0.1,
 'transform_images_to_PNG': True,
 'pixel_depth': 8,
 'min_nucleus_size': 25,
 'boundary_size': 2,
 'augment_images': False,
 'elastic_points': 16,
 'elastic_distortion': 5,
 'elastic_augmentations': 10,
 'learning_rate': 0.0001,
 'epochs': 15,
 'steps_per_epoch': 500,
 'batch_size': 10,
 'val_batch_size': 10,
 'rescale_labels': True,
 'crop_size': 256,
 'cell_min_size': 16,
 'boundary_boost_factor': 1,
 'object_dilation': 3,
 'raw_images_dir': 'FISH/raw_images/',
 'raw_annotations_dir': 'FISH/raw_annotations/',
 'path_files_training': 'FISH/training.txt',
 'path_files_validation': 'FISH/validation.txt',
 'path_files_test': 'FISH/test.txt',
 'normalized_images_dir': 'FISH/norm_images/',
 'boundary_labels_dir': 'FISH/boundary_labels/',
 'experiment_dir': 'FISH/experiments/00/out/',
 'probmap_out_dir': 'FISH/experiments/00/out/prob/',
 'labels_out_di

### Set the train-validation split List

##### exp 00
train_aug: 有机标图的前48张图  
用其余图训练，然后在这48张图上valid
##### exp 01
只使用剩余图里随机的24张做训练，随机24张做测试  
加入或者不加入机标图的数据会产生什么影响

In [6]:
# 00
def create_image_lists(dir_raw_images):
    file_list = os.listdir(dir_raw_images)
    image_list = [x for x in file_list if x.endswith("png")]
    image_list = sorted(image_list)

    image_list_train_aug = image_list[:48]
    image_list_validation = image_list[:48]
    image_list_test = []
    image_list_train = image_list[48:]

    return image_list_train, image_list_test, image_list_validation, image_list_train_aug

In [5]:
# 01
def create_image_lists(dir_raw_images):
    file_list = os.listdir(dir_raw_images)
    image_list = [x for x in file_list if x.endswith("png")]
    image_list = sorted(image_list)

    image_list_train_aug = image_list[:48]
    image_list_test = []
    
    image_list_2 = image_list[48:]
    random.shuffle(image_list_2)
    image_list_train = image_list_2[:24]
    image_list_validation = image_list_2[-24:]
    
    return image_list_train, image_list_test, image_list_validation, image_list_train_aug

#### 将每一部分(train valid)的数据的名字写进一个list，通过data_partitions字典读取

In [7]:
config_vars["path_files_training_aug"] = 'FISH/training_aug.txt'

In [8]:
# Prepare split txt files
[list_training, list_test, list_validation, list_training_aug] = create_image_lists(
    config_vars["normalized_images_dir"],
#         config_vars["training_fraction"],
#         config_vars["validation_fraction"]
)

utils.dirtools.write_path_files(config_vars["path_files_training"], list_training)
utils.dirtools.write_path_files(config_vars["path_files_test"], list_test)
utils.dirtools.write_path_files(config_vars["path_files_validation"], list_validation)

# modify the write path method to add 'raw_masks/' ahead of the name
utils.dirtools.write_path_files2(config_vars["path_files_training_aug"], list_training_aug)
    
data_partitions = utils.dirtools.read_data_partitions(config_vars, load_augmented=False)

In [9]:
data_partitions = utils.dirtools.read_data_partitions(config_vars)

#### add raw masks into training set

In [9]:
with open(config_vars["path_files_training_aug"]) as f:
    data_partitions["training"] += f.read().splitlines()
 

In [None]:
data_partitions["training"]

# Initiate data generators

In [10]:
# build session running on GPU 1
configuration = tf.ConfigProto()
configuration.gpu_options.allow_growth = True
configuration.gpu_options.visible_device_list = "0, 1"
session = tf.Session(config = configuration)

# apply session
keras.backend.set_session(session)

train_gen = utils.data_provider.random_sample_generator(
    config_vars["normalized_images_dir"],
    config_vars["boundary_labels_dir"],
    data_partitions["training"],
    config_vars["batch_size"],
    config_vars["pixel_depth"],
    config_vars["crop_size"],
    config_vars["crop_size"],
    config_vars["rescale_labels"]
)

val_gen = utils.data_provider.single_data_from_images(
     config_vars["normalized_images_dir"],
     config_vars["boundary_labels_dir"],
     data_partitions["validation"],
     config_vars["val_batch_size"],
     config_vars["pixel_depth"],
     config_vars["crop_size"],
     config_vars["crop_size"],
     config_vars["rescale_labels"]
)

# Build model

In [11]:
import warnings
warnings.filterwarnings("ignore")

In [12]:
# build model
model = utils.model_builder.get_model_3_class(config_vars["crop_size"], config_vars["crop_size"], activation=None)
# model.summary()

#loss = "categorical_crossentropy"
loss = utils.objectives.weighted_crossentropy

metrics = [keras.metrics.categorical_accuracy, 
           utils.metrics.channel_recall(channel=0, name="background_recall"), 
           utils.metrics.channel_precision(channel=0, name="background_precision"),
           utils.metrics.channel_recall(channel=1, name="interior_recall"), 
           utils.metrics.channel_precision(channel=1, name="interior_precision"),
           utils.metrics.channel_recall(channel=2, name="boundary_recall"), 
           utils.metrics.channel_precision(channel=2, name="boundary_precision"),
          ]

optimizer = keras.optimizers.RMSprop(lr=config_vars["learning_rate"])

model.compile(loss=loss, metrics=metrics, optimizer=optimizer)

In [13]:
# Performance logging
callback_csv = keras.callbacks.CSVLogger(filename=config_vars["csv_log_file"])

In [14]:
callback_tboard = keras.callbacks.TensorBoard(log_dir='./logs/00/0', histogram_freq=0, batch_size=32, write_graph=True, write_grads=False, write_images=True, update_freq='epoch')

callbacks = [callback_csv, callback_tboard]
callbacks

[<keras.callbacks.CSVLogger at 0x7f4bca732dd8>,
 <keras.callbacks.TensorBoard at 0x7f4bca73f278>]

# Training 

In [15]:
# TRAIN
statistics = model.fit_generator(
    generator=train_gen,
    steps_per_epoch=config_vars["steps_per_epoch"],
    epochs=config_vars["epochs"],
    validation_data=val_gen,
    validation_steps=int(len(data_partitions["validation"])/config_vars["val_batch_size"]),
    callbacks=callbacks,
    verbose = 1
)

model.save_weights(config_vars["model_file"])

print('Done! :)')

Epoch 1/15
Training with 291 images.
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15
Done! :)


### Continue training by load existing model, and train on augmentation matlab-labelled images