In [1]:
%load_ext autoreload

%autoreload 2
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow.keras.callbacks as cb
import time
import numpy as np
from data import get_datasets
import matplotlib.pyplot as plt
import os
import shutil
from functools import partial
from data import get_datasets
from experiment import create_model
from sentiment import sentiment
import model_adv

In [2]:
(train_dataset, test_dataset), info = get_datasets()
encoder = info.features['text'].encoder
vocab_size=info.features['text'].encoder.vocab_size 

In [3]:
transformer, _optimizer, _checkpoint, _manager = create_model(models_dir="models/pos_enc_True", load_checkpoint=True, 
                           vocab_size=vocab_size, use_positional_encoding=True, run_eagerly=True)

# adv = model_adv.create_model(models_dir="adv", load_checkpoint=False, run_eagerly=True, d_model=512)

Loaded previously trained model.


In [6]:
EPOCHS = 100
D_MODEL = 128
def find_adv_k(x, y, transformer):
    """
    x - (batch, text)
    y - (batch, label)
    
    Passes x through transformer and receives y_logits, w, k.
    Creates an adversarial model - adv.
    Pass k through adv model and receives adv_k.
    Computes loss by passing (x, custom_k=adv_k) through transformer:
        Receives adv_y_logits, adv_w, adv_k (same k)
        loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=adv_y_logits, from_logits=True)
        loss_k = - |k - adv_k|
        loss = loss_t + loss_k
    
    Returns: 
        k - original k
        w - original attention weights
        adv_k - adversarial k
        adv_w - adversarial attention weights w
        loss_k - the one that is maximized
        loss_t - transformer loss (the one that is minimized)
        loss - adversarial loss, e.g.  loss = loss_t = loss_k
    """
    y_logits, w, k = transformer(x, training=False)
    adv, optimizer = model_adv.create_model(models_dir="adv", load_checkpoint=False, 
                                            run_eagerly=True, d_model=D_MODEL)
    
    loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=y_logits, from_logits=True)
    print(f'[{0}] Loss_t = {loss_t:<20.10}')

    for epoch in range(1, EPOCHS+1):        
        with tf.GradientTape() as tape:
            adv_k = adv(k, training=True)
            
            adv_y_logits, adv_w, adv_k2  = transformer(x, custom_k=adv_k, training=False)
            cond = tf.reduce_all(tf.equal(adv_k, adv_k2)).numpy()
            
            assert cond == True
            loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=adv_y_logits, from_logits=True)
            loss_k = - tf.math.reduce_sum(tf.math.abs(k - adv_k))
            loss = loss_t + loss_k
        
        print(f'[{epoch:<2}] Loss_t = {loss_t:<20.10}  Loss_k = {loss_k:<20.10}   Loss = {loss:<20.10}')

        grads = tape.gradient(loss, adv.trainable_weights)                
        optimizer.apply_gradients(zip(grads, adv.trainable_weights))
  

In [7]:
for batch, (x, y) in enumerate(train_dataset):
    find_adv_k(x, y, transformer)

[0] Loss_t = 0.2084553838        
[1 ] Loss_t = 0.5221247673          Loss_k = -1806262.5             Loss = -1806262.0          
[2 ] Loss_t = 0.5203827024          Loss_k = -2095850.625           Loss = -2095850.125        
[3 ] Loss_t = 0.5209338665          Loss_k = -2410787.0             Loss = -2410786.5          
[4 ] Loss_t = 0.5219099522          Loss_k = -2755802.5             Loss = -2755802.0          
[5 ] Loss_t = 0.5249029994          Loss_k = -3135517.25            Loss = -3135516.75         
[6 ] Loss_t = 0.527751863           Loss_k = -3554072.75            Loss = -3554072.25         
[7 ] Loss_t = 0.5303804278          Loss_k = -4016218.0             Loss = -4016217.5          
[8 ] Loss_t = 0.5345689058          Loss_k = -4525277.5             Loss = -4525277.0          
[9 ] Loss_t = 0.538839817           Loss_k = -5084765.5             Loss = -5084765.0          
[10] Loss_t = 0.5418853164          Loss_k = -5697141.0             Loss = -5697140.5          
[11] L

KeyboardInterrupt: 

In [29]:
w = 0.628025472164154
w2 = 0.1232

In [32]:
print(f'{w:<20.10} {w2:<20.10}')
print(f'{w2:<20.10} {w:<20.10}')

0.6280254722         0.1232              
0.1232               0.6280254722        
