In [None]:
!pip install --user "image-quality[dataset]>=1.2.4"

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, GlobalAveragePooling2D
from notebooks.utils import (
    show_images,
    gaussian_filter,
    image_normalization,
    rescale,
    image_shape)
import imquality.datasets

In [None]:
print(f'tensorflow version {tf.__version__}')

# Introduction

In this tutorial, we will implement the Deep CNN-Based Blind Image Quality Predictor (DIQA) methodology proposed by Jongio Kim, Anh-Duc Nguyen, and Sanghoon Lee [1]. Also, I will go through the following TensorFlow 2.0 concepts:
- Download and prepare a dataset using a *tf.data.Dataset builder*.
- Define a TensorFlow input pipeline to pre-process the dataset records using the *tf.data* API.
- Create the CNN model using the *tf.keras* functional API.
- Define a custom training loop for the objective error map model.
- Train the objective error map and subjective score model.
- Use the trained subjective score model to make predictions.

*Note: Some of the functions are implemented in [utils.py](https://github.com/ocampor/image-quality/blob/master/notebooks/utils.py) as they are out of the guide's scope.*

## What is DIQA?

DIQA is an original proposal that focuses on solving some of the most concerning challenges of applying deep learning to image quality assessment (IQA). The advantages against other methodologies are:

- The model is not limited to work exclusively with Natural Scene Statistics (NSS) images [1].
- Prevents overfitting by splitting the training into two phases (1) feature learning and (2) mapping learned features to subjective scores.

## Problem

The cost of generating datasets for IQA is high since it requires expert supervision. Therefore, the fundamental IQA benchmarks are comprised of solely a few thousands of records. The latter complicates the creation of deep learning models because they require large amounts of training samples to generalize.

As an example, let's consider the most frequently used datasets to train and evaluate IQA methods  [Live](https://live.ece.utexas.edu/research/quality/subjective.htm),  [TID2008](http://www.ponomarenko.info/tid2008.htm), [TID2013](http://www.ponomarenko.info/tid2013.htm), [CSIQ](http://vision.eng.shizuoka.ac.jp/mod/page/view.php?id=23). An overall summary of each dataset is contained in the next table:

| Dataset | References | Distortions | Severity | Total Samples |
|---------|------------|-------------|----------|---------------|
| LiveIQA | 29         | 5           | 5        | 1011          |
| TID2008 | 25         | 17          | 5        | 1701          |
| TID2013 | 25         | 24          | 5        | 3025          |
| CSIQ    | 30         | 6           | 5        | 930           |

The total amount of samples does not exceed 4,000 records for any of them.

# Dataset

The IQA benchmarks only contain a limited amount of records that might not be enough to train a CNN. However, for this guide purpose, we are going to use the [Live](https://live.ece.utexas.edu/research/quality/subjective.htm) dataset. It is comprised of 29 reference images, and 5 different distortions with 5 severity levels each.

The first task is to download and prepare the dataset. I have created a couple of TensorFlow dataset builders
for image quality assessment and published them in the [image-quality](https://github.com/ocampor/image-quality) package. The builders
are an interface defined by [tensorflow-datasets](https://www.tensorflow.org/datasets). 

*Note: This process might take several minutes because of the size of the dataset (700 megabytes).*

In [None]:
builder = imquality.datasets.LiveIQA()
builder.download_and_prepare()

After downloading and preparing the data, turn the builder into a dataset, and shuffle it. Note that the batch is equal to 1. The reason is that each image has a different shape. Increasing the batch TensorFlow will raise an error.

In [None]:
ds = builder.as_dataset(shuffle_files=True)['train']
ds = ds.shuffle(1024).batch(1)

The output is a generator; therefore, to access it using the bracket operator [ ] causes an error. There are two ways to access the images in the generator. The first way is to turn the generator into an iterator and extract a single sample at a time using the *next* function.

In [None]:
next(iter(ds)).keys()

As you can see, the output is a dictionary that contains the tensor representation for the distorted image,  the reference image, and the subjective score (dmos).

Another way is to extract samples from the generator by taking samples with a for loop:

In [None]:
for features in  ds.take(2):
    distorted_image = features['distorted_image']
    reference_image = features['reference_image']
    dmos = tf.round(features['dmos'][0], 2)
    distortion = features['distortion'][0]
    print(f'The distortion of the image is {dmos} with'
          f' a distortion {distortion} and shape {distorted_image.shape}')
    show_images([reference_image, distorted_image])

# Methodology

## Image Normalization

The first step for DIQA is to pre-process the images. The image is converted into grayscale, and then a low-pass filter is applied. The low-pass filter is defined as:

\begin{align*}
\hat{I} = I_{gray} - I^{low}
\end{align*}

where the low-frequency image is the result of the following algorithm:

1. Blur the grayscale image.
2. Downscale it by a factor of 1 / 4.
3. Upscale it back to the original size.

The main reasons for this normalization are (1) the Human Visual System (HVS) is not sensitive to changes in the low-frequency band, and (2) image distortions barely affect the low-frequency component of images.

In [None]:
def image_preprocess(image: tf.Tensor) -> tf.Tensor:
    image = tf.cast(image, tf.float32)
    image = tf.image.rgb_to_grayscale(image)
    image_low = gaussian_filter(image, 16, 7 / 6)
    image_low = rescale(image_low, 1 / 4, method=tf.image.ResizeMethod.BICUBIC)
    image_low = tf.image.resize(image_low, size=image_shape(image), method=tf.image.ResizeMethod.BICUBIC)
    return image - tf.cast(image_low, image.dtype)

In [None]:
for features in ds.take(2):
    distorted_image = features['distorted_image']
    reference_image = features['reference_image']
    I_d = image_preprocess(distorted_image)
    I_d = tf.image.grayscale_to_rgb(I_d)
    I_d = image_normalization(I_d, 0, 1)
    show_images([reference_image, I_d])

**Fig 1.** On the left, the original image. On the right, the image after applying the low-pass filter.

## Objective Error Map

For the first model, objective errors are used as a proxy to take advantage of the effect of increasing data. The loss function is defined by the mean squared error between the predicted and ground-truth error maps.

\begin{align*}
\mathbf{e}_{gt} = err(\hat{I}_r, \hat{I}_d)
\end{align*}

and *err(·)* is an error function. For this implementation, the authors recommend using

\begin{align*}
\mathbf{e}_{gt} = | \hat{I}_r -  \hat{I}_d | ^ p
\end{align*}

with *p=0.2*. The latter is to prevent that the values in the error map are small or close to zero.

In [None]:
def error_map(reference: tf.Tensor, distorted: tf.Tensor, p: float=0.2) -> tf.Tensor:
    assert reference.dtype == tf.float32 and distorted.dtype == tf.float32, 'dtype must be tf.float32'
    return tf.pow(tf.abs(reference - distorted), p)

In [None]:
for features in ds.take(3):
    reference_image = features['reference_image'] 
    I_r = image_preprocess(reference_image)
    I_d = image_preprocess(features['distorted_image'])
    e_gt = error_map(I_r, I_d, 0.2)
    I_d = image_normalization(tf.image.grayscale_to_rgb(I_d), 0, 1)
    e_gt = image_normalization(tf.image.grayscale_to_rgb(e_gt), 0, 1)
    show_images([reference_image, I_d, e_gt])

**Fig 2.** On the left, the original image. In the middle, the pre-processed image, and finally, the image representation of the error map.

## Reliability Map

According to the authors, the model is likely to fail to predict images with homogeneous regions. To prevent it, they propose a reliability function. The assumption is that blurry areas have lower reliability than textured ones. The reliability function is defined as

\begin{align*}
\mathbf{r} = \frac{2}{1 + exp(-\alpha|\hat{I}_d|)} - 1
\end{align*}

where α controls the saturation property of the reliability map. The positive part of a sigmoid is used to assign sufficiently large values to pixels with low intensity.

In [None]:
def reliability_map(distorted: tf.Tensor, alpha: float) -> tf.Tensor:
    assert distorted.dtype == tf.float32, 'The Tensor must by of dtype tf.float32'
    return 2 / (1 + tf.exp(- alpha * tf.abs(distorted))) - 1

The previous definition might directly affect the predicted score. Therefore, the average reliability map is used instead.

\begin{align*}
\mathbf{\hat{r}} = \frac{1}{\frac{1}{H_rW_r}\sum_{(i,j)}\mathbf{r}(i,j)}\mathbf{r}
\end{align*}

For the Tensorflow function, we just calculate the reliability map and divide it by its mean.

In [None]:
def average_reliability_map(distorted: tf.Tensor, alpha: float) -> tf.Tensor:
    r = reliability_map(distorted, alpha)
    return r / tf.reduce_mean(r)

In [None]:
for features in ds.take(2):
    reference_image = features['reference_image'] 
    I_d = image_preprocess(features['distorted_image'])
    r = average_reliability_map(I_d, 1)
    r = image_normalization(tf.image.grayscale_to_rgb(r), 0, 1)
    show_images([reference_image, r], cmap='gray')

**Fig 3.** On the left, the original image, and on the right, its average reliability map.

## Loss function

The loss function is defined as the mean square error of the product between the reliability map and the objective error map. The error is the difference between the predicted error map and the ground-truth error map.

\begin{align*}
\mathcal{L}_1(\hat{I}_d; \theta_f, \theta_g) = ||g(f(\hat{I}_d, \theta_f), \theta_g) - \mathbf{e}_{gt}) \odot \mathbf{\hat{r}}||^2_2
\end{align*}

The loss function requires to multiply the error by the reliability map; therefore, we cannot use the default loss implementation *tf.loss.MeanSquareError*.


In [None]:
def loss(model, x, y_true, r):
    y_pred = model(x)
    return tf.reduce_mean(tf.square((y_true - y_pred) * r))

After creating the custom loss, we need to tell TensorFlow how to differentiate it. The good thing is that we can take advantage of [automatic differentiation](https://www.tensorflow.org/tutorials/customization/autodiff) using *tf.GradientTape*.

In [None]:
def gradient(model, x, y_true, r):
    with tf.GradientTape() as tape:
        loss_value = loss(model, x, y_true, r)
    return loss_value, tape.gradient(loss_value, model.trainable_variables)

## Optimizer
The authors suggested using a Nadam optimizer with a learning rate of *2e-4*.

In [None]:
optimizer = tf.optimizers.Nadam(learning_rate=2 * 10 ** -4)

# Training

## Objective Error Model
For the training phase, it is convenient to utilize the *tf.data* input pipelines to produce a much cleaner and readable code. The only requirement is to create the function to apply to the input.

In [None]:
def calculate_error_map(features):
    I_d = image_preprocess(features['distorted_image'])
    I_r = image_preprocess(features['reference_image'])
    r = rescale(average_reliability_map(I_d, 0.2), 1 / 4)
    e_gt = rescale(error_map(I_r, I_d, 0.2), 1 / 4)
    return (I_d, e_gt, r)

Then, map the *tf.data.Dataset* to the *calculate_error_map* function.

In [None]:
train = ds.map(calculate_error_map)

Applying the transformation is executed in almost no time. The reason is that the processor is not performing any operation to the data yet, it happens on demand. This concept is commonly called [lazy-evaluation](https://wiki.python.org/moin/Generators).

So far, the following components are implemented:
1. The generator that pre-processes the input and calculates the target.
2. The loss and gradient functions required for the custom training loop.
3. The optimizer function.

The only missing bits are the models' definition. 


![alt text](https://d3i71xaburhd42.cloudfront.net/4b1f961ae1fac044c23c51274d92d0b26722f877/4-Figure2-1.png "CNN architecture")



**Fig 4.** Architecture of the objective error model and subjective score model.

In the previous image, it is depicted how:
- The pre-processed image gets into the convolutional neural network (CNN). 
- It is transformed by 8 convolutions with the Relu activation function and "same" padding. This is defined as f(·).
- The output of f(·) is processed by the last convolution with a linear activation function. This is defined as g(·).

In [None]:
input = tf.keras.Input(shape=(None, None, 1), batch_size=1, name='original_image')
f = Conv2D(48, (3, 3), name='Conv1', activation='relu', padding='same')(input)
f = Conv2D(48, (3, 3), name='Conv2', activation='relu', padding='same', strides=(2, 2))(f)
f = Conv2D(64, (3, 3), name='Conv3', activation='relu', padding='same')(f)
f = Conv2D(64, (3, 3), name='Conv4', activation='relu', padding='same', strides=(2, 2))(f)
f = Conv2D(64, (3, 3), name='Conv5', activation='relu', padding='same')(f)
f = Conv2D(64, (3, 3), name='Conv6', activation='relu', padding='same')(f)
f = Conv2D(128, (3, 3), name='Conv7', activation='relu', padding='same')(f)
f = Conv2D(128, (3, 3), name='Conv8', activation='relu', padding='same')(f)
g = Conv2D(1, (1, 1), name='Conv9', padding='same', activation='linear')(f)

objective_error_map = tf.keras.Model(input, g, name='objective_error_map')

objective_error_map.summary()

For the custom training loop, it is necessary to:

1. Define a metric to measure the performance of the model.
2. Calculate the loss and the gradients.
3. Use the optimizer to update the weights.
4. Print the accuracy.

In [None]:
for epoch in range(1):
    epoch_accuracy = tf.keras.metrics.MeanSquaredError()
    
    step = 0
    for I_d, e_gt, r in train:
        loss_value, gradients = gradient(objective_error_map, I_d, e_gt, r)
        optimizer.apply_gradients(zip(gradients, objective_error_map.trainable_weights))
        
        epoch_accuracy(e_gt, objective_error_map(I_d))

        if step % 100 == 0:
            print('step %s: mean loss = %s' % (step, epoch_accuracy.result()))
        
        step += 1

*Note: It would be a good idea to use the Spearman’s rank-order correlation coefficient (SRCC) or Pearson’s linear correlation coefficient (PLCC) as accuracy metrics.*

# Subjective Score Model

To create the subjective score model, let's use the output of f(·) to train a regressor. 

In [None]:
v = GlobalAveragePooling2D(data_format='channels_last')(f)
h = Dense(128, activation='relu')(v)
h = Dense(1)(h)
subjective_error = tf.keras.Model(input, h, name='subjective_error')

subjective_error.compile(
    optimizer=optimizer,
    loss=tf.losses.MeanSquaredError(),
    metrics=[tf.metrics.MeanSquaredError()])

subjective_error.summary()

Training a model with the fit method of *tf.keras.Model* expects a dataset that returns two arguments. The first one is the input, and the second one is the target.

In [None]:
def calculate_subjective_score(features):
    I_d = image_preprocess(features['distorted_image'])
    mos = features['dmos']
    return (I_d, mos)

train = ds.map(calculate_subjective_score)

Then, *fit* the subjective score model.

In [None]:
history = subjective_error.fit(train, epochs=1)

# Prediction

Predicting with the already trained model is simple. Just use the *predict* method in the model.

In [None]:
sample = next(iter(ds))
I_d = image_preprocess(sample['distorted_image'])
target = sample['dmos'][0]
prediction = subjective_error.predict(I_d)[0][0]

print(f'the predicted value is: {prediction:.4f} and target is: {target:.4f}')

# Conclusion

In this article, we learned how to utilize the tf.data module to create easy to read and memory-efficient data pipelines. Also, we implemented the Deep CNN-Based Blind Image Quality Predictor (DIQA) model using the functional Keras API. The model was trained with a custom training loop that uses the auto differentiation feature from TensorFlow.

The next step is to find the hyperparameters that maximize the PLCC or SRCC accuracy metrics and evaluate the overall performance of the model against other methodologies.

Another idea is to use a much larger dataset to train the objective error map model and see the resulting overall performance.

# Related Articles

If you want to learn more about image quality assessment methodologies, you can read.

http://bit.ly/advanced-iqa

Also, take a look at an image quality assessment method based on natural scene statistics and handcrafted features.

http://bit.ly/brisque-article

# Jupyter Notebook

http://bit.ly/train-diqa-github 

# Bibliography

[1] Kim, J., Nguyen, A. D., & Lee, S. (2019). Deep CNN-Based Blind Image Quality Predictor. IEEE Transactions on Neural Networks and Learning Systems. https://doi.org/10.1109/TNNLS.2018.2829819