<a href="https://colab.research.google.com/github/jnickg/steganet/blob/main/steganogan_keras.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SteganoGAN in Keras
This notebook contains code attempting to reimplement SteganoGAN in Keras, for the purpose of better understanding (and scrutinizing) it.

*Based on https://github.com/DAI-Lab/SteganoGAN/tree/master/steganogan*

We start with the basic `Conv2D` layer, which is used in vairous parts of the overall model. Batch normalization comes after the activation, as seen [here](https://github.com/DAI-Lab/SteganoGAN/blob/master/steganogan/encoders.py#L117-L121).

## Sub-networks

In [144]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import LeakyReLU

def steganogan_conv2d_layer(layer_in, num_filters, kernel_size, name=None, normalize=True, activation_fn=LeakyReLU()):
    model = Conv2D(num_filters, kernel_size, padding='same', activation=activation_fn, name=name)(layer_in)
    if normalize:
        normalize_name = f'{name}_normalize' if name is not None else None
        model = BatchNormalization(name=normalize_name)(model)
    return model

### Encoder
We will focus on the Dense variant of SteganoGAN, since researchers reported getting the best results with hit. This takes four hyperparameters: The image dimensions $(W, H, C)$, and the depth of the data to be encoded $D$.

In [145]:
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Add
from tensorflow.keras.activations import sigmoid

def steganogan_encoder_dense_model(W, H, C, D):
    """
    The BasicEncoder module takes an cover image and a data tensor and combines
    them into a steganographic image.
    Input: (N, 3, H, W), (N, D, H, W)
    Output: (N, 3, H, W)
    """
    input_image = Input(shape=(W, H, C), name=f'image{W}x{H}x{C}')
    input_data  = Input(shape=(W, H, D), name=f'data{W}x{H}x{D}')

    image_preprocess = steganogan_conv2d_layer(input_image, 32, 3, name='image_preprocess')

    image_data_process_1_in = Concatenate(name='image_data_process_1_in')([image_preprocess, input_data])
    image_data_process_1 = steganogan_conv2d_layer(image_data_process_1_in, 32, 3, name='image_data_process_1')

    image_data_process_2_in = Concatenate(name='image_data_process_2_in')([image_preprocess, image_data_process_1, input_data])
    image_data_process_2 = steganogan_conv2d_layer(image_data_process_2_in, 32, 3, name='image_data_process_2')

    encoder_in = Concatenate(name='encoder_in')([image_preprocess, image_data_process_1, image_data_process_2, input_data])
    encoder = steganogan_conv2d_layer(encoder_in, 3, 3, name='encoder', normalize=False, activation_fn=sigmoid)

    encoder_out = Add(name='encoder_out')([input_image, encoder])

    return ([input_image, input_data], encoder_out)

### Decoder

Similarly, the decoder takes the same hyperparameters. We simplify usage a bit by reshaping the output to be a 1D vector of the decoded data.

In [146]:
from tensorflow.keras.layers import Reshape

def steganogan_decoder_dense_model(W, H, C, D, input=None):
    """
    The DenseDecoder module takes an steganographic image and attempts to decode
    the embedded data tensor.
    Input: (N, 3, H, W)
    Output: (N, D, H, W)
    """
    if input is None:
        input = Input(shape=(W, H, C), name=f'cover_image{W}x{H}x{C}')
    decode_1 = steganogan_conv2d_layer(input, 32, 3, name='decode_1')

    decode_2 = steganogan_conv2d_layer(decode_1, 32, 3, name='decode_2')

    decode_3_in = Concatenate(name='decode_3_in')([decode_1, decode_2])
    decode_3 = steganogan_conv2d_layer(decode_3_in, 32, 3, name='decode_3')

    decoder_in = Concatenate(name='decoder_in')([decode_1, decode_2, decode_3])
    decoder = steganogan_conv2d_layer(decoder_in, D, 3, name='decoder', normalize=False, activation_fn=sigmoid)
    decoder = Reshape((1,-1))(decoder)

    return input, decoder

### Critic

The critic also takes the same hyperparameters, and produces a single-element tensor as its output: the average of the final convolutional layer.

In [147]:
from tensorflow.keras.layers import AveragePooling2D

def steganogan_critic_model(W, H, C, D):
    """
    The BasicCritic module takes an image and predicts whether it is a cover
    image or a steganographic image (N, 1).
    Input: (N, 3, H, W)
    Output: (N, 1)
    """
    model_in = Input(shape=(W, H, C), name=f'image{W}x{H}x{C}')
    model = steganogan_conv2d_layer(model_in, 32, 3, name='conv_1')
    model = steganogan_conv2d_layer(model, 32, 3, name='conv_2')
    model = steganogan_conv2d_layer(model, 32, 3, name='conv_3')
    model = steganogan_conv2d_layer(model, 1, 3, name='conv_4', normalize=False, activation_fn=sigmoid)
    model = AveragePooling2D(pool_size=(model.shape[1], model.shape[2]), name='mean')(model)
    return model_in, model

## Overall Model

We implement the overall model as a subclass of `keras.Model`, which makes it easier to train. It is initialized with the inputs and outputs for each subnetwork, plus the hyperparameters used to create those. Finally, there is a `noise_func`, which is a data generator function. This allows the model to generate user-specified random data, for encoding/decoding steps.

In [149]:
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy

class KerasSteganoGAN(tf.keras.Model):
  def __init__(self, gen_in, gen_out, dec_in, dec_out, critic_in, critic_out, image_height, image_width, image_channels, data_depth, noise_func=None):
    super(KerasSteganoGAN, self).__init__()

    # Metadata regarding discrminator input images
    self.img_rows = image_height
    self.img_cols = image_width
    self.channels = image_channels
    self.img_shape = (self.img_rows, self.img_cols, self.channels)
    
    # Metadata regarding generator inputdata
    self.data_depth = data_depth
    self.noise_func = noise_func

    # Build and compile the critic
    self.critic = Model(inputs=critic_in, outputs=critic_out, name='KerasSteganoGAN_critic')

    # Build the generator
    self.encoder = Model(inputs=gen_in, outputs=gen_out, name='KerasSteganoGAN_encoder')
    self.decoder = Model(inputs=dec_in, outputs=dec_out, name='KerasSteganoGAN_decoder')
    
  def compile(self,
              d_optimizer=Adam(learning_rate=2e-4, beta_1=0.5),
              g_optimizer=Adam(learning_rate=2e-4, beta_1=0.5),
              loss=BinaryCrossentropy(),
              disc_noise_in=(0.0, 4e-2),
              metrics=['loss']):
    super(KerasSteganoGAN, self).compile()
    self.loss_fn = loss
    self.d_optimizer = d_optimizer
    self.g_optimizer = g_optimizer
    self.disc_noise_in=disc_noise_in

  def generator_noise(self, batch_size=1):
    return noise_func(batch_size, data_depth)

  def critic_loss(self, real_output, fake_output):
    real_labels = tf.zeros_like(real_output)
    fake_labels = tf.ones_like(fake_output)
    # Calculate loss comparing real outputs ("not fake") to zeros, and fake
    # outputs ("yes fake") to ones.
    real_loss = self.loss_fn(real_labels, real_output)
    fake_loss = self.loss_fn(fake_labels, fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

  def generator_loss(self, fake_output):
    # Calculate loss assuming that they should have been evaluated as real, i.e.
    # zero, or "not fake"
    return self.loss_fn(tf.zeros_like(fake_output), fake_output)

  def call(self, input_tensor, training=False):
    x = input_tensor
    x = self.critic(x, training=training)
    return x

  def train_step(self, real_images):
    # If we were give an (x,y) tuple, drop the labels because we just don't care
    if isinstance(real_images, tuple):
      real_images = real_images[0]

    # We need to know how big our batches are, so we can create sets of
    # generated images in the same quantity
    batch_size = tf.shape(real_images)[0]

    # Train critic first, with real images AND a batch of fakes. Its loss
    # is determined by how well it can discern between real and fake images
    seeds = self.encoder_noise(batch_size=batch_size)

    #gen_inputs = 
    generated_images = self.encoder(seeds, training=True)

    # If specified, apply noise to prevent critic from cheating
    if (self.disc_noise_in is not None):
      (noise_mean, noise_sd) = self.disc_noise_in
      noise_shape = [batch_size, self.img_rows, self.img_cols, self.channels]
      real_images = real_images + tf.random.normal(noise_shape, mean=noise_mean, stddev=noise_sd)
      generated_images = generated_images + tf.random.normal(noise_shape, mean=noise_mean, stddev=noise_sd)

    with tf.GradientTape() as disc_tape:
      real_predictions = self.critic(real_images, training=True)
      fake_predictions = self.critic(generated_images, training=True)

      disc_loss = self.critic_loss(real_predictions, fake_predictions)

    disc_grad = disc_tape.gradient(disc_loss,
                                   self.critic.trainable_weights)
    self.d_optimizer.apply_gradients(zip(disc_grad,
                                         self.critic.trainable_weights))

    # Train generator by pitting it against the updated critic. Its loss
    # is determined by how well it can trick the critic.
    gen_seeds = self.encoder_noise(batch_size=batch_size)
    with tf.GradientTape() as gen_tape:
      new_generated_images = self.encoder(gen_seeds, training=True)
      predictions_on_gen = self.critic(new_generated_images, training=True)
      gen_loss = self.encoder_loss(predictions_on_gen)
      
    gen_grad = gen_tape.gradient(gen_loss,
                                 self.encoder.trainable_weights)
    self.g_optimizer.apply_gradients(zip(gen_grad,
                                         self.encoder.trainable_weights))
    

    return {'gen_loss': gen_loss, 'disc_loss': disc_loss}

## Creating the Model

### Hyperparameters
We list out the hyperparameters here, so they are consolidated into a single space.

In [151]:
# Image dimensions
my_W = 128
my_H = 128
my_C = 3
# Data dimension (along with my_W and my_H)
my_D = 2

### Data Generator

Let's define a simple function we can use to generate data for the encoder to put into carrier images. We'll start by using a lorem ipsum generator, and later use it with a custom dictionary.

Here's a sample usage of the API:

In [155]:
from loremipsum import generate_sentences

sentences = generate_sentences(1, start_with_lorem=True)
for s in sentences:
  print(s[2])
  break

Lorem ipsum.


And then, we bake it into a function with a simple adapter interface.

In [156]:
def generate_data(quantity, depth):
  pass

Finally, we put it all together and create our model.

In [157]:
from tensorflow.keras.models import Model

def build_steganogan(subnet_summary=False):
  encoder_in, encoder_out = steganogan_encoder_dense_model(my_W, my_H, my_C, my_D)
  steganogan_encoder = Model(encoder_in, encoder_out, name='steganogan_encoder')
  if (subnet_summary):
    steganogan_encoder.summary()

  decoder_in, decoder_out = steganogan_decoder_dense_model(my_W, my_H, my_C, my_D)
  steganogan_decoder = Model(decoder_in, decoder_out, name='steganogan_decoder')
  if (subnet_summary):
    steganogan_decoder.summary()

  critic_in, critic_out = steganogan_critic_model(my_W, my_H, my_C, my_D)
  steganogan_critic = Model(critic_in, critic_out, name='steganogan_critic')
  if (subnet_summary):
    steganogan_critic.summary()

  steganoGAN = KerasSteganoGAN(encoder_in, encoder_out, decoder_in, decoder_out, critic_in, critic_out, my_H, my_W, my_C, my_D)
  steganoGAN.build((None, my_H, my_W, my_C))
  steganoGAN.summary()
  return steganoGAN

Actually creating it prints out the parameters for each sub-network, and then prints out a final summary of the consolidated model.

In [158]:
steganoGAN = build_steganogan(subnet_summary=True)


Model: "steganogan_encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
image128x128x3 (InputLayer)     [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
image_preprocess (Conv2D)       (None, 128, 128, 32) 896         image128x128x3[0][0]             
__________________________________________________________________________________________________
image_preprocess_normalize (Bat (None, 128, 128, 32) 128         image_preprocess[0][0]           
__________________________________________________________________________________________________
data128x128x2 (InputLayer)      [(None, 128, 128, 2) 0                                            
_________________________________________________________________________________

## Helper Functions

This function displays a batch of images arranged into a grid, for instrumenting the training of our model. We'll be able to see a few test images along the way.

In [159]:
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
from math import sqrt

def display_batch(images, count, vmin=None, vmax=None):
  w = int(sqrt(count))
  f = plt.figure(figsize=(w, w))
  f.set_size_inches(w+2, w+2)
  gs1 = gridspec.GridSpec(w, w)
  gs1.update(wspace=0.025, hspace=0.05)
  for i in range(count):
    current_image = np.array(images[i, :, :, 0])
    # define subplot
    plt.subplot(w, w, 1 + i)
    # turn off axis
    plt.axis('off')
    # plot raw pixel data
    plt.imshow(current_image, cmap='gray', vmin=vmin, vmax=vmax)
  plt.show()

The class below was created for a normal GAN, but serves as a callback to display batches of test images (using the above function), throughout training.

It will need some modification to support SteganoGAN, whose architecture does not totally align with a standard GAN (i.e. not just "latent noise" as the generator input; an actual image is needed too).

In [None]:
class GanSampler(keras.callbacks.Callback):
  def __init__(self, print_seed=False, sample_square_width=1, epoch_interval=10):
    super(GanSampler, self).__init__()
    self.print_seed = print_seed
    self.num_samples = sample_square_width * sample_square_width
    self.w = sample_square_width
    self.epoch_interval = epoch_interval

  def display_fixed_seeds(self):
    generated_images = self.model.generator(self.fixed_seed)
    display_batch(generated_images, self.num_samples)

  def on_train_begin(self, logs=None):
    self.fixed_seed = self.model.generator_noise(batch_size=self.num_samples)

    if self.print_seed:
      print(f'\nFixed seeds for sampling: {self.fixed_seed}')
    
    print('\nGenerated images using fixed seeds, before ANY training:')
    self.display_fixed_seeds()

  def on_train_end(self, logs=None):
    print('\nFinal generated images using fixed seeds')
    self.display_fixed_seeds()
  
  def on_epoch_end(self, epoch, logs=None):
    if (epoch % self.epoch_interval == 0):
      print(f'\nGenerated image at end of epoch {epoch+1}, using fixed seeds')
      self.display_fixed_seeds()