In [None]:
from bertviz import head_view
import torch
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt

import train_top
import training_pipeline.training_pipeline as training_pipeline
import bridge.tokens
from mubert import mubert, modeling, toptokens
from jargon.bert import tokenization as jargon_tokenization

tf.compat.v1.disable_eager_execution()

In [None]:
config_file = "bridgebot_config_small.json"
checkpoint = ""
bert_vocab_file = "gs://njt-serene-epsilon/jargon/uncased_L-8_H-512_A-8/vocab.txt"
data_file = "gs://njt-serene-epsilon/top-dataset-lin-bids/top_builder/bridgebot-top/0.0.6/top_builder-test.tfrecord-00000-of-00001"
init_checkpoint = "gs://njt-serene-epsilon/models/top.small.0/lin-bids/model.ckpt-362000"

In [None]:
flags = tf.compat.v1.flags

FLAGS = flags.FLAGS

flags.DEFINE_integer(
    "max_player_seq_length", 256,
    "Sequences longer than this will be truncated, and sequences shorter "
    "than this will be padded. Must match data generation.")

flags.DEFINE_integer(
    "max_jargon_seq_length", 256,
    "Sequences longer than this will be truncated, and sequences shorter "
    "than this will be padded. Must match data generation.")

flags.DEFINE_integer(
    "max_top_seq_length", 16,
    "Must match data generation.")

flags.DEFINE_string("jargon_vocab_file", bert_vocab_file,
    "The vocabulary file that the jargon/BERT model was trained on.")

flags.DEFINE_bool("jargon_do_lower_case", True, "True if model is uncased.")

flags.DEFINE_bool("use_player_sequence_out", False, "Use only embedding if False.")

flags.DEFINE_bool("use_jargon_sequence_out", False, "Use only embedding if False")

flags.DEFINE_bool("freeze_player", True, "Stop player gradient if True")

flags.DEFINE_bool("freeze_jargon", True, "Stop jargon gradient if True")

flags.DEFINE_string("f", '', "hack") # seem required for flags to work

In [None]:
data_config = training_pipeline.TopConfig(name='attention_data',
        max_player_seq_length=FLAGS.max_player_seq_length,
        max_jargon_seq_length=FLAGS.max_jargon_seq_length,
        max_top_seq_length=FLAGS.max_top_seq_length)

def _ex_feature_info(fi):
    if fi.dtype == tf.int32:
        return tf.io.FixedLenFeature(fi.shape, tf.int64)
    else:
        return tf.io.FixedLenFeature(fi.shape, fi.dtype)

features_info = {k: _ex_feature_info(v) for k,v in
                    training_pipeline._top_features_info(data_config).items()}

In [None]:
def get_examples(data_config):
    # Hack to make iterating possible without eager execution
    dataset = tfds.as_numpy(tf.data.TFRecordDataset(data_file))
    for rec in dataset:
        with tf.compat.v1.Session() as sess:
            example = sess.run(tf.io.parse_single_example(rec, features_info))
        yield example
            
def model_builder(bridgebot_config, init_checkpoint):
    def model_fn(features, params):        
        predictions, top_embedding_table, top_attention_probs =\
            train_top.features_to_predictions(features,
                                    params,
                                    bridgebot_config,
                                    use_one_hot_embeddings=False,
                                    is_training=False,
                                    compute_attention_probs=True)
        
        target_ids = features["target_ids"]
        target_positions = features["target_positions"]
        target_weights = features["target_weights"]

        with tf.compat.v1.variable_scope("top"):
            top_loss, top_example_loss, top_log_probs = \
                mubert.get_masked_output(
                    bridgebot_config.top, predictions, 
                    top_embedding_table,
                     target_positions, target_ids, target_weights)
        
        train_top.init_from_checkpoint(init_checkpoint, False)
        
        return top_log_probs, top_attention_probs
    return model_fn
    
def get_initialized_model(bridgebot_config):
    with tf.name_scope("attention") as scope:
        model = model_builder(bridgebot_config, init_checkpoint)
    return model
        
def run_example(bridgebot_config, features_gen):

    def _to32(dtype):
        if dtype == tf.int64:
            return tf.int32
        return dtype

    first_run = True
    with tf.compat.v1.Session() as sess:
        for example in features_gen:
            features = {k: tf.constant(v, shape=(1, len(v)), 
                        dtype=_to32(features_info[k].dtype))
                        for k, v in example.items()}
            
            model = get_initialized_model(bridgebot_config)
            params = {"batch_size" : 1}
            top_log_probs, top_attention_probs = model(features, params)
            if first_run:
                sess.run(tf.compat.v1.global_variables_initializer())
                first_run = False
            yield sess.run((top_attention_probs, features))

