In [None]:
import os
from PIL import Image, ImageFilter
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Add, Lambda
from tensorflow.python.keras.layers import PReLU
from tensorflow.keras.losses import MeanSquaredError, BinaryCrossentropy, MeanAbsoluteError
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input
from tensorflow.keras.metrics import Mean
import time
from sklearn.model_selection import train_test_split

from tensorflow.image import psnr

import matplotlib.pyplot as plt

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))


In [None]:
# Define the input directory
dataset_path = "../mri_dataset/kaggle_3m/"
output_path = "../mri_dataset/super_resolved/"
X = []
y = []

# Define the target resolution (width, height)
target_resolution = (128, 128)

# Create the output directory if it doesn't exist
if not os.path.exists(output_path):
    os.makedirs(output_path)

# Define a custom sorting function to sort filenames by numeric part
def sort_by_numeric(filename):
    return int(''.join(filter(str.isdigit, filename)))

# Function to downsample an image
def downsample_image(img, target_resolution):
    # Apply Gaussian blur
    
    # Resize using Bicubic interpolation
    resized_img = img.resize(target_resolution, Image.BICUBIC)
    
    return resized_img

def downsample_image_blur(img, target_resolution):
    # Apply Gaussian blur
    img_blurred = img.filter(ImageFilter.GaussianBlur(radius=2))
    
    # Resize using Bicubic interpolation
    resized_img = img_blurred.resize(target_resolution, Image.BICUBIC)
    
    return resized_img

# Iterate over TIFF files in the dataset
for patient in os.listdir(dataset_path):#[::10]:
    # Iterate over TIFF files for each patient
    for tif_file in os.listdir(os.path.join(dataset_path, patient)):#sorted(os.listdir(os.path.join(dataset_path, patient)), key=sort_by_numeric):
        if "_mask" not in tif_file:  # Filter out "mask" images
            # Read original image
            original_image_path = os.path.join(dataset_path, patient, tif_file)
            original_img = Image.open(original_image_path)
            
            # Downsample the image
            downscaled_img = downsample_image_blur(original_img, target_resolution)
            # Save the downsampled image to the output directory
            output_downscaled_image_path = os.path.join(output_path, patient, tif_file)
            #os.makedirs(os.path.dirname(output_downscaled_image_path), exist_ok=True)
            #downscaled_img.save(output_downscaled_image_path)
            
            # Convert image to numpy array and append to training data
            original_img_array = np.array(original_img)
            downscaled_img_array = np.array(downscaled_img)
            X.append(downscaled_img_array)
            y.append(original_img_array)

# Convert training data to numpy arrays
X = np.array(X)
y = np.array(y)

# Normalize the pixel values to the range [0, 1]
X = X.astype('float32') / 255.0
y = y.astype('float32') / 255.0

In [None]:
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.01, random_state=42)
print(X.shape)
print(X_train.shape)
print(X_val.shape)

In [None]:
batch_size = 1

# Calculate the number of complete batches that fit
num_batches = X_train.shape[0] // batch_size
X_train_batches = np.array(np.split(X_train[:num_batches * batch_size],num_batches))
y_train_batches = np.array(np.split(y_train[:num_batches * batch_size],num_batches))

num_batches_val = X_val.shape[0] // batch_size
X_val_batches = np.expand_dims(X_val, axis = 0)
y_val_batches = np.expand_dims(y_val, axis = 0)
#X_val_batches = np.array(np.split(X_val[:num_batches_val * batch_size],num_batches_val))
#y_val_batches = np.array(np.split(y_val[:num_batches_val * batch_size],num_batches_val))

# Reshape the array into batches
print(y_train_batches.shape)
print(y_val_batches.shape)

In [None]:
upsamples_per_scale = {
    2: 1,
    4: 2,
    8: 3
}

In [None]:
def pixel_shuffle(scale):
    return lambda x: tf.nn.depth_to_space(x, scale)
    
def upsample(x_in, num_filters):
    x = Conv2D(num_filters, kernel_size=3, padding='same')(x_in)
    x = Lambda(pixel_shuffle(scale=2))(x)
    return PReLU(shared_axes=[1, 2])(x)


def residual_block(block_input, num_filters, momentum=0.8):
    x = Conv2D(num_filters, kernel_size=3, padding='same')(block_input)
    x = BatchNormalization(momentum=momentum)(x)
    x = PReLU(shared_axes=[1, 2])(x)
    x = Conv2D(num_filters, kernel_size=3, padding='same')(x)
    x = BatchNormalization(momentum=momentum)(x)
    x = Add()([block_input, x])
    return x


