# Achieving Image Super-Resolution using CNNs and GANs

**This project aims to achieve Single Image Super Resolution using Deep Convolutional Neural Networs (SRCNN) and Generative Adersarial Networks (SRGAN)**

**Dataset:** [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/)

**References**:

**[1]** Dong, C., Loy, C.C., He, K., Tang, X., 2016. Image super-resolution using deep convolutional networks. IEEE Transactions on Pattern Analysis and Machine Intelligence 38, 295–307. doi:10.1109/TPAMI.2015.2439281.

**[2]** Kim, J., Lee, J.K., Lee, K.M., 2016. Accurate image super-resolution using very deep convolutional networks, in: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

**[3]** Ledig, C., Theis, L., Huszar, F., Caballero, J., Cunningham, A., Acosta, A., Aitken, A., Tejani, A., Totz, J., Wang, Z., Shi, W., 2017. Photo-realistic single image super-resolution using a generative adversarial network, in: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

# Imports

In [None]:
# Tensorflow
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, Rescaling, LeakyReLU, PReLU
from tensorflow.keras.utils import load_img, img_to_array, array_to_img
from tensorflow import keras
from tensorflow.keras import layers, models
import tensorflow as tf
from tensorflow.keras.applications import VGG19

import tensorflow_datasets as tfds
import tensorflow as tf

# IPython
from IPython.display import display
from IPython.display import clear_output

# Google
from google.colab import files

# Utils
import numpy as np
import matplotlib.pyplot as plt
import os
import zipfile
import random

# Dataset - DIV2K

In [None]:
def extract_dataset(dataset_path):
  # Path to the tar.gz file
  tar_file_path = dataset_path+".tar.gz"

  # Destination directory where you want to extract the contents
  extracted_dir = dataset_path

  os.makedirs(extracted_dir, exist_ok=True)

  # Extract the contents using shutil
  with zipfile.ZipFile(tar_file_path, 'r') as zip_ref:
      zip_ref.extractall(extracted_dir)

  print(f"Dataset extracted to: {extracted_dir}")

  ## TODO: add option to save dataset to the drive

  return extracted_dir

## High Resolution - 2K

In [None]:
dataset_name = "DIV2K_train_HR"
url = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip"
dataset_path = tf.keras.utils.get_file(dataset_name, origin=url, untar=True)
extracted_dir = extract_dataset(dataset_path)

## Extract files
complete_dataset_path = extracted_dir + "/" + dataset_name
x_train_hr = [os.path.join(complete_dataset_path, file) for file in os.listdir(complete_dataset_path) if file.endswith(('.png', '.jpg', '.jpeg'))]
x_train_hr = sorted(x_train_hr)

In [None]:
dataset_name = "DIV2K_valid_HR"
url = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip"
dataset_path = tf.keras.utils.get_file(dataset_name, origin=url, untar=True)
extracted_dir = extract_dataset(dataset_path)

## Extract files
complete_dataset_path_hr = extracted_dir + "/" + dataset_name
x_validation_hr = [os.path.join(complete_dataset_path_hr, file) for file in os.listdir(complete_dataset_path_hr) if file.endswith(('.png', '.jpg', '.jpeg'))]
x_validation_hr = sorted(x_validation_hr)

split_rate=0.2

## Split validation dataset into test and validation
x_test_hr = x_validation_hr[round(len(x_validation_hr)*(1-split_rate)):]
x_validation_hr = x_validation_hr[:round(len(x_validation_hr)*(1-split_rate))]

## Low Resolution - Downscaled using bicubic x8 interpolation

In [None]:
dataset_name = "DIV2K_train_LR_x8"
url = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_x8.zip"
dataset_path = tf.keras.utils.get_file(dataset_name, origin=url, untar=True)
extracted_dir = extract_dataset(dataset_path)

# Extract files
complete_dataset_path_lr = extracted_dir + "/" + dataset_name
x_train_lr_x8 = [os.path.join(complete_dataset_path_lr, file) for file in os.listdir(complete_dataset_path_lr) if file.endswith(('.png', '.jpg', '.jpeg'))]
x_train_lr_x8 = sorted(x_train_lr_x8)

