In [None]:
%pip install captum --upgrade

In [None]:
import pickle
import io
from models import *

model = Model()
class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else:
            return super().find_class(module, name)

# Load the model using CPU_Unpickler
with open('model.pkl', 'rb') as file:
    model = CPU_Unpickler(file).load()

model.eval()

In [60]:
def model_output(*inputs):
    out = model(*inputs)[0]
    return out.unsqueeze(0)#torch.sigmoid(out)

model_input = model.base_model.embedding

In [61]:
from captum.attr import LayerIntegratedGradients

lig = LayerIntegratedGradients(model_output, model_input)

In [62]:
from params import *
import torch

def construct_input_and_baseline(text):
    max_length = 510
    baseline_token_id = TOKENIZER.pad_token_id
    sep_token_id = TOKENIZER.sep_token_id
    cls_token_id = TOKENIZER.cls_token_id

    # Encode text with max_length set to 512
    text_ids = TOKENIZER.encode(text, truncation=True, add_special_tokens=False)
    #print(text_ids)
    token_list_sum = [cls_token_id] + text_ids + [sep_token_id]
    token_list_1 = TOKENIZER.convert_ids_to_tokens(token_list_sum)
    #print(token_list_1)

    # Pad or truncate to exactly 512 tokens
    pad_length = max_length - len(text_ids)
    if pad_length > 0:
        text_ids += [baseline_token_id] * pad_length
    else:
        text_ids = text_ids[:max_length]

    # Construct input_ids and baseline_input_ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    baseline_input_ids = [cls_token_id] + [baseline_token_id] * len(text_ids) + [sep_token_id]
    token_list_sum = [cls_token_id] + token_list_1 + [sep_token_id]

    # Convert to tensor
    input_ids_tensor = torch.tensor([input_ids], device='cpu')
    baseline_input_ids_tensor = torch.tensor([baseline_input_ids], device='cpu')

    # Convert input_ids to tokens
    #print(input_ids)
    token_list = TOKENIZER.convert_ids_to_tokens(input_ids)
    #print(token_list)

    return input_ids_tensor, baseline_input_ids_tensor, token_list_1

In [None]:
text = 'Your input text here.'
input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)

print(f'original text: {input_ids}')
print(f'baseline text: {baseline_input_ids}')

In [64]:
target_class_index=5

In [None]:
attributions, delta = lig.attribute(inputs=input_ids,
                                    target=target_class_index,
                                    baselines=baseline_input_ids,
                                    return_convergence_delta=True,
                                    internal_batch_size=1)
print(attributions.size())

In [None]:
def summarize_attributions(attributions):

    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    
    return attributions

In [None]:
from captum.attr import visualization as viz

def explain(text, **labels_and_values):

    label_to_index_mapping = {'toxic': 0, 'severe_toxic': 1, 'obscene': 2, 'threat': 3, 'insult': 4, 'identity_hate': 5}

    visualizations = []

    for label, value in labels_and_values.items():
        print(label)
        print(value)
        target_class_index = label_to_index_mapping[label]
        print(target_class_index)  # Using the mapping to get the index
        input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)
        attributions, delta = lig.attribute(inputs=input_ids,
                                            target= target_class_index, #target_class_index,
                                            baselines=baseline_input_ids,
                                            return_convergence_delta=True,
                                            internal_batch_size=1)
        attributions_sum = summarize_attributions(attributions)

        score_vis = viz.VisualizationDataRecord(
            word_attributions=attributions_sum,
            pred_prob=torch.sigmoid(model(input_ids)[0][target_class_index]),
            pred_class=torch.sigmoid(model(input_ids)[0][target_class_index]).round(),
            true_class=value,
            attr_class=text,
            attr_score=attributions_sum.sum(),
            raw_input_ids=all_tokens,
            convergence_score=delta
        )

        visualizations.append(score_vis)

    viz.visualize_text(visualizations)

In [None]:
explain("try it here", toxic=1, severe_toxic=2, obscene=3, threat=4, insult=5, identity_hate=6)