In [None]:
# Import dependencies
import os
import numpy as np
import cv2
import tensorflow as tf
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML

from tensorflow.keras import Model
from tensorflow.keras.layers import Dropout, BatchNormalization, Conv2D, MaxPooling2D, Conv2DTranspose, \
    concatenate, Input, Add, Concatenate, GlobalAveragePooling2D, Activation, GaussianNoise, Input, Softmax


In [None]:
# Settings

# NOTE We only train on normal data!!!!
datapathTrain = '../data_plants/plants_binary/train/normal/'
datapathTest =  '../data_plants/plants_binary/test/normal/'

# For postprocessing
datapathAnomalous =  '../data_plants/plants_binary/test/sick/'
datapathVeryAnomalous =  '../data_misc/'

# Input settings
imageWidth = 240
imageHeight = 240
imageChannels = 3

# Training parameter
learningrate = 1e-4
nepoches = 30
batchSize = 10
intermediateResults = 300
patience = 3

# Lets build a data pipeline generator

The data pipeline generator is a helper class which builds tf.data.set objects tailor made for your model needs. Clearly, Keras offeres these functionality already built in, but its good to know how things work under the hood. Especially when you start working with non-standard models.

The generated tf.data.set object prepares the raw data for the neural network. It also carries out image augmentation etc..

In [None]:


class DatapipeGenerator:
    def __init__(self, datapath: str):

        # Save datapath
        self.datapath = datapath

        # Find all image files in datapath
        self.filenames = []
        for root, dirs, files in os.walk(datapath):
            for file in files:
                if file.endswith(".png") or file.endswith(".jpg") or file.endswith(".jpeg"):
                    name = str(os.path.join(root, file))
                    self.filenames.append(name)

        self.iw, self.ih, self.ic = None, None, None


    # ============================
    def create(
        self, iw: int, ih: int, ic:int, batchSize: int,
        shuffle_buffer_size:int=50000,
        augmentations:list=['fliph', 'flipv', 'color', 'crop', 'noise'],
        nrepeat:int=1
    ):

        """Creates the datapipe"""

        self.iw, self.ih, self.ic = iw, ih, ic

        # Let's build the pipeline
        dataset = tf.data.Dataset.from_tensor_slices(self.filenames)

        dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
        dataset = dataset.repeat(nrepeat)

        # Load the image
        dataset = dataset.map(self._processLoadImage)

        # Add Gaussian nOise to image
        dataset = dataset.map(self._processAddNoise)

        # Augment the image
        if 'noise' in augmentations:
            dataset = dataset.map(self._processAddNoise)
        if 'fliph' in augmentations:
            dataset = dataset.map(self._processAugmentFlip)
        if 'flipv' in augmentations:
            dataset = dataset.map(self._processAugmentFlipVertically)
        if 'color' in augmentations:
            dataset = dataset.map(self._processAugmentColor)
        if 'crop' in  augmentations:
            dataset = dataset.map(self._processAugmentCrop)

        # Apply batching
        dataset = dataset.batch(batchSize)

        # Prefetching
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        return dataset



    # ============================
    def _processLoadImage(self, imgpath):
        img = tf.io.read_file(imgpath)
        img = tf.image.decode_jpeg(img, channels=self.ic)
        img = tf.image.convert_image_dtype(img, tf.float32)
        img = tf.image.resize(img, (self.ih, self.iw))
        return img, imgpath


    # ============================
    def _processAddNoise(self, img, imgpath, mean=0.0, stddev=0.1):

        def addnoise(img):
            weight = tf.random.uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)
            gnoise = tf.random.normal(shape=tf.shape(img), mean=mean, stddev=stddev, dtype=tf.float32)
            return tf.add(img, gnoise*weight)

        def nonoise(img):
            return img

        choice = tf.random.uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)
        img = tf.cond(
            choice < 0.5,
            lambda: addnoise(img),
            lambda: nonoise(img)
        )

        return img, imgpath

    # ============================
    def _processAugmentColor(self, img, imgpath,
                             rand_hue=0.01, rand_saturation=[0.8,1.2],
                             rand_brightness=0.01, rand_contrast=[0.8,1.1], **kwargs):

        if self.ic == 3:
            img = tf.image.random_hue(img, rand_hue)
            img = tf.image.random_saturation(img, rand_saturation[0], rand_saturation[1])
            
        img = tf.image.random_brightness(img, rand_brightness)
        img = tf.image.random_contrast(img, rand_contrast[0], rand_contrast[1])
        
        return img, imgpath


   # ============================
    def _processAugmentCrop(self, img, imgpath, rand_scales=[0.7, 1.0, 0.01], **kwargs):

        # Generate 20 crop settings, ranging from a 1% to 20% crop.
        def cropimage(img, width, height):
            scales = np.arange(rand_scales[0], rand_scales[1], rand_scales[2])
            cropboxes = np.zeros((len(scales), 4))
            for i, scale in enumerate(scales):
                cx1 = cy1 = 0.5 - (0.5 * scale)
                cx2 = cy2 = 0.5 + (0.5 * scale)
                cropboxes[i] = [cx1, cy1, cx2, cy2]

            cropboxes = tf.convert_to_tensor(cropboxes, dtype=tf.float32)

            # Create different crops for an image
            crops = tf.image.crop_and_resize(
                [img],
                boxes=cropboxes,
                box_indices=np.zeros(cropboxes.shape[0]),
                crop_size=(height, width)
            )

            # Return a random crop
            idx = tf.random.uniform(shape=[], minval=0, maxval=cropboxes.shape[0], dtype=tf.int32)
            return crops[idx,:,:,:]

        def nocrop(img):
            return img

        # =======================
        choice = tf.random.uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)
        img = tf.cond(
            choice < 0.5,
            lambda: nocrop(img),
            lambda: cropimage(img, width=self.iw, height=self.ih)
        )

        return img, imgpath


    # ============================
    def _processAugmentFlip(self, img, imgpath):

        # Flip
        def flip(img):
            img = tf.image.flip_left_right(img)
            return img

        def noflip(img):
            return img

        choice = tf.random.uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)
        img = tf.cond(choice < 0.5,
            lambda: noflip(img),
            lambda: flip(img)
        )

        return img, imgpath

    # ============================
    def _processAugmentFlipVertically(self, img, imgpath):

        # Flip
        def flip(img):
            img = tf.image.flip_up_down(img)
            return img

        def noflip(img):
            return img

        choice = tf.random.uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)
        img = tf.cond(choice < 0.5,
            lambda: noflip(img),
            lambda: flip(img)
        )

        return img, imgpath