def features_to_view(bridgebot_config, features):
    player_tokenizer = bridge.tokens.Tokenizer()
    jargon_tokenizer = jargon_tokenization.FullTokenizer(
        vocab_file=FLAGS.jargon_vocab_file,
        do_lower_case=FLAGS.jargon_do_lower_case)
    top_tokenizer = toptokens.Tokenizer()
    
    # batch size = 1
    player_input_ids = features["player_input_ids"][0]
    jargon_input_ids = features["jargon_input_ids"][0]
    query_ids = features["query_ids"][0]
    
    hidden_pad = ["[HIDDEN]"] *\
        bridgebot_config.mubert.representation.hidden_state_length
    player_tokens = player_tokenizer.ids_to_tokens(player_input_ids)
    jargon_tokens = jargon_tokenizer.convert_ids_to_tokens(jargon_input_ids)
    query_tokens = top_tokenizer.ids_to_tokens(query_ids)
    
    return hidden_pad + player_tokens + jargon_tokens + query_tokens


def show_head_view(bridgebot_config, features, attention):
    attention = [torch.from_numpy(layer) for layer in attention]
    tokens = features_to_view(bridgebot_config, features)   
    head_view(attention, tokens)

    
def tokens_to_squeeze(tokens, accepted=('[HIDDEN]', '[PAD]')):
    prev_t,  start_idx = None, None
    for i, t in enumerate(tokens + [None]):
        emit = True
        if t is not None:
            if t.startswith('##'):
                emit = False
            elif t in accepted:
                if t == prev_t:
                    emit = False
        if emit:
            if start_idx is not None:
                if i == start_idx + 1:
                    repl = prev_t
                elif prev_t.startswith('##'):
                    repl = "".join([tokens[start_idx]] + \
                         [t[2:] for t in tokens[start_idx + 1:i]])
                else:
                    repl = "{} * {}".format(i - start_idx, tokens[start_idx])
                yield slice(start_idx, i), repl
            start_idx = i
        prev_t = t
        
def squeezed_tokens_and_attention(tokens, attention):
    squeeze_info = list(tokens_to_squeeze(tokens))
    squeezed_tokens = [t for _, t in squeeze_info]
    atten_shape = attention[0].shape
    target_size = len(squeeze_info)
    squeezed_attention = []
    for layer in attention:
        new_shape = list(atten_shape)
        new_shape[2] = target_size
        row_squeeze = np.zeros(new_shape)
        for i, (s, _) in enumerate(squeeze_info):
            row_squeeze[:,:,i,:] = np.sum(layer[:,:,s,:], axis=2)
        new_shape[3] = target_size
        col_squeeze = np.zeros(new_shape)
        for i, (s, _) in enumerate(squeeze_info):
            col_squeeze[:,:,:,i] = np.mean(row_squeeze[:,:,:,s], axis=3)
        squeezed_attention.append(col_squeeze)
    return squeezed_tokens, squeezed_attention

In [None]:
bridgebot_config = modeling.BridgebotConfig.from_json_file(config_file)
features_gen = get_examples(data_config)

In [None]:
attention_gen = run_example(bridgebot_config, features_gen)

In [None]:
atten, ex = next(attention_gen)

In [None]:
all_tokens = features_to_view(bridgebot_config, ex)
sq_tokens, sq_atten = squeezed_tokens_and_attention(all_tokens, atten)
all_torch_atten = [torch.from_numpy(10*layer) for layer in sq_atten]

In [None]:
head_view(all_torch_atten, sq_tokens, sentence_b_start=sq_tokens.index('[CLS]'))

 all_tokens[177:181] + all_tokens[135:139]

In [None]:
sel = np.zeros(len(all_tokens), dtype=bool)
sel[177:181] = 1
sel[390:394] = 1
sel[414:417] = 1
sel[640:644] = 1
sel[384:388] = 1
sel[403:407] = 1
np.array(all_tokens)[sel]

In [None]:
all_tokens[384:388]

In [None]:
all_tokens[403:407]

In [None]:
def select_atten(idxs, tokens, atten):
    t = np.array(atten)[:,:,:,idxs,:]
    return np.array(tokens)[idxs], t[:,:,:,:, idxs]

In [None]:
sel_tokens, sel_atten = select_atten(sel, all_tokens, atten)

In [None]:
sel_atten_torch = torch.from_numpy(sel_atten)
head_view(sel_atten_torch, sel_tokens)

In [None]:
plt.imshow(sel_atten[:,:,:,:,:].sum(axis=(0,1,2)))

In [None]:
plt