### WARNING

These two XAI methods we have implemented are rather hard to execute on GTR-T5 architecture

Said XAI methods are based on applying hooks in specific parts of models and retrieve attention maps (after applying softmax) computed in the multi-headed self-attention module.

However the implementation of GTR-T5 doesnt allow us to directly inspect attention maps in question, hence there needs to be implemented some logic that allows us to retrieve this information. We will most likely have to overwrite some modules within the GTR-T5 architecture in order to do so...

For now, we shall illustrate how these two methods can be executed on a different transformer architecture that we can get attention maps from more easily...

*Keep in mind that we're now using STS model that wasn't further fine-tuned on the task of Claim-post matching, hence its performance is expected to be worse than our fine-tuned GTR-T5 model.*

In [1]:
import sys
sys.path.append("../src")

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

In [2]:
from functools import partial
import captum.attr as a
import torch
from captum.attr import visualization as viz

from dataset import OurDataset
from explain import SentenceTransformerToHF,STS_ExplainWrapper
from xai import GAE_Explain, ConservativeLRP, semantic_search_forward_function


In [3]:
dataset = OurDataset(csv_dirpath="./data", split="test")
claim, post = dataset[0]

GTR-T5 model with GAE explainability method

In [5]:
model = SentenceTransformerToHF("../models/GTR-T5-FT").to("cuda").eval()
tokenizer = model.tokenizer

claim_enc = tokenizer(claim, return_tensors="pt").to("cuda")
post_enc = tokenizer(post, return_tensors="pt").to("cuda")

with torch.no_grad():
    claim_emb = model(**claim_enc)[0]
forward_function = partial(semantic_search_forward_function, embedding=claim_emb)

  return torch.load(checkpoint_file, map_location="cpu")


RuntimeError: No CUDA GPUs are available

In [10]:
explain_class = GAE_Explain(
    module_path_expressions=None, # set this to None in order to utilize output_attentions argument instead of hooks <================
    apply_normalization=False, 
    normalization_approach="min-max"
)
explain_class.prepare_model(model)
post_explanation, predictions = explain_class._explain_batch(
    model, tokenizer, post, forward_function=forward_function
)
explain_class.cleanup()

print(post_explanation.shape)

torch.Size([1, 265])


In [9]:
post_explanation

