# Description
This notebook trains and evaluates the **EfficientNet** model.

# MOUNT GOOGLE Drive


In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

# Change your working directory

In [0]:
 cd /content/gdrive/My\ Drive/WIRE_DETECTION/TPU_COLAB/

# MODEL SPEC

In [0]:
model_name='efficientnetb7'
iden='model1'

# TPU CHECK
The model trains faster in TPU (approximately 17 times)

In [0]:
%tensorflow_version 2.x
import tensorflow as tf
print("Tensorflow version " + tf.__version__)

try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime;')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)

# FIXED PARAMETERS


In [0]:
from glob import glob 
import os 

BUCKET='tfalldata' # @param
TFIDEN='WireDTF'  # @param
IMG_DIM=256 # @param
NB_CHANNEL=3 # @param
BATCH_SIZE=128 # @param
BUFFER_SIZE=2048 # @param
TRAIN_DATA=1024*21 # @param
EVAL_DATA=1024*2 # @param
EPOCHS=250 # @param
TOTAL_DATA=TRAIN_DATA+EVAL_DATA
STEPS_PER_EPOCH = TOTAL_DATA//BATCH_SIZE
EVAL_STEPS      = EVAL_DATA//BATCH_SIZE
GCS_PATH='gs://{}/{}'.format(BUCKET,TFIDEN)
print(GCS_PATH)

WEIGHT_PATH=os.path.join(os.getcwd(),'model_weights','{}.h5'.format(iden))
if os.path.exists(WEIGHT_PATH):
  print('FOUND PRETRAINED WEIGHTS')
  LOAD_WEIGHTS=True 
else:
  print('NO PRETRAINED WEIGHTS FOUND')
  LOAD_WEIGHTS=False

# Dataset wrapper with tf.data api

In [0]:
def data_input_fn(mode,BUFFER_SIZE,BATCH_SIZE,img_dim): 
    
    def _parser(example):
        feature ={  'image'  : tf.io.FixedLenFeature([],tf.string) ,
                    'target' : tf.io.FixedLenFeature([],tf.string)
        }    
        parsed_example=tf.io.parse_single_example(example,feature)
        image_raw=parsed_example['image']
        image=tf.image.decode_png(image_raw,channels=3)
        image=tf.cast(image,tf.float32)/255.0
        image=tf.reshape(image,(img_dim,img_dim,3))
        
        
        target_raw=parsed_example['target']
        target=tf.image.decode_png(target_raw,channels=1)
        target=tf.cast(target,tf.float32)/255.0
        target=tf.reshape(target,(img_dim,img_dim,1))
        
        return image,target
    gcs_pattern=os.path.join(GCS_PATH,mode,'*.tfrecord')
    file_paths = tf.io.gfile.glob(gcs_pattern)
    dataset = tf.data.TFRecordDataset(file_paths)
    dataset = dataset.map(_parser)
    dataset = dataset.shuffle(BUFFER_SIZE,reshuffle_each_iteration=True)
    dataset = dataset.repeat()
    dataset = dataset.batch(BATCH_SIZE,drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset
eval_ds = data_input_fn("Eval",BUFFER_SIZE,BATCH_SIZE,IMG_DIM)
train_ds = data_input_fn("Train",BUFFER_SIZE,BATCH_SIZE,IMG_DIM)
for x,y in eval_ds.take(1):
  print(x.shape)
  print(y.shape)

# install segmentation-models

In [0]:
!pip3 install segmentation-models

# framework setup

In [0]:
import segmentation_models as sm
sm.set_framework('tf.keras')

# model creation

In [0]:
def ssim(y_true, y_pred):
  return tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))

with tpu_strategy.scope():
  model = sm.Unet(model_name,input_shape=(IMG_DIM,IMG_DIM,NB_CHANNEL), encoder_weights=None)
  model.compile(optimizer="Adam",
                loss=tf.keras.losses.mean_squared_error,
                metrics=[ssim])
  if LOAD_WEIGHTS:
    model.load_weights(WEIGHT_PATH)
model.summary()






# Training

In [0]:
import numpy as np 
import matplotlib.pyplot as plt
%matplotlib inline
# reduces learning rate on plateau
lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(factor=0.1,
                               cooldown= 10,
                               patience=10,
                               verbose =1,
                               min_lr=0.1e-5)

mode_autosave = tf.keras.callbacks.ModelCheckpoint(WEIGHT_PATH,
                                                  monitor='val_ssim', 
                                                  mode = 'max', 
                                                  save_best_only=True, 
                                                  verbose=1, 
                                                  period =10)

# stop learining as metric on validatopn stop increasing
early_stopping = tf.keras.callbacks.EarlyStopping(patience=15, 
                               verbose=1, 
                               mode = 'auto') 

callbacks = [mode_autosave, lr_reducer,early_stopping ]



history = model.fit(train_ds,
                    steps_per_epoch=STEPS_PER_EPOCH,
                    epochs=EPOCHS,
                    verbose=1,
                    validation_data=eval_ds,
                    validation_steps=EVAL_STEPS,
                    callbacks=callbacks)

