In [1]:
%load_ext autoreload

%autoreload 2
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow.keras.callbacks as cb
import time
import numpy as np
from data import get_datasets
import matplotlib.pyplot as plt
import os
import shutil
from functools import partial
from data import get_datasets
from experiment import create_model
from sentiment import sentiment
import model_adv

In [2]:
(train_dataset, test_dataset), info = get_datasets()
encoder = info.features['text'].encoder
vocab_size=info.features['text'].encoder.vocab_size 

In [3]:
transformer, _optimizer, _checkpoint, _manager = create_model(models_dir="models/pos_enc_True", load_checkpoint=True, 
                           vocab_size=vocab_size, use_positional_encoding=True, run_eagerly=True)

# adv = model_adv.create_model(models_dir="adv", load_checkpoint=False, run_eagerly=True, d_model=512)

Loaded previously trained model.


In [4]:
EPOCHS = 50
D_MODEL = 128



def find_adv_k(x, y, transformer):
    """
    x - (batch, text)
    y - (batch, label)
    
    Passes x through transformer and receives y_logits, w, k.
    Creates an adversarial model - adv.
    Pass k through adv model and receives adv_k.
    Computes loss by passing (x, custom_k=adv_k) through transformer:
        Receives adv_y_logits, adv_w, adv_k (same k)
        loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=adv_y_logits, from_logits=True)
        loss_k = - |k - adv_k|
        loss = loss_t + loss_k
    
    Returns: 
        k - original k
        w - original attention weights
        adv_k - adversarial k
        adv_w - adversarial attention weights w
        loss_k - the one that is maximized
        loss_t - transformer loss (the one that is minimized)
        loss - adversarial loss, e.g.  loss = loss_t = loss_k
    """
    best_loss_t = None
    best_adv_k = None
    

    y_logits, w, k = transformer(x, training=False)
    print(f'Returned k: {tf.shape(k)}')
    y_logits2, _, _ = transformer(x, training=False, custom_k=k)
    
    adv, optimizer = model_adv.create_model(models_dir="adv", load_checkpoint=False, 
                                            run_eagerly=True, d_model=D_MODEL)
    
    loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=y_logits, from_logits=True)
    best_loss_t = loss_t
    original_loss_t = loss_t
    original_k = k
    
    print(f'[{0}] Loss_t = {loss_t:<20.10}')
    loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=y_logits2, from_logits=True)
    print(f'[{0}] Loss_t (custom k) = {loss_t:<20.10}')
    
    alpha = 0.9999
    
    
    for epoch in range(1, EPOCHS+1):        
        with tf.GradientTape() as tape:
            adv_k = adv(k, training=True)
            if best_adv_k is None:
                best_adv_k = adv_k
            
            adv_y_logits, adv_w, adv_k2  = transformer(x, custom_k=adv_k, training=False)
            cond = tf.reduce_all(tf.equal(adv_k, adv_k2)).numpy()
            bs = tf.shape(k)[0]
            assert cond == True
            loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=adv_y_logits, from_logits=True)
            summed = tf.math.reduce_sum(tf.math.pow(tf.reshape(k, [bs, -1]) - tf.reshape(adv_k, [bs, -1]), 2), axis=1)
            
            loss_k = - tf.math.reduce_mean(summed)
            loss = alpha*loss_t + (1.0 - alpha)*loss_k
            
            if loss_t < best_loss_t:
                best_loss_t = loss_t
                best_adv_k  = adv_k
            
        
        print(f'[{epoch:<2}] Loss_t = {loss_t:<20.10}  Loss_k = {loss_k:<20.10}   Loss = {loss:<20.10}')

        grads = tape.gradient(loss, adv.trainable_weights)                
        optimizer.apply_gradients(zip(grads, adv.trainable_weights))
  
    return original_loss_t, original_k, best_loss_t, best_adv_k

In [5]:
x, y = next(iter(train_dataset))
#     for batch, (x, y) in enumerate(train_dataset):
loss_t, k, best_loss_t, best_adv_k = find_adv_k(x, y, transformer)

Returned k: [ 64   4 194 128]
[0] Loss_t = 0.1850163043        
[0] Loss_t (custom k) = 0.1850163043        
[1 ] Loss_t = 0.4331596494          Loss_k = -150493.1719           Loss = -14.61620045        
[2 ] Loss_t = 0.3474534154          Loss_k = -173734.3594           Loss = -17.02601814        
[3 ] Loss_t = 0.291121304           Loss_k = -200938.8594           Loss = -19.8027935         
[4 ] Loss_t = 0.2605164051          Loss_k = -232728.6406           Loss = -23.01237297        
[5 ] Loss_t = 0.2482067347          Loss_k = -269986.3125           Loss = -26.75044823        
[6 ] Loss_t = 0.2455919236          Loss_k = -313791.125            Loss = -31.13354492        
[7 ] Loss_t = 0.2470159978          Loss_k = -365251.5              Loss = -36.2781601         
[8 ] Loss_t = 0.2497117817          Loss_k = -425645.5625           Loss = -42.31486893        
[9 ] Loss_t = 0.2523645759          Loss_k = -496556.5625           Loss = -49.4033165         
[10] Loss_t = 0.2542693615 

In [6]:
print(loss_t)
print(best_loss_t)
print(k[0][0][0])
print(best_adv_k[0][0][0])

tf.Tensor(0.1850163, shape=(), dtype=float32)
tf.Tensor(0.1850163, shape=(), dtype=float32)
tf.Tensor(
[-5.01886308e-02 -1.64749131e-01  6.57088682e-02  1.29200399e+00
 -6.17367089e-01  4.79863137e-02  2.02779993e-01  2.81210303e-01
  1.17436022e-01  4.75367457e-01 -1.51617639e-03  4.78382438e-01
 -1.07432950e+00 -1.10027768e-01  1.01997055e-01 -6.40636742e-01
 -3.84948522e-01 -1.79203898e-01  2.02490919e-04  9.92396951e-01
  9.80666950e-02 -1.79248899e-01 -2.01024994e-01 -1.04747105e+00
 -1.81850165e-01 -1.59880206e-01 -7.11462200e-01  4.23706174e-01
  5.38996398e-01  3.63422751e-01  4.64513123e-01  3.26554865e-01
 -3.49301249e-01 -6.68285728e-01  1.50979722e+00 -7.32147843e-02
  9.85800754e-04 -2.60735989e-01 -3.84469897e-01  3.37077171e-01
  8.64960849e-01  6.10059023e-01  1.07984984e+00  2.57375509e-01
  9.72188890e-01  1.05649281e+00  1.32399893e+00  2.93195367e-01
 -5.40404499e-01  5.35042524e-01  1.11359501e+00  9.66471851e-01
  8.72436166e-01 -1.49107361e+00  1.22011757e+00 -5.

In [9]:
np.save('best_adv_k', best_adv_k.numpy())
np.save('k', k.numpy())
np.save('loss_t', loss_t.numpy())
np.save('best_loss_t', best_loss_t.numpy())

In [10]:
# Load the best values

np.load('best_adv_k.npy')
np.load('k.npy')
print(np.load('loss_t.npy'))
print(np.load('best_loss_t.npy'))

0.1850163
0.1850163