tensor([[1.0000e+00, 1.2613e-04, 5.7248e-05, 1.4765e-05, 1.3159e-05, 1.1565e-04,
         6.0412e-06, 5.8283e-06, 3.1936e-05, 1.9109e-05, 7.3629e-06, 8.0102e-06,
         3.2303e-06, 2.5099e-06, 5.4276e-06, 5.6052e-06, 7.4079e-06, 3.3228e-06,
         2.3297e-05, 6.1909e-06, 1.2348e-05, 6.7844e-06, 2.8338e-06, 2.8651e-06,
         2.2784e-05, 8.9914e-06, 1.1358e-05, 7.6824e-06, 5.2267e-06, 5.7656e-06,
         5.9684e-06, 6.5961e-06, 6.6663e-06, 7.0466e-06, 3.3931e-06, 3.9102e-06,
         2.2554e-06, 1.8821e-06, 2.6749e-06, 4.7805e-06, 2.9755e-06, 7.9391e-06,
         4.4908e-06, 2.3143e-05, 1.1125e-05, 5.7555e-06, 2.4182e-06, 6.7310e-06,
         1.5927e-05, 5.7595e-06, 6.8159e-06, 3.9504e-06, 2.6456e-06, 2.9310e-06,
         2.6321e-06, 8.4157e-06, 3.7179e-06, 4.8046e-06, 3.0956e-05, 3.3579e-05,
         1.4366e-05, 4.4978e-06, 4.5215e-06, 5.1890e-06, 2.4793e-06, 5.5241e-06,
         8.4987e-06, 1.4557e-06, 6.8409e-07, 1.1352e-06, 1.4743e-06, 1.3147e-06,
         1.1462e-06, 7.1690e

E5 model

In [4]:
model = SentenceTransformerToHF("intfloat/multilingual-e5-large").to("cuda")
model.to("cuda")
model.eval()
tokenizer = model.tokenizer

In [5]:
claim_enc = tokenizer(claim, return_tensors="pt").to("cuda")
post_enc = tokenizer(post, return_tensors="pt").to("cuda")
with torch.no_grad():
    claim_emb = model(**claim_enc)[0]
forward_function = partial(semantic_search_forward_function, embedding=claim_emb)

In [6]:
def visualize(
    post_explanation: torch.Tensor, 
    predictions: torch.Tensor, 
    post_enc: dict[str, torch.Tensor]
) -> None:
    visualization = viz.VisualizationDataRecord(
                word_attributions=post_explanation[0],
                pred_prob=predictions[0],
                pred_class=1,
                true_class=1,
                attr_class=1,
                attr_score=post_explanation.sum(),
                raw_input_ids=tokenizer.convert_ids_to_tokens(post_enc["input_ids"][0]),
                convergence_score=1
            )
    viz.visualize_text([visualization])

### GAE (min-max normalization)

In [6]:
module_paths_to_hook = [
    "hf_transformer.encoder.layer.*.attention.self.dropout"
]
explain_class = GAE_Explain(module_paths_to_hook, apply_normalization=False, normalization_approach="min-max")
explain_class.prepare_model(model)
post_explanation, predictions = explain_class._explain_batch(model, tokenizer, post, forward_function=forward_function)
explain_class.cleanup()

tensor([[ True, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, F

In [18]:
post_explanation

tensor([[0.0000e+00, 7.1155e-06, 5.9631e-06, 5.3429e-06, 5.8513e-06, 5.9327e-06,
         6.8035e-06, 6.4001e-06, 6.8093e-06, 6.6593e-06, 6.4982e-06, 8.0593e-06,
         5.0477e-06, 1.4904e-05, 1.1923e-05, 5.4962e-06, 9.1611e-06, 3.5473e-06,
         5.2752e-06, 5.9625e-06, 3.8489e-06, 4.1159e-06, 5.2670e-06, 4.1888e-06,
         6.1746e-06, 2.8636e-06, 2.6356e-06, 4.3011e-06, 5.4318e-06, 4.3307e-06,
         6.2363e-06, 2.9962e-06, 4.9295e-06, 2.1995e-06, 3.1132e-06, 3.6458e-06,
         5.4378e-06, 2.4725e-06, 4.0849e-06, 3.1929e-06, 3.6560e-06, 4.6275e-06,
         5.9486e-06, 5.2603e-06, 5.7327e-06, 4.1946e-06, 7.2795e-06, 4.2807e-06,
         3.7396e-06, 4.7363e-06, 5.3330e-06, 1.9722e-06, 2.3648e-06, 3.4652e-06,
         3.8025e-06, 4.0812e-06, 2.7034e-06, 4.0414e-06, 3.3803e-06, 5.1507e-06,
         3.3043e-06, 1.0579e-05, 3.1239e-06, 3.0423e-06, 2.2744e-06, 3.5977e-06,
         5.1379e-06, 6.3002e-06, 6.3288e-06, 5.6651e-06, 1.8999e-06, 3.4449e-06,
         2.8494e-06, 3.2891e

In [18]:
visualize(post_explanation, predictions, post_enc)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.88),1.0,13.61,"#s ▁You ▁still ▁drink ing ▁Das ani ? ▁Das ani ▁is ▁considered ▁the ▁worst ▁bottle d ▁water ▁in ▁the ▁market ▁because ▁it ▁contain s ▁po tas s ium ▁chlor ide . ▁It ' s ▁the ▁same ▁chemical ▁that ' s ▁given ▁to ▁death ▁ row ▁in mates ▁before ▁they ▁die . ▁Ex tende d ▁exposure ▁to ▁this ▁may ▁lead ▁to ▁we a ker ▁bone s ▁& ▁cardiac ▁arrest . ▁@ The ▁Fo o Commun ity ▁PRI FI ED ▁WA TER ▁AFF ORA PUR E , ▁FRE SH ▁T AST ▁ PUR IFI ED ▁WA TER ▁WE DS VE ▁16. 9 ▁FL ▁ OZ ▁C LOS P ▁500 ▁ml ▁plant ▁bottle ▁ PUR IFI ED ▁WA TER ▁E NH ANCE D ▁WI TH ▁MIN ERA LS ▁FOR ▁& ▁P URE , ▁F REN ▁ LOG ▁PT ) ▁500 ▁ml . ▁ ASAN IS ANI ▁DAS ANI ▁D ASAN ASAN T ▁plo n bo ▁O ▁plant ball ▁0 ▁CAR ES ▁CAR TIE ▁D ▁plant ▁bottle ▁A TER ▁16. 9 ▁FL ▁ OZ ▁( 106 ▁PT ) ▁500 ▁ml ▁N EFE S ▁S ▁18 464 ▁ PUR IFI ED ▁WA TER ▁E NH ANCE D ▁WI TH ▁MIN ERA LS ▁FOR ▁0 ▁CAL ORES ▁PER A TR ▁ PUR IFI ED ▁WA TER ▁FINAL S ▁FOR ▁A ▁P URE , ▁FRE SH ▁TA SSE ▁16. 9 ▁FL OZ ▁( LOG ▁PI ) ▁500 ▁ml . ▁plant ▁bottle s ▁KO PL OZ ▁( L 06 ▁PT ) ▁500 ▁ml #/s"
,,,,


In [24]:
visualize(post_explanation, predictions, post_enc)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.88),1.0,89.27,"#s ▁You ▁still ▁drink ing ▁Das ani ? ▁Das ani ▁is ▁considered ▁the ▁worst ▁bottle d ▁water ▁in ▁the ▁market ▁because ▁it ▁contain s ▁po tas s ium ▁chlor ide . ▁It ' s ▁the ▁same ▁chemical ▁that ' s ▁given ▁to ▁death ▁ row ▁in mates ▁before ▁they ▁die . ▁Ex tende d ▁exposure ▁to ▁this ▁may ▁lead ▁to ▁we a ker ▁bone s ▁& ▁cardiac ▁arrest . ▁@ The ▁Fo o Commun ity ▁PRI FI ED ▁WA TER ▁AFF ORA PUR E , ▁FRE SH ▁T AST ▁ PUR IFI ED ▁WA TER ▁WE DS VE ▁16. 9 ▁FL ▁ OZ ▁C LOS P ▁500 ▁ml ▁plant ▁bottle ▁ PUR IFI ED ▁WA TER ▁E NH ANCE D ▁WI TH ▁MIN ERA LS ▁FOR ▁& ▁P URE , ▁F REN ▁ LOG ▁PT ) ▁500 ▁ml . ▁ ASAN IS ANI ▁DAS ANI ▁D ASAN ASAN T ▁plo n bo ▁O ▁plant ball ▁0 ▁CAR ES ▁CAR TIE ▁D ▁plant ▁bottle ▁A TER ▁16. 9 ▁FL ▁ OZ ▁( 106 ▁PT ) ▁500 ▁ml ▁N EFE S ▁S ▁18 464 ▁ PUR IFI ED ▁WA TER ▁E NH ANCE D ▁WI TH ▁MIN ERA LS ▁FOR ▁0 ▁CAL ORES ▁PER A TR ▁ PUR IFI ED ▁WA TER ▁FINAL S ▁FOR ▁A ▁P URE , ▁FRE SH ▁TA SSE ▁16. 9 ▁FL OZ ▁( LOG ▁PI ) ▁500 ▁ml . ▁plant ▁bottle s ▁KO PL OZ ▁( L 06 ▁PT ) ▁500 ▁ml #/s"
,,,,


### GAE (L2 normalization)

In [25]:
module_paths_to_hook = [
    "hf_transformer.encoder.layer.*.attention.self.dropout"
]
explain_class = GAE_Explain(module_paths_to_hook, apply_normalization=True, normalization_approach="l2")
explain_class.prepare_model(model)
post_explanation, predictions = explain_class._explain_batch(model, tokenizer, post, forward_function=forward_function)
explain_class.cleanup()

In [26]:
visualize(post_explanation, predictions, post_enc)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.88),1.0,13.61,"#s ▁You ▁still ▁drink ing ▁Das ani ? ▁Das ani ▁is ▁considered ▁the ▁worst ▁bottle d ▁water ▁in ▁the ▁market ▁because ▁it ▁contain s ▁po tas s ium ▁chlor ide . ▁It ' s ▁the ▁same ▁chemical ▁that ' s ▁given ▁to ▁death ▁ row ▁in mates ▁before ▁they ▁die . ▁Ex tende d ▁exposure ▁to ▁this ▁may ▁lead ▁to ▁we a ker ▁bone s ▁& ▁cardiac ▁arrest . ▁@ The ▁Fo o Commun ity ▁PRI FI ED ▁WA TER ▁AFF ORA PUR E , ▁FRE SH ▁T AST ▁ PUR IFI ED ▁WA TER ▁WE DS VE ▁16. 9 ▁FL ▁ OZ ▁C LOS P ▁500 ▁ml ▁plant ▁bottle ▁ PUR IFI ED ▁WA TER ▁E NH ANCE D ▁WI TH ▁MIN ERA LS ▁FOR ▁& ▁P URE , ▁F REN ▁ LOG ▁PT ) ▁500 ▁ml . ▁ ASAN IS ANI ▁DAS ANI ▁D ASAN ASAN T ▁plo n bo ▁O ▁plant ball ▁0 ▁CAR ES ▁CAR TIE ▁D ▁plant ▁bottle ▁A TER ▁16. 9 ▁FL ▁ OZ ▁( 106 ▁PT ) ▁500 ▁ml ▁N EFE S ▁S ▁18 464 ▁ PUR IFI ED ▁WA TER ▁E NH ANCE D ▁WI TH ▁MIN ERA LS ▁FOR ▁0 ▁CAL ORES ▁PER A TR ▁ PUR IFI ED ▁WA TER ▁FINAL S ▁FOR ▁A ▁P URE , ▁FRE SH ▁TA SSE ▁16. 9 ▁FL OZ ▁( LOG ▁PI ) ▁500 ▁ml . ▁plant ▁bottle s ▁KO PL OZ ▁( L 06 ▁PT ) ▁500 ▁ml #/s"
,,,,


### ConservativeLRP (min-max normalization)

In [4]:
store_A_path_expressions = [
    "hf_transformer.embeddings"
]
attent_path_expressions = [
    "hf_transformer.encoder.layer.*.attention.self.dropout"
]
norm_layer_path_expressions = [
    "hf_transformer.embeddings.LayerNorm",
    "hf_transformer.encoder.layer.*.attention.output.LayerNorm",
    "hf_transformer.encoder.layer.*.output.LayerNorm",
]
lrp = ConservativeLRP(
    store_A_path_expressions, attent_path_expressions, norm_layer_path_expressions, 
    apply_normalization=True, normalization_approach="min-max"
)
lrp.prepare_model(model)
post_explanation, predictions = lrp._explain_batch(model, tokenizer, post, forward_function=forward_function)
lrp.cleanup()

NameError: name 'model' is not defined

In [22]:
visualize(post_explanation, predictions, post_enc)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.88),1.0,179.18,"#s ▁You ▁still ▁drink ing ▁Das ani ? ▁Das ani ▁is ▁considered ▁the ▁worst ▁bottle d ▁water ▁in ▁the ▁market ▁because ▁it ▁contain s ▁po tas s ium ▁chlor ide . ▁It ' s ▁the ▁same ▁chemical ▁that ' s ▁given ▁to ▁death ▁ row ▁in mates ▁before ▁they ▁die . ▁Ex tende d ▁exposure ▁to ▁this ▁may ▁lead ▁to ▁we a ker ▁bone s ▁& ▁cardiac ▁arrest . ▁@ The ▁Fo o Commun ity ▁PRI FI ED ▁WA TER ▁AFF ORA PUR E , ▁FRE SH ▁T AST ▁ PUR IFI ED ▁WA TER ▁WE DS VE ▁16. 9 ▁FL ▁ OZ ▁C LOS P ▁500 ▁ml ▁plant ▁bottle ▁ PUR IFI ED ▁WA TER ▁E NH ANCE D ▁WI TH ▁MIN ERA LS ▁FOR ▁& ▁P URE , ▁F REN ▁ LOG ▁PT ) ▁500 ▁ml . ▁ ASAN IS ANI ▁DAS ANI ▁D ASAN ASAN T ▁plo n bo ▁O ▁plant ball ▁0 ▁CAR ES ▁CAR TIE ▁D ▁plant ▁bottle ▁A TER ▁16. 9 ▁FL ▁ OZ ▁( 106 ▁PT ) ▁500 ▁ml ▁N EFE S ▁S ▁18 464 ▁ PUR IFI ED ▁WA TER ▁E NH ANCE D ▁WI TH ▁MIN ERA LS ▁FOR ▁0 ▁CAL ORES ▁PER A TR ▁ PUR IFI ED ▁WA TER ▁FINAL S ▁FOR ▁A ▁P URE , ▁FRE SH ▁TA SSE ▁16. 9 ▁FL OZ ▁( LOG ▁PI ) ▁500 ▁ml . ▁plant ▁bottle s ▁KO PL OZ ▁( L 06 ▁PT ) ▁500 ▁ml #/s"
,,,,


### ConservativeLRP (L2 normalization)

In [14]:
store_A_path_expressions = [
    "hf_transformer.embeddings"
]
attent_path_expressions = [
    "hf_transformer.encoder.layer.*.attention.self.dropout"
]
norm_layer_path_expressions = [
    "hf_transformer.embeddings.LayerNorm",
    "hf_transformer.encoder.layer.*.attention.output.LayerNorm",
    "hf_transformer.encoder.layer.*.output.LayerNorm",
]
lrp = ConservativeLRP(
    store_A_path_expressions, attent_path_expressions, norm_layer_path_expressions, 
    apply_normalization=True, normalization_approach="l2"
)
lrp.prepare_model(model)
post_explanation, predictions = lrp._explain_batch(model, tokenizer, post, forward_function=forward_function)
lrp.cleanup()

: 

In [28]:
visualize(post_explanation, predictions, post_enc)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.88),1.0,-7.13,"#s ▁You ▁still ▁drink ing ▁Das ani ? ▁Das ani ▁is ▁considered ▁the ▁worst ▁bottle d ▁water ▁in ▁the ▁market ▁because ▁it ▁contain s ▁po tas s ium ▁chlor ide . ▁It ' s ▁the ▁same ▁chemical ▁that ' s ▁given ▁to ▁death ▁ row ▁in mates ▁before ▁they ▁die . ▁Ex tende d ▁exposure ▁to ▁this ▁may ▁lead ▁to ▁we a ker ▁bone s ▁& ▁cardiac ▁arrest . ▁@ The ▁Fo o Commun ity ▁PRI FI ED ▁WA TER ▁AFF ORA PUR E , ▁FRE SH ▁T AST ▁ PUR IFI ED ▁WA TER ▁WE DS VE ▁16. 9 ▁FL ▁ OZ ▁C LOS P ▁500 ▁ml ▁plant ▁bottle ▁ PUR IFI ED ▁WA TER ▁E NH ANCE D ▁WI TH ▁MIN ERA LS ▁FOR ▁& ▁P URE , ▁F REN ▁ LOG ▁PT ) ▁500 ▁ml . ▁ ASAN IS ANI ▁DAS ANI ▁D ASAN ASAN T ▁plo n bo ▁O ▁plant ball ▁0 ▁CAR ES ▁CAR TIE ▁D ▁plant ▁bottle ▁A TER ▁16. 9 ▁FL ▁ OZ ▁( 106 ▁PT ) ▁500 ▁ml ▁N EFE S ▁S ▁18 464 ▁ PUR IFI ED ▁WA TER ▁E NH ANCE D ▁WI TH ▁MIN ERA LS ▁FOR ▁0 ▁CAL ORES ▁PER A TR ▁ PUR IFI ED ▁WA TER ▁FINAL S ▁FOR ▁A ▁P URE , ▁FRE SH ▁TA SSE ▁16. 9 ▁FL OZ ▁( LOG ▁PI ) ▁500 ▁ml . ▁plant ▁bottle s ▁KO PL OZ ▁( L 06 ▁PT ) ▁500 ▁ml #/s"
,,,,
