In [None]:
import bert
from bert import run_classifier
from bert import optimization
from bert import modeling
import tensorflow as tf
import tensorflow_hub as hub
from datetime import datetime
from sklearn import metrics
logger = tf.get_logger()
logger.propagate = False
import numpy as np
from collections import defaultdict

In [None]:
bert_model_hub = "https://tfhub.dev/google/small_bert/bert_uncased_L-4_H-512_A-8/1"
model_output_dir = "finetuned_weights/bert_small"
tf.gfile.MakeDirs(model_output_dir)
num_labels = 3 
hidden_size = 512
is_predicting = True
input_fn = bert.run_classifier.file_based_input_fn_builder("test", 128, is_training=False, drop_remainder=False)
config = modeling.BertConfig(vocab_size=30522, hidden_size=512, num_hidden_layers=4,
                             num_attention_heads=8, intermediate_size=2048, type_vocab_size=2)


In [None]:
def forward(features):
    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    label_ids = features["label_ids"]

    with tf.variable_scope("module"):
        model = modeling.BertModel(config=config,
                                   is_training=True,
                                   input_ids=input_ids,
                                   input_mask=input_mask,
                                   token_type_ids=segment_ids,
                                   use_one_hot_embeddings=False)

    # Use "pooled_output" for classification tasks on an entire sentence.
    output_layer = model.get_pooled_output()

    with tf.variable_scope("", reuse=tf.AUTO_REUSE):
        A = tf.get_variable("output_weights", [hidden_size, num_labels], initializer=tf.truncated_normal_initializer(stddev=0.02))
        bias = tf.get_variable("output_bias", [num_labels], initializer=tf.zeros_initializer())


    output_layer = tf.keras.layers.Dropout(rate=0.1)(output_layer, training= not is_predicting)
    logits = tf.nn.xw_plus_b(output_layer, A, bias)

    predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
    probs = tf.nn.softmax(logits,axis=-1)

    xs = tf.get_default_graph().get_tensor_by_name("module/bert/embeddings/embedding_lookup/Identity:0")
    grad_m = tf.gradients(probs[:, 0], xs=[xs])[0]
    grad_m = tf.reduce_sum(grad_m, axis=[0, 2, 3])
    grad_f = tf.gradients(probs[:, 1], xs=[xs])[0]
    grad_f = tf.reduce_sum(grad_f, axis=[0, 2, 3])
    grad_n = tf.gradients(probs[:, 2], xs=[xs])[0]
    grad_n = tf.reduce_sum(grad_n, axis=[0, 2, 3])


    return {"logits": logits, "predictions" : predictions, "probs" : probs, "input_ids": input_ids,
            "grad_m" : grad_m, "grad_f" : grad_f, "grad_n": grad_n, "labels": label_ids}

dataset = input_fn({"batch_size": 1, "drop_remainder": False})
iterator = dataset.make_one_shot_iterator()
model_output = forward(iterator.get_next())

In [None]:
initializer = tf.global_variables_initializer()
saver = tf.train.Saver(filename='test_checkpoint')
tokenizer = bert.tokenization.FullTokenizer(vocab_file="vocab.txt", do_lower_case=True)
    
results = []
with tf.Session() as sess:
    sess.run(initializer)
    saver.restore(sess, 'finetuned_weights/bert_small/model.ckpt-312')
    while True:
        try:
            result = sess.run(model_output)        
            results.append(result)
        except:
            break
        

In [None]:
with open("results_raw.txt", "w+") as f:
    for i, result in enumerate(results):
        for k, v in result.items():
            result[k] = v.flatten()
        label = result["labels"]
        prediciton = result["predictions"]

        print("Example: ", i, "(correct)" if label == prediciton else "(wrong)", "\n", file=f)
        words = tokenizer.convert_ids_to_tokens(result["input_ids"])
        text = " ".join(words)

        print("Label: ", label, " prediction: ", prediciton, " probs: ", result["probs"], file=f)
        print(file=f)
        print(text, file=f)
        print(file=f)

        # Extract which words have the highest gradients
        for name in ["masculine", "feminine", "neutral"]:
            word_grads = defaultdict(int)
            position_grads = result["grad_"+name[0]]  # grad_m, grad_f, grad_n
            for w, g in zip(words, position_grads):
                word_grads[w] += g

            highest_influence_on_output = sorted(word_grads.keys(), key=lambda k: -word_grads[k]**2)
            top5_words = highest_influence_on_output[:5]

            print("Words with high {} gradient (word, grad): ".format(name), file=f)
            for w in top5_words:
                  print(w.ljust(15, "_"), "{:6.4e}".format(word_grads[w]), file=f)
            print("   if negative: word is making the example less {}".format(name), file=f)
            print("   if positive: word is making the example more {}".format(name), file=f)
            print(file=f)

        print("__"*20, file=f)

    