# Unet_Train

Training a UNET with basic paramters:

* input_h,iput_w = 450,450
* classes = 2
* targets = train_y, eval_y


In [5]:
input_patch_size = [450,450] # input patch (expect output to be smaller)
input_channels = [3] # RGB
output_channels = [2] #cell and bkg
NBATCH = 5 # example patch per batch
EPCS = 100 # epochs
REG = True # regularization
USEW = True # use weights
W = 10 # importance of weights
NAME_NET = 'MIMO_W100' # name for saving
JT = True # Just train set means no evaluation during training 

dataset_path = '../CD_Dataset'
trained_models_path = './trained_models'

In [4]:
import keras
import sys
import matplotlib
import numpy as np
import time
import matplotlib.pyplot as plt
import os

sys.path.append('../')

from datasets import CD_Dataset
from models import Unet
from utility import show_batches, from_categorical, train, crop_receptive, predict_full_image
from utility import dice, precision, Pc
from datasets import combine_y_w
from models import dice_coef, precision

seed = int((time.time()*1e6)%1e6)
np.random.seed(seed)
keras.initializers.RandomNormal(mean=0.0, stddev=0.05, seed=seed)

NAME_NET = os.path.join(trained_models_path,NAME_NET)

In [None]:
# Load the dataset
dataset = CD_Dataset( path=dataset_path, 
                     train_y_path="train_y",  
                     eval_y_path="eval_y", 
                     fit=True, 
                     download=True, 
                     num_classes=output_channels[0] )

In [None]:
# define your model
model_input_size = input_patch_size + input_channels
mimo = MimoNet(model_input_size, classes=output_channels[0], metrics=[dice_coef,precision], regularized=REG)
model_output_size = list(mimo.outputs_shape[0])
print("input size: {}\noutput_size: {}".format(model_input_size,model_output_size))

In [None]:
# Visualize some data
means = dataset.mean_features()
stds = dataset.std_features()
xs,ys,ws= dataset.sample_X_Y_W_patch_batch(input_patch_size,n_batch=5,fit=False, rotate=False)
xs_c = crop_receptive(xs,model_output_size[:2])
ys_c = crop_receptive(ys,model_output_size[:2])
ws_c = crop_receptive(np.expand_dims(ws,3),model_output_size[:2])
ys_imgs_c = from_categorical(ys_c)
show_batches([xs_c,ys_imgs_c,ws],["xs","ys","ws"])

In [None]:
# Train your model
histo = train(mimo,dataset,n_batch=NBATCH,epochs=EPCS,just_train=JT,use_weights=USEW, W=W, name=NAME_NET)

In [None]:
# Save model
mimo.save_model(NAME_NET)
eval_histo = np.array(histo[0])
train_histo = np.array(histo[1])
train_histo.dump(NAME_NET+'_train_histo.pkl')
eval_histo.dump(NAME_NET+'_eval_histo.pkl')

In [None]:
# Load model
mimo = MimoNet(model_input_size, classes=output_channels[0], metrics=[dice_coef,precision], regularized=REG)
mimo.load_model(NAME_NET)
train_histo = np.load(NAME_NET+'_train_histo.pkl')
eval_histo = np.load(NAME_NET+'_eval_histo.pkl')

In [None]:
# Show results

X,Y,W = dataset.get_X_Y_W(index=7,train=True)
Y_hat,crops = predict_full_image(mimo,X)
cropsh,cropsw = crops

fig = matplotlib.pyplot.gcf()
fig.set_size_inches(10.5, 10.5,forward=True)

Yimg = from_categorical(np.expand_dims(Y,0))[0]

plt.imshow(Yimg[cropsh[0]:cropsh[1],cropsw[0]:cropsw[1]])
plt.show()


fig = matplotlib.pyplot.gcf()
fig.set_size_inches(10.5, 10.5,forward=True)
plt.imshow(Y_hat[cropsh[0]:cropsh[1],cropsw[0]:cropsw[1]])
plt.show()

fig = matplotlib.pyplot.gcf()
fig.set_size_inches(10.5, 10.5,forward=True)
plt.imshow(X[cropsh[0]:cropsh[1],cropsw[0]:cropsw[1]]*stds + means)
plt.show()

dice_s = dice(Yimg[cropsh[0]:cropsh[1],cropsw[0]:cropsw[1]],Y_hat[cropsh[0]:cropsh[1],cropsw[0]:cropsw[1]])
precision_s = precision(Yimg[cropsh[0]:cropsh[1],cropsw[0]:cropsw[1]],Y_hat[cropsh[0]:cropsh[1],cropsw[0]:cropsw[1]])
Pc_s = Pc(Yimg[cropsh[0]:cropsh[1],cropsw[0]:cropsw[1]],Y_hat[cropsh[0]:cropsh[1],cropsw[0]:cropsw[1]])

table = [["dice",dice_s],
         ["Precision",precision_s],
         ["Pc",Pc_s]]
print(table)