In [None]:
import os
import time
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import layers
from keras.layers import Layer, LeakyReLU
from keras.models import Model
from PIL import Image

# Custom Subpixel Convolutional Layer for upsampling
class SubpixelConv2D(Layer):
    """ Subpixel Convolutional Layer for upsampling """
    def __init__(self, scale=2, **kwargs):
        super(SubpixelConv2D, self).__init__(**kwargs)
        self.scale = scale

    def call(self, inputs):
        return tf.nn.depth_to_space(inputs, self.scale)

# Function to load and preprocess images
def load_and_preprocess_images(folder, image_size=(50, 50)):
    images = []
    filenames = []
    for filename in os.listdir(folder):
        img_path = os.path.join(folder, filename)
        if os.path.isfile(img_path):
            img = Image.open(img_path)
            img = img.resize(image_size)
            img_array = np.array(img)
            red_channel = img_array[:,:,0]  # Extract the red channel
            images.append(red_channel)  # Store red channel images
            filenames.append(filename)
    return np.array(images).reshape((-1, image_size[0], image_size[1], 1)) / 255.0, filenames

# Stochastic Quantization Layer (12-bit)
class StochasticQuantization(Layer):
    def __init__(self, num_bits=12):
        super(StochasticQuantization, self).__init__()
        self.num_bits = num_bits
        self.scale = tf.constant(2**num_bits - 1, dtype=tf.float32)

    def call(self, inputs, training=False):
        if training:
            noise = tf.random.uniform(shape=tf.shape(inputs), minval=-0.5, maxval=0.5)
            inputs = inputs + noise  # Add noise to help training
        quantized = tf.round(inputs * self.scale) / self.scale  # Scale and round
        return quantized


# Post-Processing Network
def post_processing_network(x):
    x = layers.Conv2D(64, (3, 3), padding='same')(x)
    x = LeakyReLU()(x)
    x = layers.Conv2D(64, (3, 3), padding='same')(x)
    x = LeakyReLU()(x)
    x = layers.Conv2D(1, (3, 3), padding='same')(x)
    x = layers.Activation('sigmoid')(x)  # Output: (50, 50, 1)
    return x

# Paths to Training and Testing Data
train_folder_path ="Path to Training Data"
test_folder_path = "Path to Testing Data"

# Load images and filenames
x_train, train_filenames = load_and_preprocess_images(train_folder_path)
x_test, test_filenames = load_and_preprocess_images(test_folder_path)

# Encoder
input_img = keras.Input(shape=(50, 50, 1))
x = layers.Conv2D(32, (3, 3), padding='same')(input_img)
x = LeakyReLU()(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)  # Output: (25, 25)

x = layers.Conv2D(64, (3, 3), padding='same')(x)
x = LeakyReLU()(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)  # Output: (13, 13)

x = layers.Flatten()(x)
encoded = layers.Dense(30)(x)

# Apply Stochastic Quantization (12-bit)
encoded = StochasticQuantization(num_bits=12)(encoded)

# Decoder
x = layers.Dense(13*13*64)(encoded)
x = layers.Reshape((13, 13, 64))(x)
x = layers.Conv2D(64, (3, 3), padding='same')(x)
x = LeakyReLU()(x)
x = layers.BatchNormalization()(x)
x = SubpixelConv2D(scale=2)(x)  # Output: (26, 26)

x = layers.Conv2D(64, (3, 3), padding='same')(x)
x = LeakyReLU()(x)
x = layers.BatchNormalization()(x)

x = layers.Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same')(x)  # Output: (52, 52)
x = layers.Cropping2D(cropping=((1, 1), (1, 1)))(x)  # Correct to (50, 50)
decoded = LeakyReLU()(x)

# Post-Processing Network
refined_output = post_processing_network(decoded)

# autoencoder model
autoencoder = keras.Model(input_img, refined_output)

# Encoder Model
encoder = keras.Model(input_img, encoded)  # Extracts latent representation
decoder=keras.Model(encoded,decoded)
# Compile the Autoencoder
autoencoder.compile(optimizer='adam', loss='mean_squared_error')
autoencoder.summary()

# Train the Autoencoder
autoencoder.fit(x_train, x_train, epochs=5, batch_size=32, shuffle=True, validation_data=(x_test, x_test))



In [None]:
# Encode and decode some images
decoded_imgs = autoencoder.predict(x_test)
# MSE and PSNR
def mse(imageA, imageB):
    err = np.mean((imageA - imageB) ** 2)
    return err

def psnr(imageA, imageB):
    mse_value = mse(imageA, imageB)
    if mse_value == 0:
        return 100
    PIXEL_MAX = 1.0  # Since pixel values are normalized to [0, 1]
    return 20 * np.log10(PIXEL_MAX / np.sqrt(mse_value))
    # Calculate PSNR for each test image
mse_values = [mse(x_test[i].astype('float32'), decoded_imgs[i].astype('float32')) for i in range(len(x_test))]
psnr_values = [psnr(x_test[i].astype('float32'), decoded_imgs[i].astype('float32')) for i in range(len(x_test))]

# Calculate the average MSE
average_mse = np.mean(mse_values)
print(f"Average MSE: {average_mse}")

# Calculate the average PSNR
average_psnr = np.mean(psnr_values)
print(f"Average PSNR: {average_psnr} dB")
residuals = np.abs(x_test - decoded_imgs)
import matplotlib.pyplot as plt
# Visualization of results
n = 10  
plt.figure(figsize=(30, 10))  

for i in range(n):
    # Display original images
    ax = plt.subplot(4, n, i + 1)
    plt.imshow(x_test[i].reshape(50, 50), cmap='gray')
    plt.title("Original")
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    
    # Display reconstructed images
    ax = plt.subplot(4, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(50, 50), cmap='gray')
    plt.title("Reconstructed")
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    
    # Display residual images
    ax = plt.subplot(4, n, i + 1 + 2*n)
    plt.imshow(residuals[i].reshape(50, 50  ), cmap='gray')
    plt.title("Residual")
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # Display PSNR values directly below residuals
    ax = plt.subplot(4, n, i + 1 + 3*n)
    ax.axis('off')  # Hide the axis for cleanliness
    plt.text(0.5, 0.5, f"{psnr_values[i]:.2f} dB", fontsize=12, va='center', ha='center', transform=ax.transAxes)

plt.show()
