In [1]:
%load_ext autoreload

%autoreload 2
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow.keras.callbacks as cb
import time
import numpy as np
from data import get_datasets
import matplotlib.pyplot as plt
import os
import shutil
from functools import partial
from data import get_datasets
from experiment import create_model, evaluate
from sentiment import sentiment
import model_adv

In [2]:
(train_dataset, test_dataset), info = get_datasets() #batch_size=1)
encoder = info.features['text'].encoder
vocab_size=info.features['text'].encoder.vocab_size 

In [3]:
transformer, _optimizer, _checkpoint, _manager = create_model(models_dir="models/pos_enc_True", load_checkpoint=True, 
                           vocab_size=vocab_size, use_positional_encoding=True, run_eagerly=True)

# adv = model_adv.create_model(models_dir="adv", load_checkpoint=False, run_eagerly=True, d_model=512)

Loaded previously trained model.


In [4]:
# Initial accuracy
metrics = evaluate(transformer, test_dataset)
print(metrics.result().numpy())

0.82691425


In [5]:
EPOCHS = 50
D_MODEL = 128

adv, optimizer = model_adv.create_model(models_dir="adv", load_checkpoint=False, 
                                            run_eagerly=True, d_model=D_MODEL)
    
def find_adv_k(x, y, transformer):
    """
    x - (batch, text)
    y - (batch, label)
    
    Passes x through transformer and receives y_logits, w, k.
    Creates an adversarial model - adv.
    Pass k through adv model and receives adv_k.
    Computes loss by passing (x, custom_k=adv_k) through transformer:
        Receives adv_y_logits, adv_w, adv_k (same k)
        loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=adv_y_logits, from_logits=True)
        loss_k = - |k - adv_k|
        loss = loss_t + loss_k
    
    Returns: 
        k - original k
        w - original attention weights
        adv_k - adversarial k
        adv_w - adversarial attention weights w
        loss_k - the one that is maximized
        loss_t - transformer loss (the one that is minimized)
        loss - adversarial loss, e.g.  loss = loss_t = loss_k
    """
    best_loss_t = None
    best_adv_k = None
    

    y_logits, w, k = transformer(x, training=False)
    print(f'Returned k: {tf.shape(k)}')
    y_logits2, _, _ = transformer(x, training=False, custom_k=k)
    
    
    loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=y_logits, from_logits=True)
    best_loss_t = loss_t
    original_loss_t = loss_t
    original_k = k
    
    print(f'[{0}] Loss_t = {loss_t:<20.10}')
    loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=y_logits2, from_logits=True)
    print(f'[{0}] Loss_t (custom k) = {loss_t:<20.10}')
    
    alpha = 0.5
    
    
    for epoch in range(1, EPOCHS+1):        
        with tf.GradientTape() as tape:
            adv_k = adv(k, training=True)
            if best_adv_k is None:
                best_adv_k = adv_k
            
            adv_y_logits, adv_w, adv_k2  = transformer(x, custom_k=adv_k, training=False)
            cond = tf.reduce_all(tf.equal(adv_k, adv_k2)).numpy()
            bs = tf.shape(k)[0]
            assert cond == True
            loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=adv_y_logits, from_logits=True)
            summed = tf.math.reduce_sum(tf.math.pow(tf.reshape(k, [bs, -1]) - tf.reshape(adv_k, [bs, -1]), 2), axis=1)
            
            loss_k = - tf.math.log(tf.math.reduce_mean(summed))
            
            loss = alpha*loss_t + (1 - alpha)*loss_k
            
            if loss_t < best_loss_t:
                best_loss_t = loss_t
                best_adv_k  = adv_k
            
        
        print(f'[{epoch:<2}] Loss_t = {loss_t:<20.10}  Loss_k = {loss_k:<20.10}   Loss = {loss:<20.10}')

        grads = tape.gradient(loss, adv.trainable_weights)                
        optimizer.apply_gradients(zip(grads, adv.trainable_weights))
  
    return original_loss_t, original_k, best_loss_t, best_adv_k

