In [38]:
import os
import warnings
import sys
import codecs
import numpy as np
import argparse
import json
import pickle

from util import read_passages, evaluate, make_folds, clean_words, test_f1, to_BIO, from_BIO, from_BIO_ind, arg2param

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 1.0
config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
sess = tf.Session(config=config)
import keras.backend as K
K.set_session(sess)
from keras.activations import softmax
from keras.regularizers import l2
from keras.models import Model, model_from_json
from keras.layers import Input, LSTM, Dense, Dropout, TimeDistributed, Bidirectional
from keras.callbacks import EarlyStopping,LearningRateScheduler, ModelCheckpoint
from keras.optimizers import Adam, RMSprop, SGD
from crf import CRF
from attention import TensorAttention
from custom_layers import HigherOrderTimeDistributedDense
from generator import BertDiscourseGenerator
from keras_bert import load_trained_model_from_checkpoint, Tokenizer

from discourse_tagger_generator_bert import PassageTagger

import matplotlib.pyplot as plt
from scipy.special import softmax
from matplotlib import transforms
from operator import itemgetter, attrgetter


In [2]:
use_attention = True
att_context = "LSTM_clause"
bidirectional = bid = True
crf = True
lstm = False
maxseqlen = 40
maxclauselen = 60
input_size = 768
embedding_dropout=0.4 
high_dense_dropout=0.4
attention_dropout=0.6
lstm_dropout=0.5
word_proj_dim=300 
hard_k=0 
lstm_dim = 350 
rec_hid_dim = 75 
att_proj_dim = 200 
batch_size = 10
reg=0


In [3]:
prefix="scidt_scibert/"
model_ext = "att=%s_cont=%s_lstm=%s_bi=%s_crf=%s"%(str(use_attention), att_context, str(lstm), str(bid), str(crf))
model_config_file = open(prefix+"model_%s_config.json"%model_ext, "r")
model_weights_file_name = prefix+"model_%s_weights"%model_ext
model_label_ind = prefix+"model_%s_label_ind.json"%model_ext
label_ind_json = json.load(open(model_label_ind))
label_ind = {k: int(label_ind_json[k]) for k in label_ind_json}
num_classes = len(label_ind)

In [4]:
if use_attention:
    inputs = Input(shape=(maxseqlen, maxclauselen, input_size))
    x = Dropout(embedding_dropout)(inputs)
    x = HigherOrderTimeDistributedDense(input_dim=input_size, output_dim=word_proj_dim, reg=reg)(x)
    att_input_shape = (maxseqlen, maxclauselen, word_proj_dim)
    x = Dropout(high_dense_dropout)(x)
    x, raw_attention = TensorAttention(att_input_shape, context=att_context, hard_k=hard_k, proj_dim = att_proj_dim, rec_hid_dim = rec_hid_dim, return_attention=True)(x)
    x = Dropout(attention_dropout)(x)
else:
    inputs = Input(shape=(maxseqlen, input_size))
    x = Dropout(embedding_dropout)(inputs)
    x = TimeDistributed(Dense(input_dim=input_size, units=word_proj_dim))

if bidirectional:
    x = Bidirectional(LSTM(input_shape=(maxseqlen,word_proj_dim), units=lstm_dim, 
                                  return_sequences=True,kernel_regularizer=l2(reg),
                                  recurrent_regularizer=l2(reg), 
                                  bias_regularizer=l2(reg)))(x)
    x = Dropout(lstm_dropout)(x) 
elif lstm:
    x = LSTM(input_shape=(maxseqlen,word_proj_dim), units=lstm_dim, return_sequences=True,
                    kernel_regularizer=l2(reg),
                    recurrent_regularizer=l2(reg), 
                    bias_regularizer=l2(reg))(x)
    x = Dropout(lstm_dropout)(x) 

if crf:
    Crf = CRF(num_classes,learn_mode="join")
    discourse_prediction = Crf(x)
    tagger = Model(inputs=inputs, outputs=[discourse_prediction])        
else:
    discourse_prediction = TimeDistributed(Dense(num_classes, activation='softmax'),name='discourse')(x)
    tagger = Model(inputs=inputs, outputs=[discourse_prediction])

In [5]:
tagger.load_weights(model_weights_file_name)


In [6]:
if crf:
    tagger.compile(optimizer=Adam(), loss=Crf.loss_function, metrics=[Crf.accuracy])