In [None]:
dataset_name = "DIV2K_valid_LR_x8"
url = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_x8.zip"
dataset_path = tf.keras.utils.get_file(dataset_name, origin=url, untar=True)
extracted_dir = extract_dataset(dataset_path)

# Extract files
complete_dataset_path_lr = extracted_dir + "/" + dataset_name
x_validation_lr_x8 = [os.path.join(complete_dataset_path_lr, file) for file in os.listdir(complete_dataset_path_lr) if file.endswith(('.png', '.jpg', '.jpeg'))]
x_validation_lr_x8 = sorted(x_validation_lr_x8)

## Split validation dataset into test and validation
x_test_lr_x8 = x_validation_lr_x8[round(len(x_validation_lr_x8)*(1-split_rate)):]
x_validation_lr_x8 = x_validation_lr_x8[:round(len(x_validation_lr_x8)*(1-split_rate))]

# Super Resolution Using CNNS (SRCNN)

## Defining the model

In [None]:
def srcnn_model(input_shape, upscaling=(2,2)):
    input_layer = Input(shape=input_shape)

    # Feature extraction
    x = Conv2D(64, (9, 9), activation='relu', padding='same')(input_layer)

    # Non-linear mapping
    x = Conv2D(32, (1, 1), activation='relu', padding='same')(x)

    # Upsampling
    x = UpSampling2D(size=upscaling)(x)

    # Reconstruction
    x = Conv2D(3, (5, 5), padding='same')(x) # No activation on the last layer

    model = Model(input_layer, x)
    model.compile(optimizer='adam', loss='mean_squared_error', metrics=["accuracy"])

    return model

srcnn = srcnn_model(input_shape=(128,128, 3))

In [None]:
# Model Info
srcnn.summary()

## Importing dataset

In [None]:
input_img_size = srcnn.input_shape[1:3]
out_img_size = srcnn.output_shape[1:3]
num_imgs_train = len(x_train_hr)
num_imgs_validation = len(x_validation_hr)
num_imgs_test = len(x_test_hr)

def path_to_image(path, img_size):
    img = img_to_array(load_img(path, target_size=img_size))
    img = img.astype("float32") / 255
    return img

train_inputs = np.zeros((num_imgs_train,) + input_img_size + (3,), dtype="float32")
train_targets = np.zeros((num_imgs_train,) + out_img_size + (3,), dtype="float32")

for i in range(num_imgs_train):
    train_inputs[i] = path_to_image(x_train_lr_x8[i], input_img_size)
    train_targets[i] = path_to_image(x_train_hr[i], out_img_size)

validation_inputs = np.zeros((num_imgs_validation,) + input_img_size + (3,), dtype="float32")
validation_targets = np.zeros((num_imgs_validation,) + out_img_size + (3,), dtype="float32")

for i in range(num_imgs_validation):
    validation_inputs[i] = path_to_image(x_validation_lr_x8[i], input_img_size)
    validation_targets[i] = path_to_image(x_validation_hr[i], out_img_size)

test_inputs = np.zeros((num_imgs_test,) + input_img_size + (3,), dtype="float32")
test_targets = np.zeros((num_imgs_test,) + out_img_size + (3,), dtype="float32")

for i in range(num_imgs_test):
    test_inputs[i] = path_to_image(x_test_lr_x8[i], input_img_size)
    test_targets[i] = path_to_image(x_test_hr[i], out_img_size)

## Training the model

In [None]:
callbacks = [
    keras.callbacks.ModelCheckpoint("scrnn", save_best_only=True)
]

In [None]:
history = srcnn.fit(train_inputs, train_targets, batch_size=32, epochs=5, validation_data=(validation_inputs, validation_targets), callbacks=callbacks)

In [None]:
## Save and download model
srcnn.save('srcnn.keras')
files.download('srcnn.keras')

## Evaluating Results

In [None]:
epochs = range(1, len(history.history["loss"]) + 1)
loss = history.history["loss"]
val_loss = history.history["val_loss"]
plt.figure()
plt.plot(epochs, loss, "bo", label="Training loss")
plt.plot(epochs, val_loss, "b", label="Validation loss")
plt.title("Training and validation loss")
plt.legend()

