In [None]:
import os
import pandas as pd
import azureml.core
import numpy as np
import plotly.graph_objects as go
from IPython.core.display import HTML
from utils import *

from azureml.core import Workspace, Environment, Experiment, Datastore, Dataset, ScriptRunConfig
from azureml.train.automl.run import AutoMLRun
from azureml.core.datastore import Datastore
from azureml.data.data_reference import DataReference

from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
import torch
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
from sklearn import preprocessing
import torch
from transformers import TrainingArguments, Trainer
from transformers import BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AutoTokenizer
from transformers import EarlyStoppingCallback
from transformers.integrations import AzureMLCallback
from transformers import AutoTokenizer, DataCollatorWithPadding
import joblib
# from datasets import Dataset, DatasetDict

# Check core SDK version number
print("SDK version:", azureml.core.VERSION)


In [None]:
from azureml.core import Workspace
ws = Workspace.from_config()

In [None]:
ds_X_train = Dataset.get_by_name(ws, name="owner_g_classfication_train", version=8)
ds_X_val = Dataset.get_by_name(ws, name="owner_g_classfication_val", version=8)
ds_X_test = Dataset.get_by_name(ws, name="owner_g_classfication_test", version=8)
ds_temporal_test = Dataset.get_by_name(ws, name="owner_g_classfication_temporal_test", version=3)

In [None]:
print(f'{ds_X_train.tags}: V{ds_X_train.version}')
print(f'{ds_X_val.tags}: V{ds_X_val.version}')
print(f'{ds_X_test.tags}: V{ds_X_test.version}')
print(f'{ds_temporal_test.tags}: V{ds_temporal_test.version}')

In [None]:
pdf_X_train = ds_X_train.to_pandas_dataframe()
pdf_X_val = ds_X_val.to_pandas_dataframe()
pdf_X_test = ds_X_test.to_pandas_dataframe()
pdf_temporal_test = ds_temporal_test.to_pandas_dataframe()

In [None]:
base_checkpoint = "bert-base-uncased"
text_field_name = "txt_field"
target_name = "target"

In [None]:
model_directory = 'model_output/model'
model = AutoModelForSequenceClassification.from_pretrained(model_directory, num_labels=51)
tokenizer = AutoTokenizer.from_pretrained(model_directory)

In [None]:
le=joblib.load(model_directory + '/labelEncoder.joblib')
le

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
# model = best_model.model
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
# tokenizer = best_model.tokenizer
print(device)

In [None]:
from datasets import Dataset

train = Dataset.from_pandas(pdf_temporal_test)


In [None]:
def predict(inputs):
    return model(inputs)[0]

In [None]:
ref_token_id = tokenizer.pad_token_id
sep_token_id = tokenizer.sep_token_id
cls_token_id = tokenizer.cls_token_id

In [None]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [None]:
def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[0][1].unsqueeze(-1)

In [None]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [None]:
text = list(pdf_temporal_test['txt_field'].iloc[0:1])[0]

In [None]:
# print(pdf_temporal_test[(pdf_temporal_test['target'] == pdf_temporal_test['pred'])][['txt_field', 'target', 'pred']])
# text = list(pdf_temporal_test[(pdf_temporal_test['target'] == pdf_temporal_test['pred'])]['txt_field'])[1]
# text

In [None]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [None]:
predict(input_ids)

In [None]:
custom_forward(input_ids)

In [None]:
attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    n_steps=700,
                                    internal_batch_size=3,
                                    return_convergence_delta=True)

In [None]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [None]:
attributions_sum = summarize_attributions(attributions)
attributions_sum

In [None]:
score = predict(input_ids)

print('Sentence: ', text)
score

In [None]:
score_vis = viz.VisualizationDataRecord(attributions_sum,
                                        torch.softmax(score, dim = 1)[0][1],
                                        torch.argmax(torch.softmax(score, dim = 0)[0]),
                                        1,
                                        text,
                                        attributions_sum.sum(),       
                                        all_tokens,
                                        delta)


In [None]:
print('Visualization For Score')
viz.visualize_text([score_vis])

In [None]:
text