In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
%config InlineBackend.figure_format='retina'
import tensorflow as tf

# Load F3 inline data

In [None]:
# load data
inline = np.load('../data/inline.npy')
inline_mask = np.load('../data/inline_mask.npy')

In [None]:
# expand the dimension
inline = np.expand_dims(inline, axis=3)
inline = inline[:,:448,:928,:]
inline_mask = np.expand_dims(inline_mask, axis=3)
inline_mask = inline_mask[:,:448,:928,:]

inline = inline[:601,:,:,:]
inline_mask = inline_mask[:601,:,:,:]

In [None]:
plt.figure(figsize = [18,6.5])
for i in np.arange(12):
    ax = plt.subplot(3, 4, i+1)
    ax.imshow(inline[i*50,:,:,0], cmap = 'gray', vmin = 0, vmax = 1)
    ax.axis('off')
    ax.set_title('Inline: '+str(i*50+100))

In [None]:
plt.figure(figsize = [18,6])
for i in np.arange(13):
    ax = plt.subplot(3, 5, i+1)
    ax.imshow(inline_mask[i*50,:,:,0], vmin = 0, vmax = 9)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title('Inline: '+str(i*50+100))

# Supervised training

In [None]:
# Model: UNet
## Batchnorm: default
tf.device('/physical_device:GPU:0')
import sys
sys.path.append('../')
from model.unet_BN_default import *

In [None]:
## save path
model_name = 'unet_BN_default'
checkpoint_filepath = '../save_model/supervised_'+str(model_name)+'.h5'

In [None]:
known_label = np.array([100,400]) # training: inline 200 and 500
valid_idx = 50 # val: inline 150

## data (add augmented data if needed)
training_images = inline[known_label,:,:,:]
training_labels = inline_mask[known_label,:,:,:]
validation_images = inline[np.hstack([known_label,valid_idx]),:,:,:]
validation_labels = inline_mask[np.hstack([known_label,valid_idx]),:,:,:]

In [None]:
## Training
n_epochs = 5 ## change into 500 for extensive training
learning_rate = 0.001

m_super = model(1, 10, activation = 'softmax', test_case = 'test_UNet_supervised')

m_super.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),
                loss=['sparse_categorical_crossentropy'],
                metrics=['mse','sparse_categorical_accuracy'])

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='loss',
    mode='min',
    save_best_only=True)

history_super = m_super.fit(training_images, 
                            training_labels, 
                            batch_size = 2,
                            epochs=n_epochs, verbose = 1,callbacks= [model_checkpoint_callback],
                            validation_data = (validation_images,validation_labels))

## A quick inference

In [None]:
images_test = inline[np.arange(0,601,50),:,:,:]
prediction= m_super.predict(images_test,batch_size = 1)

In [None]:
plt.figure(figsize = [18,6])
for i in np.arange(13):
    ax = plt.subplot(3, 5, i+1)
    ax.imshow(np.argmax(prediction,axis = 3)[i,:,:], vmin = 0, vmax = 9)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title('Inline: '+str(i*50+100))

# Semi-supervised learning

## Stage 1: Pretraining

In [None]:
## data
training_images = inline
training_labels = inline

In [None]:
## save path
checkpoint_filepath = '../save_model/pretraining.h5'

In [None]:
n_epochs = 1 ## change into 500 for extensive training
learning_rate = 0.0005

m_pre = model(1, 1, activation = 'sigmoid', test_case = 'test_UNet_supervised')

m_pre.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),
              loss=['mse'],metrics=['mse'])

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True)

history_pre = m_pre.fit(training_images, 
                        training_labels, 
                        batch_size = 8,
                        epochs=n_epochs, verbose = 1,callbacks= [model_checkpoint_callback], 
                        validation_split = 0.05)

In [None]:
plt.figure(figsize = [15,6])
for i in np.arange(13):
    ax = plt.subplot(3, 5, i+1)
    prediction= m_pre.predict(inline[i*50:(i*50+1),:,:,:])[0,:,:,:]
    ax.imshow(prediction[:,:,0], cmap = 'gray', vmin = 0, vmax = 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title('Inline: '+str(i*50+100))

## Stage 2: Supervised

In [None]:
m_pre = model(1, 1, activation = 'sigmoid', test_case = 'test_UNet_supervised')

m_pre.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),
              loss=['mse'],metrics=['mse'])

m_pre.load_weights('../save_model/pretraining.h5')

In [None]:
## save path
model_name = 'unet_BN_default'
checkpoint_filepath = '../save_model/semi_'+str(model_name)+'.h5'

In [None]:
## data (add augmented data if needed)
training_images = inline[known_label,:,:,:]
training_labels = np.array(inline_mask[known_label,:,:,:],dtype = 'float64')
validation_images = inline[(known_label[0],known_label[1],valid_idx),:,:,:]
validation_labels = np.array(inline_mask[(known_label[0],known_label[1],valid_idx),:,:,:],dtype = 'float64')

In [None]:
## training
model_name = 'unet_BN_default'
checkpoint_filepath = '../save_model/semi_model_'+str(model_name)+'.h5'

n_epochs = 5 ## change into 500 for extensive training
learning_rate = 0.001

m_semi = model(1, 10, activation = 'softmax', test_case = 'test_UNet_semi_supervised')

m_semi.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),
               loss=['sparse_categorical_crossentropy'],
               metrics=['mse','sparse_categorical_accuracy'])

encoder_layer_num = {}
encoder_layer_num['UNet'] = 30
## load the pretraining encoder weight
for i in range(encoder_layer_num['UNet']):
    m_semi.layers[i].set_weights(m_pre.layers[i].get_weights())


model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='loss',
    mode='min',
    save_best_only=True)

history_semi = m_semi.fit(training_images, 
                          training_labels, 
                          batch_size = 2,
                          epochs=n_epochs, verbose = 1,callbacks= [model_checkpoint_callback],
                          validation_data = (validation_images,validation_labels))

## A quick inference

In [None]:
images_test = inline[np.arange(0,601,50),:,:,:]
prediction= m_semi.predict(images_test,batch_size = 1)

In [None]:
plt.figure(figsize = [18,6])
for i in np.arange(13):
    ax = plt.subplot(3, 5, i+1)
    ax.imshow(np.argmax(prediction,axis = 3)[i,:,:], vmin = 0, vmax = 9)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title('Inline: '+str(i*50+100))