epochs = range(1, len(history.history["accuracy"]) + 1)
accuracy = history.history["accuracy"]
val_accuracy = history.history["val_accuracy"]
plt.figure()
plt.plot(epochs, accuracy, "bo", label="Training accuracy")
plt.plot(epochs, val_accuracy, "b", label="Validation accuracy")
plt.title("Training and validation accuracy")
plt.legend()

## Plotting results

In [None]:
files.upload()
srcnn = keras.models.load_model("srcnn.keras")
srcnn.summary()

In [None]:
test_loss, test_acc = srcnn.evaluate(train_inputs, train_targets)
print(f"Train accuracy: {test_acc:.3f}")

In [None]:
number_of_images = 10

for _ in range(0,number_of_images):
  i = random.randint(0, np.shape(test_inputs)[0]-1)
  result_image = srcnn.predict(np.expand_dims(test_inputs[i], 0))[0]
  plt.figure(figsize=(12,10))
  plt.subplot(1,3,1)
  test_image = test_inputs[i]
  plt.axis("off")
  plt.imshow(array_to_img(test_image))
  plt.title("Low Resolution")
  plt.subplot(1,3,2)
  plt.axis("off")
  plt.imshow(array_to_img(result_image))
  plt.title("SR")
  plt.subplot(1,3,3)
  test_image = test_targets[i]
  plt.axis("off")
  plt.imshow(array_to_img(test_image))
  plt.title("Ground Truth")

In [None]:
index = 3
dpi = 50
input = np.expand_dims(test_inputs[index], axis=0)
_, height, width, depth = input.shape

figsize = width / float(dpi), height / float(dpi)

fig = plt.figure(figsize=figsize)
ax = fig.add_axes([0, 0, 1, 1])


# Display the image.
img = np.clip(np.squeeze(input), 0, 1)
ax.imshow(img)
ax.axis('on')

plt.show()

input = np.expand_dims(test_inputs[index], axis=0)
generated_images = srcnn.predict(input)

_, height, width, depth = generated_images.shape

figsize = width / float(dpi), height / float(dpi)

fig = plt.figure(figsize=figsize)
ax = fig.add_axes([0, 0, 1, 1])


img = np.clip(np.squeeze(generated_images), 0, 1)
plt.imshow(img)
plt.show()


## SRCNN - Deeper Model

In [None]:
def srcnn_deeper_model(input_shape, upscaling=(2,2)):
    input_layer = Input(shape=(None, None, 3))

    # Feature extraction
    x = Conv2D(64, (9, 9), activation='relu', padding='same')(input_layer)
    x = Conv2D(64, (9, 9), activation='relu', padding='same')(x)

    # Non-linear mapping
    x = Conv2D(32, (1, 1), activation='relu', padding='same')(x)

    # Upsampling para 2n x 2n
    x = UpSampling2D(size=upscaling)(x)

    # Reconstrução - output 2n x 2n image
    x = Conv2D(32, (5, 5), activation='relu', padding='same')(x)
    x = Conv2D(3, (5, 5),  activation='relu', padding='same')(x)

    model = Model(input_layer, x)
    model.compile(optimizer='rmsprop', loss='mean_squared_error', metrics=["accuracy"])

    return model

srcnn = srcnn_deeper_model((128, 128))

In [None]:
srcnn.summary()

### Training

In [None]:
callback = [
    keras.callbacks.ModelCheckpoint("srcnn_deeper",save_best_only=True)
]

In [None]:
history = srcnn.fit(train_inputs,
                            train_targets,
                            batch_size=32,
                            epochs=5,
                            callbacks=callback,
                            validation_data=(validation_inputs, validation_targets)
                          )

In [None]:
srcnn.save('srcnn_deeper.keras')
from google.colab import files
files.download('srcnn_deeper.keras')

### Evaluating results