In [None]:
# Initialize both Generators

datapipeGenTrain = DatapipeGenerator(datapath=datapathTrain)
datapipeGenTest  = DatapipeGenerator(datapath=datapathTest)
datapipeGenAnomalous = DatapipeGenerator(datapath=datapathAnomalous)
datapipeGenVeryAnomalous = DatapipeGenerator(datapath=datapathVeryAnomalous)

dpTrain = datapipeGenTrain.create(
    iw=imageWidth,
    ih=imageHeight,
    ic=imageChannels,
    batchSize=batchSize,
    augmentations=['fliph', 'flipv', 'color', 'crop', 'noise']
)

dpTest = datapipeGenTest.create(
    iw=imageWidth,
    ih=imageHeight,
    ic=imageChannels,
    batchSize=batchSize,
    augmentations=[] # No augmentations on the test data set!
)

dpAnomalous = datapipeGenAnomalous.create(
    iw=imageWidth,
    ih=imageHeight,
    ic=imageChannels,
    batchSize=batchSize,
    augmentations=[] # No augmentations on the test data set!
)

dpVeryAnomalous = datapipeGenVeryAnomalous.create(
    iw=imageWidth,
    ih=imageHeight,
    ic=imageChannels,
    batchSize=batchSize,
    augmentations=[] # No augmentations on the test data set!
)

In [None]:
# Let's show how to use and test our pipeline

for it, (imgs, paths) in enumerate(dpTrain):
    
    fig, axs = plt.subplots(1,4, figsize=(15,15))
    for b in range(4):
        axs[b].imshow(imgs[b,...].numpy())
        
    plt.show()
    
    break # Break or wait until the end of days...

# Let's build a wrapper for the autencoder

Wrap the keras model and all methods required for training in one class

In [None]:
# This time we outsource some stuff...

from bottleneck import Nonbt1d    # Bottleneck module
from downsample import Downsample # Downsample module
from upsample import Upsample     # Upsample module


