In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, MaxPooling2D, concatenate, UpSampling2D, Dropout
import numpy as np
from sklearn.model_selection import train_test_split
import unet
from skimage import io, transform
import time
import skimage
from matplotlib import pyplot as plt
import glob

In [None]:
class TextSegDataset(tf.keras.utils.Sequence):

    def __init__(self, images_paths, masks_paths, output_size, validation_split, batch_size=1, type='train',seed=968):

        self.type = type
        self.batch_size = batch_size
        self.output_size = output_size
        self.validation_split = validation_split
        self.images_paths_train, self.images_paths_test, self.masks_paths_train,self.masks_paths_test = train_test_split(
            images_paths, masks_paths, test_size=self.validation_split,random_state=seed)

    def __len__(self):
        l = None
        if self.type == 'train':
            l = len(self.images_paths_train)//self.batch_size
        elif self.type == 'test':
            l = len(self.images_paths_test)//self.batch_size
        return l

    def __getitem__(self, idx):

        batch_x = None
        batch_y = None

        if self.type == 'train':
            batch_x = self.images_paths_train[idx * self.batch_size:(idx + 1) * self.batch_size]
            batch_y = self.masks_paths_train[idx * self.batch_size:(idx + 1) *
                                                                   self.batch_size]
        elif self.type == 'test':
            batch_x = self.images_paths_test[idx * self.batch_size:(idx + 1) * self.batch_size]
            batch_y = self.masks_paths_test[idx * self.batch_size:(idx + 1) *
                                                                  self.batch_size]
        output_x = []
        for file_name in batch_x:
            img = transform.resize(io.imread(file_name)/255, (self.output_size, self.output_size))
            output_x.append(img)

        output_y = []
        for file_name in batch_y:
            img = transform.resize(io.imread(file_name)/255, (self.output_size, self.output_size,1))
            output_y.append(img)

        return (np.array(output_x),np.array(output_y))

In [None]:
class Unet(tf.keras.Model):

    def __init__(self):
        super(Unet, self).__init__()

    def __call__(self, inputs):
        conv1 = Conv2D(64, 3, activation='relu', dilation_rate=2, padding='same', kernel_initializer='he_normal')(
            inputs)
        conv1 = BatchNormalization()(conv1)
        conv1 = Conv2D(64, 3, activation='relu', dilation_rate=2, padding='same', kernel_initializer='he_normal')(conv1)
        conv1 = BatchNormalization()(conv1)
        pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
        conv2 = Conv2D(128, 3, activation='relu', dilation_rate=2, padding='same', kernel_initializer='he_normal')(
            pool1)
        conv2 = BatchNormalization()(conv2)
        conv2 = Conv2D(128, 3, activation='relu', dilation_rate=2, padding='same', kernel_initializer='he_normal')(
            conv2)
        conv2 = BatchNormalization()(conv2)
        pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
        conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
        conv3 = BatchNormalization()(conv3)
        conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
        conv3 = BatchNormalization()(conv3)
        pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
        conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
        conv4 = BatchNormalization()(conv4)
        conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
        conv4 = BatchNormalization()(conv4)
        drop4 = Dropout(0.5)(conv4, training=True)
        pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

        conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
        conv5 = BatchNormalization()(conv5)
        conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
        conv5 = BatchNormalization()(conv5)
        drop5 = Dropout(0.5)(conv5, training=True)

        up6 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
            UpSampling2D(size=(2, 2))(drop5))
        merge6 = concatenate([drop4, up6], axis=3)
        conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
        conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)

        up7 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
            UpSampling2D(size=(2, 2))(conv6))
        merge7 = concatenate([conv3, up7], axis=3)
        conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
        conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)

        up8 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
            UpSampling2D(size=(2, 2))(conv7))
        merge8 = concatenate([conv2, up8], axis=3)
        conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
        conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)

        up9 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
            UpSampling2D(size=(2, 2))(conv8))
        merge9 = concatenate([conv1, up9], axis=3)
        conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
        conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
        conv9 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)

        conv10 = Conv2D(1, 1, activation='sigmoid')(conv9)

        return conv10

In [None]:
def printProgressBar(epoch, iteration, total, prefix='', suffix='', decimals=1, length=100, fill='█', printEnd="\r",
                     eta=None, loss=None, train_type='train'):
    percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
    filledLength = int(length * iteration // total)
    bar = fill * filledLength + '-' * (length - filledLength)
    print(
        f'\r{prefix} |{bar}| {percent}% {suffix} ETA:{eta} s epoch={epoch}: recon_loss={str(np.round(loss, 4))}',
        end=printEnd)
    # Print New Line on Complete
    if iteration == total:
        print()


In [None]:
def generate_and_save_images(model, epoch, test_sample):
    logits=model(test_sample)
    prediction = logits.numpy()[0] * 255
    mask=prediction*255
    mask=mask.astype(np.uint8)
    mask=skimage.color.gray2rgb(mask)

    # plt.imshow(mask[:,:,0,:])
    # plt.show()

    img=test_sample[0]*255
    img=img.astype(np.uint8)
    # plt.imshow(img)
    # plt.show()
    #
    # plt.imshow(np.concatenate([mask[:,:,0,:],img],axis=1))
    # plt.show()

    io.imsave('image_at_epoch_{:04d}.png'.format(epoch), np.concatenate([mask[:,:,0,:],img],axis=1))

In [None]:
test_sample = test_dataset[np.random.randint(0, len(test_dataset))][0][2]
generate_and_save_images(model, 0, np.expand_dims(test_sample, axis=0).astype(np.float32))

In [None]:
images_path = 'data/image/'
masks_path = 'data/semantic_label/'

images_path_list = glob.glob(images_path + '*')
masks_path_list = glob.glob(masks_path + '*')

batch_size = 3

train_dataset = TextSegDataset(images_path_list, masks_path_list, output_size=512, validation_split=0.2, type='train',batch_size=batch_size)
test_dataset = TextSegDataset(images_path_list, masks_path_list, output_size=512, validation_split=0.2, type='test',batch_size=batch_size)

In [None]:
model = Unet()

In [None]:
epochs=100
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
beta = 1
total_loss=[]

l_train=len(train_dataset)

for epoch in range(1, epochs + 1):


    start_time = time.time()
    total_loss = []

    printProgressBar(epoch, 0, l_train, eta=None, loss=[0, 0, 0], prefix='Progress:', suffix='Complete',
                     train_type='train', length=10)
    for i,x_train in enumerate(train_dataset):

        with tf.GradientTape() as tape:
            logits=model(x_train[0])

            loss=tf.reduce_sum(tf.keras.losses.MSE(logits,x_train[1]))

        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        total_loss.append(loss)
        printProgressBar(epoch, i + 1, l_train, eta=None, prefix='Progress:', suffix='Complete', train_type='train',
                         loss=np.mean(total_loss), length=10)

    end_time = time.time()

    loss = tf.keras.metrics.Mean()
    for i, x_test in enumerate(test_dataset):
        logits=model(x_test[0])
        test_loss=tf.reduce_sum(tf.keras.losses.MSE(logits,x_test[1]))
        loss(test_loss)

    print('Epoch: {}, Test set recon: {}, time elapse for current epoch: {}'
          .format(epoch, loss.result(), end_time - start_time))

    test_sample = test_dataset[np.random.randint(0, len(test_dataset))][0][2]
    generate_and_save_images(model, 0, np.expand_dims(test_sample, axis=0).astype(np.float32))