In [1]:
%load_ext autoreload

%autoreload 2
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow.keras.callbacks as cb
import time
import numpy as np
from data import get_datasets
import matplotlib.pyplot as plt
import os
import shutil
from functools import partial
from data import get_datasets
from experiment import create_model
from sentiment import sentiment
import model_adv

In [2]:
(train_dataset, test_dataset), info = get_datasets()
encoder = info.features['text'].encoder
vocab_size=info.features['text'].encoder.vocab_size 

In [3]:
transformer, _optimizer, _checkpoint, _manager = create_model(models_dir="models/pos_enc_True", load_checkpoint=True, 
                           vocab_size=vocab_size, use_positional_encoding=True, run_eagerly=True)

# adv = model_adv.create_model(models_dir="adv", load_checkpoint=False, run_eagerly=True, d_model=512)

Loaded previously trained model.


In [4]:
EPOCHS = 100
D_MODEL = 128
def find_adv_k(x, y, transformer):
    """
    x - (batch, text)
    y - (batch, label)
    
    Passes x through transformer and receives y_logits, w, k.
    Creates an adversarial model - adv.
    Pass k through adv model and receives adv_k.
    Computes loss by passing (x, custom_k=adv_k) through transformer:
        Receives adv_y_logits, adv_w, adv_k (same k)
        loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=adv_y_logits, from_logits=True)
        loss_k = - |k - adv_k|
        loss = loss_t + loss_k
    
    Returns: 
        k - original k
        w - original attention weights
        adv_k - adversarial k
        adv_w - adversarial attention weights w
        loss_k - the one that is maximized
        loss_t - transformer loss (the one that is minimized)
        loss - adversarial loss, e.g.  loss = loss_t = loss_k
    """
    y_logits, w, k = transformer(x, training=False)
    print(f'Returned k: {tf.shape(k)}')
    y_logits2, _, _ = transformer(x, training=False, custom_k=k)
    
    adv, optimizer = model_adv.create_model(models_dir="adv", load_checkpoint=False, 
                                            run_eagerly=True, d_model=D_MODEL)
    
    loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=y_logits, from_logits=True)
    
    print(f'[{0}] Loss_t = {loss_t:<20.10}')
    loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=y_logits2, from_logits=True)
    print(f'[{0}] Loss_t (custom k) = {loss_t:<20.10}')
    alpha = 0.9999
    for epoch in range(1, EPOCHS+1):        
        with tf.GradientTape() as tape:
            adv_k = adv(k, training=True)
            
            adv_y_logits, adv_w, adv_k2  = transformer(x, custom_k=adv_k, training=False)
            cond = tf.reduce_all(tf.equal(adv_k, adv_k2)).numpy()
            bs = tf.shape(k)[0]
            assert cond == True
            loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=adv_y_logits, from_logits=True)
            summed = tf.math.reduce_sum(tf.math.pow(tf.reshape(k, [bs, -1]) - tf.reshape(adv_k, [bs, -1]), 2), axis=1)
            loss_k = - tf.math.reduce_mean(summed)
            loss = alpha*loss_t + (1.0 - alpha)*loss_k
        
        print(f'[{epoch:<2}] Loss_t = {loss_t:<20.10}  Loss_k = {loss_k:<20.10}   Loss = {loss:<20.10}')

        grads = tape.gradient(loss, adv.trainable_weights)                
        optimizer.apply_gradients(zip(grads, adv.trainable_weights))
  

In [5]:
for batch, (x, y) in enumerate(train_dataset):
    find_adv_k(x, y, transformer)

Returned k: [ 64   4 197 128]
[0] Loss_t = 0.2537696958        
[0] Loss_t (custom k) = 0.2537696958        
[1 ] Loss_t = 0.7581431866          Loss_k = -159044.7031           Loss = -15.14640236        
[2 ] Loss_t = 0.550424099           Loss_k = -184534.5312           Loss = -17.9030838         
[3 ] Loss_t = 0.3905030489          Loss_k = -214514.8125           Loss = -21.06101608        
[4 ] Loss_t = 0.2757305205          Loss_k = -249708.4375           Loss = -24.69513893        
[5 ] Loss_t = 0.2109524012          Loss_k = -291155.8125           Loss = -28.90464783        
[6 ] Loss_t = 0.1844428182          Loss_k = -340043.625            Loss = -33.81993484        
[7 ] Loss_t = 0.1791913807          Loss_k = -397746.5625           Loss = -39.59548187        
[8 ] Loss_t = 0.1847180128          Loss_k = -465797.7188           Loss = -46.39506912        
[9 ] Loss_t = 0.1941983253          Loss_k = -546062.5625           Loss = -54.41207504        
[10] Loss_t = 0.203651756  

KeyboardInterrupt: 