In [1]:
# paremeters
checkpoint = "lightning_logs/2021-04-27_134822.564762/latest/checkpoints/baseline-epoch=01-val_acc_lab_grade=58.76.ckpt"
jsons = ["/home/lhchan/eb_class/ismi/ismi_late_submission.json"]
batch_size = 1
bert_path = "TurkuNLP/bert-base-finnish-cased-v1"

In [2]:
# libraries
from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
import captum
from IPython.core.display import HTML, display

In [3]:
import torch
import pytorch_lightning as pl
from transformers import AutoTokenizer
import data_reader, model
import pickle

In [5]:
# data
data = data_reader.JsonDataModule(jsons,
                                  batch_size=batch_size,
                                  bert_model_name=bert_path,
                                  model_type="whole_essay")
data.setup()

Removing data points without label(s)...
BEFORE 78
AFTER 78
OUTPUT: lab_grade
3   32/78=41.0
4   31/78=39.7
5   8/78=10.3
2   6/78=7.7
1   1/78=1.3

After segmenting essays
OUTPUT: lab_grade
3   32/78=41.0
4   31/78=39.7
5   8/78=10.3
2   6/78=7.7
1   1/78=1.3



In [6]:
# LayerIntegratedGradients only takes in inputs (tensor or tuple of tensors)
class ExplainWholeEssayClassModel(model.WholeEssayClassModel):
    def __init__(self, class_nums):
        super().__init__(class_nums)
    def forward(self, batch_input_ids, batch_attention_mask, batch_token_type_ids):
        enc = self.bert(input_ids=batch_input_ids,
                        attention_mask=batch_attention_mask,
                        token_type_ids=batch_token_type_ids) #BxS_LENxSIZE; BxSIZE
        return {name: layer(enc.pooler_output) for name, layer in self.cls_layers.items()}

In [7]:
# model
tokenizer = AutoTokenizer.from_pretrained(bert_path)

trained_model = ExplainWholeEssayClassModel.load_from_checkpoint(checkpoint_path=checkpoint, class_nums=data.class_nums())
trained_model.eval()
#trained_model.cuda() # needs around ~13GB of memory

ExplainWholeEssayClassModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(50105, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [29]:
# helper functions
def predict(input_ids, attention_mask, token_type_ids): #inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    pred = trained_model(input_ids, attention_mask, token_type_ids)
    return pred["lab_grade"] #return the output of the classification layer

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

def aggregate(inp,attrs,tokenizer):
    """detokenize and merge attributions"""
    detokenized=[]
    for l in inp[0].cpu().tolist():
        detokenized.append(tokenizer.convert_ids_to_tokens(l))
    attrs=attrs.cpu().tolist()
    

    aggregated=[]
    for token,a_val in zip(detokenized[0],attrs): #One text from the batch at a time!
        if token.startswith("##"):
            #This is a continuation. We need to pool by absolute value, i.e. pick the most extreme one
            current_tok,current_a_val=aggregated[-1] #this is what we have so far
            if abs(current_a_val)>abs(a_val): #what we have has larger absval
                aggregated[-1]=(aggregated[-1][0]+token[2:],aggregated[-1][1])
            else:
                aggregated[-1]=(aggregated[-1][0]+token[2:],a_val) #the new value had a large absval, let's use that
        else:
            aggregated.append((token,a_val))
    return aggregated

def print_aggregated(target,aggregated):
    with open("delme_before_print", "wb") as f:
        pickle.dump([target,aggregated], f)
    
    to_print=""
    to_print = to_print+"<html><body>"
    x=captum.attr.visualization.format_word_importances([t for t,a in aggregated],[a for t,a in aggregated])
    to_print = to_print+"<b>"+str(target)+"</b>"
    to_print = to_print+"""<table style="border:solid;">"""+x+"</table>"
    to_print = to_print+"</body></html>"
    display(HTML(to_print))

In [9]:
def build_ref(i, batch, tokenizer, device):
    """Given index and a batch, return reference
    input_ids, token_type_ids, and attention_mask"""
    ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
    sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
    cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

    # ref input token id
    ref_input_ids = torch.tensor([token if token==cls_token_id or token==sep_token_id else ref_token_id for token in batch["input_ids"][i]])
    # ref_token_type_ids
    ref_token_type_ids = batch["token_type_ids"][i]
    # ref_attention_mask
    ref_token_type_ids = batch["attention_mask"][i]
    
    return (torch.unsqueeze(ref_input_ids,dim=0).to(device),
           torch.unsqueeze(ref_token_type_ids,dim=0).to(device),
           torch.unsqueeze(ref_token_type_ids,dim=0).to(device))

In [23]:
def predict_and_explain(trained_model, tokenizer, obj_batch):
    trained_model.zero_grad() #to be safe perhaps it's not needed
    device=trained_model.device

    lig = LayerIntegratedGradients(predict, trained_model.bert.embeddings)
    predictions = predict(torch.nn.utils.rnn.pad_sequence(obj_batch["input_ids"],batch_first=True).to(device),
                         torch.nn.utils.rnn.pad_sequence(obj_batch["attention_mask"],batch_first=True).to(device),
                         torch.nn.utils.rnn.pad_sequence(obj_batch["token_type_ids"],batch_first=True).to(device),)
    for i, prediction in enumerate(predictions):
        prediction_cls=int(torch.argmax(prediction))
        print("Gold standard:", obj_batch["lab_grade"][i])
        print("Prediction:", ("1","2","3","4","5")[prediction_cls],"Weights:",prediction.tolist())
        ref_input = build_ref(i, obj_batch, tokenizer, device)
        inp = (obj_batch["input_ids"][i].unsqueeze(0).to(device),
                obj_batch["attention_mask"][i].unsqueeze(0).to(device),
                obj_batch["token_type_ids"][i].unsqueeze(0).to(device))
        all_tokens = tokenizer.convert_ids_to_tokens(inp[0][0])
        for target, classname in enumerate(("1","2","3","4","5")):
            attrs, delta = lig.attribute(inputs=inp,
                                  baselines=ref_input,
                                  return_convergence_delta=True,target=target)
            try:
                with open("delme", "wb") as f:
                    pickle.dump([obj_batch["essay"],attrs, delta], f)
                print("saved")
            except Exception as e:
                print(e)
            attrs_sum = summarize_attributions(attrs)
            aggregated = aggregate(inp, attrs_sum, tokenizer)

            x=captum.attr.visualization.format_word_importances(all_tokens,attrs_sum)
            print("ATTRIBUTION WITH RESPECT TO",classname)
            print_aggregated(target, aggregated)
            #display(HTML(x))
            print()

In [30]:
for batch in data.val_dataloader():
    predict_and_explain(trained_model, tokenizer, batch)
    break

Gold standard: tensor(3)
Prediction: 3 Weights: [-1.7724227905273438, -0.8241284489631653, 1.4079138040542603, 1.323012113571167, 0.6676863431930542]
saved
ATTRIBUTION WITH RESPECT TO 1



saved
ATTRIBUTION WITH RESPECT TO 2



saved
ATTRIBUTION WITH RESPECT TO 3



saved
ATTRIBUTION WITH RESPECT TO 4



saved
ATTRIBUTION WITH RESPECT TO 5





In [32]:
c=0
for batch in data.val_dataloader():
    if c==0:
        c += 1
        continue
    predict_and_explain(trained_model, tokenizer, batch)
    break

Gold standard: tensor(3)
Prediction: 5 Weights: [-1.8693541288375854, -1.1776630878448486, 0.6236181855201721, 1.6303231716156006, 1.6535723209381104]
saved
ATTRIBUTION WITH RESPECT TO 1



saved
ATTRIBUTION WITH RESPECT TO 2



saved
ATTRIBUTION WITH RESPECT TO 3



saved
ATTRIBUTION WITH RESPECT TO 4



saved
ATTRIBUTION WITH RESPECT TO 5





In [18]:
for batch in data.val_dataloader():
    predict_and_explain(trained_model, tokenizer, batch)
    break

Gold standard: tensor(3)
Prediction: 3 Weights: [-1.7724227905273438, -0.8241284489631653, 1.4079138040542603, 1.323012113571167, 0.6676863431930542]
[Errno 2] No such file or directory: ''
ATTRIBUTION WITH RESPECT TO 1



[Errno 2] No such file or directory: ''
ATTRIBUTION WITH RESPECT TO 2



[Errno 2] No such file or directory: ''
ATTRIBUTION WITH RESPECT TO 3



[Errno 2] No such file or directory: ''
ATTRIBUTION WITH RESPECT TO 4



[Errno 2] No such file or directory: ''
ATTRIBUTION WITH RESPECT TO 5



