<a href="https://colab.research.google.com/github/dbpedia/RDF2text-GAN/blob/master/Transformers/Adversarial_Training_latest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Adverserial training script


In [None]:
#! pip install tf-nightly-gpu

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf
import time
import numpy as np
import matplotlib.pyplot as plt
import io
import unicodedata
import re
from re import finditer

## Setup input pipeline

In [None]:

from google.colab import drive
drive.mount('/content/gdrive')

file_path = "/content/gdrive/My Drive/f_data.txt"
test_path = "/content/gdrive/My Drive/data/processed_graphs/eng/gat/test_data.txt"

In [None]:
from pretraining import *
from transformer_generator import *
from transformer_discriminator import *

In [None]:
batch_size = 16
max_len = 40
train_dataset, tokenizer_txt = create_generator_dataset(file_path, BATCH_SIZE=batch_size, MAX_LEN=max_len)

## Loss and metrics

In [None]:
def discriminator_loss(real_output, fake_output):

    '''
    Quantifies discriminator's ability to distinguish real sequences from fakes.
    It compares the discriminator's predictions on real sequences to an array of 1s,
    and the discriminator's predictions on fake (generated) sequences
    to an array of 0s.
    '''
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    real_loss = loss_object(tf.ones_like(real_output), real_output)
    fake_loss = loss_object(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss



def generator_loss(fake_output):

    '''
    Quantifies generator's ability to trick the discriminator. 
    If the generator is doing well, discriminator will classify 
    fake sequences as real (or 1). We thus compare the discriminators
    decisions on the generated sequences to an array of 1s.
    '''
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    fake_output = tf.convert_to_tensor(fake_output, dtype=tf.float32)
    loss_ = loss_object(tf.ones_like(fake_output,dtype=tf.float32), fake_output)
    return  loss_ 




## Set hyperparameters and training variables

To keep this example small and relatively fast, the values for *num_layers, d_model, and dff* have been reduced. 

The values used in the base model of transformer were; *num_layers=6*, *d_model = 512*, *dff = 2048*. See the [paper](https://arxiv.org/abs/1706.03762) for all the other versions of the transformer.

Note: By changing the values below, you can get the model that achieved state of the art on many tasks.

In [None]:
#Generator params
num_layers = 4
d_model = 128
dff = 512
num_heads = 8

input_vocab_size = target_vocab_size = tokenizer_txt.vocab_size + 2
 
dropout_rate = 0.1

generator_optimizer = tf.keras.optimizers.Adam(1e-4)


In [None]:
learning_rate = CustomSchedule(d_model)


train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
    name='train_accuracy')

## Define generator 

In [None]:
generator = Transformer(num_layers, d_model, num_heads, dff,
                          input_vocab_size, target_vocab_size, 
                          pe_input=input_vocab_size, 
                          pe_target=target_vocab_size,
                          rate=dropout_rate)

## Define pre-training functions 

In [None]:
def pretrain_loss_function(real, pred):
  '''
  # Sparse categorical crossentropy 
  # loss function used for generator pretraining
  '''
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
                                                              reduction='none')
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask

  return tf.reduce_sum(loss_)/tf.reduce_sum(mask)



def pretrain_step(inp, tar):
    '''
    # Pretraining step for generator network
    '''
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]
  
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
  
    with tf.GradientTape() as tape:

        predictions, _ = generator(inp, tar_inp, 
                                   True, 
                                   enc_padding_mask, 
                                   combined_mask, 
                                   dec_padding_mask)
        loss = pretrain_loss_function(tar_real, predictions)

  gradients = tape.gradient(loss, generator.trainable_variables)    
  generator_optimizer.apply_gradients(zip(gradients, generator.trainable_variables))
  train_loss(loss)
  train_accuracy(tar_real, predictions)

## Pass data through generator to be able to load in weights

In [None]:

for (inpt, targ) in train_dataset:
  pretrain_step(inpt, targ)
  print('Loss {:.4f} \nAccuracy {:.4f}'.format(
                                   train_loss.result(),
                                   train_accuracy.result()))
  break


## Define discriminator 

In [None]:
# Define discriminator and load in weights
DATA_MAX_LEN = 135
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator = TransformerDiscriminator2(tokenizer_txt.vocab_size+2, maxlen=DATA_MAX_LEN)


## Load in weights from earlier pre-training (Each model trained for 10 epochs)