class AutoencoderWrapper:
    def __init__(self, iw, ih, ic=3, learnRate=0.001):

        self.iw = iw
        self.ih = ih
        self.ic = ic

        self.model = None

        # Build the model
        self._buildModel()
        
        # Pick an optimizer
        self.optimizer = tf.keras.optimizers.Adam(learnRate)

    def _buildModel(self):
        """Build the modelWithoutPosthead"""
        
        # Feature extractor
        inputs = Input((self.ih, self.iw, self.ic))
        
        # Encoder
        d1 = Downsample(8, kernelsize=(3,3), maxpoolsize=(2,2), strides=(2,2))(inputs)
        d2 = Downsample(16, kernelsize=(3,3), maxpoolsize=(2,2), strides=(2,2))(d1)
        d3 = Downsample(32, kernelsize=(3,3), maxpoolsize=(2,2), strides=(2,2))(d2)
        d4 = Downsample(64, kernelsize=(3,3), maxpoolsize=(2,2), strides=(2,2))(d3)
        b4 = Nonbt1d(nfilters=64)(d4)

        
        # Decoder
        u19 = Upsample(32, (2, 2), strides=(2, 2), padding='same')(b4)
        u20 = Upsample(16, (2, 2), strides=(2, 2), padding='same')(u19)
        u20 = Dropout(0.5)(u20)
        b21 = Nonbt1d(nfilters=16)(u20)
        u22 = Upsample(8, (2, 2), strides=(2, 2), padding='same')(b21)
        u22 = Dropout(0.5)(u22)
        b23 = Nonbt1d(nfilters=8)(u22)
        u22 = Upsample(self.ic, (2, 2), strides=(2, 2), padding='same', activation='sigmoid')(b23)

        # Model with final heatmap classification
        self.model = Model(inputs=[inputs], outputs=u22)



    @tf.function
    def loss(self, ytrue, ypred):
        return tf.reduce_mean(tf.square(ypred-ytrue))
  
    @tf.function
    def trainStep(self, imgs, ytrue):
        
        # Calculate gradients of loss wrt to imgs
        with tf.GradientTape() as t:
            ypred = self.model(imgs)
            loss = self.loss(ytrue, ypred)

        # Change weightsPretrainedImageNet
        grads = t.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

        return ypred, loss

    @tf.function
    def testStep(self, imgs, ytrue):
        ypred = self.model(imgs)
        loss = self.loss(ytrue, ypred)
        return ypred, loss

In [None]:
# Instantiate our Model Wrap
autoencoderWrap = AutoencoderWrapper(
    iw = imageWidth,
    ih = imageHeight,
    ic = imageChannels,
    learnRate = learningrate
)

# print model summary
print(autoencoderWrap.model.summary())

# Let's train our model

In [None]:
# Trainingsloop starts here

lossTrain, lossTest = [[],[]], [[],[]]
bestTestLoss, earlyStoppingCtr = 1e+4, 0

# Load weights if exist
if os.path.isfile("./weightsBest.h5"):
    try:
         autoencoderWrap.model.load_weights("./weightsBest.h5")
    except:
        print("Wegight load failed! Have you changed your model architecture?")
        
for e in range(nepoches):
    print(f"Starting epoche {e}")


    # =====================
    # Run Training data set
    print("Start Traing")
    
    # Loop through dataset
    for it, (imgs, paths) in enumerate(dpTrain):
        
        # Run a train step
        pred, loss = autoencoderWrap.trainStep(imgs=imgs, ytrue=imgs)
        
        # Log results
        lossTrain[0].append(e * len(dpTrain) + it)
        lossTrain[1].append(loss.numpy())
    
        # Write out intermediate results
        if (e * len(dpTrain) + it) % intermediateResults == 0:
                
            print(f"Train - Epoche: {e}, Iteration: {it}/{len(dpTrain)}, TrainLoss: {loss.numpy():.7f}")
    
    
    # =====================
    # Run test data set:
    print("Start Testing")
    
    lossMean = 0
    
    # Loop through dataset
    for it, (imgs, paths) in enumerate(dpTest):
            
        # Run a test step
        pred, loss = autoencoderWrap.testStep(imgs=imgs, ytrue=imgs)
        lossMean += loss.numpy()/len(dpTest)

    lossTest[0].append((e+1) * len(dpTrain))
    lossTest[1].append(lossMean)
    
    
    # Finally plot some images
    imgs = imgs.numpy()

    fig, axs = plt.subplots(2,4, figsize=(15,7))
    for b in range(min([4,imgs.shape[0]])):
        axs[0,b].imshow(imgs[b,...])
        axs[1,b].imshow(pred[b,...])

    plt.show()
    
    # For early stopping
    if lossMean < bestTestLoss:
        bestTestLoss = lossMean
        earlyStoppingCtr = 0
        autoencoderWrap.model.save_weights("./weightsBest.h5")
        
    else:
        earlyStoppingCtr += 1
    
    print(f"Test  - Epoche: {e}, TestLoss: {lossMean:.7f}, BestLoss: {bestTestLoss:.7f}, EarlyStoppingCtr {earlyStoppingCtr}/{patience}")
    
    
    # =====================
    # Early Stopping
    if earlyStoppingCtr >= patience:
        print("Maximum patience reached. Stopping training")
        break
        
    # =====================
    # Plot losses
    fig, axs = plt.subplots(1, figsize=(15,7))
    axs.plot(lossTrain[0],lossTrain[1],'b-')
    axs.plot(lossTest[0],lossTest[1],'r-')
    axs.set_yscale('log')
    plt.show()

