<a href="https://colab.research.google.com/github/dbpedia/RDF2text-GAN/blob/master/Transformers/Adverserial_training(alternate%20approach).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### GAN Class for adverserial training

In [None]:
#! pip install tf-nightly-gpu

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf
import time
import numpy as np
import matplotlib.pyplot as plt
import io
import unicodedata
import re
from re import finditer

## Setup input pipeline

In [None]:
from google.colab import drive

drive.mount('/content/gdrive')
file_path = "/content/gdrive/My Drive/f_data.txt"
test_path = "/content/gdrive/My Drive/data/processed_graphs/eng/gat/test_data.txt"

In [None]:
from pretraining import *
from transformer_generator import *
from transformer_discriminator import *

In [None]:

train_dataset, tokenizer_txt = create_generator_dataset(file_path, BATCH_SIZE=16)

## Loss and metrics

In [None]:
def discriminator_loss(real_output, fake_output):
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    real_loss = loss_object(tf.ones_like(real_output), real_output)
    fake_loss = loss_object(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

#Primary loss for plain adverserial training
def generator_loss(real_output, fake_output):
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    loss_ = loss_object(tf.ones_like(fake_output), fake_output)
    return  loss_ 

## Set hyperparameters

To keep this example small and relatively fast, the values for *num_layers, d_model, and dff* have been reduced. 

The values used in the base model of transformer were; *num_layers=6*, *d_model = 512*, *dff = 2048*. See the [paper](https://arxiv.org/abs/1706.03762) for all the other versions of the transformer.

Note: By changing the values below, you can get the model that achieved state of the art on many tasks.

In [None]:
num_layers = 4
d_model = 128
dff = 512
num_heads = 8

target_vocab_size = tokenizer_txt.vocab_size + 2
input_vocab_size = target_vocab_size
dropout_rate = 0.1

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)

learning_rate = CustomSchedule(d_model)

In [None]:
generator = Transformer(num_layers, d_model, num_heads, dff,
                          input_vocab_size, target_vocab_size, 
                          pe_input=input_vocab_size, 
                          pe_target=target_vocab_size,
                          rate=dropout_rate)


In [None]:
DATA_MAX_LEN = 250
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

discriminator = TransformerDiscriminator(tokenizer_txt.vocab_size+2, maxlen=DATA_MAX_LEN)



In [None]:
#generator.load_weights('./generator_weights.h5')
#discriminator.load_weights('./discriminator_weights.h5')

Create the checkpoint path and the checkpoint manager. This will be used to save checkpoints every `n` epochs.

In [None]:
class GAN(keras.Model):

    def __init__(self, discriminator, generator, tokenizer, batch_size=16):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.tokenizer_txt = tokenizer
        self.batch_size = batch_size

    def compile(self, d_optimizer, g_optimizer, d_loss, g_loss):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.g_loss = g_loss
        self.d_loss = d_loss

    def pad(self, tensor, maxlen=250):
      return tf.keras.preprocessing.sequence.pad_sequences(tensor,
                                                            padding='post',
                                                            value=0,
                                                            maxlen=maxlen)
    @tf.function  
    def tf_disc_predict(self, input_):

      def get_disc_predictions(input_):
        return self.discriminator(input_)

      rv = tf.py_function(get_disc_predictions, inp=[input_], Tout=[tf.float32])
      return rv

    @tf.function( experimental_relax_shapes=True )
    def tf_gen_batch(self, preds, inp,  tar):

      def gen_batch(preds, inp, tar, max_len = 100):

        def decode_text(array, tokenizer):
          return tokenizer.decode([i for i in array if 0<i< tokenizer.vocab_size])

        disc_data = []
        for sent in preds:
          unparsed = decode_text(sent, self.tokenizer_txt)
          retokenized = self.tokenizer_txt.encode(unparsed.split('<end>')[0]+'<end>')
          disc_data.append(padded)
        
        disc_data = self.pad(disc_data)
        gens = self.pad(tf.concat([inp, disc_data], axis=-1, name='concat'))
        real = self.pad(tf.concat([inp, tar], axis=-1, name='concat'))
        all_data = tf.concat([gens, real], axis=0)
        all_labels = tf.concat([ tf.zeros((self.batch_size, 1)) ,
                                 tf.ones((self.batch_size, 1))],
                                axis=0)
        
        return all_data, all_labels, gens

      all_data, all_labels, gens = tf.py_function(gen_batch, inp=[preds, inp, tar], Tout=[tf.int32, tf.float32, tf.float32])

      return all_data, all_labels, gens

    def train_step(self, data):
        inp, tar = data
        tar_inp = tar[:, :-1]
        tar_real = tar[:, 1:]
        
        enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

        predictions, _ = generator(inp, tar_inp, 
                                 False, 
                                 enc_padding_mask, 
                                 combined_mask, 
                                 dec_padding_mask)
    
        batch_pred = tf.argmax(predictions, axis=-1)
        all_, labels, gens = self.tf_gen_batch(batch_pred, inp, tar)
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(all_)
            d_loss_ = self.d_loss(labels, predictions)

        grads = tape.gradient(d_loss_, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )
        
        # Train the generator 
        with tf.GradientTape() as tape:
            tape.watch(gens)
            predictions = self.tf_disc_predict(gens)

            # Assemble labels that say "all real images"
            misleading_labels = tf.ones((self.batch_size, 1))
            g_loss = self.g_loss(misleading_labels, predictions)

        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
        return {"d_loss": d_loss, "g_loss": g_loss}

In [None]:
EPOCHS = 10

In [None]:
gan = GAN(discriminator=discriminator, generator=generator, tokenizer=tokenizer_txt)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    d_loss = discriminator_loss,
    g_loss= generator_loss
)