In [None]:
generator.load_weights('./generator_weights.h5')
discriminator.load_weights('./discriminator_weights.h5')

## Define helper functions to render generations

In [None]:
def render_preds(batch_pred, inp, tar, n=2):
    '''
    Print out input, target, and preds of n batch elements
    '''
    print(type(batch_pred), type(inp), batch_pred.shape, inp.shape)
    for (ind,i) in enumerate(batch_pred):
      print('\n| Predicted: ', decode_text(i, tokenizer_txt))
      print('| True: ', decode_text(tar[ind], tokenizer_txt))
      print('| Input RDF: ', decode_text(inp[ind], tokenizer_txt))
      print()
      if ind==n:
        break


## Define adversarial training step

In [None]:
def train_step(inp, tar):
    # targets shifted by 1 index position
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]
    #Get encoding, combined and decoding masks
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

    # Initialize Generator gradient tape
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:

        # Get prediction probabilities from generator
        predictions, _ = generator(inp, tar_inp, 
                             True, 
                             enc_padding_mask, 
                             combined_mask, 
                             dec_padding_mask)
        # Get predicted sequences for batch
        batch_pred = tf.argmax(predictions, axis=-1)

        # Pad predicted batch
        batch_pred = tf.keras.preprocessing.sequence.pad_sequences(batch_pred, padding='post',
                                                                   value=0, maxlen=tar.shape[-1])
        # Get discriminator's predictions of real & generated output
        disc_preds_real = discriminator([inp, tar], training=True)
        disc_preds_fake = discriminator([inp, batch_pred], training=True)

        # Calculate loss using discriminator and generator loss functions
        d_loss = discriminator_loss(disc_preds_real, disc_preds_fake)
        g_loss = generator_loss(disc_preds_fake)

    # Get discriminator gradients and apply using optimizer
    disc_grads = disc_tape.gradient(d_loss, discriminator.trainable_weights)
    discriminator_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_weights))
    
    # Get generator gradients and apply using optimizer
    gen_grads = gen_tape.gradient(g_loss, generator.trainable_weights)
    generator_optimizer.apply_gradients(zip(gen_grads, generator.trainable_weights))

In [None]:
EPOCHS = 10

## Define training function

In [None]:
def train():
  '''
  Function to initialize training process
  Prints Generator and discriminator loss during training
  '''
  for epoch in range(EPOCHS):
    start = time.time()
    
    train_loss.reset_states()
    train_accuracy.reset_states()
    
    for (batch, (inp, tar)) in enumerate(train_dataset):
      train_step(inp, tar)
      
      if batch % 50 == 0:
        print ('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(
            epoch + 1, batch, train_loss.result(), train_accuracy.result()))
        
    if (epoch + 1) % 5 == 0:
      ckpt_save_path = ckpt_manager.save()
      print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                          ckpt_save_path))
      
    print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
                                                  train_loss.result(), 
                                                  train_accuracy.result()))

    print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

In [None]:
train()

In [None]:
generator.save_weights('./generator_weights.h5')

## Evaluate

In [None]:
def evaluate_(inp_sentence):

  encoder_input = tf.expand_dims(inp_sentence, 0)

  decoder_input = [tokenizer_txt.vocab_size]
  output = tf.expand_dims(decoder_input, 0)
    
  for i in range(MAX_LENGTH):
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
        encoder_input, output)
  
    # predictions.shape == (batch_size, seq_len, vocab_size)
    predictions, attention_weights = transformer(encoder_input, 
                                                 output,
                                                 False,
                                                 enc_padding_mask,
                                                 combined_mask,
                                                 dec_padding_mask)
    
    # select the last word from the seq_len dimension
    predictions = predictions[: ,-1:, :]  # (batch_size, 1, vocab_size)

    predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
    
    # return the result if the predicted_id is equal to the end token
    if predicted_id == tokenizer_txt.vocab_size+1:
      return tf.squeeze(output, axis=0)
    
    # concatentate the predicted_id to the output which is given to the decoder
    # as its input.
    output = tf.concat([output, predicted_id], axis=-1)

  return tf.squeeze(output, axis=0)

In [None]:
MAX_LENGTH=250
rdfb, txtb = next(iter(train_dataset))

In [None]:
predicted_sentence = evaluate_(rdfb[0])

In [None]:
decode_text(predicted_sentence, tokenizer_txt)