This notebook trains a GAN to remove stars from deep sky images. The code is inspired by [sample from tensorflow's website](https://www.tensorflow.org/tutorials/generative/pix2pix). The training data consists of only two images. One image of the Antenna Galaxy and another is a starmap that was created from a star cluster image.
                                                                                          
                                                                                        

![](https://github.com/code2k13/starreduction/raw/main/images/star_reduction_title.png)

## Attributions

The training images used in this code were sourced from Wikimedia Commons and processed using GIMP.
### Antennae Galaxy Image
This image was downloaded from Wikimedia Commons and converted to grayscale using GIMP by me for purpose of training the model.

Link to processed image: https://github.com/code2k13/starreduction/blob/main/training_data/Antennae_galaxies_xl.png
>[NASA, ESA, and the Hubble Heritage Team (STScI/AURA)-ESA/Hubble Collaboration](https://commons.wikimedia.org/wiki/File:Antennae_galaxies_xl.jpg), Public domain, via Wikimedia Commons

Url : [https://commons.wikimedia.org/wiki/File:Antennae_galaxies_xl.jpg](https://commons.wikimedia.org/wiki/File:Antennae_galaxies_xl.jpg)

Direct Link: [https://upload.wikimedia.org/wikipedia/commons/f/f6/Antennae_galaxies_xl.jpg](https://upload.wikimedia.org/wikipedia/commons/f/f6/Antennae_galaxies_xl.jpg)

### Star cluster NGC 3572 and its surroundings
This image was downloaded from Wikimedia Commons and star mask was created by me using GIMP.

Link to the processed image: https://github.com/code2k13/starreduction/blob/main/training_data/star_map_base.png

>[ESO/G. Beccari](https://commons.wikimedia.org/wiki/File:The_star_cluster_NGC_3572_and_its_dramatic_surroundings.jpg"), [https://creativecommons.org/licenses/by/4.0] (via Wikimedia Commons) 

Url: [https://commons.wikimedia.org/wiki/File:Antennae_galaxies_xl.jpg](https://commons.wikimedia.org/wiki/File:Antennae_galaxies_xl.jpg) 

Direct Link: [https://upload.wikimedia.org/wikipedia/commons/9/95/The_star_cluster_NGC_3572_and_its_dramatic_surroundings.jpg](https://upload.wikimedia.org/wikipedia/commons/9/95/The_star_cluster_NGC_3572_and_its_dramatic_surroundings.jpg)

In [None]:
import tensorflow as tf
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
from PIL import Image, ImageFilter,ImageEnhance
import cv2
import random
from matplotlib import pyplot as plt
import numpy as np
%matplotlib inline

In [None]:
IMG_SIZE = 1024
OUTPUT_CHANNELS = 1
LAMBDA = 10 

In [None]:
def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

In [None]:
def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))
        
  result.add(tf.keras.layers.ReLU())

  return result

In [None]:
def Generator():
  inputs = tf.keras.layers.Input(shape=[IMG_SIZE, IMG_SIZE, OUTPUT_CHANNELS])

  down_stack = [
    downsample(64, 4, apply_batchnorm=False),  # (batch_size, 128, 128, 64)
    downsample(128, 4),  # (batch_size, 64, 64, 128)
    downsample(256, 4),  # (batch_size, 32, 32, 256)
    downsample(512, 4),  # (batch_size, 16, 16, 512)
    downsample(512, 4),  # (batch_size, 8, 8, 512)
    downsample(512, 4),  # (batch_size, 4, 4, 512)
    downsample(512, 4),  # (batch_size, 2, 2, 512)
    downsample(512, 4),  # (batch_size, 1, 1, 512)
  ]

  up_stack = [
    upsample(512, 4, apply_dropout=True),  # (batch_size, 2, 2, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 8, 8, 1024)
    upsample(512, 4),  # (batch_size, 16, 16, 1024)
    upsample(256, 4),  # (batch_size, 32, 32, 512)
    upsample(128, 4),  # (batch_size, 64, 64, 256)
    upsample(64, 4),  # (batch_size, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh')  # (batch_size, 256, 256, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    #print(up_stack,skips)
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

  # Mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss, gan_loss, l1_loss

In [None]:
def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[IMG_SIZE, IMG_SIZE, OUTPUT_CHANNELS], name='input_image')
  tar = tf.keras.layers.Input(shape=[IMG_SIZE, IMG_SIZE, OUTPUT_CHANNELS], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar])  # (batch_size, 256, 256, channels*2)

  down1 = downsample(64, 4, False)(x)  # (batch_size, 128, 128, 64)
  down2 = downsample(128, 4)(down1)  # (batch_size, 64, 64, 128)
  down3 = downsample(256, 4)(down2)  # (batch_size, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (batch_size, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1)  # (batch_size, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (batch_size, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2)  # (batch_size, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)

In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
generator = Generator()
discriminator = Discriminator()

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [None]:
def train_step(input_image, target, step):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)
        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(gen_total_loss,generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss,discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(generator_gradients,generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,discriminator.trainable_variables))

    with summary_writer.as_default():
        tf.summary.scalar('gen_total_loss', gen_total_loss, step=step//1000)
        tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=step//1000)
        tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=step//1000)
        tf.summary.scalar('disc_loss', disc_loss, step=step//1000)

In [None]:
import albumentations as A
transform = A.Compose([
        A.RandomRotate90(),
        A.Flip(),
        A.Transpose(),
        A.OneOf([
            A.IAAAdditiveGaussianNoise(),
            A.GaussNoise(),
        ], p=0.2),
        A.OneOf([
            A.MotionBlur(p=.2),
            A.MedianBlur(blur_limit=3, p=0.1),
            A.Blur(blur_limit=3, p=0.1),
        ], p=0.2),
        A.ShiftScaleRotate(shift_limit=0.5625, scale_limit=0.8, rotate_limit=145, p=0.2),
        A.OneOf([
            A.OpticalDistortion(p=0.3),
            A.GridDistortion(p=.1),
            A.IAAPiecewiseAffine(p=0.3),
        ], p=0.2),
        A.OneOf([
            A.CLAHE(clip_limit=2),
            A.IAASharpen(),
            A.IAAEmboss(),
            A.RandomBrightnessContrast(),            
        ], p=0.3),
        A.HueSaturationValue(p=0.3),
    ])

 

We are generating images on the fly. 

In [None]:
!wget https://github.com/code2k13/starreduction/raw/main/training_data/Antennae_galaxies_xl.png
!wget https://github.com/code2k13/starreduction/raw/main/training_data/star_map_base.png

In [None]:
def get_dataset_batch(batch_size = 2):
    im = Image.open("Antennae_galaxies_xl.png")
    im2 = Image.open("star_map_base.png")
    trainA = []
    trainB = []
    for i in range(0,batch_size):    
        x = random.randint(0,4096-IMG_SIZE)
        y =  random.randint(0,4096-IMG_SIZE)
        corp_actual_rect = (x, y, x+IMG_SIZE, y+IMG_SIZE)
        corped_actual = im.resize((IMG_SIZE,IMG_SIZE))
        corped_actual = corped_actual.convert('LA')
        enhancer1 = ImageEnhance.Brightness(corped_actual)
        factor = 0.15 + random.random()*3
        star_overlayed = enhancer1.enhance(factor) 
        star_overlayed = star_overlayed.rotate(random.randint(1,360), expand=False) 
        star_overlayed = star_overlayed.convert('L')
        star_overlayed = Image.fromarray(transform(image = np.asarray(star_overlayed))["image"])
        star_overlayed = star_overlayed.convert('LA')
        ca  = star_overlayed.copy()          
        ca = ca.convert('L') 

        x = random.randint(0,1024-IMG_SIZE)
        y =  random.randint(0,1024-IMG_SIZE)
        crop_rectangle = (x, y, x+IMG_SIZE, y+IMG_SIZE)
        star_corped = im2.crop(crop_rectangle)
        star_corped.filter(ImageFilter.GaussianBlur(2))
        star_corped = star_corped.rotate(random.randint(1,360), expand=False)  
       
        enhancer2 = ImageEnhance.Brightness(star_corped)
        factor = random.random() * 1
        star_enhanced= enhancer2.enhance(factor)      
        star_enhanced = star_enhanced.convert('RGBA')
        star_overlayed = star_overlayed.convert('RGBA')
        star_overlayed.paste(star_enhanced, (0,0),mask=star_enhanced)
        star_overlayed = star_overlayed.convert('L')
        trainA.append(np.asarray(ca,dtype="float32").reshape(1,IMG_SIZE,IMG_SIZE,1)/255)
        trainB.append(np.asarray(star_overlayed,dtype="float32").reshape(1,IMG_SIZE,IMG_SIZE,1)/255)
    return trainA,trainB
     

In [None]:
random.seed(10)
test_starless, test_stars = get_dataset_batch(20)
plt.imshow(test_stars[10].reshape((IMG_SIZE,IMG_SIZE,1)),cmap='gray')
plt.show()
plt.imshow(test_starless[10].reshape((IMG_SIZE,IMG_SIZE,1)),cmap='gray')
plt.show()

In [None]:
import datetime
log_dir="logs/"
summary_writer = tf.summary.create_file_writer(log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

There is no real notion of 'epochs' here. Everytime random image is being created for training. We have some control over repetability using the 'n' variable which is reset to 'total_samples' everytime it exceeds 'total_samples'. This 'n' variable is used as random seed before creation of images.

In [None]:
from matplotlib.pyplot import figure
EPOCHS = 5000
SAVE_MODEL_AFTER = 500
n = 0
total_samples = 500
BATCH_SIZE = 4
plt.rcParams['figure.figsize'] = (15, 5)

for epoch in range(EPOCHS):
    random.seed(n)
    train_starless, train_stars = get_dataset_batch(BATCH_SIZE)
    start = time.time()
    train_step(np.asarray(train_stars).reshape((-1,IMG_SIZE,IMG_SIZE,OUTPUT_CHANNELS)),
               np.asarray(train_starless).reshape((-1,IMG_SIZE,IMG_SIZE,OUTPUT_CHANNELS)),n)
    
    if (epoch + 1) % SAVE_MODEL_AFTER == 0:
        ckpt_save_path = checkpoint.save('test_')
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,ckpt_save_path))
        sample_prediction = generator(test_stars[10], training=False)
        f, axarr = plt.subplots(1,3)
        axarr[0].imshow(test_starless[10].reshape((IMG_SIZE,IMG_SIZE,OUTPUT_CHANNELS)),cmap="gray")
        axarr[1].imshow(test_stars[10].reshape((IMG_SIZE,IMG_SIZE,OUTPUT_CHANNELS)),cmap="gray")
        axarr[2].imshow(sample_prediction.numpy().reshape((IMG_SIZE,IMG_SIZE,OUTPUT_CHANNELS)),cmap="gray")
        #axarr.imshow()
        plt.show()
        sample_prediction = generator(test_stars[5], training=False)
        f, axarr = plt.subplots(1,3)
        axarr[0].imshow(test_starless[5].reshape((IMG_SIZE,IMG_SIZE,OUTPUT_CHANNELS)),cmap="gray")
        axarr[1].imshow(test_stars[5].reshape((IMG_SIZE,IMG_SIZE,OUTPUT_CHANNELS)),cmap="gray")
        axarr[2].imshow(sample_prediction.numpy().reshape((IMG_SIZE,IMG_SIZE,OUTPUT_CHANNELS)),cmap="gray")
        #axarr.imshow()
        plt.show()
    n = n + BATCH_SIZE
    if n > total_samples:
        n = 0