In [2]:
%load_ext autoreload

%autoreload 2
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow.keras.callbacks as cb
import time
import numpy as np
from data import get_datasets
import matplotlib.pyplot as plt
import os
import shutil
from functools import partial
from data import get_datasets
from experiment import create_model
from sentiment import sentiment
import model_adv

In [3]:
(train_dataset, test_dataset), info = get_datasets()
encoder = info.features['text'].encoder
vocab_size=info.features['text'].encoder.vocab_size 

In [4]:
transformer, _optimizer, _checkpoint, _manager = create_model(models_dir="models/pos_enc_True", load_checkpoint=True, 
                           vocab_size=vocab_size, use_positional_encoding=True, run_eagerly=True)

# adv = model_adv.create_model(models_dir="adv", load_checkpoint=False, run_eagerly=True, d_model=512)

Loaded previously trained model.


In [4]:
def kld(a1, a2) :
    #(B, *, A), #(B, *, A)
    a1 = tf.clip_by_value(a1, 0, 1)
    a2 = tf.clip_by_value(a2, 0, 1)
    log_a1 = tf.math.log(a1 + 1e-10)
    log_a2 = tf.math.log(a2 + 1e-10)

    kld = a1 * (log_a1 - log_a2)
    kld = tf.math.reduce_sum(kld)

    return kld

def jsd(p, q) :
    m = 0.5 * (p + q)
    jsd = 0.5 * (kld(p, m) + kld(q, m))
    print(jsd)
    
    return tf.expand_dims(jsd,-1)

In [38]:
EPOCHS = 100
D_MODEL = 128
adv, optimizer = model_adv.create_model(models_dir="adv", load_checkpoint=False, 
                                            run_eagerly=True, d_model=D_MODEL)
def find_adv_k(x, y, transformer):
    """
    x - (batch, text)
    y - (batch, label)
    
    Passes x through transformer and receives y_logits, w, k.
    Creates an adversarial model - adv.
    Pass k through adv model and receives adv_k.
    Computes loss by passing (x, custom_k=adv_k) through transformer:
        Receives adv_y_logits, adv_w, adv_k (same k)
        loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=adv_y_logits, from_logits=True)
        loss_k = - |k - adv_k|
        loss = loss_t + loss_k
    
    Returns: 
        k - original k
        w - original attention weights
        adv_k - adversarial k
        adv_w - adversarial attention weights w
        loss_k - the one that is maximized
        loss_t - transformer loss (the one that is minimized)
        loss - adversarial loss, e.g.  loss = loss_t = loss_k
    """
    y_logits, w, k = transformer(x, training=False)
    k = tf.random.uniform(tf.shape(k))*10
    
    
    loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=y_logits, from_logits=True)
    print(f'[{0}] Loss_t = {loss_t:<20.10}')

    for epoch in range(1, 500):        
        with tf.GradientTape() as tape:
            
            adv_k = adv(k, training=True)
            
            adv_y_logits, adv_w, adv_k2  = transformer(x, custom_k=adv_k, training=False)
            cond = tf.reduce_all(tf.equal(adv_k, adv_k2)).numpy()
            
            assert cond == True
            loss_t = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=adv_y_logits, from_logits=True)
            
            loss_k = tf.math.reduce_sum(tf.math.abs(k - adv_k))

#             loss_k = jsd(k,adv_k)

            loss = (loss_t) - tf.math.log(loss_k)
        
        print(f'[{epoch:<2}] Loss_t = {loss_t:<20.10}  Loss_k = {loss_k:<20.10}   Loss = {loss:<20.10}')

        grads = tape.gradient(loss, adv.trainable_weights)                
        optimizer.apply_gradients(zip(grads, adv.trainable_weights))
    return adv_k
  

In [14]:
for batch, (x, y) in enumerate(train_dataset):
    find_adv_k(x, y, transformer)

[0] Loss_t = 0.3737921715        
[1 ] Loss_t = 0.447034508           Loss_k = 75698112.0             Loss = -17.69522858        
[2 ] Loss_t = 0.3983289003          Loss_k = 77519544.0             Loss = -17.76771164        
[3 ] Loss_t = 0.3674170375          Loss_k = 79313352.0             Loss = -17.82150078        
[4 ] Loss_t = 0.3464756012          Loss_k = 81107984.0             Loss = -17.86481667        
[5 ] Loss_t = 0.3062952161          Loss_k = 82907392.0             Loss = -17.92693901        
[6 ] Loss_t = 0.2654330432          Loss_k = 84716992.0             Loss = -17.98939514        
[7 ] Loss_t = 0.227781266           Loss_k = 86535872.0             Loss = -18.04828835        
[8 ] Loss_t = 0.2096125185          Loss_k = 88360816.0             Loss = -18.08732796        
[9 ] Loss_t = 0.1970053017          Loss_k = 90178688.0             Loss = -18.12029839        
[10] Loss_t = 0.1664873362          Loss_k = 91977056.0             Loss = -18.17056274        
[11] L

