<a href="https://colab.research.google.com/github/mushir2004/Generation/blob/main/deblur_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Activation, BatchNormalization, Add, Dropout, LeakyReLU, Conv2DTranspose, Lambda
from tensorflow.keras.models import Model

# Define ReflectionPadding2D if not already defined
class ReflectionPadding2D(tf.keras.layers.Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, input_shape):
        return (input_shape[0],
                input_shape[1] + 2 * self.padding[0],
                input_shape[2] + 2 * self.padding[1],
                input_shape[3])

    def call(self, x):
        padding_width, padding_height = self.padding
        return tf.pad(x, [[0, 0], [padding_height, padding_height], [padding_width, padding_width], [0, 0]], 'REFLECT')

# Define res_block if not already defined
def res_block(input_tensor, filters, kernel_size=(3, 3), strides=(1, 1), use_dropout=False):
    """
    Instantiate a Keras ResNet Block using the Functional API.
    :param input_tensor: Input tensor
    :param filters: Number of filters to use
    :param kernel_size: Shape of the kernel for the convolution
    :param strides: Shape of the strides for the convolution
    :param use_dropout: Boolean value to determine the use of dropout
    :return: Keras tensor
    """
    x = ReflectionPadding2D((1, 1))(input_tensor)
    x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    if use_dropout:
        x = Dropout(0.5)(x)

    x = ReflectionPadding2D((1, 1))(x)
    x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides)(x)
    x = BatchNormalization()(x)

    merged = Add()([input_tensor, x])
    return merged

# Model parameters
ngf = 64
input_nc = 3
output_nc = 3
input_shape_generator = (256, 256, input_nc)
n_blocks_gen = 9

def generator_model():
    """Build generator architecture."""
    inputs = Input(shape=input_shape_generator)

    x = ReflectionPadding2D((3, 3))(inputs)
    x = Conv2D(filters=ngf, kernel_size=(7, 7), padding='valid')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # Increase filter number
    n_downsampling = 2
    for i in range(n_downsampling):
        mult = 2 ** i
        x = Conv2D(filters=ngf * mult * 2, kernel_size=(3, 3), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

    # Apply 9 ResNet blocks
    mult = 2 ** n_downsampling
    for i in range(n_blocks_gen):
        x = res_block(x, ngf * mult, use_dropout=True)

    # Decrease filter number to 3 (RGB)
    for i in range(n_downsampling):
        mult = 2 ** (n_downsampling - i)
        x = Conv2DTranspose(filters=int(ngf * mult / 2), kernel_size=(3, 3), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

    x = ReflectionPadding2D((3, 3))(x)
    x = Conv2D(filters=output_nc, kernel_size=(7, 7), padding='valid')(x)
    x = Activation('tanh')(x)

    # Add direct connection from input to output and recenter to [-1, 1]
    outputs = Add()([x, inputs])
    outputs = Lambda(lambda z: z / 2)(outputs)

    model = Model(inputs=inputs, outputs=outputs, name='Generator')
    return model

# Create the generator model
generator = generator_model()

# Print the model summary
generator.summary()


Model: "Generator"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_2 (InputLayer)        [(None, 256, 256, 3)]        0         []                            
                                                                                                  
 reflection_padding2d_10 (R  (None, 262, 262, 3)          0         ['input_2[0][0]']             
 eflectionPadding2D)                                                                              
                                                                                                  
 conv2d_12 (Conv2D)          (None, 256, 256, 64)         9472      ['reflection_padding2d_10[0][0
                                                                    ]']                           
                                                                                          