# TESTING THE PIXELATED IMAGE CORRECTION MODEL

## UPGRADING THE LIBRARIES TO THE REQUIRED VERSIONS

In [None]:
%pip install --upgrade pip
%pip install --upgrade tensorflow
%pip install --upgrade scikit-learn
%pip install --upgrade pydot
%pip install --upgrade matplotlib

## IN CASE YOU ARE FACING ISSUES WHILE LOADING MODELS OR IMPORTING LIBRARIES YOU CAN
## USE THE VERSIONS OF THESE LIBRARIES ON WHICH OUR MODEL WAS CREATED

# %pip install --upgrade pip==24.1.1
# %pip install --upgrade ipykernel==5.5.6
# %pip install --upgrade numpy==1.25.2
# %pip install --upgrade pandas==2.0.3
# %pip install --upgrade tensorflow==2.16.2
# %pip install --upgrade keras==3.4.1
# %pip install --upgrade scikit-learn==1.5.0
# %pip install --upgrade matplotlib==3.9.0
# %pip install --upgrade pydot==2.0.0

IMPORTING THE LIBRARIES

In [1]:
import os
import tensorflow as tf
import tensorflow.keras as tfk
from tensorflow.keras.applications import VGG19, vgg19

## REQUIRED FUNCTIONS

These are custom functions required by `depixelator` model

In [2]:
vgg = VGG19(include_top=False, weights='imagenet', input_shape=(None, None, 3))
vgg_model = tfk.Model(inputs=vgg.input, outputs=[vgg.get_layer('block3_conv4').output, vgg.get_layer('block4_conv1').output, vgg.get_layer('block4_conv2').output])
vgg_model.trainable = False

@tfk.utils.register_keras_serializable()
def compute_perceptual_loss(y_true, y_pred):
    y_true = vgg19.preprocess_input(y_true * 255.0)
    y_pred = vgg19.preprocess_input(y_pred * 255.0)

    true_features = vgg_model(y_true)
    pred_features = vgg_model(y_pred)

    loss = 0.0
    for t, p in zip(true_features, pred_features):
        loss += tf.reduce_mean(tf.abs(t - p))
    return loss