In [None]:
epochs = range(1, len(history.history["loss"]) + 1)
loss = history.history["loss"]
val_loss = history.history["val_loss"]
plt.figure()
plt.plot(epochs, loss, "bo", label="Training loss")
plt.plot(epochs, val_loss, "b", label="Validation loss")
plt.title("Training and validation loss")
plt.legend()

epochs = range(1, len(history.history["accuracy"]) + 1)
accuracy = history.history["accuracy"]
val_accuracy = history.history["val_accuracy"]
plt.figure()
plt.plot(epochs, accuracy, "bo", label="Training accuracy")
plt.plot(epochs, val_accuracy, "b", label="Validation accuracy")
plt.title("Training and validation accuracy")
plt.legend()

In [None]:
number_of_images = 10

for _ in range(0,number_of_images):
  i = random.randint(0, np.shape(test_inputs)[0]-1)
  result_image = srcnn.predict(np.expand_dims(test_inputs[i], 0))[0]
  plt.figure(figsize=(12,10))
  plt.subplot(1,3,1)
  test_image = test_inputs[i]
  plt.axis("off")
  plt.imshow(array_to_img(test_image))
  plt.title("Low Resolution")
  plt.subplot(1,3,2)
  plt.axis("off")
  plt.imshow(array_to_img(result_image))
  plt.title("SR")
  plt.subplot(1,3,3)
  test_image = test_targets[i]
  plt.axis("off")
  plt.imshow(array_to_img(test_image))
  plt.title("Ground Truth")

# Very Deep Super Resolution (VDSR)

## Defining the model

In [None]:
def vdsr_model(scale_factor=2, num_filters=64, num_layers=20):
    input_low_resolution = tf.keras.Input(shape=(None, None, 3))

    # Initial convolution
    x = layers.Conv2D(num_filters, 3, padding='same', activation='relu')(input_low_resolution)

    # Intermediate convolutions
    for _ in range(num_layers - 2):
        x = layers.Conv2D(num_filters, 3, padding='same', activation='relu')(x)

    # Upsampling layer
    x = layers.Conv2DTranspose(num_filters, 3, strides=scale_factor, padding='same', activation='relu')(x)

    # Final convolution
    output_sr = layers.Conv2D(3, 3, padding='same')(x)  # No activation for the last layer

    model = Model(inputs=input_low_resolution, outputs=output_sr)

    model.compile(optimizer='adam', loss='mean_squared_error', metrics=["accuracy"])

    return model

vdsr_model = vdsr_model(scale_factor=2, num_filters=64, num_layers=20)

In [None]:
vdsr_model.summary()

## Importing dataset

In [None]:
input_img_size = srcnn.input_shape[1:3]
out_img_size = srcnn.output_shape[1:3]
num_imgs_train = len(x_train_hr)
num_imgs_validation = len(x_validation_hr)
num_imgs_test = len(x_test_hr)

def path_to_image(path, img_size):
    img = img_to_array(load_img(path, target_size=img_size))
    img = img.astype("float32") / 255
    return img

train_inputs = np.zeros((num_imgs_train,) + input_img_size + (3,), dtype="float32")
train_targets = np.zeros((num_imgs_train,) + out_img_size + (3,), dtype="float32")

for i in range(num_imgs_train):
    train_inputs[i] = path_to_image(x_train_lr_x8[i], input_img_size)
    train_targets[i] = path_to_image(x_train_hr[i], out_img_size)

validation_inputs = np.zeros((num_imgs_validation,) + input_img_size + (3,), dtype="float32")
validation_targets = np.zeros((num_imgs_validation,) + out_img_size + (3,), dtype="float32")

for i in range(num_imgs_validation):
    validation_inputs[i] = path_to_image(x_validation_lr_x8[i], input_img_size)
    validation_targets[i] = path_to_image(x_validation_hr[i], out_img_size)

test_inputs = np.zeros((num_imgs_test,) + input_img_size + (3,), dtype="float32")
test_targets = np.zeros((num_imgs_test,) + out_img_size + (3,), dtype="float32")

for i in range(num_imgs_test):
    test_inputs[i] = path_to_image(x_test_lr_x8[i], input_img_size)
    test_targets[i] = path_to_image(x_test_hr[i], out_img_size)

