In [2]:
import tensorflow as tf
import numpy as np 
import matplotlib.pyplot as plt
import os

In [10]:
from tensorflow.keras.callbacks import TensorBoard,ModelCheckpoint,EarlyStopping,LearningRateScheduler,Callback
from tensorflow.keras import backend as K
import numpy as np
from tensorflow.keras import Model
from DeepLabv3Plus_tf import deeplab_forward
# Next the model definition
from tensorflow.keras.layers import ZeroPadding2D,Conv2D,Cropping2D,Lambda,Concatenate,BatchNormalization,Activation,Lambda,Input,Reshape,Dropout

In [3]:
from semi_supervised_function import *

In [4]:
#run_on_cpu()

In [47]:
#Now we create 3 Datasets from the 3 TF Records

batch_size = 32

#One with both the supervised and the semi supervised
dataset = tf.data.TFRecordDataset(['salty_train.tfrecords','salty_test.tfrecords'])
dataset = dataset.map(parser_train).shuffle(buffer_size=18000*2)
dataset = dataset.batch(batch_size).repeat()

#One Only Supervised train
dataset_sup = tf.data.TFRecordDataset(['salty_train.tfrecords'])
dataset_sup = dataset_sup.map(parser_train).shuffle(buffer_size=5000)
dataset_sup = dataset_sup.batch(batch_size).repeat()

#One with the validation set
valid = tf.data.TFRecordDataset(['salty_valid.tfrecords'])
valid = valid.map(parser).shuffle(buffer_size=5000)
valid = valid.batch(batch_size).repeat()

In [6]:
#We laod the validation set and a small amount of the training set to memory for validation purposes
X_Val,Y_Val=load_to_memory('salty_valid.tfrecords')
X_Val_t,Y_Val_t=load_to_memory('salty_train.tfrecords',size=160)

In [12]:
#Two differently augmented images a input
image_corrupt_1 = Input((101, 101, 1),name="image_corrupt_1")
image_corrupt_2 = Input((101, 101, 1),name="image_corrupt_2")

#The scaling input, which is an array of 1 an 0 if a label exists or not 
scaling = Input((1,1), name="scaling" )

#the depth input
depth = Input( (1,1,1),name="depth" )

#A cnn layer w use to go to 3 dimensions
cnn_1=Conv2D(3, (1, 1), activation='relu')

#we pad them 
s = ZeroPadding2D( padding = ((13, 14), (13, 14)) ) (image_corrupt_1)
s_c=ZeroPadding2D( padding = ((13, 14), (13, 14)) ) (image_corrupt_2)

#we get them to 3 dimensions
img_input =  cnn_1(s) 
img_input_c=cnn_1(s_c) 

#Do forward pass with different dropout configurations
extract,low_rep,deep_rep=deeplab_forward(img_input,backbone="mobilenetv2",drop=1)
extract_c,low_rep_c,deep_rep_c=deeplab_forward(img_input_c,backbone="mobilenetv2",drop=0)

#We add depth
dd = Lambda(lambda x: x * 0.001) (depth)
dd = Lambda(lambda x: _expand(x, 128, 128) )(dd)

#We define a "get logits" function. 

def get_logits(extract):
    x_2=Concatenate(axis=3,name = "con_last")([extract,dd])
    x_2=Conv2D(64,(1,1),activation="relu",name = "conv_last")(x_2)
    logits=Conv2D(1,(1,1),name = "logits_last")(x_2)
    logits = Cropping2D(cropping=((13, 14), (13, 14)),name = "crop_last" ) (logits)
    return logits

#We ge tthe logits from both
logits=get_logits(extract)
logits_c=get_logits(extract_c)

#we get the sigmoids
sigmoids=Activation(activation="sigmoid")(logits)
model=Model(inputs=[image_corrupt_1,image_corrupt_2,depth,scaling],outputs=[sigmoids])


