<a href="https://colab.research.google.com/github/cohen-raz/image_processing/blob/main/IMPR_Ex5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 3 Dataset Handling

In [None]:
def get_rand_window(im, crop_size, get_index=False):
    """
    :return: if get index, then return random window indexing from given im
            in given crop_size.
            else return window
    """
    max_row = im.shape[0] - crop_size[0]
    max_col = im.shape[1] - crop_size[1]
    rand_row = np.random.randint(0, max_row)
    rand_col = np.random.randint(0, max_col)
    if get_index:
        return (rand_row, rand_row + crop_size[0]), (
            rand_col, rand_col + crop_size[0])
    return im[rand_row:rand_row + crop_size[0],
           rand_col:rand_col + crop_size[0]]


def load_dataset(filenames, batch_size, corruption_func, crop_size):
    """
    A generator for generating pairs of image patches, corrupted and original
    :param filenames: a list of filenames of clean images.
    :param batch_size: The size of the batch of images for each iteration of Stochastic Gradient Descent.
    :param corruption_func: A function receiving a numpy array representation of an image as a single argument, and returning a randomly corrupted version of the input image.
    :param crop_size: A tuple (height, width) specifying the crop size of the patches to extract.
    :return:outputs random tuples of the form (source_batch, target_batch), where each output variable is an array of shape(batch_size, height, width, 1).
     target_batch is made of clean images and source_batch is their respective randomly corrupted version
     according to corruption_func(im)
    """
    files_dict = {}

    while True:
        rand_files = np.random.choice(filenames, size=batch_size, replace=True)
        original_patch_lst = []
        corrupted_patch_lst = []

        for im_file in rand_files:
            # check if image already cached
            if im_file not in files_dict.keys():
                files_dict[im_file] = read_image(im_file, 1)
            current_im = files_dict[im_file]

            # crop
            large_regular_crop = get_rand_window(current_im, (
                crop_size[0] * 3, crop_size[1] * 3))
            large_corrupted_crop = corruption_func(large_regular_crop.copy())

            row_index, col_index = get_rand_window(large_corrupted_crop,
                                                   crop_size, get_index=True)
            # set corruption to patch
            corrupted_patch = large_corrupted_crop[row_index[0]:row_index[1],
                              col_index[0]: col_index[1]]
            regular_patch = large_regular_crop[row_index[0]:row_index[1],
                            col_index[0]: col_index[1]]

            corrupted_patch_lst.append(
                (corrupted_patch - 0.5).reshape(corrupted_patch.shape[0],
                                                corrupted_patch.shape[1], 1))
            original_patch_lst.append(
                (regular_patch - 0.5).reshape(regular_patch.shape[0],
                                              regular_patch.shape[1], 1))

        target_batch = np.stack(original_patch_lst, axis=0)
        source_batch = np.stack(corrupted_patch_lst, axis=0)

        yield source_batch, target_batch

# 4 Neural Network Model

In [None]:
def resblock(input_tensor, num_channels):
    """
    Takes as input a symbolic input tensor and the number of channels for each of its convolutional layers, and returns the symbolic output tensor of the resnet block.
    The convolutional layers should use “same” border mode, so as to not decrease the spatial dimension of the output tensor.
    :param input_tensor: input tensor
    :param num_channels: number of channels
    :return: symbolic output tensor of the resnet block
    """
    C=Conv2D(num_channels,(3,3),padding='same')(input_tensor)
    A=Activation('relu')(C)
    O=Conv2D(num_channels,(3,3),padding='same')(A)
    add=Add()([O,input_tensor])
    return Activation('relu')(add)

In [None]:
def build_nn_model(height, width, num_channels, num_res_blocks):
    """
    Create an untrained Keras model with input dimension the shape of (height, width, 1), and all convolutional layers (including residual
    blocks) with number of output channels equal to num_channels, except the very last convolutional layer which should have a single output channel.
    The number of residual blocks should be equal to num_res_blocks.
    :param height: height
    :param width: width
    :param num_channels: number of channels
    :param num_res_blocks: number of residual blocks
    :return: an untrained Keras model.
    """
    input=Input(shape=(height,width,1))
    C=Conv2D(num_channels,(3,3),padding='same')(input)
    A=Activation('relu')(C)
    current_input=A
    for i in range(num_res_blocks):
      current_input=resblock(current_input,num_channels)

    final_c=Conv2D(1,(3,3),padding='same')(current_input)
    add=Add()([input,final_c])
    return Model(inputs=input,outputs=add)

# 5 Training Networks for Image Restoration

