In [1]:
%load_ext autoreload

%autoreload 2
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow.keras.callbacks as cb
import time
import numpy as np
from data import get_datasets
import matplotlib.pyplot as plt
import os
import shutil
from functools import partial
from data import get_datasets
from experiment import create_model
from sentiment import sentiment
import model_adv

In [2]:
(train_dataset, test_dataset), info = get_datasets(batch_size=1)
encoder = info.features['text'].encoder
vocab_size=info.features['text'].encoder.vocab_size 

In [3]:
transformer, _optimizer, _checkpoint, _manager = create_model(models_dir="models/pos_enc_True", load_checkpoint=True, 
                           vocab_size=vocab_size, use_positional_encoding=True, run_eagerly=True)

# adv = model_adv.create_model(models_dir="adv", load_checkpoint=False, run_eagerly=True, d_model=512)

Loaded previously trained model.


In [4]:
EPOCHS = 50
D_MODEL = 128

def find_adv_k(x, y, transformer):
    """
    x - (batch, text)
    y - (batch, label)
    
    Passes x through transformer and receives y_logits, w, k.
    Creates an adversarial model - adv.
    Pass k through adv model and receives adv_k.
    Computes loss by passing (x, custom_k=adv_k) through transformer:
        Receives adv_y_logits, adv_w, adv_k (same k)
        loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=adv_y_logits, from_logits=True)
        loss_k = - |k - adv_k|
        loss = loss_t + loss_k
    
    Returns: 
        k - original k
        w - original attention weights
        adv_k - adversarial k
        adv_w - adversarial attention weights w
        loss_k - the one that is maximized
        loss_t - transformer loss (the one that is minimized)
        loss - adversarial loss, e.g.  loss = loss_t = loss_k
    """
    best_loss_t = None
    best_adv_k = None
    

    y_logits, w, k = transformer(x, training=False)
    print(f'Returned k: {tf.shape(k)}')
    y_logits2, _, _ = transformer(x, training=False, custom_k=k)
    
    adv, optimizer = model_adv.create_model(models_dir="adv", load_checkpoint=False, 
                                            run_eagerly=True, d_model=D_MODEL)
    
    loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=y_logits, from_logits=True)
    best_loss_t = loss_t
    original_loss_t = loss_t
    original_k = k
    
    print(f'[{0}] Loss_t = {loss_t:<20.10}')
    loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=y_logits2, from_logits=True)
    print(f'[{0}] Loss_t (custom k) = {loss_t:<20.10}')
    
    alpha = 0.9999
    
    
    for epoch in range(1, EPOCHS+1):        
        with tf.GradientTape() as tape:
            adv_k = adv(k, training=True)
            if best_adv_k is None:
                best_adv_k = adv_k
            
            adv_y_logits, adv_w, adv_k2  = transformer(x, custom_k=adv_k, training=False)
            cond = tf.reduce_all(tf.equal(adv_k, adv_k2)).numpy()
            bs = tf.shape(k)[0]
            assert cond == True
            loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=adv_y_logits, from_logits=True)
            summed = tf.math.reduce_sum(tf.math.pow(tf.reshape(k, [bs, -1]) - tf.reshape(adv_k, [bs, -1]), 2), axis=1)
            
            loss_k = - tf.math.reduce_mean(summed)
            loss = alpha*loss_t + (1.0 - alpha)*loss_k
            
            if loss_t < best_loss_t:
                best_loss_t = loss_t
                best_adv_k  = adv_k
            
        
        print(f'[{epoch:<2}] Loss_t = {loss_t:<20.10}  Loss_k = {loss_k:<20.10}   Loss = {loss:<20.10}')

        grads = tape.gradient(loss, adv.trainable_weights)                
        optimizer.apply_gradients(zip(grads, adv.trainable_weights))
  
    return original_loss_t, original_k, best_loss_t, best_adv_k

In [5]:
x, y = next(iter(train_dataset))
#     for batch, (x, y) in enumerate(train_dataset):
loss_t, k, best_loss_t, best_adv_k = find_adv_k(x, y, transformer)

Returned k: [  1   4  56 128]
[0] Loss_t = 0.02403318696       
[0] Loss_t (custom k) = 0.02403318696       
[1 ] Loss_t = 0.05705162138         Loss_k = -43891.63281           Loss = -4.332117081        
[2 ] Loss_t = 0.05273097008         Loss_k = -51835.21094           Loss = -5.130795002        
[3 ] Loss_t = 0.04886035249         Loss_k = -61121.37109           Loss = -6.063281536        
[4 ] Loss_t = 0.04534782469         Loss_k = -71994.59375           Loss = -7.154115677        
[5 ] Loss_t = 0.04206982255         Loss_k = -84694.39062           Loss = -8.427372932        
[6 ] Loss_t = 0.03912955523         Loss_k = -99606.90625           Loss = -9.921565056        
[7 ] Loss_t = 0.03650836274         Loss_k = -117143.4219           Loss = -11.67783737        
[8 ] Loss_t = 0.03418093175         Loss_k = -137766.4531           Loss = -13.74246693        
[9 ] Loss_t = 0.03208424896         Loss_k = -162000.8125           Loss = -16.16799927        
[10] Loss_t = 0.03020050563

In [9]:
print(f'Best loss_t: {best_loss_t}')
print(f'Original loss_t: {loss_t}')

Best loss_t: 0.013373379595577717
Original loss_t: 0.024033186957240105


In [10]:
print(k[0][0][0])
print(best_adv_k[0][0][0])

tf.Tensor(
[-0.10514279 -0.7446283  -0.37535304  0.7732584  -0.1891772   0.7325174
 -0.05518687  1.0767896   0.11997275  0.49861163 -0.09280941 -0.0515768
 -0.61197674  0.60941297 -0.20162156 -1.0039016  -1.4525076  -0.34856468
  0.31456915  0.7075647  -1.091167   -0.68078005 -0.67013174 -0.96027386
 -0.7915867  -0.23067713 -1.0558378   0.13707896  0.7173154   0.00786834
  0.85752827  0.90897775 -0.31805348  0.44462976  0.3518911   0.15995233
 -0.21541786  0.5320765  -0.3091937   0.23128088  0.84781516  0.89916176
 -0.2199289   0.0950876   0.13624406  1.1994206   0.43694854  0.5668871
 -0.8313954   0.1284687   0.8594588   0.70675844  0.3895667  -0.21708894
  1.3379259  -0.90666705 -0.6052241   0.17854702 -0.89288545 -0.02649421
 -0.70923495 -1.9466168   0.58874416  0.4724625  -1.5841348  -0.2232738
 -0.17609075  0.48799652 -0.96991855  1.281586   -1.060228   -0.46949005
 -0.13753739 -0.37512478  1.5053993  -0.01799305 -0.10419203  0.20251174
 -1.3018364   0.28232265  0.716445   -0.6071

In [11]:
np.save('adv_results/x', x.numpy())
np.save('adv_results/y', y.numpy())
np.save('adv_results/best_adv_k', best_adv_k.numpy())
np.save('adv_results/k', k.numpy())
np.save('adv_results/loss_t', loss_t.numpy())
np.save('adv_results/best_loss_t', best_loss_t.numpy())

In [12]:
# Load the best values
np.load('adv_results/x.npy')
np.load('adv_results/y.npy')
np.load('adv_results/best_adv_k.npy')
np.load('adv_results/k.npy')
print(np.load('adv_results/loss_t.npy'))
print(np.load('adv_results/best_loss_t.npy'))

0.024033187
0.01337338