from tensorflow.keras.utils import get_file
def load_weights(backbone="xception"):
    if backbone == "mobilenetv2":
        WEIGHTS_PATH_MOBILE = "https://github.com/bonlime/keras-deeplab-v3-plus/releases/download/1.1/deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels.h5"
        weights_path = get_file('deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels.h5',
                                WEIGHTS_PATH_MOBILE,
                                cache_subdir='models')
        
    if backbone == "xception":
        WEIGHTS_PATH_X = "https://github.com/bonlime/keras-deeplab-v3-plus/releases/download/1.1/deeplabv3_xception_tf_dim_ordering_tf_kernels.h5"

        weights_path = get_file('deeplabv3_xception_tf_dim_ordering_tf_kernels.h5',
                                WEIGHTS_PATH_X,
                                cache_subdir='models')

    model.load_weights(weights_path, by_name=True)

load_weights(backbone = "mobilenetv2")
    


    


In [43]:
def dice_loss(y_true, y_pred):

    bin_crossentropyloss = tf.keras.losses.binary_crossentropy(y_true,y_pred)
    #the scale dsupervised loss scale is a normal Keras Input it is one if this case has a label 0 if not
    super_loss=tf.reduce_mean(scaling*((1-dice_coef(y_true, y_pred))*0.99+bin_crossentropyloss*0.01),reduction_indices=(1,2))
    
    #We add unsupervised losses at multiple levels: 1. the final representation resulting form deeplab
    #2.) The deepes representation resulting from the deeplabs
    #3.) the final logits. 
    unsuper_loss_1=1/3*tf.reduce_mean(0.5*tf.keras.losses.mean_squared_error(low_rep,low_rep_c),reduction_indices=(1,2))
    unsuper_loss_2=1/3*tf.reduce_mean(0.5*tf.keras.losses.mean_squared_error(deep_rep,deep_rep_c),reduction_indices=(1,2))
    unsuper_loss_3=1/3*tf.reduce_mean(0.5*tf.keras.losses.mean_squared_error(logits,logits_c),reduction_indices=(1,2))
    
    #The unsupervised factor has to be determined, in the paper they suggest slowly increasing it during training. 
    return super_loss + scale*(unsuper_loss_1+unsuper_loss_2+unsuper_loss_3)

In [44]:
import math

In [45]:
#Callbacks
#1- Tensorboard
tens=TensorBoard(log_dir='../logs/semi_upsampled_mobilenet_ramping', histogram_freq=0, batch_size=32)
#LR Sched
lrate = LearningRateScheduler(step_decay)
#THe evaluation too. 
internalEval = IntervalEvaluation( validation_data=(X_Val,Y_Val),training_data= (X_Val_t,Y_Val_t),interval = 1 )
#Model Evaluation 
check=ModelCheckpoint(filepath="semi_upsampled_xcept_0.001.h5",save_best_only=True,monitor="val_loss")


In [None]:
for j in range (1,500):
#Get the modifier for the loss 
    x= j/100
    x=math.exp(-5*(1-x)**2)
    scale=tf.convert_to_tensor(x,dtype=tf.float32,name="scale_Loss")
    #recompile loss with new scaling term 
    model.compile(optimizer='adam', loss=[dice_loss])
    #fit for one epoch 
    model.fit(dataset,steps_per_epoch=int((8000)/batch_size), epochs=j+1,validation_data=valid,validation_steps=20,callbacks=[tens,lrate,check,internalEval],initial_epoch=j)#18000+
    

Epoch 2/2
Changing learning rate to 0.001
Validation score is 0.15499999999999997.Train score is 0.20187500000000003 Best so far is 0.403125
Epoch 3/3
Changing learning rate to 0.001
Validation score is 0.20437500000000003.Train score is 0.231875 Best so far is 0.403125
Epoch 4/4
Changing learning rate to 0.001

In [None]:
#In a way at each epoch we "sample" 4k true example and 4k unlabeled examples. 

model.compile(optimizer='adam', loss=[dice_loss])
model.fit(dataset,steps_per_epoch=int((8000)/batch_size), epochs=1000,validation_data=valid,validation_steps=20,callbacks=[tens,lrate,check,internalEval])#18000+