## Training the model

In [None]:
callback = [
    keras.callbacks.ModelCheckpoint("vdsr", save_best_only=True)
]

In [None]:
history = vdsr_model.fit(train_inputs, train_targets, batch_size=32, epochs=3, validation_data=(validation_inputs, validation_targets), callbacks=callback)

In [None]:
## Save and download model
vdsr_model.save('vdsr.keras')
files.download('vdsr.keras')

In [None]:
files.upload()

vdsr_model = keras.models.load_model("vdsr.keras")
vdsr_model.summary()

## Evaluating Results

In [None]:
epochs = range(1, len(history.history["loss"]) + 1)
loss = history.history["loss"]
val_loss = history.history["val_loss"]
plt.figure()
plt.plot(epochs, loss, "bo", label="Training loss")
plt.plot(epochs, val_loss, "b", label="Validation loss")
plt.title("Training and validation loss")
plt.legend()

epochs = range(1, len(history.history["accuracy"]) + 1)
accuracy = history.history["accuracy"]
val_accuracy = history.history["val_accuracy"]
plt.figure()
plt.plot(epochs, accuracy, "bo", label="Training accuracy")
plt.plot(epochs, val_accuracy, "b", label="Validation accuracy")
plt.title("Training and validation accuracy")
plt.legend()

## Plotting results

In [None]:
number_of_images = 10

for _ in range(0,number_of_images):
  i = random.randint(0, np.shape(test_inputs)[0]-1)
  result_image = vdsr_model.predict(np.expand_dims(test_inputs[i], 0))[0]
  plt.figure(figsize=(12,10))
  plt.subplot(1,3,1)
  test_image = test_inputs[i]
  plt.axis("off")
  plt.imshow(array_to_img(test_image))
  plt.title("Low Resolution")
  plt.subplot(1,3,2)
  plt.axis("off")
  plt.imshow(array_to_img(result_image))
  plt.title("SR")
  plt.subplot(1,3,3)
  test_image = test_targets[i]
  plt.axis("off")
  plt.imshow(array_to_img(test_image))
  plt.title("Ground Truth")

# Generative Adversarial Model (SRGAN)

## Defining the model

In [None]:
def residual_block(input):
  x = layers.Conv2D(64, (3, 3), padding='same')(input)
  x = layers.BatchNormalization(momentum = 0.5)(x)
  x = PReLU(shared_axes = [1,2])(x)
  x = layers.Conv2D(64, (3, 3), padding='same', activation = 'relu')(x)
  x = layers.BatchNormalization(momentum = 0.5)(x)

  return layers.add([input, x])

def upsampling_block(input):
  x = layers.Conv2D(64, (3, 3), padding='same')(input)
  x = layers.UpSampling2D(size=2)(x)
  x = PReLU(shared_axes = [1,2])(x)

  return x

def generator_model(input_shape, upscaling_factor=2):
    input =  tf.keras.Input(shape=input_shape)

    # Encoder
    x = layers.Conv2D(64, (9, 9), padding='same')(input)
    x = PReLU(shared_axes = [1,2])(x)
    temp = x

    for _ in range(32):
        x = residual_block(x)

    # Decoder
    x = layers.Conv2D(64, (3, 3), padding='same', activation=tf.keras.layers.LeakyReLU(alpha=0.4))(x)
    # x = layers.BatchNormalization(momentum = 0.2)(x)
    x = layers.add([x, temp])
    for _ in range(upscaling_factor-1):
      x = upsampling_block(x)

    # output = layers.Conv2D(3, (3, 3), padding='same', activation='sigmoid')(x)
    output = layers.Conv2D(3, (3, 3), padding='same', activation=tf.keras.layers.LeakyReLU(alpha=0.3))(x)

    model = Model(inputs=input, outputs=output)

    return model

def discriminator_block(input, filters, strides=1, batch_norm=True):
  x = layers.Conv2D(filters, (3, 3), strides=strides, padding='same', activation=tf.keras.layers.LeakyReLU(alpha=0.2))(input)
  if batch_norm:
    x = layers.BatchNormalization(momentum = 0.5)(x)
  x = LeakyReLU(alpha=0.2)(x)

  return x

