In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
run_our_lrp = True
attn_eager = False
use_bf16 = True
if use_bf16:
    lrp_dtype = torch.bfloat16
    dtype_name = "bf16"
else:
    lrp_dtype = torch.float32
    dtype_name = "f32"
if run_our_lrp:
    lrp_name = "our"
else:
    lrp_name = "attn"
if attn_eager:
    attn_mode = "eager"
else:
    attn_mode = "sdpa"

In [3]:
from transformers import AutoTokenizer

if not run_our_lrp:
    from lxt.explicit.models.llama import LlamaForCausalLM, attnlrp
    from lxt.utils import pdf_heatmap, clean_tokens
    import transformers
else:
    from transformers import LlamaForCausalLM

import os
import sys
sys.path.append('../../src')
from lrp_engine import LRPEngine

path = 'meta-llama/Llama-3.2-1B-Instruct'

model = LlamaForCausalLM.from_pretrained(path, torch_dtype=lrp_dtype, device_map="cuda", attn_implementation=attn_mode)
tokenizer = AutoTokenizer.from_pretrained(path)

if not run_our_lrp:
    # apply AttnLRP rules
    attnlrp.register(model)

device = "cuda" if torch.cuda.is_available() else "cpu"


In [4]:
def evaluate_llama(model, tokenizer, input_ids, lrp=None):
    # get input embeddings so that we can compute gradients w.r.t. input embeddings
    input_embeds = model.get_input_embeddings()(input_ids)
    
    # inference and get the maximum logit at the last position
    output_logits = model(inputs_embeds=input_embeds.requires_grad_()).logits
    max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)
    
    # get the top k tokens and their logits
    topk_logits, topk_indices = torch.topk(output_logits[0, -1, :], k=5, dim=-1)
    
    # convert token indices to strings
    topk_tokens = tokenizer.convert_ids_to_tokens(topk_indices)
    
    # initialize relevance scores with max_logits itself and backpropagate
    if lrp is not None:
        relevance = lrp.run(max_logits.unsqueeze(0))
        if len(relevance) > 2:
            print("error")
            return relevance, 1
        relevance = relevance[1][0][0].float().cpu()
    else:
        max_logits.backward(max_logits)
        relevance = input_embeds.grad.float().sum(-1).cpu()[0] # cast to float32 before summation for higher precision
    
    # normalize relevance between [-1, 1] for plotting
    relevance = relevance / relevance.abs().max()

    # print(relevance.topk(5))

    topk_indices = topk_indices.tolist()
    non_start_end_tok_relevance = relevance
    non_start_end_tok_relevance[0] = 0.0
    non_start_end_tok_relevance[-1] = 0.0
    topk_attr_inds = input_ids[0][non_start_end_tok_relevance.topk(5).indices].tolist()
    top_attr_ind = topk_attr_inds[0]
    union = len(set(topk_attr_inds).union(set(topk_indices)))
    intersect = len(set(topk_attr_inds).intersection(set(topk_indices)))

    # print(topk_attr_inds, topk_indices)
    # print(tokenizer.convert_ids_to_tokens(topk_attr_inds), tokenizer.convert_ids_to_tokens(topk_indices))

    top1_hit = top_attr_ind == topk_indices[0]

    return top1_hit, relevance#, union, intersect
    

In [5]:
from datasets import load_dataset

dataset = load_dataset("squad_v2")

In [6]:
if run_our_lrp:
    lrp = LRPEngine(use_attn_lrp=True, dtype=lrp_dtype)
else:
    lrp = None

In [7]:
ds = [ example for example in dataset["validation"] if example["answers"]["text"] ]

In [8]:
example = ds[0]
question = example["question"]
context = example["context"]
prompt = context + " " + question
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)

In [9]:
# Warmup iteration
info, _ = evaluate_llama(model, tokenizer, input_ids, lrp)

In [10]:
len(lrp.promise_bucket.start_nodes_to_promise)

162

In [11]:
from tqdm import tqdm

results = []
top1_label_hits = 0
top1_model_hits = 0
total_examples = 0
total_intersect = 0
total_union = 0

for example in tqdm(ds):
    question = example["question"]
    context = example["context"]
    answers = example["answers"]["text"]

    prompt = context + " " + question
    input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)
    if input_ids.shape[-1] > 512:
        continue
    
    top1_hit, relevance = evaluate_llama(model, tokenizer, input_ids, lrp)
    # print(answers)
    if top1_hit:
        top1_model_hits += 1
    total_examples += 1

    results.append({
        "input_ids": input_ids,
        "attr": relevance.detach().cpu(),
    })
    
    if not (total_examples % 100):
        print(top1_model_hits, total_examples, top1_model_hits / total_examples)
        break


  0%|                                                                                         | 0/5928 [00:00<?, ?it/s]


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
from lxt.utils import pdf_heatmap, clean_tokens
os.chdir("heatmaps")

In [13]:
for i, result in enumerate(results):
    relevance = result["attr"]
    relevance = relevance / relevance.abs().max()
    tokens = tokenizer.convert_ids_to_tokens(result["input_ids"][0])
    tokens = clean_tokens(tokens)
    pdf_heatmap(tokens, relevance, path=f'{lrp_name}lrp_{dtype_name}_{attn_mode}_heatmap_{i}_2.pdf', backend='xelatex')
    # pdf_heatmap(tokens, relevance, path=f'{lrp_name}lrp_single_{attn_mode}_test2.pdf', backend='xelatex')

PDF file generated successfully.
PDF file generated successfully.
PDF file generated successfully.
PDF file generated successfully.
PDF file generated successfully.
PDF file generated successfully.
PDF file generated successfully.
PDF file generated successfully.
PDF file generated successfully.
PDF file generated successfully.
PDF file generated successfully.
PDF file generated successfully.
PDF file generated successfully.


UnicodeEncodeError: 'charmap' codec can't encode character '\u0130' in position 938: character maps to <undefined>

In [None]:
# top1_model_hits / total_examples