# Set up environment

In [None]:
import math, re, os
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math
import random
import cv2
import urllib
from functools import partial
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing import image
import warnings
warnings.filterwarnings('ignore')
print("Tensorflow version " + tf.__version__)

# Set up environment and variables

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 64
CLASSES = ['crack', 'non crack']
EPOCHS = 100

# Define data loading methods

In [None]:
TRAINING_FILENAMES = tf.io.gfile.glob('../input/vgg-skz-crack-dataset/train/*/*.jpg')
VALID_FILENAMES = tf.io.gfile.glob('../input/vgg-skz-crack-dataset/validation/*/*.jpg')
TEST_FILENAMES = tf.io.gfile.glob('../input/vgg-skz-crack-dataset/test/*/*.jpg')

### Augmentation methods (adding random noise to the image)

In [None]:
def add_noise(image):
    VARIABILITY = 60
    deviation = VARIABILITY*random.random()
    noise = np.random.normal(0, deviation, image.shape)
    image += noise
    np.clip(image, 0., 255.)
    return image

In [None]:
train_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

AUTOTUNE = tf.data.experimental.AUTOTUNE

def get_training_dataset():
    dataset = train_datagen.flow_from_directory(
        '../input/vgg-skz-crack-dataset/train',
        class_mode='categorical',
        target_size=[256, 256],
        batch_size=BATCH_SIZE,
        shuffle=True,
    )
    
    return dataset


valid_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

def get_validation_dataset():
    dataset = valid_datagen.flow_from_directory(
        '../input/vgg-skz-crack-dataset/validation',
        class_mode='categorical',
        target_size=[256, 256],
        batch_size=BATCH_SIZE,
    )
    
    return dataset

test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
def get_test_dataset():
    dataset = test_datagen.flow_from_directory(
        '../input/vgg-skz-crack-dataset/test',
        target_size=[256, 256],
        batch_size=BATCH_SIZE,
    )
    
    return dataset

test2_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

def count_data_items(filenames):
    return len(filenames)

In [None]:
NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
NUM_VALIDATION_IMAGES = count_data_items(VALID_FILENAMES)
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)

print('Dataset: {} training images, {} validation images, {} test images'.format(
    NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))

In [None]:
# numpy and matplotlib defaults
np.set_printoptions(threshold=15, linewidth=80)

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[int(label)], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_image(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image.astype('uint8'))
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)

def display_batch_of_images(directory_iterator, predictions=None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = directory_iterator.next()
    labels = np.argmax(labels, axis=-1)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else CLASSES[int(label)]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], int(label))
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_image(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

In [None]:
ds_train = get_training_dataset()
ds_valid = get_validation_dataset()
display_batch_of_images(ds_train)
display_batch_of_images(ds_valid)

# Building the model

### Learning rate schedule

In [None]:
lr_scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-5, 
    decay_steps=10000, 
    decay_rate=0.9)

### Building the model with pretrained ResNet50 model

In [None]:
# img_adjust_layer = tf.keras.layers.Lambda(tf.keras.applications.resnet50.preprocess_input, input_shape=[0, *IMAGE_SIZE, 3])
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()


base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, pooling='avg')
base_model.trainable = False

model = tf.keras.Sequential([
    # Base
    base_model,
    
    # Head
#     tf.keras.layers.Flatten()
    tf.keras.layers.Dense(128, activation='linear'),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(2, activation='softmax')  
])

model.summary()

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=lr_scheduler, epsilon=0.001),
    loss='categorical_crossentropy',  
    metrics=['accuracy'])

### Train the model

In [None]:
early_stopping_callbacks = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)

In [None]:
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE
VALID_STEPS = NUM_VALIDATION_IMAGES // BATCH_SIZE

history = model.fit(ds_train, 
                    steps_per_epoch=STEPS_PER_EPOCH,
                    epochs=EPOCHS,
                    validation_data=ds_valid,
                    validation_steps=VALID_STEPS,
                    callbacks=[early_stopping_callbacks]
                )

# Evaluating model

In [None]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')

plt.figure()

plt.plot(epochs, loss, 'r', label='Training Loss')
plt.plot(epochs, val_loss, 'b', label='Validation Loss')
plt.title('Training and validation loss')

plt.legend()

plt.show()

# Save the model

In [None]:
model.save('./crack_detection_resnet50_model')

# Load the model

In [None]:
# model = tf.keras.models.load_model("../input/resnet-new-crack-detection/crack_detection_resnet50_model")

# Run predictions

In [None]:
ds_test = get_test_dataset()
STEP_SIZE_TEST = ds_test.n // ds_test.batch_size
ds_test.reset()
probabilities = model.predict(ds_test, steps=STEP_SIZE_TEST, verbose=1)
predictions = np.argmax(probabilities, axis=-1)

display_batch_of_images(ds_test, predictions)

### Run prediction on whole image

In [None]:
def predict_on_crops(input_image, https=False, height=256, width=256, save_crops = False):
    if https:
        req = urllib.request.urlopen(input_image)
        arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
        im = cv2.imdecode(arr, -1)
    else:
        im = cv2.imread(input_image)
        
    try:
        imgheight, imgwidth, channels = im.shape
    except:
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        imgheight, imgwidth, channels = im.shape
    k=0
    output_image = np.zeros_like(im)
    for i in range(0,imgheight,height):
        for j in range(0,imgwidth,width):
            a = im[i:i+height, j:j+width]
            a = np.expand_dims(a, axis=0)
            processed_a = test2_datagen.flow(a).next()
            ## discard image cropss that are not full size
            predicted_class = CLASSES[int(np.argmax(model.predict(processed_a), axis=-1))]
            ## save image
            file, ext = os.path.splitext(input_image)
            image_name = file.split('/')[-1]
            folder_name = 'out_' + image_name
            ## Put predicted class on the image
            if predicted_class == 'crack':
                color = (0,0, 255)
            else:
                color = (0, 255, 0)
            cv2.putText(a, predicted_class, (50,50), cv2.FONT_HERSHEY_SIMPLEX , 0.7, color, 1, cv2.LINE_AA) 
            b = np.zeros_like(a, dtype=np.uint8)
            b[:] = color
            add_img = cv2.addWeighted(a, 0.9, b, 0.1, 0, dtype=cv2.CV_64F)
            ## Save crops
            if save_crops:
                if not os.path.exists(os.path.join('predictions', folder_name)):
                    os.makedirs(os.path.join('predictions', folder_name))
                filename = os.path.join('predictions', folder_name,'img_{}.png'.format(k))
                cv2.imwrite(filename, add_img)
            output_image[i:i+height, j:j+width,:] = add_img
            k+=1
    ## Save output image
    cv2.imwrite(os.path.join('predictions', folder_name+ '.jpg'), output_image)
    
    plt.figure(figsize=(10,10))
    plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))

In [None]:
predict_on_crops('../input/crack-test/test_big/00001.jpg')