# Jet CGAN v1
## Nick Elsey
This notebook is exploratory work on using conditional adversarial networks for modeling detector response for jets at STAR.
We use data generated by pythia using various detector response models that can be found in jetgan.data.generators.

Generative adversarial networks: https://arxiv.org/abs/1701.00160

Conditional adversarial networks: https://arxiv.org/abs/1611.07004 

U-net architecture: https://arxiv.org/abs/1505.04597


In [1]:
from __future__ import absolute_import, division, print_function

import tensorflow as tf

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
from jetgan.data import image

## Load the dataset
we use data that has been pre-generated. If it has not been done already, data can be generated from the 
command line in the project home directory by using `python jetgan/data/generate_data.py`. The `--help` flag
will list all the relevant options, and give the list of types of detector responses available. The generate_data.py
script will create CSV files with jets and jet images stored, by default in data/raw.

In [2]:
INPUT_DATA_JET = '../data/raw/det_jet.txt'
INPUT_GEN_JET = '../data/raw/gen_jet.txt'
INPUT_DATA_IMAGE = '../data/raw/det_image.txt'
INPUT_GEN_IMAGE = '../data/raw/gen_image.txt'
INPUT_DELIM = ','
IMAGE_HEIGHT = 32
IMAGE_WIDTH = 32
IMAGE_CHANNELS = 2
OUTPUT_CHANNELS = 2

In [3]:
# load the data using pre-built loading scripts
det_jet_file = open(INPUT_DATA_JET, 'r')
det_image_file = open(INPUT_DATA_IMAGE, 'r')
gen_jet_file = open(INPUT_GEN_JET, 'r')
gen_image_file = open(INPUT_GEN_IMAGE, 'r')

detector_jets, detector_images = image.load_jet_images(jet_file=det_jet_file, image_file=det_image_file, image_width=IMAGE_WIDTH, image_height=IMAGE_HEIGHT,
                                                       image_channels=IMAGE_CHANNELS, delimiter=INPUT_DELIM)
generator_jets, generator_images = image.load_jet_images(jet_file=gen_jet_file, image_file=gen_image_file, image_width=IMAGE_WIDTH, image_height=IMAGE_HEIGHT,
                                                         image_channels=IMAGE_CHANNELS, delimiter=INPUT_DELIM)

## Create Generator
* Model is a modified U-net architecture
* skip connections between encoder & decoder layers with similar size
* each block of the downsampling consists of conv -> batchnorm -> leaky relu
* each block of the upsampling consists of transposed conv -> batchnorm -> dropout -> relu


In [4]:
def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())
    
  result.add(tf.keras.layers.LeakyReLU())

  return result

In [5]:
def upsample(filters, size, apply_batchnorm=True, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same', 
                                    kernel_initializer=initializer,
                                    use_bias=False))
    
  if apply_batchnorm:
      result.add(tf.keras.layers.BatchNormalization())
  
  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

In [6]:
def Generator():
  down_stack = [
    downsample(32, 4, apply_batchnorm=False), # (bs, 16, 16, 32)
    downsample(64, 4), # (bs, 8, 8, 64)
    downsample(64, 4), # (bs, 4, 4, 64)
    downsample(128, 4), # (bs, 2, 2, 128)
    downsample(128, 4), # (bs, 1, 1, 128)
  ]

  up_stack = [
    upsample(128, 4, apply_dropout=True), # (bs, 2, 2, 256)
    upsample(128, 4, apply_dropout=True), # (bs, 4, 4, 256)
    upsample(64, 4), # (bs, 8, 8, 128)
    upsample(64, 4), # (bs, 16, 16, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4, 
                                         strides=2, 
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh') # (bs, 32, 32, 2)

  concat = tf.keras.layers.Concatenate() 

  inputs = tf.keras.layers.Input(shape=[None,None,2])
  x = inputs
  
  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = concat([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

## Discriminator
* Discriminator is a patch GAN
* each layer is similar to the downsampler in the generator
  - conv2d -> batchnorm -> leakyrelu
* discriminator receives two inputs - the "target" image, and either a generated image or an actual image from the training sample
* the model then predicts if the generated/actual image is generated or not

In [7]:
def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[None, None, 2], name='input_image')
  tar = tf.keras.layers.Input(shape=[None, None, 2], name='target_image')
  
  x = tf.keras.layers.concatenate([inp, tar]) # (bs, 32, 32, channels*2)
  
  down1 = downsample(32, 4, False)(x) # (bs, 16, 16, 32)
  down2 = downsample(64, 4)(down1) # (bs, 8, 8, 64)
  down3 = downsample(128, 4)(down2) # (bs, 8, 8, 128)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 10, 10, 128)
  conv = tf.keras.layers.Conv2D(256, 4, strides=1, 
                                kernel_initializer=initializer, 
                                use_bias=False)(zero_pad1) # (bs, 7, 7, 256)
  
  batchnorm1 = tf.keras.layers.BatchNormalization()(conv) 

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
  
  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 9, 9, 512)
  
  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2) # (bs, 6, 6, 1)
  
  return tf.keras.Model(inputs=[inp, tar], outputs=last)

## Loss Function
* Generator Loss
    * sigmoid cross entropy loss of the generated images and an array of ones
    * The paper also includes L1 loss which is the mean absolute error between the generated image and the target image.
    * This allows the generated image to become structurally similar to the target image.
    * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA is variable
* Discriminator Loss
    * two inputs: real images, generated images
    * loss of real images is sigmoid crossentropy between the prediction of the real images and an array of ones
    * loss of generated images is sigmoid crossentropy between prediction of the generated images and an array of zeroes
    * total discriminator loss is the equally weighted sum of the two

In [8]:
LAMBDA = 100

In [9]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [10]:
def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
  
  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

In [11]:
def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
  
  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss

## Optimizers
* Generator Optimizer
    * Adam
* Discriminator
    * Adam

In [12]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

KeyboardInterrupt: 

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

## Training