KeyboardInterrupt: 

In [None]:
from experiment import evaluate


results = evaluate(transformer,test_dataset,adv)

here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here


In [20]:
results.result().numpy()

0.8160248

In [39]:
# sentence = "Once again has out a movie for far longer than . No magic , it was all I could do to keep from turning it off an hour in ."
sentence = "This is a very terrible movie . I will never watch this bad movie again ."
# sentence = "this was a perfect movie . I enjoyed it a lot"
# sentence = "this was a perfect movie . I enjoyed it a lot"
input_sent = encode_sentence(sentence, encoder)


adv_k = find_adv_k(input_sent, [0], transformer)





[0] Loss_t = 0.4253558218        
[1 ] Loss_t = 0.5418224335          Loss_k = 46607.38281            Loss = -10.20769215        
[2 ] Loss_t = 0.3511209488          Loss_k = 48955.91406            Loss = -10.44755459        
[3 ] Loss_t = 0.2920276821          Loss_k = 52272.45703            Loss = -10.57219696        
[4 ] Loss_t = 0.2646303475          Loss_k = 56341.35938            Loss = -10.67455387        
[5 ] Loss_t = 0.2502566874          Loss_k = 61138.85938            Loss = -10.7706461         
[6 ] Loss_t = 0.2413300425          Loss_k = 66670.57031            Loss = -10.866189          
[7 ] Loss_t = 0.2354661971          Loss_k = 72968.02344            Loss = -10.96231079        
[8 ] Loss_t = 0.2310985178          Loss_k = 80082.42969            Loss = -11.05971336        
[9 ] Loss_t = 0.2275890112          Loss_k = 88080.00781            Loss = -11.15841198        
[10] Loss_t = 0.2247388661          Loss_k = 97001.64062            Loss = -11.25774384        
[11] L

[86] Loss_t = 0.1947472841          Loss_k = 2772090.5              Loss = -14.6403656         
[87] Loss_t = 0.1946867853          Loss_k = 2821671.5              Loss = -14.65815353        
[88] Loss_t = 0.1946270019          Loss_k = 2871532.5              Loss = -14.67572975        
[89] Loss_t = 0.1945672929          Loss_k = 2921663.25             Loss = -14.69309616        
[90] Loss_t = 0.1945032626          Loss_k = 2972054.0              Loss = -14.71026134        
[91] Loss_t = 0.1944360286          Loss_k = 3022700.75             Loss = -14.7272253         
[92] Loss_t = 0.1943701506          Loss_k = 3073612.5              Loss = -14.74399376        
[93] Loss_t = 0.1943125427          Loss_k = 3124804.5              Loss = -14.76057053        
[94] Loss_t = 0.1942647099          Loss_k = 3176285.0              Loss = -14.77695847        
[95] Loss_t = 0.1942239553          Loss_k = 3228061.25             Loss = -14.79316807        
[96] Loss_t = 0.1941892803          Loss

[172] Loss_t = 0.19215554            Loss_k = 7975559.0              Loss = -15.6997366         
[173] Loss_t = 0.1921373308          Loss_k = 8045862.5              Loss = -15.70853043        
[174] Loss_t = 0.1921211481          Loss_k = 8116374.0              Loss = -15.71727276        
[175] Loss_t = 0.1921062768          Loss_k = 8187094.0              Loss = -15.72596359        
[176] Loss_t = 0.1920926124          Loss_k = 8258012.5              Loss = -15.73460197        
[177] Loss_t = 0.1920792311          Loss_k = 8329122.5              Loss = -15.74318886        
[178] Loss_t = 0.1920663267          Loss_k = 8400435.0              Loss = -15.75172806        
[179] Loss_t = 0.1920549423          Loss_k = 8471948.0              Loss = -15.76021671        
[180] Loss_t = 0.1920464039          Loss_k = 8543673.0              Loss = -15.76865578        
[181] Loss_t = 0.1920387894          Loss_k = 8615595.0              Loss = -15.7770462         
[182] Loss_t = 0.1920329034   