else:
    tagger.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])

In [7]:
inp = tagger.input
attention_output = tagger.layers[4].output[1]

In [8]:
functor = K.function([inp, K.learning_phase()], [attention_output])

In [9]:
test_file = "lucky_train.txt"

In [10]:
params = {
    "repfile":"/nas/home/xiangcil/scibert_scivocab_uncased",
    "use_attention": True,
    "batch_size": 10,
    "maxseqlen": 40,
    "maxclauselen": 60
         }

In [11]:
pretrained_path = params["repfile"]
config_path = os.path.join(pretrained_path, 'bert_config.json')
checkpoint_path = os.path.join(pretrained_path, 'bert_model.ckpt')
vocab_path = os.path.join(pretrained_path, 'vocab.txt')

bert = load_trained_model_from_checkpoint(config_path, checkpoint_path)
bert._make_predict_function() # Crucial step, otherwise TF will give error.

token_dict = {}
with codecs.open(vocab_path, 'r', 'utf8') as reader:
    for line in reader:
        token = line.strip()
        token_dict[token] = len(token_dict)
tokenizer = Tokenizer(token_dict)  

In [12]:
str_seqs, label_seqs = read_passages(test_file, is_labeled=True)
str_seqs = clean_words(str_seqs)
label_seqs = to_BIO(label_seqs)

In [13]:
bert_generator = BertDiscourseGenerator(bert, tokenizer, str_seqs, label_seqs, label_ind, 10, True, 40, 60, True, input_size=768)

In [14]:
test_X, test_Y = bert_generator.make_data(str_seqs, label_seqs)

In [15]:
attention_raw_scores = functor([test_X])[0]

In [16]:
attention_scores = softmax(attention_raw_scores,axis=-1)

In [17]:
reverse_token_dict = {v:k for k,v in token_dict.items()}

In [34]:
str_seqs_tokenized = []
for str_seq in str_seqs:
    str_seq_tokenized = []
    for clause in str_seq:
        clause_tokenized = []
        indices, segments = tokenizer.encode(clause.lower(), max_len=512)
        for i in range(60):
            clause_tokenized.append(reverse_token_dict[indices[i]])
        str_seq_tokenized.append(clause_tokenized)
    str_seqs_tokenized.append(str_seq_tokenized)
label_seqs = from_BIO(label_seqs)

In [35]:
original_label = list(set(label.split("_")[-1] for label in label_ind))


In [64]:
all_distributions = {label:{} for label in original_label}
for para, (str_seq_tokenized, label_seq) in enumerate(zip(str_seqs_tokenized, label_seqs)):
    for sent, (sentence, label) in enumerate(zip(str_seq_tokenized, label_seq)):
        for idx, token in enumerate(sentence):
            weight = attention_raw_scores[para,-len(str_seq_tokenized)+sent,idx]
            if token[0]!="[" and weight>=2.5:
                token_weights = all_distributions[label].get(token,[])
                token_weights.append(weight)
                all_distributions[label][token] = token_weights
                

In [65]:
all_mean_distribution = {}
for label, token_weights in all_distributions.items():
    this_label = []
    for token, weights in token_weights.items():
        count = len(weights)
        this_label.append((token,count))
    sorted_list = sorted(this_label, key=itemgetter(1), reverse=True)
    all_mean_distribution[label] = sorted_list[:5]

In [66]:
all_mean_distribution

{'method': [('we', 65), ('to', 49), ('for', 33), ('was', 25), ('were', 24)],
 'fact': [('to', 10), ('not', 5), ('that', 4), ('well', 3), ('evidence', 3)],
 'hypothesis': [('be', 32),
  ('to', 23),
  ('could', 23),
  ('might', 22),
  ('that', 19)],
 'result': [('not', 102),
  ('was', 87),
  ('that', 48),
  ('did', 42),
  ('found', 39)],
 'problem': [('not', 25),
  ('to', 15),
  ('however', 12),
  ('still', 11),
  ('been', 7)],
 'implication': [('that', 72),
  ('not', 20),
  ('suggest', 13),
  ('be', 12),
  ('may', 10)],
 'goal': [('to', 134),
  ('whether', 16),
  ('determine', 14),
  ('investigate', 11),
  ('we', 10)],
 'none': [('is', 2), ('described', 2), ('shown', 1)]}