# Postprocessing

In [None]:
# Load weights if exist
if os.path.isfile("./weightsBest.h5"):
    autoencoderWrap.model.load_weights("./weightsBest.h5")


    
# Loop through dataset
display(HTML('<h1>Normal data</h1>'))
for it, (imgs, paths) in enumerate(dpTest):

    # Predict
    pred = autoencoderWrap.model.predict(imgs)
    break
    
# Finally plot some images
imgs = imgs.numpy()
error = np.clip(np.abs(pred-imgs).mean(axis=-1), 0, 1)

fig, axs = plt.subplots(3,4, figsize=(15,7))
for b in range(min([4,imgs.shape[0]])):
    axs[0,b].imshow(imgs[b,...])
    axs[0,b].set_xticks([])
    axs[0,b].set_yticks([])
    axs[0,b].title.set_text(f"Original")
    axs[1,b].imshow(pred[b,...])
    axs[1,b].set_xticks([])
    axs[1,b].set_yticks([])
    axs[1,b].title.set_text(f"Reconstructed")
    axs[2,b].imshow(error[b,...], cmap="gray")
    axs[2,b].title.set_text(f"Error: {np.mean(error[b,...]):.3f}")
    axs[2,b].set_xticks([])
    axs[2,b].set_yticks([])
plt.show()


    
# Loop through dataset
display(HTML('<h1>Anomalous data</h1>'))
for it, (imgs, paths) in enumerate(dpAnomalous):

    # Predict
    pred = autoencoderWrap.model.predict(imgs)
    break
    
# Finally plot some images
imgs = imgs.numpy()
error = np.clip(np.abs(pred-imgs).mean(axis=-1), 0, 1)

fig, axs = plt.subplots(3,4, figsize=(15,7))
for b in range(min([4,imgs.shape[0]])):
    axs[0,b].imshow(imgs[b,...])
    axs[0,b].set_xticks([])
    axs[0,b].set_yticks([])
    axs[0,b].title.set_text(f"Original")
    axs[1,b].imshow(pred[b,...])
    axs[1,b].set_xticks([])
    axs[1,b].set_yticks([])
    axs[1,b].title.set_text(f"Reconstructed")
    axs[2,b].imshow(error[b,...], cmap="gray")
    axs[2,b].title.set_text(f"Error: {np.mean(error[b,...]):.3f}")
    axs[2,b].set_xticks([])
    axs[2,b].set_yticks([])
plt.show()

    
    
# Loop through dataset
display(HTML('<h1>Very anomalous data</h1>'))
for it, (imgs, paths) in enumerate(dpVeryAnomalous):

    # Predict
    pred = autoencoderWrap.model.predict(imgs)

    
# Finally plot some images
imgs = imgs.numpy()
error = np.clip(np.abs(pred-imgs).mean(axis=-1), 0, 1)

fig, axs = plt.subplots(3,4, figsize=(15,7))
for b in range(min([4,imgs.shape[0]])):
    axs[0,b].imshow(imgs[b,...])
    axs[0,b].set_xticks([])
    axs[0,b].set_yticks([])
    axs[0,b].title.set_text(f"Original")
    axs[1,b].imshow(pred[b,...])
    axs[1,b].set_xticks([])
    axs[1,b].set_yticks([])
    axs[1,b].title.set_text(f"Reconstructed")
    axs[2,b].imshow(error[b,...], cmap="gray")
    axs[2,b].title.set_text(f"Error: {np.mean(error[b,...]):.3f}")
    axs[2,b].set_xticks([])
    axs[2,b].set_yticks([])
plt.show()