In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Add, PReLU, Lambda, Conv2DTranspose, LeakyReLU, Flatten, Dense, GlobalMaxPooling2D, Activation, Layer
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

In [None]:
!unzip /content/archive.zip -d /content/;

In [None]:
class SubpixelConv2D(Layer):
    def __init__(self, scale, **kwargs):
        super(SubpixelConv2D, self).__init__(**kwargs)
        self.scale = scale

    def call(self, inputs):
        return tf.nn.depth_to_space(inputs, self.scale)

    def get_config(self):
        config = super(SubpixelConv2D, self).get_config()
        config['scale'] = self.scale
        return config

In [None]:
def residual_block(inputs):
    x = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    output = Add()([x, inputs])
    output = Activation('relu')(output)
    return output

def upscale_block(inputs):
    x = Conv2D(filters=1024, kernel_size=3, strides=1, padding='same')(inputs)
    x = SubpixelConv2D(scale=2)(x)
    x = PReLU(shared_axes=[1,2])(x)
    return x

def build_generator():
    inputs = Input(shape=(None, None, 3))
    x = Conv2D(filters=64, kernel_size=9, strides=1, padding='same')(inputs)
    x = PReLU(shared_axes=[1,2])(x)
    residual = x
    
    for _ in range(16):
        residual = residual_block(residual)
        
    x = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(residual)
    x = BatchNormalization()(x)
    x = Add()([x, residual])
    
    x = upscale_block(x)
    x = upscale_block(x)
    
    outputs = Conv2D(filters=3, kernel_size=9, strides=1, padding='same', activation='tanh')(x)
    
    
    return Model(inputs, outputs)

def build_discriminator():
    model = tf.keras.Sequential([
        Conv2D(64, (3,3), strides=(1,1), padding='same', input_shape=(None, None, 3)),
        LeakyReLU(alpha=0.2),
        Conv2D(64, (3,3), strides=(2,2), padding='same'),
        LeakyReLU(alpha=0.2),
        Conv2D(128, (3,3), strides=(1,1), padding='same'),
        LeakyReLU(alpha=0.2),
        Conv2D(128, (3,3), strides=(2,2), padding='same'),
        LeakyReLU(alpha=0.2),
        Conv2D(256, (3,3), strides=(1,1), padding='same'),
        LeakyReLU(alpha=0.2),
        Conv2D(256, (3,3), strides=(2,2), padding='same'),
        LeakyReLU(alpha=0.2),
        Conv2D(512, (3,3), strides=(1,1), padding='same'),
        LeakyReLU(alpha=0.2),
        Conv2D(512, (3,3), strides=(2,2), padding='same'),
        LeakyReLU(alpha=0.2),
        Conv2D(1024, (3,3), strides=(1,1), padding='same'),
        LeakyReLU(alpha=0.2),
        Conv2D(1024, (3,3), strides=(2,2), padding='same'),
        LeakyReLU(alpha=0.2),
        GlobalMaxPooling2D(),
        Dense(1024),
        LeakyReLU(alpha=0.2),
        Dense(1, activation='sigmoid')
    ])
    return model



def build_srgan(generator):
  # Build and compile the discriminator
  discriminator = build_discriminator()
  discriminator.compile(loss='binary_crossentropy', optimizer=Adam(lr=1e-4))

  discriminator.trainable = False

  # Define the inputs and outputs of the SRGAN model
  input_highres = Input(shape=(None,None,3))
  input_lowres = Input(shape=(None,None,3))

  # Use the generator to create a super-resolution version of the low-resolution input
  generated_highres = generator(input_lowres)

  # Feed the super-resolution image into the discriminator to determine if it is real or fake
  validity = discriminator(generated_highres)

  # Define the full SRGAN model
  srgan = Model(inputs=[input_lowres, input_highres], outputs=[validity, generated_highres])
  srgan.compile(loss=['binary_crossentropy', 'mse'], loss_weights=[1e-3, 1], optimizer=Adam(lr=1e-4))

  return srgan

In [None]:
import numpy as np
from PIL import Image
import os
# Define the paths to the low- and high-resolution image directories
lr_image_dir = '/content/Data/HR'
hr_image_dir = '/content/Data/LR'

# Load the image pairs into NumPy arrays
lr_images = []
hr_images = []
lr_images_vad = []
hr_images_vad = []
for i in range(80):
    # Load the low-resolution image
    lr_image_path = os.path.join(lr_image_dir, f'{i}.png')
    lr_image = np.asarray(Image.open(lr_image_path))
    lr_image = np.resize(lr_image, (256,256,3))
    lr_images.append(lr_image)

    # Load the high-resolution image
    hr_image_path = os.path.join(hr_image_dir, f'{i}.png')
    hr_image = np.asarray(Image.open(hr_image_path))
    hr_image = np.resize(hr_image, (256,256,3))
    hr_images.append(hr_image)
for i in range(80, 100):
  # Load the low-resolution image
    lr_image_path = os.path.join(lr_image_dir, f'{i}.png')
    lr_image_vad = np.asarray(Image.open(lr_image_path))
    lr_image_vad = np.resize(lr_image_vad, (256,256,3))
    lr_images_vad.append(lr_image_vad)

    # Load the high-resolution image
    hr_image_path = os.path.join(hr_image_dir, f'{i}.png')
    hr_image_vad = np.asarray(Image.open(hr_image_path))
    hr_image_vad = np.resize(hr_image_vad, (256,256,3))
    hr_images_vad.append(hr_image_vad)

# Convert the lists to NumPy arrays
lr_images = np.array(lr_images)
hr_images = np.array(hr_images)
lr_images_vad = np.array(lr_images_vad)
hr_images_vad = np.array(hr_images_vad)
lr_images.shape

In [None]:
from tensorflow.keras.optimizers import Adam

# build the generator and discriminator models
generator = build_generator()
discriminator = build_discriminator()

# compile the discriminator model
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(lr=1e-4))

# set the discriminator model to be non-trainable in the SRGAN model
discriminator.trainable = False

# build the SRGAN model
srgan = build_srgan(generator)


In [None]:
srgan.summary()

In [None]:
generator.summary()

In [None]:
discriminator.summary()

In [None]:
srgan.fit([lr_images,hr_images],[lr_images,np.ones(hr_images.shape)], batch_size=16, epochs=100)