In [76]:
import tensorflow as tf
tf.__version__

'2.3.0'

In [77]:
import argparse
import tqdm.notebook as tqdm
import tensorflow as tf
from tensorflow import keras

import data_generators as DA
import unet_utilities
from unet_utilities import *
from tf_da import *

from imp import reload

h,w = 256,256
truth_only = True
N = 30000
beta = 0.005
dropout_rate = 0.1

u_net = UNet(depth_mult=0.25,padding='SAME',
             factorization=False,n_classes=2,
             beta=beta,dropout_rate=dropout_rate,
             squeeze_and_excite=False)

print("Setting up data generator...")
IA = ImageAugmenter(
    saturation_lower=0.9,saturation_upper=1.1,
    hue_max_delta=0.1,
    contrast_lower=0.9,contrast_upper=1.1)
key_list = [x.strip() for x in open('../u-net/training_set_files').readlines()]

loss_fn = WeightedCrossEntropy()

Setting up data generator...


In [78]:
reload(unet_utilities)

hdf5_dataset = HDF5Dataset(
    h5py_path='../u-net/segmentation_dataset.h5',
    input_height=h,input_width=w,key_list=key_list,
    augment_fn=IA.augment)

def load_generator():
    while True:
        yield hdf5_dataset.grab()
def load_generator_no_transform():
    while True:
        yield hdf5_dataset.grab(augment=False)

generator = load_generator
output_types = (tf.float32,tf.float32,tf.float32)
output_shapes = (
    tf.TensorShape((h,w,3)),
    tf.TensorShape((h,w,2)),
    tf.TensorShape((h,w,1)))
tf_dataset = tf.data.Dataset.from_generator(
    generator,output_types=output_types,output_shapes=output_shapes)
tf_dataset_val = tf.data.Dataset.from_generator(
    load_generator_no_transform,output_types=output_types,output_shapes=output_shapes)
if truth_only == True:
    tf_dataset = tf_dataset.filter(
        lambda x,y,w: tf.reduce_sum(y[:,:,1:]) > 0.)
    tf_dataset_val = tf_dataset_val.filter(
        lambda x,y,w: tf.reduce_sum(y[:,:,1:]) > 0.)
tf_dataset = tf_dataset.batch(4).prefetch(100)

tf_dataset_val = tf_dataset_val.batch(4).prefetch(100)

print("Setting up training...")
loss_average = tf.keras.metrics.Mean()
iou = tf.keras.metrics.MeanIoU(2)
train_updater = TrainUpdater(
    optimizer=keras.optimizers.RMSprop(learning_rate=0.0001),
    loss=loss_fn)

Setting up training...


In [79]:
reload(data_generators)
print("Training...")
writer = tf.summary.create_file_writer('summaries/testing')
tf_dataset_iterable = iter(tf_dataset)
tf_dataset_val_iterable = iter(tf_dataset_val)
with writer.as_default():
    for i in range(N):
        x,y,w = next(tf_dataset_iterable)

        train_updater(u_net,x,y,w) # does all the heavy lifting

        # validation
        if i % 200 == 0:
            iou.reset_states()
            train_updater.reset()   
            image_list = []
            y_true_list = []
            y_pred_list = []
            y_pred_bin_list = []
            w_list = []

            for _ in range(15):
                x,y,w = next(tf_dataset_val_iterable)
                y_pred = u_net(x,training=False)
                y_pred_binary = tf.argmax(y_pred,axis=-1)
                iou.update_state(y[:,:,:,1],y_pred_binary)
                l = train_updater.loss(y,y_pred,w,u_net)
                train_updater.loss_average.update_state(l)
                image_list.append(x)
                y_true_list.append(
                    tf.expand_dims(y[:,:,:,1],axis=-1))
                y_pred_list.append(y_pred[:,:,:,1:])
                y_pred_bin_list.append(
                    tf.expand_dims(y_pred_binary,axis=-1))
                w_list.append(w)

            iou_value = iou.result().numpy()
            loss_value = train_updater.get_loss().numpy()
            print(i,iou_value,loss_value)
            """
            tf.summary.scalar("Loss", loss_value, step=i)
            tf.summary.scalar("MeanIoU", loss_value, step=i)
            tf.summary.image(
                "InputImage", tf.concat(image_list,axis=0), step=i)
            tf.summary.image(
                "GroundTruth", tf.concat(y_true_list,axis=0), step=i)
            tf.summary.image(
                "Prediction", tf.concat(y_pred_list,axis=0), step=i)
            tf.summary.image(
                "PredictionBinary", tf.concat(y_pred_bin_list,axis=0),
                step=i)
            tf.summary.image(
                "WeightMap", tf.concat(w_list,axis=0), step=i)
            """
            iou.reset_states()
            train_updater.reset()
            
        if i % 5000 == 0:
            u_net.save_weights('checkpoints/u-net-0.25'.format(i))

        loss_value = train_updater.current_l.numpy().mean()

u_net.save_weights('checkpoints/u-net-0.25'.format(i))

Training...
0 0.17072754 0.6663724
200 0.7493442 0.20614417
400 0.8112313 0.14664334
600 0.8491573 0.123913445
800 0.8186612 0.10911007
1000 0.8445133 0.09429424
1200 0.84624743 0.087796874
1400 0.8785995 0.07665678
1600 0.8515825 0.0909051
1800 0.9078976 0.072904624
2000 0.8980746 0.079080924
2200 0.9016384 0.07325923
2400 0.86834 0.07657171
2600 0.88825417 0.070163935
2800 0.90867794 0.07226667
3000 0.8940488 0.068074904
3200 0.88487947 0.07196034
3400 0.8998283 0.07252496
3600 0.8938217 0.06432007
3800 0.9028151 0.064660475
4000 0.8921608 0.06401144
4200 0.9088532 0.0639839
4400 0.9086637 0.057518363
4600 0.8988035 0.060506247
4800 0.8188261 0.07018542
5000 0.89535916 0.060218368
5200 0.9169121 0.05309448
5400 0.8762423 0.05995385
5600 0.8955216 0.058003552
5800 0.9146917 0.050292935
6000 0.92140293 0.053135004
6200 0.87578285 0.06143992
6400 0.90291536 0.054957986
6600 0.87824386 0.061856724
6800 0.86962914 0.053638227
7000 0.9145088 0.049661662
7200 0.9028008 0.053794425
7400 0.88