In [1]:
%load_ext autoreload

%autoreload 2
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow.keras.callbacks as cb
import time
import numpy as np
from data import get_datasets
import matplotlib.pyplot as plt
import os
import shutil
from functools import partial
from data import get_datasets
from experiment import create_model
from sentiment import sentiment
import model_adv

In [2]:
(train_dataset, test_dataset), info = get_datasets()
encoder = info.features['text'].encoder
vocab_size=info.features['text'].encoder.vocab_size 

In [3]:
transformer, _optimizer, _checkpoint, _manager = create_model(models_dir="checkpoints", load_checkpoint=True, 
                           vocab_size=vocab_size, use_positional_encoding=True, run_eagerly=True)

# adv = model_adv.create_model(models_dir="adv", load_checkpoint=False, run_eagerly=True, d_model=512)

Loaded previously trained model.


In [4]:
def kld(a1, a2) :
    #(B, *, A), #(B, *, A)
    a1 = tf.clip_by_value(a1, 0, 1)
    a2 = tf.clip_by_value(a2, 0, 1)
    log_a1 = tf.math.log(a1 + 1e-10)
    log_a2 = tf.math.log(a2 + 1e-10)

    kld = a1 * (log_a1 - log_a2)
    kld = tf.math.reduce_sum(kld)

    return kld

def jsd(p, q) :
    m = 0.5 * (p + q)
    jsd = 0.5 * (kld(p, m) + kld(q, m))
    print(jsd)
    
    return tf.expand_dims(jsd,-1)

In [17]:
EPOCHS = 100
D_MODEL = 128
def find_adv_k(x, y, transformer):
    """
    x - (batch, text)
    y - (batch, label)
    
    Passes x through transformer and receives y_logits, w, k.
    Creates an adversarial model - adv.
    Pass k through adv model and receives adv_k.
    Computes loss by passing (x, custom_k=adv_k) through transformer:
        Receives adv_y_logits, adv_w, adv_k (same k)
        loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=adv_y_logits, from_logits=True)
        loss_k = - |k - adv_k|
        loss = loss_t + loss_k
    
    Returns: 
        k - original k
        w - original attention weights
        adv_k - adversarial k
        adv_w - adversarial attention weights w
        loss_k - the one that is maximized
        loss_t - transformer loss (the one that is minimized)
        loss - adversarial loss, e.g.  loss = loss_t = loss_k
    """
    y_logits, w, k = transformer(x, training=False)
    adv, optimizer = model_adv.create_model(models_dir="adv", load_checkpoint=False, 
                                            run_eagerly=True, d_model=D_MODEL)
    
    loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=y_logits, from_logits=True)
    print(f'[{0}] Loss_t = {loss_t:<20.10}')

    for epoch in range(1, EPOCHS+1):        
        with tf.GradientTape() as tape:
            adv_k = adv(k, training=True)
            
            adv_y_logits, adv_w, adv_k2  = transformer(x, custom_k=adv_k, training=False)
            cond = tf.reduce_all(tf.equal(adv_k, adv_k2)).numpy()
            
            assert cond == True
            loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=adv_y_logits, from_logits=True)
            
            loss_k = tf.math.reduce_sum(tf.math.abs(k - adv_k))

#             loss_k = jsd(k,adv_k)

            loss = (loss_t) - tf.math.log(loss_k)
        
        print(f'[{epoch:<2}] Loss_t = {loss_t:<20.10}  Loss_k = {loss_k:<20.10}   Loss = {loss:<20.10}')

        grads = tape.gradient(loss, adv.trainable_weights)                
        optimizer.apply_gradients(zip(grads, adv.trainable_weights))
  

In [18]:
for batch, (x, y) in enumerate(train_dataset):
    find_adv_k(x, y, transformer)

[0] Loss_t = 44.30617523         
[1 ] Loss_t = 40.44004822           Loss_k = 1772366.25             Loss = 26.05222321         
[2 ] Loss_t = 33.23210526           Loss_k = 1786517.375            Loss = 18.8363266          
[3 ] Loss_t = 26.02960396           Loss_k = 1803062.375            Loss = 11.62460709         
[4 ] Loss_t = 18.75469971           Loss_k = 1822144.625            Loss = 4.339175224         
[5 ] Loss_t = 11.55359459           Loss_k = 1843906.0              Loss = -2.873802185        
[6 ] Loss_t = 8.217758179           Loss_k = 1868928.125            Loss = -6.223117828        
[7 ] Loss_t = 10.31872177           Loss_k = 1898565.25             Loss = -4.137887001        
[8 ] Loss_t = 12.32525253           Loss_k = 1927975.875            Loss = -2.146728516        
[9 ] Loss_t = 12.90054893           Loss_k = 1954663.0              Loss = -1.585179329        
[10] Loss_t = 12.38732147           Loss_k = 1978487.75             Loss = -2.11052227         
[11] L

KeyboardInterrupt: 

In [None]:
w = 0.628025472164154
w2 = 0.1232

In [None]:
print(f'{w:<20.10} {w2:<20.10}')
print(f'{w2:<20.10} {w:<20.10}')