In [9]:
x, y = next(iter(train_dataset))
#     for batch, (x, y) in enumerate(train_dataset):
loss_t, k, best_loss_t, best_adv_k = find_adv_k(x, y, transformer)

Returned k: [ 64   4 199 128]
[0] Loss_t = 0.1198091134        
[0] Loss_t (custom k) = 0.1198091134        
[1 ] Loss_t = 0.2460487336          Loss_k = -17.12693787           Loss = -8.440444946        
[2 ] Loss_t = 0.1655866802          Loss_k = -17.18998146           Loss = -8.512197495        
[3 ] Loss_t = 0.1474911571          Loss_k = -17.25179863           Loss = -8.552153587        
[4 ] Loss_t = 0.145966202           Loss_k = -17.31235313           Loss = -8.583193779        
[5 ] Loss_t = 0.1431607604          Loss_k = -17.37161255           Loss = -8.614226341        
[6 ] Loss_t = 0.138781175           Loss_k = -17.42959595           Loss = -8.645407677        
[7 ] Loss_t = 0.1318671852          Loss_k = -17.48634338           Loss = -8.677238464        
[8 ] Loss_t = 0.123671256           Loss_k = -17.54191589           Loss = -8.709122658        
[9 ] Loss_t = 0.117032364           Loss_k = -17.59639549           Loss = -8.739681244        
[10] Loss_t = 0.1128930598 

In [10]:
metrics = evaluate(transformer, test_dataset, adv_model=adv)
print(metrics.result().numpy())

0.79986244


In [34]:
print(tf.norm(k).numpy())
print(tf.norm(best_adv_k).numpy())

214.1467
79561.1


In [14]:
print(f'Original loss_t: {loss_t}')
print(f'Best loss_t:     {best_loss_t}')

Original loss_t: 0.0007400644826702774
Best loss_t:     0.00014332833234220743


In [24]:
print(tf.reshape(k[0][0][0],(-1, 4)))
print(tf.reshape(best_adv_k[0][0][0], (-1, 4)))

tf.Tensor(
[[-1.0622562  -0.23984744  1.0748044   1.3748744 ]
 [-0.69511235  0.541495   -0.26275066  0.00983193]
 [ 0.587861    0.8562857   0.37515846  0.7165856 ]
 [-0.8721777  -0.08632609 -0.13359296 -0.27416104]
 [-0.9817763  -0.07322308 -0.07859396  1.1635443 ]
 [-0.5991031  -0.3586921  -0.63432235 -1.2879654 ]
 [-1.3708589   0.6427558   0.01798093 -0.36737213]
 [ 2.1075368   0.32159662  0.22088861  0.41846114]
 [ 0.18461812 -0.52916014  1.0895694  -1.4266332 ]
 [-0.21394011 -0.13209644 -0.45749953  0.2583496 ]
 [ 0.8876147   0.55790997  0.8856335   0.17673488]
 [ 1.8994558   0.9954314   1.8560383   1.3102771 ]
 [-0.1030802   0.2939429   1.0747086   0.5858982 ]
 [ 1.9763365  -0.88368344  2.5498774  -0.2831556 ]
 [ 1.8231454   0.20237808 -0.8465917   0.2549476 ]
 [-1.0648656  -1.008372    1.7104899   0.34775013]
 [-0.35755393 -1.0473596  -0.05636305  1.8079398 ]
 [-0.05273503 -0.01887925 -1.3248074  -1.0591247 ]
 [-0.7161503   0.5835323   2.18067     0.10653253]
 [ 1.4768964  -0.481

In [11]:
np.save('adv_results/x', x.numpy())
np.save('adv_results/y', y.numpy())
np.save('adv_results/best_adv_k', best_adv_k.numpy())
np.save('adv_results/k', k.numpy())
np.save('adv_results/loss_t', loss_t.numpy())
np.save('adv_results/best_loss_t', best_loss_t.numpy())

In [12]:
# Load the best values
np.load('adv_results/x.npy')
np.load('adv_results/y.npy')
np.load('adv_results/best_adv_k.npy')
np.load('adv_results/k.npy')
print(np.load('adv_results/loss_t.npy'))
print(np.load('adv_results/best_loss_t.npy'))

0.024033187
0.01337338