@tfk.utils.register_keras_serializable()
def compute_psnr(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

@tfk.utils.register_keras_serializable()
def compute_ssim(y_true, y_pred):
    return tf.image.ssim(y_true, y_pred, max_val=1.0)

These is the definition of the `depixelate_images` function

In [3]:
def depixelate_images(detector, depixelator, pixelated_image_paths, depixelated_dir, real_image_paths=[]):
    """
    Depixelate a list of pixelated images using the specified detector and depixelator models.
    Optionally, compare the depixelated images with real images if provided.

    Parameters:
    - detector: The model used to detect pixelated regions.
    - depixelator: The model used to depixelate the images.
    - pixelated_image_paths: List of file paths to the pixelated images.
    - depixelated_dir: Directory where depixelated images will be saved.
    - real_image_paths: Optional list of file paths to the real images for comparison.
    """

    def load_process_image(image_path):
        """
        Load and preprocess the image to ensure it is compatible with the model input requirements.

        Parameters:
        - image_path: Path to the image file.

        Returns:
        - Preprocessed image tensor.
        """
        img = tf.io.read_file(image_path)
        img = tf.image.decode_image(img, channels=3, expand_animations=False, dtype=tf.float32)

        # Calculate padding to make dimensions multiples of 64
        shape = tf.shape(img)
        height, width = shape[0], shape[1]
        padding_width  = ((width  + 63) // 64) * 64 - width
        padding_height = ((height + 63) // 64) * 64 - height
        left  = padding_width  // 2
        upper = padding_height // 2

        # Pad the image and expand dimensions to match the model input
        img = tf.image.pad_to_bounding_box(img, upper, left, height + padding_height, width + padding_width)
        return tf.expand_dims(img, axis=0)

    def depixelate_image(image):
        """
        Depixelate the image if the detector model identifies pixelation.

        Parameters:
        - image: Input image tensor.

        Returns:
        - Depixelated image tensor if pixelation is detected, otherwise returns the original image.
        """
        if detector(image, training=False) > 0:
            return depixelator(image, training=False)
        else:
            return image

    # Ensure the output directory exists
    os.makedirs(depixelated_dir, exist_ok=True)

    # Check if real images are available for comparison
    real_available = len(real_image_paths) > 0

    for i, pixelated_image_path in enumerate(pixelated_image_paths):

        # Load and preprocess the pixelated image
        pixelated_image = load_process_image(pixelated_image_path)
        depixelated_image = depixelate_image(pixelated_image)

        if real_available:

            # print('Depixelating... ', pixelated_image_path)
            # print('Ground Truth:', real_image_paths[i])

            # Load and preprocess the real image
            real_image = load_process_image(real_image_paths[i])

            # Calculate and print initial and final metrics
            # print(
            #     'Initial Metrics:',
            #     'Perceptual Loss:', compute_perceptual_loss(real_image, pixelated_image).numpy(),
            #     'PSNR:', compute_psnr(real_image, pixelated_image).numpy()[0],
            #     'SSIM:', compute_ssim(real_image, pixelated_image).numpy()[0]
            # )
            print(
                # 'Final Metrics:',
                'Perceptual Loss:', compute_perceptual_loss(real_image, depixelated_image).numpy(),
                'PSNR:', compute_psnr(real_image, depixelated_image).numpy()[0],
                'SSIM:', compute_ssim(real_image, depixelated_image).numpy()[0]
            )

        # Save the depixelated image
        depixelated_image_path = os.path.join(depixelated_dir, 'pred_' + os.path.basename(pixelated_image_path))
        print('Depixelated Image Saved As:', depixelated_image_path, '\n')

        depixelated_image = tf.image.convert_image_dtype(tf.squeeze(depixelated_image, axis=0), dtype=tf.uint8)
        depixelated_image = tf.image.encode_png(depixelated_image)
        tf.io.write_file(depixelated_image_path, depixelated_image)

## TODO: IMAGE PATHS (important)

- tester has to implement this cell based on how his testing data is organized
- to depixelate images we require `pixelated_image_paths` (mandatory) which is the paths (not names) to all of your pixelated images
- you can also implement `real_image_paths` (optional) if you also have ground truth images for some metrics comparison
- make sure a pixelated image and the corresponding ground truth image has same index on both the list

In [4]:
# pixelated_image_dir = './test/'
# # real_image_dir = './'

# pixelated_image_basenames = sorted(os.listdir(pixelated_image_dir))
# pixelated_image_paths = [os.path.join(pixelated_image_dir, bname) for bname in pixelated_image_basenames]

# # real_image_paths = [os.path.join(real_image_dir, bname.split('_')[0] + '.png') for bname in pixelated_image_basenames]

In [5]:
# check if your implementation working correctly

print(len(pixelated_image_paths))
# print(len(real_image_paths))

print(pixelated_image_paths)
# print(real_image_paths)

10
10
['./test/flower_d2.png', './test/flower_d3.png', './test/flower_d4.png', './test/flower_d5.png', './test/flower_d6.png', './test/flower_d7.png', './test/flower_d8.png', './test/flower_j10.png', './test/flower_j18.png', './test/flower_j28.png']
['./flower.png', './flower.png', './flower.png', './flower.png', './flower.png', './flower.png', './flower.png', './flower.png', './flower.png', './flower.png']


## LOADING MODELS

- models are loaded
- and a `depixelated_dir` is to be defined where the depixelated images are saved

In [6]:
depixelator = tf.keras.models.load_model('./depixelator_004_2.keras')
detector    = tf.keras.models.load_model('./detector_005_1_50.keras')

depixelated_dir = './pred_test/'

## DEPIXELATION

- run the `depixelate_images` function with the loaded `detector`, `depixelator` model and `pixelated_image_paths` and `depixelated_dir`
- all the depixelated images will be saved on `depixelated_dir`
- optionally if you also pass a `real_image_paths` this function will also show some metrics (PSNR, SSIM, Pixel Loss, Perceptual Loss) but this will slow down the depixelation process

In [7]:
depixelate_images(detector, depixelator, pixelated_image_paths, depixelated_dir)

# depixelate_images(detector, depixelator, pixelated_image_paths, depixelated_dir, real_image_paths)

Depixelated Image Saved As: ./pred_test/pred_flower_d2.png 

Depixelated Image Saved As: ./pred_test/pred_flower_d3.png 

Depixelated Image Saved As: ./pred_test/pred_flower_d4.png 

Depixelated Image Saved As: ./pred_test/pred_flower_d5.png 

Depixelated Image Saved As: ./pred_test/pred_flower_d6.png 

Depixelated Image Saved As: ./pred_test/pred_flower_d7.png 

Depixelated Image Saved As: ./pred_test/pred_flower_d8.png 

Depixelated Image Saved As: ./pred_test/pred_flower_j10.png 

Depixelated Image Saved As: ./pred_test/pred_flower_j18.png 

Depixelated Image Saved As: ./pred_test/pred_flower_j28.png 



In [None]:
from google.colab import runtime
runtime.unassign()