# Discriminator model
def discriminator_model(input_shape):
    input =  tf.keras.Input(shape=input_shape)

    features = 32

    x = discriminator_block(input, features, batch_norm=False)
    x = discriminator_block(x, features, strides = 2)
    x = discriminator_block(x, features * 2)
    x = discriminator_block(x, features * 2)
    x = discriminator_block(x, features * 4)
    x = discriminator_block(x, features * 8)
    x = discriminator_block(x, features * 8)

    x = layers.Flatten()(x)
    x = layers.Dense(features)(x)  # Output a single value for real or fake
    x = LeakyReLU(alpha=0.2)(x)

    output = layers.Dense(1, activation = 'sigmoid')(x)

    model = Model(inputs=input, outputs=output)

    return model

# GAN model combining generator and discriminator
def gan_model(input_shape, output_shape, generator, discriminator, vgg):
    low_resolution_input = tf.keras.Input(shape= input_shape)
    high_resolution_input = tf.keras.Input(shape=output_shape)

    generator_output = generator(low_resolution_input)
    generated_features = vgg(generator_output)

    discriminator.trainable = False  # Freeze discriminator during GAN training

    discriminator_output = discriminator(generator_output)
    model = Model(inputs=[low_resolution_input, high_resolution_input], outputs=[discriminator_output, generated_features])

    return model

def vgg_model(input_shape):
  vgg = VGG19(weights='imagenet', include_top = False, input_shape = input_shape)

  return Model(inputs=vgg.inputs, outputs=vgg.layers[10].output)


# Create instances of the models
generator = generator_model((128, 128, 3), upscaling_factor=2)
discriminator = discriminator_model(input_shape=(256, 256, 3))
vgg = vgg_model(input_shape=(256, 256, 3))

vgg.trainable = False

gan = gan_model((128, 128, 3), (256, 256, 3), generator, discriminator, vgg)

generator.compile(optimizer='adam', loss='mean_squared_error')
discriminator.compile(optimizer='adam', loss='binary_crossentropy', loss_weights=[1e-3])
gan.compile(optimizer='adam', loss=['binary_crossentropy', 'mean_squared_error'], loss_weights=[1e-3, 1])

In [None]:
generator.summary()

In [None]:
discriminator.summary()

In [None]:
gan.summary()

## Importing dataset

In [None]:
input_img_size = generator.input_shape[1:3]
out_img_size = generator.output_shape[1:3]

num_imgs_train = len(x_train_hr)
num_imgs_validation = len(x_validation_hr)
num_imgs_test = len(x_test_hr)


def get_img_array(img_path, target_size):
    img = keras.utils.load_img(img_path, target_size=target_size)
    array = keras.utils.img_to_array(img)
    # array = array.astype("float32") / 255
    array = np.expand_dims(array, axis=0)
    # array = keras.applications.vgg19.preprocess_input(array)
    return array

train_inputs = np.zeros((num_imgs_train,) + input_img_size + (3,), dtype="float32")
train_targets = np.zeros((num_imgs_train,) + out_img_size + (3,), dtype="float32")

for i in range(num_imgs_train):
    train_inputs[i] = get_img_array(x_train_lr_x8[i], input_img_size)
    train_targets[i] = get_img_array(x_train_hr[i], out_img_size)

validation_inputs = np.zeros((num_imgs_validation,) + input_img_size + (3,), dtype="float32")
validation_targets = np.zeros((num_imgs_validation,) + out_img_size + (3,), dtype="float32")

for i in range(num_imgs_validation):
    validation_inputs[i] = get_img_array(x_validation_lr_x8[i], input_img_size)
    validation_targets[i] = get_img_array(x_validation_hr[i], out_img_size)

test_inputs = np.zeros((num_imgs_test,) + input_img_size + (3,), dtype="float32")
test_targets = np.zeros((num_imgs_test,) + out_img_size + (3,), dtype="float32")

for i in range(num_imgs_test):
    test_inputs[i] = get_img_array(x_test_lr_x8[i], input_img_size)
    test_targets[i] = get_img_array(x_test_hr[i], out_img_size)