In [None]:
def train_model(model, images, corruption_func, batch_size, steps_per_epoch, num_epochs, num_valid_samples):
    """
    Divide the images into a training set and validation set, using an 80-20 split, and generate from each set a dataset with the given batch size
    and corruption function. Eventually it will train the model.
    :param model:  a general neural network model for image restoration.
    :param images: a list of file paths pointing to image files. You should assume these paths are complete, and should append anything to them.
    :param corruption_func: a corruption function.
    :param batch_size: the size of the batch of examples for each iteration of SGD.
    :param steps_per_epoch: the number of update steps in each epoch.
    :param num_epochs: the number of epochs for which the optimization will run.
    :param num_valid_samples: the number of samples in the validation set to test on after every epoch.
    """
    # split data
    im_train,im_validation=train_test_split(images,train_size=0.8)
    #create generator for each dataset 
    crop_size=model.inputs[0].shape
    crop_size=(crop_size[1],crop_size[2])

    data_gen_train=load_dataset(im_train,batch_size,corruption_func,crop_size)
    data_gen_validation=load_dataset(im_validation,batch_size,corruption_func,crop_size)

    num_valid_samples=num_valid_samples//batch_size
    model.compile(loss="mean_squared_error",optimizer=Adam(beta_2=0.9))
    model.fit_generator(data_gen_train,steps_per_epoch=steps_per_epoch,epochs=num_epochs,validation_data=data_gen_validation,validation_steps =num_valid_samples, use_multiprocessing=True)


# 6 Image Restoration of Complete Images

In [None]:
def restore_image(corrupted_image, base_model):
    """
    Restore full images of any size
    :param corrupted_image: a grayscale image of shape (height, width) and with values in the [0, 1] range of type float64 that is affected
    by a corruption generated from the same corruption function encountered during training (the image is not necessarily from the training set though).
    :param base_model: a neural network trained to restore small patches. The input and output of the network are images with values in the [−0.5, 0.5] range.
    :return: the restored image
    """
    a = Input(shape=(corrupted_image.shape[0], corrupted_image.shape[1], 1))
    output = base_model(a)
    new_model = Model(inputs=a, outputs=output)
    
    adapted_im = corrupted_image - 0.5
    adapted_im=adapted_im.reshape(-1,adapted_im.shape[0], adapted_im.shape[1],1)
    fitted_im = new_model.predict(adapted_im)[0]
    return np.clip(fitted_im+0.5, 0, 1).reshape(corrupted_image.shape).astype(np.float64)


    

# 7 Application to Image Denoising and Deblurring
## 7.1 Image Denoising
### 7.1.1 Gaussian Noise

In [None]:
def add_gaussian_noise(image, min_sigma, max_sigma):
    """
    Add random gaussian noise to an image
    :param image: a grayscale image with values in the [0, 1] range of type float64.
    :param min_sigma: a non-negative scalar value representing the minimal variance of the gaussian distribution.
    :param max_sigma: a non-negative scalar value larger than or equal to min_sigma, representing the maximal variance of the gaussian distribution
    :return: the corrupted image
    """
    sigma = np.random.uniform(min_sigma, max_sigma)
    mean = 0
    noise = np.random.normal(mean, sigma, (image.shape[0], image.shape[1]))
    noised_im = np.around(255 * (image + noise)) / 255
    return np.clip(noised_im, 0, 1)

In [None]:
#@markdown ### 7.1.2 Training a Denoising Mode

denoise_num_res_blocks = 5 #@param {type:"slider", min:1, max:15, step:1}


In [None]:
def denoising_corruption_func(im):
    min_sigma = 0
    max_sigma = 0.2
    return add_gaussian_noise(im, min_sigma, max_sigma)

def learn_denoising_model(denoise_num_res_blocks, quick_mode=False):
    """
    Train a denoising model
    :param denoise_num_res_blocks: number of residual blocks
    :param quick_mode: is quick mode
    :return: the trained model
    """
    filenames = images_for_denoising()
    height = 24
    width = 24
    num_channels = 48

    if quick_mode:
        batch_size = 10
        steps_per_epoch = 3
        num_epochs = 2
        num_valid_samples = 30
    else:
        batch_size = 100
        steps_per_epoch = 100
        num_epochs = 10
        num_valid_samples = 1000
    model = build_nn_model(height, width, num_channels, denoise_num_res_blocks)
    train_model(model, filenames, denoising_corruption_func, batch_size, steps_per_epoch,
                num_epochs, num_valid_samples)
    return model

