In [None]:
import os
import random
import numpy as np
import psutil
import torch
import pickle
import sys

from olfmlm.configure_data import configure_data
from olfmlm.model import BertModel
from olfmlm.optim import Adam
from olfmlm.utils import save_checkpoint
from olfmlm.utils import load_checkpoint
from argparse import Namespace

In [None]:
args=Namespace(alternating=False, always_mlm=True, attention_dropout=0.1, batch_size=16, bert_config_file='bert_config.json', cache_dir='cache_dir', checkpoint_activations=False, clip_grad=1.0, continual_learning=False, cuda=True, delim=',', distributed_backend='nccl', dynamic_loss_scale=True, epochs=32, eval_batch_size=None, eval_iters=2000, eval_max_preds_per_seq=None, eval_seq_length=None, eval_text_key=None, eval_tokens=1000000, fp32_embedding=False, fp32_layernorm=False, fp32_tokentypes=False, hidden_dropout=0.0, hidden_size=1024, incremental=False, intermediate_size=None, layernorm_epsilon=1e-12, lazy_loader=True, load=None, load_all_rng=False, load_optim=True, load_rng=True, local_rank=None, log_interval=1000000, loose_json=False, lr=0.0001, lr_decay_iters=None, lr_decay_style='linear', max_dataset_size=None, max_position_embeddings=512, max_preds_per_seq=80, model_type='rg+mlm', modes='mlm,rg', no_aux=False, num_attention_heads=16, num_layers=24, num_workers=22, presplit_sentences=True, pretrained_bert=False, rank=0, resume_dataloader=False, save='pretrained_berts/rg+mlm', save_all_rng=False, save_iters=None, save_optim=True, save_rng=True, seed=1234, seq_length=128, shuffle=True, split='1000,1,1', test_data=None, text_key='text', tokenizer_model_type='bert-base-uncased', tokenizer_path='tokenizer.model', tokenizer_type='BertWordPieceTokenizer', track_results=True, train_data=['bert_corpus'], train_iters=1000000, train_tokens=1000000000, use_tfrecords=False, valid_data=None, vocab_size=30522, warmup=0.01, weight_decay=0.02, world_size=1)

In [None]:
data_config = configure_data()
data_config.set_defaults(data_set_type='BERT', transpose=False)
(train_data, val_data, test_data), tokenizer = data_config.apply(args)
args.data_size = tokenizer.num_tokens

In [None]:
from colour import Color
white = Color("white")
red = Color("red")
color_list=list(white.range_to(red, 256))

In [None]:
model = 'max_token_hidden'
import pickle
with open('embeds/raw_val.pkl', 'rb') as handle:
    raw_text = pickle.load(handle)
    
with open('grads/{}/view_1'.format(model), 'rb') as handle:
    grads_1= pickle.load(handle)
    
with open('grads/{}/view_2'.format(model), 'rb') as handle:
    grads_2= pickle.load(handle)
    
with open('grads/{}/view_3'.format(model), 'rb') as handle:
    grads_3= pickle.load(handle)

In [None]:
def truncate_sequence(tokens):
    """
    Truncate sequence pair
    """
    max_num_tokens = val_data.dataset.max_seq_len-2-3
    while True:
        if len(tokens) <= max_num_tokens:
            break
        idx = 0 if random.random() < 0.5 else len(tokens) - 1
        tokens.pop(idx)

In [None]:
from IPython.display import Markdown, display
def get_color_map(grad):
    grad=grad/grad.sum()
    grad=grad*255
    res=[]
    color_index=0
    d={}
    for i in np.argsort(grad):
        color_index+=int(grad[i])
        d[i]=color_list[color_index]
    return d

def printmd(token, color_map, init_text=None):
    if init_text:
        res=[init_text]
    else:
        res=[]
    for index, token_id in enumerate(token):
        string=tokenizer.IdToToken(token_id)
        color=color_map[index]
        colorstr = "<span style='background-color:{}'>{}</span>".format(color, string) 
        res.append(colorstr)
    display(Markdown(" ".join(res)))

def grad_mapping(index):
    text=raw_text[index]
    
    token=tokenizer.EncodeAsIds(text)
    truncate_sequence(token)
    d_1=get_color_map(grads_1[index])
    d_2=get_color_map(grads_2[index])
    d_3=get_color_map(grads_3[index])
    print('Raw text: ', text)
    
    printmd(token, d_1, 'Facet 1: ')
    printmd(token, d_2, 'Facet 2: ')
    printmd(token, d_3, 'Facet 3: ')

In [None]:
choose_id = random.randint(0,len(raw_text)-1)
grad_mapping(choose_id)