[259] Loss_t = 0.1912311614          Loss_k = 14786009.0             Loss = -16.31796074        
[260] Loss_t = 0.1912177652          Loss_k = 14872024.0             Loss = -16.32377434        
[261] Loss_t = 0.1912064701          Loss_k = 14958210.0             Loss = -16.32956505        
[262] Loss_t = 0.1911951005          Loss_k = 15044499.0             Loss = -16.33532906        
[263] Loss_t = 0.1912409216          Loss_k = 15131226.0             Loss = -16.34103012        
[264] Loss_t = 0.1912326962          Loss_k = 15217848.0             Loss = -16.34674644        
[265] Loss_t = 0.1911713779          Loss_k = 15304428.0             Loss = -16.35248184        
[266] Loss_t = 0.1915786862          Loss_k = 15391050.0             Loss = -16.35771942        
[267] Loss_t = 0.191588521           Loss_k = 15478946.0             Loss = -16.36340141        
[268] Loss_t = 0.1921682507          Loss_k = 15566088.0             Loss = -16.36843681        
[269] Loss_t = 0.1913272142   

[345] Loss_t = 0.1906948239          Loss_k = 22705214.0             Loss = -16.74740982        
[346] Loss_t = 0.1907128245          Loss_k = 22803720.0             Loss = -16.75172234        
[347] Loss_t = 0.1906701028          Loss_k = 22902636.0             Loss = -16.75609207        
[348] Loss_t = 0.1907052547          Loss_k = 23001862.0             Loss = -16.7603817         
[349] Loss_t = 0.1907369196          Loss_k = 23101276.0             Loss = -16.76466179        
[350] Loss_t = 0.1907346398          Loss_k = 23200748.0             Loss = -16.76896095        
[351] Loss_t = 0.1906926632          Loss_k = 23300260.0             Loss = -16.773283          
[352] Loss_t = 0.1906369478          Loss_k = 23399866.0             Loss = -16.77760315        
[353] Loss_t = 0.1907259077          Loss_k = 23499636.0             Loss = -16.78177071        
[354] Loss_t = 0.1906624138          Loss_k = 23599660.0             Loss = -16.78608131        
[355] Loss_t = 0.1906809956   

[432] Loss_t = 0.1903701276          Loss_k = 31846956.0             Loss = -17.08608246        
[433] Loss_t = 0.1903612465          Loss_k = 31958952.0             Loss = -17.08960152        
[434] Loss_t = 0.1903528124          Loss_k = 32071140.0             Loss = -17.09311295        
[435] Loss_t = 0.1903429478          Loss_k = 32183496.0             Loss = -17.09662056        
[436] Loss_t = 0.1903341711          Loss_k = 32295946.0             Loss = -17.10011864        
[437] Loss_t = 0.1903371215          Loss_k = 32408552.0             Loss = -17.10359573        
[438] Loss_t = 0.1903256178          Loss_k = 32521360.0             Loss = -17.10708237        
[439] Loss_t = 0.1903253794          Loss_k = 32634220.0             Loss = -17.11054802        
[440] Loss_t = 0.1903202981          Loss_k = 32747112.0             Loss = -17.11400414        
[441] Loss_t = 0.1903109103          Loss_k = 32860118.0             Loss = -17.11745834        
[442] Loss_t = 0.1903006285   

In [10]:
def sentiment(inp_sentence, encoder, transformer,adv_k = None, adv = None):
    # inp sentence is the review  

    
    inp_sentence = encoder.encode(inp_sentence) 

    

    encoder_input = tf.expand_dims(inp_sentence, 0)
    if adv_k is not None:
        predictions,weights,_ = transformer(encoder_input,custom_k=adv_k ,training=False)
    elif adv is not None:
        y_logits, w, k = transformer(encoder_input, training=False)
        adv_k = adv(k, training=False)
        predictions, weights, adv_k = transformer(encoder_input, custom_k=adv_k ,training=False)
            
            
    else:
        predictions,weights,_ = transformer(encoder_input ,training=False)

    
    
    sent = tf.squeeze(predictions, axis=0)
    if sent >= 0.5:
        sent = 'pos'
    else:
        sent = 'neg'
    return sent,weights

In [7]:
def encode_sentence(inp_sentence, encoder):
        
    inp_sentence = encoder.encode(inp_sentence) 
    encoder_input = tf.expand_dims(inp_sentence, 0)
    
    return encoder_input
    

In [13]:
from bertviz import head_view
import torch

In [9]:

%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>

In [40]:
# sentence = "This is a very terrible movie . I will never watch this bad movie again ."
# sentence = "this was a perfect movie . I enjoyed it a lot"
tokens = sentence.split()

prediction,weights = sentiment(sentence, encoder=encoder, adv_k = adv_k, transformer=transformer)
# prediction,weights = sentiment(sentence, encoder=encoder, transformer=old_transformer)

In [41]:
head_view(torch.tensor(weights.numpy()[0]),tokens)
print("prediction after permutation is %str"%prediction)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

prediction after permutation is negtr


In [32]:
head_view(torch.tensor(weights.numpy()[0]),tokens)
print("prediction after permutation is %str"%prediction)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

prediction after permutation is negtr


In [None]:
print(f'{w:<20.10} {w2:<20.10}')
print(f'{w2:<20.10} {w:<20.10}')