## 7.2 Image Deblurring
### 7.2.1 Motion Blur

In [None]:
def add_motion_blur(image, kernel_size, angle):
    """
    Simulate motion blur on the given image using a square kernel of size kernel_size where the line has the given angle in radians, measured relative to the positive horizontal axis.
    :param image: a grayscale image with values in the [0, 1] range of type float64.
    :param kernel_size:  an odd integer specifying the size of the kernel.
    :param angle: an angle in radians in the range [0, π).
    :return: blurred image
    """
    kernel = motion_blur_kernel(kernel_size, angle)
    return convolve(image, kernel)


In [None]:
def random_motion_blur(image, list_of_kernel_sizes):
    """
    Simulate motion blur on the given image using a square kernel of size kernel_size where the line has the given angle in radians, measured relative to the positive horizontal axis.
    :param image: a grayscale image with values in the [0, 1] range of type float64.
    :param list_of_kernel_sizes: a list of odd integers.
    :return: blurred image
    """
    angel = np.random.uniform(0, np.pi)
    kernel_size = list_of_kernel_sizes[int(
        np.random.uniform(0, len(list_of_kernel_sizes)))]
    corrupted = add_motion_blur(image, kernel_size, angel)
    corrupted = np.around(255 * corrupted) / 255
    return np.clip(corrupted, 0, 1)

In [None]:
#@markdown ### 7.2.2 Training a Deblurring Model


deblur_num_res_blocks = 5 #@param {type:"slider", min:1, max:15, step:1}



In [None]:
def deblurring_corruption_func(im):
    return random_motion_blur(im, [7])

def learn_deblurring_model(deblur_num_res_blocks, quick_mode=False):
    """
    Train a deblurring model
    :param deblur_num_res_blocks: number of residual blocks
    :param quick_mode: is quick mode
    :return: the trained model
    """
    filenames = images_for_deblurring()
    height = 16
    width = 16
    num_channels = 32

    if quick_mode:
        batch_size = 10
        steps_per_epoch = 3
        num_epochs = 2
        num_valid_samples = 30
    else:
        batch_size = 100
        steps_per_epoch = 100
        num_epochs = 10
        num_valid_samples = 1000
    model = build_nn_model(height, width, num_channels, deblur_num_res_blocks)
    train_model(model, filenames, deblurring_corruption_func, batch_size, steps_per_epoch,
                num_epochs, num_valid_samples)
    return model

##7.3 Image Super-resolution
### 7.3.1 Image Low-Resolution Corruption



In [None]:
def super_resolution_corruption(image):
    """
    Perform the super resolution corruption 
    :param image: a grayscale image with values in the [0, 1] range of type float64.
    :return: corrupted image
    """
    zoom_factor_lst = [2, 3, 4]
    zoom_factor = zoom_factor_lst[int(np.random.uniform(0, 3))]

    reduced_im = zoom(image, 1 / zoom_factor)

    w_correction = image.shape[0] / reduced_im.shape[0]
    h_correction = image.shape[1] / reduced_im.shape[1]

    return zoom(reduced_im, (w_correction, h_correction))





In [None]:
#@markdown ### 7.3.2 Training a Super Resolution Model


super_resolution_num_res_blocks = 15 #@param {type:"slider", min:1, max:15, step:1}
batch_size = 65 #@param {type:"slider", min:1, max:128, step:16}
steps_per_epoch = 300 #@param {type:"slider", min:100, max:5000, step:100}
num_epochs = 10 #@param {type:"slider", min:1, max:20, step:1}
patch_size = 16 #@param {type:"slider", min:8, max:32, step:2}
num_channels = 32 #@param {type:"slider", min:16, max:64, step:2}
num_valid_samples=1000



In [None]:
def learn_super_resolution_model(super_resolution_num_res_blocks, quick_mode=False):
    """
    Train a super resolution model
    :param super_resolution_num_res_blocks: number of residual blocks
    :param quick_mode: is quick mode
    :return: the trained model
    """
    filenames = images_for_super_resolution()
    height = 16
    width = 16
    num_channels=32

    if quick_mode:
        batch_size = 10
        steps_per_epoch = 3
        num_epochs = 2
        num_valid_samples = 30
    else:
      batch_size = 65
      steps_per_epoch = 300 
      num_epochs = 10 
      num_valid_samples = 1000


    model = build_nn_model(height, width, num_channels, super_resolution_num_res_blocks)
    train_model(model, filenames, super_resolution_corruption, batch_size, steps_per_epoch,
                num_epochs, num_valid_samples)
    return model