<a href="https://colab.research.google.com/github/dude123studios/GANS/blob/main/ColorizationGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
mport tensorflow as tf
import numpy as np
import cv2 
from tensorflow.keras.losses import mean_squared_error
from tensorflow.keras.losses import mean_absolute_error
import random 
import matplotlib.pyplot as plt

In [None]:
(data,_),(test,_) = tf.keras.datasets.cifar10.load_data()
print(data.shape)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
(50000, 32, 32, 3)


In [None]:
def grayConversion(image):
    #                   blue                    green                   red
    grayValue = 0.07 * image[:,:,2] + 0.72 * image[:,:,1] + 0.21 * image[:,:,0]
    gray_img = grayValue.astype(np.uint8)
    return gray_img
new_arr = []
for img in data:
  img = grayConversion(img)
  new_arr.append(img)
gray = np.asarray(new_arr)
gray = gray.astype('float32')
gray = gray.reshape((50000,32,32,1))
print(gray.shape)

new_arr = []
for img in test:
  img = grayConversion(img)
  new_arr.append(img)
gray_test = np.asarray(new_arr)
gray_test = gray_test.astype('float32')
gray_test = gray_test.reshape((10000,32,32,1))
print(gray_test.shape)

(50000, 32, 32, 1)
(10000, 32, 32, 1)


In [None]:
def downsample(filters, size=3, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))
    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
    result.add(tf.keras.layers.LeakyReLU())

    return result

In [None]:
def upsample(filters, size=3, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

    result.add(tf.keras.layers.BatchNormalization())

    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))

    result.add(tf.keras.layers.ReLU())

    return result

In [None]:
class ResnetIdentityBlock(tf.keras.Model):
  def __init__(self, kernel_size, filters):
    super(ResnetIdentityBlock, self).__init__(name='')
    filters1, filters2, filters3 = filters

    self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
    self.bn2a = tf.keras.layers.BatchNormalization()

    self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
    self.bn2b = tf.keras.layers.BatchNormalization()

    self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
    self.bn2c = tf.keras.layers.BatchNormalization()

  def call(self, input_tensor, training=False):
    x = self.conv2a(input_tensor)
    x = self.bn2a(x, training=training)
    x = tf.nn.relu(x)

    x = self.conv2b(x)
    x = self.bn2b(x, training=training)
    x = tf.nn.relu(x)

    x = self.conv2c(x)
    x = self.bn2c(x, training=training)

    x += input_tensor
    return tf.nn.relu(x)

    
block1 = ResnetIdentityBlock(3, [128, 128, 128])
block2 = ResnetIdentityBlock(3, [128, 128,128])
block3 = ResnetIdentityBlock(3, [128, 128, 128])


resnet = [block1, block2, block3]

In [None]:

def Generator():
    down_stack = [
        downsample(32, 3, apply_batchnorm=False), 
        downsample(64, 3),
        downsample(128, 3), 
    ]

    up_stack = [
        upsample(64, 3),
        upsample(32, 3), 
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(3, 3,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='sigmoid') 


    inputs = tf.keras.layers.Input(shape=[32, 32, 1])
    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)
        
    for block in resnet:
        x = block(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        concat = tf.keras.layers.Concatenate()
        x = up(x)
        x = concat([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
def Discriminator():
    inputs = tf.keras.layers.Input(shape=[None,None,3])
    x = inputs
    g_filter = 32
    
    down_stack = [
        downsample(g_filter),
        downsample(g_filter * 2),
        downsample(g_filter * 4),
    ]
    
    for down in down_stack:
        x = down(x)

    last = tf.keras.layers.Conv2D(1, 4, strides=1, padding='same') 
    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
generator = Generator()
discriminator = Discriminator()

In [None]:
loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(beta_1=0.7)

In [None]:
@tf.function
def grayConversionTensors(imgs):  
  return tf.math.add(tf.math.add(tf.slice(imgs,[0,0,0,0],[-1,-1,-1,1])*0.21,tf.slice(imgs,[0,0,0,1],[-1,-1,-1,1])*0.72), 0.07 * tf.slice(imgs,[0,0,0,2],[-1,-1,-1,1]))
  

In [None]:
@tf.function
def train_batch(gray_imgs,normal_imgs):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    fake_imgs = generator(gray_imgs,training=True)
    
    logits_real = discriminator(normal_imgs,training=True)
    logits_fake = discriminator(fake_imgs,training=True)

    rl_loss = loss(tf.ones_like(logits_real),logits_real)
    fk_loss = loss(tf.zeros_like(logits_fake),logits_fake)
    disc_loss = rl_loss + fk_loss
    
    gray_color_imgs = grayConversionTensors(fake_imgs)

    gen_loss = tf.math.reduce_sum([
        tf.math.reduce_mean(loss(tf.ones_like(logits_fake),logits_fake)),
        tf.math.reduce_mean(tf.square(gray_color_imgs-gray_imgs)),
      ])
    
    grad = disc_tape.gradient(disc_loss,discriminator.trainable_variables)
    optimizer.apply_gradients(zip(grad,discriminator.trainable_variables))
    grad = gen_tape.gradient(gen_loss,generator.trainable_variables)
    optimizer.apply_gradients(zip(grad,generator.trainable_variables))

    return disc_loss, gen_loss


In [None]:

def train(gray_imgs,real_imgs,test_gray_imgs,test_real_imgs,epochs=50,batch_size=10):
  batches_per_epoch = int(gray_imgs.shape[0]/batch_size)
  for epoch in range(epochs):
    for batch in range(0,batches_per_epoch-1):
      batch_imgs_x = gray_imgs[batch*batch_size:(batch+1)*batch_size]
      batch_imgs_y = real_imgs[batch*batch_size:batch*batch_size+batch_size]
      disc_loss, gen_loss = train_batch(batch_imgs_x,batch_imgs_y)
      if batch == 0:
          print("discriminator: ", disc_loss.numpy)
          print("generator: {}\n".format(gen_loss.numpy))
          '''
          idx = random.randint(0,9999)
          fig, axs = plt.subplots(3, 1, figsize=(10, 10), sharey=True, sharex=True)
          gen_outputA = generator(test_gray_imgs[idx], training=False)
          axs[0,0].imshow(test_gray_imgs[idx])
          axs[0,0].set_title("Gray Image")
          axs[1,0].imshow(gen_outputA[0])
          axs[1,0].set_title("Synthesized Image")
          axs[2,0].imshow(test_real_imgs[0])
          axs[2,0].set_title("Real Image")
          plt.show()
          '''
  

In [None]:
with tf.device('gpu:0'):
  train(gray,data,gray_test,test)