# To limit execution time, we only train on 100 batches. You can train on
# the entire dataset. You will need about 20 epochs to get nice results.

In [None]:
gan.fit(train_dataset, epochs=1)

In [None]:
def render_preds(batch_pred, inp, tar, n=2):
    print(type(batch_pred), type(inp), batch_pred.shape, inp.shape)
    for (ind,i) in enumerate(batch_pred):
      print('\n| Predicted: ', decode_text(i, tokenizer_txt))
      print('| True: ', decode_text(tar[ind], tokenizer_txt))
      print('| Input RDF: ', decode_text(inp[ind], tokenizer_txt))
      print()
      if ind==n:
        break

## Evaluate

In [None]:
def evaluate_(inp_sentence):

  encoder_input = tf.expand_dims(inp_sentence, 0)

  decoder_input = [tokenizer_txt.vocab_size]
  output = tf.expand_dims(decoder_input, 0)
    
  for i in range(MAX_LENGTH):
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
        encoder_input, output)
  
    # predictions.shape == (batch_size, seq_len, vocab_size)
    predictions, attention_weights = transformer(encoder_input, 
                                                 output,
                                                 False,
                                                 enc_padding_mask,
                                                 combined_mask,
                                                 dec_padding_mask)
    
    # select the last word from the seq_len dimension
    predictions = predictions[: ,-1:, :]  # (batch_size, 1, vocab_size)

    predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
    
    # return the result if the predicted_id is equal to the end token
    if predicted_id == tokenizer_txt.vocab_size+1:
      return tf.squeeze(output, axis=0)
    
    # concatentate the predicted_id to the output which is given to the decoder
    # as its input.
    output = tf.concat([output, predicted_id], axis=-1)

  return tf.squeeze(output, axis=0)

In [None]:
MAX_LENGTH=250
rdfb, txtb = next(iter(train_dataset))

In [None]:
predicted_sentence = evaluate_(rdfb[0])

In [None]:
decode_text(predicted_sentence, tokenizer_txt)