In [None]:
import os
import glob
import functools
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import datetime

import matplotlib.pyplot as plt

import data_io
import u_net

In [None]:
import tensorflow as tf
# from tensorflow.keras import models, layers, losses

# tf.enable_eager_execution()
tf.__version__

# Set up

In [None]:
common_root = r'/awlab/users/chsu/WorkSpace/tensorflow/segmentation/data/plate_2017017086_ki67'

img_dir = os.path.join(common_root, 'images')
mask_dir = os.path.join(common_root, 'masks')

# max_intensity = 4095.
batch_size = 5
# epochs = 5

In [None]:
x_train_fnames = sorted(glob.glob(os.path.join(img_dir,'*-2.png'))) # nucleus images
y_train_fnames = sorted(glob.glob(os.path.join(mask_dir,'*_nucleus.png'))) # nucleus masks     


# Split into training and validation
x_train_fnames, x_val_fnames, y_train_fnames, y_val_fnames = \
    train_test_split(x_train_fnames, y_train_fnames, test_size=0.2, random_state=43)

num_train_data = len(x_train_fnames)
num_val_data = len(x_val_fnames)

print("Number of training samples: {}".format(num_train_data))
print("Number of validation samples: {}".format(num_val_data))

In [None]:
idx = np.random.choice(num_train_data)
print(os.path.basename(x_train_fnames[idx]))
print(os.path.basename(y_train_fnames[idx]))

# Test the input pipeline

### Get image and mask from path names

In [None]:
idx = np.random.choice(num_train_data)
img, mask = data_io._get_image_from_path(x_train_fnames[idx], y_train_fnames[idx])

plt.figure(figsize=(12,16))
plt.subplot(1,2,1)
plt.imshow(img[:,:,0], cmap='gray')
plt.subplot(1,2,2)
plt.imshow(mask[:,:,0])

In [None]:
np.unique(mask)

In [None]:
print(img.shape)
print(img.dtype)
print(mask.shape)
print(img.numpy().max())

### Test the input pipeline

In [None]:
tmp_ds = data_io.get_dataset(x_train_fnames, y_train_fnames)

In [None]:
plt.figure(figsize=(12,16))
for i, (img, mask) in enumerate(tmp_ds.shuffle(num_train_data).take(3)):
    plt.subplot(3,2,2*i+1)
    plt.imshow(img[0,:,:,0])
    plt.subplot(3,2,2*i+2)
    plt.imshow(mask[0,:,:,0])
plt.show()

# Configure training and validation datasets

In [None]:
# training data
train_cfg = {
    'resize': None, 
    'scale': 1/255.,
    'crop_size': [512, 512],
    'to_flip': True
}
tr_preproc_fn = functools.partial(data_io._augment, **train_cfg)

# validation data
val_cfg = {
    'resize': None, 
    'scale': 1/255.,
    'crop_size': [512, 512]
}
val_preproc_fn = functools.partial(data_io._augment, **val_cfg)


train_ds = data_io.get_dataset(x_train_fnames, y_train_fnames, preproc_fn=tr_preproc_fn, 
                       shuffle=True, batch_size=batch_size)
val_ds = data_io.get_dataset(x_val_fnames, y_val_fnames, preproc_fn=val_preproc_fn, 
                     shuffle=False, batch_size=batch_size)

# Build the model

In [None]:
num_filters_list = [32, 64, 128, 256, 512]
model = u_net.Unet(num_filters_list)
model.compile(optimizer='adam', loss=u_net.bce_dice_loss, metrics=[u_net.dice_loss])

# plt.figure(figsize=(12,16))
# for i, (img, mask) in enumerate(train_ds.take(1)):
#     y_pred = model(img)
    
#     plt.subplot(1,3,1)
#     plt.imshow(img[0,:,:,0])
#     plt.subplot(1,3,2)
#     plt.imshow(mask[0,:,:,0])
#     plt.subplot(1,3,3)
#     plt.imshow(y_pred[0,:,:,0])
# plt.show()

## Train the model

In [None]:
root_path = r'/awlab/users/chsu/WorkSpace/tensorflow/segmentation'
save_path = os.path.join(root_path, 'models', 'weights-{epoch:04d}.ckpt')
cp = tf.keras.callbacks.ModelCheckpoint(filepath=save_path, monitor='val_dice_loss', 
                                        save_best_only=True, save_weights_only=True, verbose=1)

log_dir = os.path.join(root_path, 'logs', '{}'.format(datetime.datetime.now()))
log_dir = log_dir.split('.')[0].replace(':','_').replace(' ','_')
if not os.path.isdir(log_dir):
    os.makedirs(log_dir, exist_ok=True)
tb = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

In [None]:
history = model.fit(train_ds, epochs=50, 
                    steps_per_epoch=int(np.ceil(num_train_data / batch_size)),
                    validation_data=val_ds,
                    validation_steps=int(np.ceil(num_val_data / batch_size)),
                    callbacks=[cp, tb])

### Plot training process

In [None]:
tr_process = pd.DataFrame.from_dict(history.history)
tr_process['epoch'] = np.array(range(1, tr_process.shape[0]+1))

In [None]:
tr_process.plot(x='epoch', y=['loss', 'val_loss'])
tr_process.plot(x='epoch', y=['dice_loss', 'val_dice_loss'])

### Or load trained weight

In [None]:
latest = tf.train.latest_checkpoint(os.path.dirname(save_path))

model.load_weights(latest)

# Visualize performance

In [None]:
import cv2

In [None]:
M = np.uint8(mask[idx,:,:,0].numpy()*255.)
I = np.uint8(img[idx,:,:,0].numpy()*255.)

I = cv2.cvtColor(I,cv2.COLOR_GRAY2RGB)

In [None]:
np.unique(M)

In [None]:
im, true_conturs, hierarchy = cv2.findContours(M.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
img2 = cv2.drawContours(I.copy(), true_conturs, -1, (0,255,0), 1)

plt.figure(figsize=(8, 10))
plt.imshow(img2)

In [None]:
plt.figure(figsize=(12,16))
for i, (img, mask) in enumerate(val_ds.shuffle(num_val_data).take(1)):
    idx = 3
    print(img.shape)
    y_pred = model(img)
    
    plt.subplot(1,3,1)
    plt.imshow(img[idx,:,:,0], cmap='gray')
    
    plt.subplot(1,3,2)
    plt.imshow(mask[idx,:,:,0], cmap='gray')
    
    plt.subplot(1,3,3)
    plt.imshow(y_pred[idx,:,:,0], cmap='gray')