## Training dataset

In [None]:
epochs = 1000
plt.figure()

# low_resolution_images = train_inputs[:, :, :,::-1]
# high_resolution_images = train_targets[:, :, :,::-1]
low_resolution_images = train_inputs
high_resolution_images = train_targets

# d_losses = [] # Create list to save values and plot it latter
# g_losses = [] # Create list to save values and plot it latter

batch_size = 1
for epoch in range(epochs):

    # Train discriminator
    idx = np.random.randint(0, len(low_resolution_images), batch_size)
    low_res_images = low_resolution_images[idx]

    print(f"Input shape {np.shape(low_res_images)}")
    # Generate fake images using the generator
    generated_images = generator.predict(low_res_images)

    oimage=low_res_images[0]
    gimages=generated_images[0]

    clear_output()
    display(array_to_img(oimage))
    display(array_to_img(gimages))

    # Label real images as 1 and fake images as 0
    real_labels = np.ones((batch_size, 1))
    fake_labels = np.zeros((batch_size, 1))

    real_images = high_resolution_images[idx]

    # Train the discriminator on real and fake images
    discriminator.trainable = True
    d_loss_fake = discriminator.train_on_batch(generated_images, y=fake_labels)
    d_loss_real = discriminator.train_on_batch(real_images, y=real_labels)
    discriminator.trainable = False
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    d_losses.append(d_loss)

    valid_labels = np.ones((batch_size, 1))

    image_features = vgg.predict(real_images)

    # Update the generator via the GAN model
    generator.trainable = True
    g_loss, _, _ = gan.train_on_batch([low_res_images, real_images], [valid_labels, image_features] )


    g_losses.append(g_loss)
    # Print progress and save generated images (optional)
    if epoch % 1 == 0:
        print(f"Epoch {epoch}, D Loss: {d_loss}, G Loss: {g_loss}")

        # Save generated images (optional)
        # generated_images = generator.predict(low_resolution_images[:5])
        # Save or visualize the images as needed

# Save the trained generator model
# generator.save('super_resolution_generator.h5')

In [None]:
discriminator.save('discriminator.keras')
generator.save('generator.keras')
gan.save('gan.keras')

from google.colab import files
files.download('discriminator.keras')
files.download('generator.keras')
files.download('gan.keras')

## Evaluating results

In [None]:
plt.figure()
plt.plot(g_losses)
plt.title("Loss generator")
plt.figure()
plt.plot(d_losses)
plt.title("Loss discriminator")

## Plotting results

In [None]:
files.upload()

generator = keras.models.load_model("generator.keras")
generator.summary()

In [None]:
index = 395
dpi = 50
input = np.expand_dims(train_inputs[index], axis=0)
_, height, width, depth = input.shape

figsize = width / float(dpi), height / float(dpi)

fig = plt.figure(figsize=figsize)
ax = fig.add_axes([0, 0, 1, 1])


# Display the image.
img = np.clip(np.squeeze(input).astype("float32") / 255, 0, 1)
ax.imshow(img)
ax.axis('on')

plt.show()

input = np.expand_dims(train_inputs[index], axis=0)
generated_images = generator.predict(input)

_, height, width, depth = generated_images.shape

figsize = width / float(dpi), height / float(dpi)

fig = plt.figure(figsize=figsize)
ax = fig.add_axes([0, 0, 1, 1])


img = np.clip(np.squeeze(generated_images).astype("float32") / 255, 0, 1)
plt.imshow(img)
plt.show()


In [None]:

dpi = 55
input = np.expand_dims(train_inputs[index], axis=0)
generated_images = generator.predict(input)

_, height, width, depth = generated_images.shape

figsize = width / float(dpi), height / float(dpi)

fig = plt.figure(figsize=figsize)
ax = fig.add_axes([0, 0, 1, 1])


img = np.clip(np.squeeze(generated_images).astype("float32") / 255, 0, 1)
plt.imshow(img)
plt.show()


In [None]:
discriminator.save('discriminator.keras')
files.download('discriminator.keras')