In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
import numpy as np
from glob import glob
import tensorflow as tf
from tqdm import tqdm

Load Data

In [2]:
SIZE = 128
BATCH_SIZE = 32
MAX_TRAIN_IMAGES = 300
IMAGE_COUNT = 600

In [3]:
def read_image(image_path):
  image_string = tf.io.read_file(image_path)
  image_tensor = tf.image.decode_jpeg(image_string, channels=3)
  resized_tensor = tf.image.resize(image_tensor, [SIZE, SIZE])
  resized_tensor = tf.cast(resized_tensor, dtype=tf.float32) / 255.0
  return resized_tensor

In [4]:
def load_data(low_res_image_path, high_res_image_path):
  low_res_image = read_image(low_res_image_path)
  high_res_image = read_image(high_res_image_path)
  return low_res_image, high_res_image

In [5]:
def get_dataset(low_res_images, high_res_images):
  dataset = tf.data.Dataset.from_tensor_slices((low_res_images, high_res_images))
  dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
  dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
  return dataset

In [6]:
low_res = glob('/Applications/ML projects/Failures/Super Resolution/low res/*')
high_res = glob('/Applications/ML projects/Failures/Super Resolution/high res/*')

_2 = []
_4 = []
_6 = []

for lrfile in tqdm(low_res):
  image_type = lrfile.split('/')[-1]
  image_res = image_type.split('_')[-1]
  if image_res == '2.jpg':
    _2.append(lrfile)
  elif image_res == '4.jpg':
    _4.append(lrfile)
  elif image_res == '6.jpg':
    _6.append(lrfile)

_2 = sorted(_2)[:IMAGE_COUNT]
_4 = sorted(_4)[:IMAGE_COUNT]
_6 = sorted(_6)[:IMAGE_COUNT]

_high_res = []
for lr_file in tqdm(_2):
  lr_image_type = lr_file.split('/')[-1]
  lr_image_res = lr_image_type.split('_')[0]
  for hr_file in high_res:
    hr_image_type = hr_file.split('/')[-1]
    hr_image_res = hr_image_type.split('.')[0]
    if hr_image_res == lr_image_res:
      _high_res.append(hr_file)

_2_train, _2_val = _2[:MAX_TRAIN_IMAGES], _2[MAX_TRAIN_IMAGES:]
_4_train, _4_val = _4[:MAX_TRAIN_IMAGES], _4[MAX_TRAIN_IMAGES:]
_6_train, _6_val = _6[:MAX_TRAIN_IMAGES], _6[MAX_TRAIN_IMAGES:]
high_res_train, high_res_val = _high_res[:MAX_TRAIN_IMAGES], _high_res[MAX_TRAIN_IMAGES:]

train_dataset_24 = get_dataset(_2_train, _4_train)
train_dataset_46 = get_dataset(_4_train, _6_train)
train_dataset_6h = get_dataset(_6_train, high_res_train)

val_dataset_24 = get_dataset(_2_val, _4_val)
val_dataset_46 = get_dataset(_4_val, _6_val)
val_dataset_6h = get_dataset(_6_val, high_res_val)

100%|██████████| 3762/3762 [00:00<00:00, 678461.18it/s]
100%|██████████| 600/600 [00:00<00:00, 1484.58it/s]
2023-03-29 22:57:21.476045: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [7]:
print(train_dataset_24)
print(val_dataset_24)

