# Train and evaluate model
#### Calling functions in ipython notebook makes it compatible with Google Colab, which provides free GPU and TPU.
## Import packages and python files

In [None]:
import tensorflow as tf
from model import unet
from pipeline import TrainGenerator, ValidationDataPreparation,TestDataPreparation
from tensorflow.keras.callbacks import ModelCheckpoint
import skimage.io as io
import matplotlib.pyplot as plt
import numpy as np

## Load the model

In [None]:
model = unet()
input_size = 704
try:
    model.load_weights("weights.hdf5")
    print("Load weights successfully.")
except:
    print("Cannot load weights.")
model.summary()

## Prepare the pipeline for training and the data for validation.

In [None]:
data_generator = TrainGenerator(batch_size = 2, target_size = input_size, deformation = False)
val_data = ValidationDataPreparation(batch_size=1, target_size = input_size)

## Train the model
### The parameters will be saved only if the loss on the validation set is decreased.

In [None]:
model_checkpoint = ModelCheckpoint('weights.hdf5', monitor='val_loss',verbose=1, save_best_only=True)
model.fit(data_generator,epochs=30,steps_per_epoch = 200, validation_data = val_data,validation_steps=25, callbacks=[model_checkpoint])

In [None]:
model.evaluate(val_data,steps=25)

## Test the model and save predictions to files.
### Calculate the accuracy and AUROC

In [None]:
test_data = TestDataPreparation(batch_size = 5, target_size = input_size)
model.evaluate(test_data,steps=1)

### Save the predictions to files

In [None]:
results = model.predict(test_data,steps = 1,verbose=1)
for i in range(5):
    io.imsave("predict/%d.png"%i,results[i])

### Additionally, calculate the SSIM for the test data.

In [None]:
im1 = tf.image.convert_image_dtype(results,np.float)
im2 = tf.image.convert_image_dtype(next(test_data)[1], np.float)
print(np.mean(tf.image.ssim(im1, im2, max_val=1.0).numpy()))