# save model
model.save_weights(WEIGHT_PATH)
def plot_history(history):
  """
  Plots model training history 
  """
  fig, (ax_loss, ax_acc) = plt.subplots(1, 2, figsize=(15,5))
  ax_loss.plot(history.epoch, history.history["loss"], label="Train loss")
  ax_loss.plot(history.epoch, history.history["val_loss"], label="Validation loss")
  ax_loss.legend()
  ax_acc.plot(history.epoch, history.history["ssim"], label="Train ssim")
  ax_acc.plot(history.epoch, history.history["val_ssim"], label="Validation ssim")
  ax_acc.legend()
  plt.show()
# show history
plot_history(history)

# Model Predictions and Scores

In [0]:
from skimage.measure import compare_ssim
from sklearn.metrics import jaccard_similarity_score
from glob import glob
from PIL import Image as imgop
import cv2
import imageio 

img_dir = os.path.join(os.getcwd(),'test','images') 
tgt_dir = os.path.join(os.getcwd(),'test','masks')

def create_dir(base_dir,ext_name):
    '''
        creates a new dir with ext_name in base_dir and returns the path
    '''
    new_dir=os.path.join(base_dir,ext_name)
    if not os.path.exists(new_dir):
        os.mkdir(new_dir)
    return new_dir

pred_path=create_dir(os.path.join(os.getcwd(),'test'),'preds')
pred_dir=create_dir(pred_path,iden)

# preprocess data
def get_img(_path):
    data=imgop.open(_path)
    data=data.resize((IMG_DIM,IMG_DIM))
    data=np.array(data)
    data=data.astype('float32')/255.0
    data=np.expand_dims(data,axis=0)
    return data

def get_gt(_path):
    # test folder mask path
    _mpath=str(_path).replace("images","masks")
    # ground truth
    gt=cv2.imread(_mpath,0)
    # resize
    gt= cv2.resize(gt,(IMG_DIM,IMG_DIM), interpolation = cv2.INTER_AREA)
    # Otsu's thresholding after Gaussian filtering
    blur = cv2.GaussianBlur(gt,(5,5),0)
    _,gt = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    return gt

def get_pred(model,img,_path):
    _ppath=str(_path).replace('images','preds/{}'.format(iden))
    pred=model.predict([img])
    pred =np.squeeze(pred)*255.0
    pred=pred.astype('uint8')
    imageio.imsave(_ppath,pred)
    pred=cv2.imread(_ppath,0)
    pred= cv2.resize(pred,(IMG_DIM,IMG_DIM), interpolation = cv2.INTER_AREA)
    # Otsu's thresholding after Gaussian filtering
    blur = cv2.GaussianBlur(pred,(5,5),0)
    _,pred = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    return pred

def get_overlay(pred,img):
    img=np.squeeze(img)
    overlay=img*0.2
    xs,ys=np.nonzero(pred)
    for x,y in zip(xs,ys):
        overlay[x,y,:]=img[x,y,:]
    
    return overlay


def get_score(pred,gt):
    (ssim_score,_) = compare_ssim(gt,pred,full=True)
    iou = jaccard_similarity_score(gt.flatten(), pred.flatten())
    return ssim_score,iou

def score_summary(arr,model_name,score_iden):
    print(model_name,':',score_iden)
    print('max:',np.amax(arr))
    print('mean:',np.mean(arr))
    print('min:',np.amin(arr))

# plotting data
def plot_data(img,gt,pred,overlay):
    fig, (ax1, ax2, ax3,ax4) = plt.subplots(1, 4,figsize=(20,20))
    ax1.imshow(np.squeeze(img))
    ax1.title.set_text('image')
    ax2.imshow(np.squeeze(gt))
    ax2.title.set_text('ground truth')
    ax3.imshow(np.squeeze(pred))
    ax3.title.set_text('prediction')
    ax4.imshow(np.squeeze(overlay))
    ax4.title.set_text('Overlay')
    plt.show()

img_paths=glob(os.path.join(img_dir,'*.*'))
SSIM=[]
IOU=[]
# inference model
model_infer = sm.Unet(model_name,input_shape=(IMG_DIM,IMG_DIM,NB_CHANNEL), encoder_weights=None)
model_infer.load_weights(WEIGHT_PATH)
print('Loaded inference weights')
for _path in img_paths:
    # ground truth
    gt=get_gt(_path)
    # image
    img=get_img(_path) 
    # prediction
    pred=get_pred(model_infer,img,_path)
    # overlay
    overlay=get_overlay(pred,img)
    # scores
    ssim_score,iou=get_score(pred,gt)
    SSIM.append(ssim_score)
    IOU.append(iou)
    plot_data(img,gt,pred,overlay)



# Evaluation Scores

In [0]:
score_summary(np.array(SSIM),model_name,'ssim')
score_summary(np.array(IOU),model_name,'IoU/F1')