<BatchDataset element_spec=(TensorSpec(shape=(32, 128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(32, 128, 128, 3), dtype=tf.float32, name=None))>
<BatchDataset element_spec=(TensorSpec(shape=(32, 128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(32, 128, 128, 3), dtype=tf.float32, name=None))>


In [8]:
print(train_dataset_46)
print(val_dataset_46)

<BatchDataset element_spec=(TensorSpec(shape=(32, 128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(32, 128, 128, 3), dtype=tf.float32, name=None))>
<BatchDataset element_spec=(TensorSpec(shape=(32, 128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(32, 128, 128, 3), dtype=tf.float32, name=None))>


In [9]:
print(train_dataset_6h)
print(val_dataset_6h)

<BatchDataset element_spec=(TensorSpec(shape=(32, 128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(32, 128, 128, 3), dtype=tf.float32, name=None))>
<BatchDataset element_spec=(TensorSpec(shape=(32, 128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(32, 128, 128, 3), dtype=tf.float32, name=None))>


Model

In [10]:
from keras.layers import Add, GlobalAveragePooling2D, Conv2D, Concatenate, MaxPooling2D, UpSampling2D, Input
from keras import Model
from keras.optimizers import Adam
from keras.callbacks import ReduceLROnPlateau

Selective Kernel Feature Fusion

In [11]:
def selective_kernel_feature_fusion(multi_scale_feature1, multi_scale_feature2, multi_scale_feature3):
  channels = list(multi_scale_feature1.shape)[-1]
  combined_feature = Add()(
      [multi_scale_feature1, multi_scale_feature2, multi_scale_feature3]
  )
  gap = GlobalAveragePooling2D()(combined_feature)
  channel_wise_statistics = tf.reshape(gap, shape=(-1, 1, 1, channels))
  compact_feature_representation = Conv2D(
      filters=channels // 8, kernel_size=(1, 1), activation='relu'
  )(channel_wise_statistics)
  feature_descriptor1 = Conv2D(channels, kernel_size=(1, 1), activation='softmax')(compact_feature_representation)
  feature_descriptor2 = Conv2D(channels, kernel_size=(1, 1), activation='softmax')(compact_feature_representation)
  feature_descriptor3 = Conv2D(channels, kernel_size=(1, 1), activation='softmax')(compact_feature_representation)
  feature1 = multi_scale_feature1 * feature_descriptor1
  feature2 = multi_scale_feature2 * feature_descriptor2
  feature3 = multi_scale_feature3 * feature_descriptor3
  aggregate_feature = Add()([feature1, feature2, feature3])
  return aggregate_feature

Dual Attention Unit

In [12]:
def channel_attention_block(input_tensor):
  channels = list(input_tensor.shape)[-1]
  gap = GlobalAveragePooling2D()(input_tensor)
  feature_descriptor = tf.reshape(gap, shape=(-1, 1, 1, channels))
  feature_activations = Conv2D(
      filters=channels // 8, kernel_size=(1, 1), activation='relu'
  )(feature_descriptor)
  feature_activations = Conv2D(
      filters=channels, kernel_size=(1, 1), activation='sigmoid'
  )(feature_activations)
  return input_tensor * feature_activations

In [13]:
def spatial_attention_block(input_tensor):
  average_pooling = tf.reduce_mean(input_tensor, axis=-1)
  average_pooling = tf.expand_dims(average_pooling, axis=-1)
  max_pooling = tf.reduce_max(input_tensor, axis=-1)
  max_pooling = tf.expand_dims(max_pooling, axis=-1)
  concatenated = Concatenate(axis=-1)([average_pooling, max_pooling])
  feature_map = Conv2D(1, kernel_size=(1, 1))(concatenated)
  feature_map = tf.nn.sigmoid(feature_map)
  return input_tensor * feature_map

In [14]:
def dual_attention_unit_block(input_tensor):
  channels = list(input_tensor.shape)[-1]
  feature_map = Conv2D(
      channels, kernel_size=(3, 3), padding='same', activation='relu'
  )(input_tensor)
  feature_map = Conv2D(
      channels, kernel_size=(3, 3), padding='same'
  )(feature_map)

  channel_attention = channel_attention_block(feature_map)
  spatial_attention = spatial_attention_block(feature_map)
  concatenation = Concatenate(axis=-1)([channel_attention, spatial_attention])
  concatenation = Conv2D(channels, kernel_size=(1, 1))(concatenation)

  return Add()([input_tensor, concatenation])

Multi Scale Residual Block

In [15]:
def down_sampling_block(input_tensor):
  channels = list(input_tensor.shape)[-1]
  main_branch = Conv2D(channels, kernel_size=(1, 1), activation='relu')(input_tensor)
  main_branch = Conv2D(channels, kernel_size=(3, 3), padding='same', activation='relu')(main_branch)
  main_branch = MaxPooling2D()(main_branch)
  main_branch = Conv2D(channels * 2, kernel_size=(1, 1))(main_branch)
  skip_branch = MaxPooling2D()(input_tensor)
  skip_branch = Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch)
  return Add()([main_branch, skip_branch])

In [16]:
def up_sampling_block(input_tensor):
  channels = list(input_tensor.shape)[-1]
  main_branch = Conv2D(channels, kernel_size=(1, 1), activation='relu')(input_tensor)
  main_branch = Conv2D(channels, kernel_size=(3, 3), padding='same', activation='relu')(main_branch)
  main_branch = UpSampling2D()(main_branch)
  main_branch = Conv2D(channels // 2, kernel_size=(1, 1))(main_branch)
  skip_branch = UpSampling2D()(input_tensor)
  skip_branch = Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch)
  return Add()([main_branch, skip_branch])

In [17]:
def multi_scale_residual_block(input_tensor, channels):
  feature1 = input_tensor
  feature2 = down_sampling_block(feature1)
  feature3 = down_sampling_block(feature2)

  feature1_dau1 = dual_attention_unit_block(feature1)
  feature2_dau1 = dual_attention_unit_block(feature2)
  feature3_dau1 = dual_attention_unit_block(feature3)

  skff1 = selective_kernel_feature_fusion(
      feature1_dau1,
      up_sampling_block(feature2_dau1),
      up_sampling_block(up_sampling_block(feature3_dau1))
  )

  skff2 = selective_kernel_feature_fusion(
      down_sampling_block(feature1_dau1),
      feature2_dau1,
      up_sampling_block(feature3_dau1)
  )

  skff3 = selective_kernel_feature_fusion(
      down_sampling_block(down_sampling_block(feature1_dau1)),
      down_sampling_block(feature2_dau1),
      feature3_dau1
  )

  feature1_dau2 = dual_attention_unit_block(skff1)
  feature2_dau2 = up_sampling_block(dual_attention_unit_block(skff2))
  feature3_dau2 = up_sampling_block(up_sampling_block(dual_attention_unit_block(skff3)))

  skff_ = selective_kernel_feature_fusion(feature1_dau2, feature2_dau2, feature3_dau2)
  feature = Conv2D(channels, kernel_size=(3, 3), padding='same')(skff_)

  return Add()([input_tensor, feature])

MIRNet Model

In [18]:
def recursive_residual_block(input_tensor, msrb_count, channels):
  x = Conv2D(channels, kernel_size=(3, 3), padding='same')(input_tensor)
  for _ in range(msrb_count):
    x = multi_scale_residual_block(x, channels)
  x = Conv2D(channels, kernel_size=(3, 3), padding='same')(x)
  return Add()([input_tensor, x])

In [19]:
def MIRNet_Model(rrb_count, msrb_count, channels):
  input_tensor = Input(shape=(None, None, 3))
  x = Conv2D(channels, kernel_size=(3, 3), padding='same')(input_tensor)
  for _ in range(rrb_count):
    x = recursive_residual_block(x, msrb_count, channels)
  x = Conv2D(3, kernel_size=(3, 3), padding='same')(x)
  output_tensor = Add()([input_tensor, x])
  return Model(input_tensor, output_tensor)

Build Models

In [20]:
RRB_COUNT = 3
MSRB_COUNT = 2
CHANNELS = 64

In [21]:
model_24 = MIRNet_Model(RRB_COUNT, MSRB_COUNT, CHANNELS)
model_46 = MIRNet_Model(RRB_COUNT, MSRB_COUNT, CHANNELS)
model_6h = MIRNet_Model(RRB_COUNT, MSRB_COUNT, CHANNELS)

Training Process

In [22]:
def charbonnier_loss(y_true, y_pred):
    return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3)))

In [23]:
def peak_signal_noise_ratio(y_true, y_pred):
    return tf.image.psnr(y_pred, y_true, max_val=255.0)

In [24]:
optimizer = Adam(1e-4)

In [25]:
model_24.compile(optimizer=optimizer, loss=charbonnier_loss, metrics=[peak_signal_noise_ratio])
model_46.compile(optimizer=optimizer, loss=charbonnier_loss, metrics=[peak_signal_noise_ratio])
model_6h.compile(optimizer=optimizer, loss=charbonnier_loss, metrics=[peak_signal_noise_ratio])

In [26]:
history_24 = model_24.fit(
    train_dataset_24,
    validation_data=val_dataset_24,
    epochs=30,
    callbacks=[
        ReduceLROnPlateau(
            monitor="val_peak_signal_noise_ratio",
            factor=0.5,
            patience=5,
            verbose=1,
            min_delta=1e-7,
            mode="max",
        )
    ],
)

Epoch 1/30


In [None]:
model_24.save('model_24.h5')

In [None]:
history_46 = model_46.fit(
    train_dataset_46,
    validation_data=val_dataset_46,
    epochs=30,
    callbacks=[
        ReduceLROnPlateau(
            monitor="val_peak_signal_noise_ratio",
            factor=0.5,
            patience=5,
            verbose=1,
            min_delta=1e-7,
            mode="max",
        )
    ],
)

In [None]:
model_46.save('model_46.h5')

In [None]:
history_6h = model_6h.fit(
    train_dataset_6h,
    validation_data=val_dataset_6h,
    epochs=30,
    callbacks=[
        ReduceLROnPlateau(
            monitor="val_peak_signal_noise_ratio",
            factor=0.5,
            patience=5,
            verbose=1,
            min_delta=1e-7,
            mode="max",
        )
    ],
)

In [None]:
model_6h.save('model_6h.h5')