def build_srresnet(scale=2, num_filters=64, num_res_blocks=16):
    if scale not in upsamples_per_scale:
        raise ValueError(f"available scales are: {upsamples_per_scale.keys()}")

    num_upsamples = upsamples_per_scale[scale]

    lr = Input(shape=(None, None, 3))
    x = lr

    x = Conv2D(num_filters, kernel_size=9, padding='same')(x)
    x = x_1 = PReLU(shared_axes=[1, 2])(x)

    for _ in range(num_res_blocks):
        x = residual_block(x, num_filters)

    x = Conv2D(num_filters, kernel_size=3, padding='same')(x)
    x = BatchNormalization()(x)
    x = Add()([x_1, x])

    for _ in range(num_upsamples):
        x = upsample(x, num_filters * 4)

    x = Conv2D(3, kernel_size=9, padding='same', activation='tanh')(x)
    sr = x#Lambda(denormalize_m11)(x)

    return Model(lr, sr)

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Add, Lambda, LeakyReLU, Flatten, Dense
from tensorflow.python.keras.layers import PReLU



def discriminator_block(x_in, num_filters, strides=1, batchnorm=True, momentum=0.8):
    x = Conv2D(num_filters, kernel_size=3, strides=strides, padding='same')(x_in)
    if batchnorm:
        x = BatchNormalization(momentum=momentum)(x)
    return LeakyReLU(alpha=0.2)(x)


def build_discriminator(hr_crop_size):
    x_in = Input(shape=(hr_crop_size, hr_crop_size, 3))
    x = x_in

    x = discriminator_block(x, 64, batchnorm=False)
    x = discriminator_block(x, 64, strides=2)

    x = discriminator_block(x, 128)
    x = discriminator_block(x, 128, strides=2)

    x = discriminator_block(x, 256)
    x = discriminator_block(x, 256, strides=2)

    x = discriminator_block(x, 512)
    x = discriminator_block(x, 512, strides=2)

    x = Flatten()(x)

    x = Dense(1024)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dense(1, activation='sigmoid')(x)

    return Model(x_in, x)

In [None]:
@tf.function
def train_step(lr, hr):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        #lr = tf.cast(lr, tf.float32)
        #hr = tf.cast(hr, tf.float32)

        sr = srgan_checkpoint.generator(lr, training=True)
        hr_output = srgan_checkpoint.discriminator(hr, training=True)
        sr_output = srgan_checkpoint.discriminator(sr, training=True)
        
        con_loss = calculate_content_loss(hr, sr)
        
        gen_loss = calculate_generator_loss(sr_output)
        perc_loss = con_loss + 0.001 * gen_loss
        disc_loss = calculate_discriminator_loss(hr_output, sr_output)
    gradients_of_generator = gen_tape.gradient(perc_loss, srgan_checkpoint.generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, srgan_checkpoint.discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, srgan_checkpoint.generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, srgan_checkpoint.discriminator.trainable_variables))

    return perc_loss, disc_loss

@tf.function
def calculate_content_loss(hr, sr):
    hr_255 = hr * 255
    hr_255 = tf.round(hr_255)
    hr_255 = tf.cast(hr_255, tf.uint16)
    sr_255 = sr * 255
    sr_255 = tf.round(sr_255)
    sr_255 = tf.cast(sr_255, tf.uint16)
    
    sr_255 = preprocess_input(sr_255)
    hr_255 = preprocess_input(hr_255)
    sr_features = perceptual_model(sr_255) / 2000
    hr_features = perceptual_model(hr_255) / 2000
    #return mean_squared_error(hr, sr)
    return mean_squared_error(hr_features, sr_features)

def calculate_generator_loss(sr_out):
    sr_255 = sr_out * 255
    sr_255 = tf.round(sr_255)
    sr_255 = tf.cast(sr_255, tf.uint16)
    return binary_cross_entropy(tf.ones_like(sr_out), sr_out)

def calculate_discriminator_loss(hr_out, sr_out):
    hr_255 = hr_out * 255
    hr_255 = tf.round(hr_255)
    hr_255 = tf.cast(hr_255, tf.uint16)
    sr_255 = sr_out * 255
    sr_255 = tf.round(sr_255)
    sr_255 = tf.cast(sr_255, tf.uint16)

    hr_loss = binary_cross_entropy(tf.ones_like(hr_out), hr_out)
    sr_loss = binary_cross_entropy(tf.zeros_like(sr_out), sr_out)
    #return mean_squared_error(hr_out, sr_out)
    return hr_loss + sr_loss

In [None]:
input_shape = target_resolution + (3,)  # (256, 256, 3)
#super_resolution_model = build_super_resolution_model(input_shape) #build_generator(input_shape) #
generator = build_srresnet(num_res_blocks=16)
discriminator = build_discriminator(hr_crop_size=256)

In [None]:
layer_5_4 = 20
vgg = VGG19(input_shape=(None, None, 3), include_top=False)
perceptual_model = Model(vgg.input, vgg.layers[layer_5_4].output)

In [None]:
srgan_checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                       psnr=tf.Variable(0.0),
                                       generator_optimizer=Adam(0.0001),
                                       discriminator_optimizer=Adam(0.0001),
                                       generator=generator,
                                       discriminator=discriminator)

In [None]:
binary_cross_entropy = BinaryCrossentropy()
mean_squared_error = MeanSquaredError()

generator_optimizer = Adam(learning_rate=0.0001)
discriminator_optimizer = Adam(learning_rate=0.0001)

In [None]:
perceptual_loss_metric = Mean()
discriminator_loss_metric = Mean()

sr_out = []
step = srgan_checkpoint.step.numpy()
step_counter = 0
sr_outputs = []
hr_outputs = []

val_gen_loss_array = []
val_con_loss_array = []
val_perc_loss_array = []
val_disc_loss_array = []

now = time.perf_counter()

for lr, hr in zip(X_train_batches, y_train_batches):
    step = srgan_checkpoint.step.numpy()
    step_counter = step_counter + 1

    perceptual_loss, discriminator_loss = train_step(lr, hr)
    perceptual_loss_metric(perceptual_loss)
    discriminator_loss_metric(discriminator_loss)

    if step_counter % 100 == 0:
        psnr_values = []
        sr_out_batch = []
        hr_out_batch = []
        for lr, hr in zip(X_val_batches, y_val_batches):
            #sr = generator.predict(lr[np.newaxis, ...], verbose = 0)[0]
            #sr = tf.clip_by_value(sr, 0, 1)
            #sr_out_batch.append(sr)
            #hr_out_batch.append(hr)

            sr = srgan_checkpoint.generator(lr, training=False)
            sr = tf.clip_by_value(sr, 0, 1)
            hr_output = srgan_checkpoint.discriminator(hr, training=False)
            sr_output = srgan_checkpoint.discriminator(sr, training=False)
            val_gen_loss = calculate_generator_loss(sr_output)
    
            val_con_loss = calculate_content_loss(hr, sr)
            val_perc_loss = val_con_loss + 0.001 * val_gen_loss
            val_disc_loss = calculate_discriminator_loss(hr_output, sr_output)

            val_gen_loss_array.append(val_gen_loss)
            val_con_loss_array.append(val_con_loss)
            val_perc_loss_array.append(val_perc_loss)
            val_disc_loss_array.append(val_perc_loss)

            fig, axes = plt.subplots(1, 4, figsize=(10, 5), dpi = 300)            
            axes = axes.flatten()
            for i in range(4):
                plt.subplot(1, 4, i+1)
                plt.imshow(hr[i], clim=[0, 1])
                plt.axis("off")
            plt.show()
            fig.savefig(f"./SRGAN_images/hr_{step_counter}.png")
            fig, axes = plt.subplots(1, 4, figsize=(10, 5), dpi = 300)            
            axes = axes.flatten()
            for i in range(4):
                plt.subplot(1, 4, i+1)
                plt.imshow(sr[i].numpy(), clim=[0, 1])
                plt.axis("off")
            plt.show()
            fig.savefig(f"./SRGAN_images/sr_{step_counter}.png")

        duration = time.perf_counter() - now
        #sr_outputs.append(sr)
        now = time.perf_counter()
        print(f'{step_counter}, perceptual loss = {val_perc_loss:.4f}, discriminator loss = {val_disc_loss:.4f}, generator loss = {val_gen_loss:.4f} ({duration:.2f}s)')
        sr_outputs.append(sr_out_batch)
        hr_outputs.append(hr_out_batch)
print("done")

In [None]:
import matplotlib.pyplot as plt

# Create figure and first axes
fig, ax1 = plt.subplots(dpi = 300)

# Plot losses on the first axis
ax1.plot(val_con_loss_array, label='Content Loss')
ax1.plot(val_perc_loss_array, label='Perceptual Loss')
ax1.plot(val_disc_loss_array, label='Discriminator Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('C, P, D Loss')

# Create second axis for generator loss
ax2 = ax1.twinx()
ax2.plot(val_gen_loss_array, color='red', label='Generator Loss')
ax2.set_ylabel('G Loss')

# Combine legends
lines_1, labels_1 = ax1.get_legend_handles_labels()
lines_2, labels_2 = ax2.get_legend_handles_labels()
ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc='upper right')

# Set title
plt.title('SRGAN Losses')

# Save and display plot
plt.savefig("SRGAN_losses.png", dpi=300)
plt.show()
