# Intro

This notebook generates the intermediate files useful for the interpretability analysis.

In particular, for each part of the context it saves the most relevant tokens (2 std above the mean across the entire context) and saves a count.
To facilitate aggregate statistics, it saves the average score for each context part, after a normalization between 0 and 1.

In [2]:
import sys
sys.path.append("../src")
from xai import AttributionOutput, AttributionUnit, RowAttribution

import pickle
import torch
import json
from tqdm import tqdm

import numpy as np
import scipy.stats
import pandas as pd

## Instructions

In the next cell, there are some variables you can set to generate the intermediate json files for a given explanation target (the part of the output the scores are relevant to) and language.

In [106]:
MODEL = "Qwen2.5-72B-Instruct"
LANGUAGE = "es"
ATTR_TARGET = "translation_label" 
ATTR_TARGET = "translation"

# TODO: Set the input directory where the raw pickle files containing attribution scores
# are located. Set also the id of the set we are studying.
INPUT_DIR = "../../results-interim-gente-xai/attributions/attnlrp/set-g/"
SET = "set-g"

# Shall we exclude the source tokens from the mean computation? Yes, if we are attributing
# the output translation.
EXCLUDE_SRC_SCORES = True

# Shall we consider the translation label as part of the input context and hence
# compute an attribution score to each of its tokens? Yes, if we are attributing
# the output translation.
ADD_TRANSLATION_LABEL = True

In [107]:
input_file = f"{INPUT_DIR}/{SET}_{LANGUAGE}_xai_{MODEL}_prompt_v1-4shot.json.attr.pkl"
with open(input_file, "rb") as f:
    data = pickle.load(f)

if ADD_TRANSLATION_LABEL:
    output_file = f"{INPUT_DIR}/processed-{ATTR_TARGET}_wtl-{MODEL}-{LANGUAGE}.json"
else:
    output_file = f"{INPUT_DIR}/processed-{ATTR_TARGET}-{MODEL}-{LANGUAGE}.json"

translation_file = f"../../results-interim-gente-xai/attributions/data/{SET}_{LANGUAGE}_xai_{MODEL}_prompt_v1-4shot.json.tsv"
translation_df = pd.read_csv(translation_file, sep="\t", index_col="ID")
print(translation_df.columns)

Index(['Europarl_ID', 'SET', 'SRC', 'REF-G', 'REF-N', 'COMMON', 'GENDER',
       'REF-G_ann', 'G-WORDS', 'translation_label', 'translation',
       'raw_output', 'set_label',
       'neutrality_label_Qwen/Qwen2.5-72B-Instruct'],
      dtype='object')


In [108]:
print(data.guidelines)
print("#" * 10)
print(data.system_prompt)
print("#" * 10)
print(data.demonstrations)
print("#" * 10)
print(data.rows[0])

AttributionUnit(tokens=['Guid', 'elines', ' for', ' Gender', '-', 'Neutral', ' Translation', ':\n', '   ', ' -', ' Use', ' neutral', ' synonyms', ' (', 'e', '.g', '.', ' coleg', 'as', ')\n', '   ', ' -', ' Use', ' neutral', ' collective', ' nouns', ' (', 'e', '.g', '.', ' el', ' profes', 'orado', ',', ' el', ' minister', 'io', ')\n', '   ', ' -', ' Use', ' neutral', ' re', 'ph', 'ras', 'ings', ' like', ' "', 'las', ' personas', ' que', '"', ' or', ' "', 'qu', 'ien', '/qu', 'ienes', '"\n', '   ', ' -', ' Avoid', ' masculine', ' forms', ' for', ' generic', ' refer', 'ents', ' (', 'ending', ' in', ' o', '/os', ')\n', '   ', ' -', ' Avoid', ' @', ' or', ' x', '\n', '   ', ' -', ' Avoid', ' double', ' feminine', '/m', 'as', 'cul', 'ine', ' forms', ' (', 'e', '.g', '.', ' amigo', '/a', ')\n', '    \n    \n'], span=(102, 201), metadata=None)
##########
AttributionUnit(tokens=['You', ' are', ' a', ' helpful', ' Spanish', ' translator', ' specialized', ' in', ' gender', '-neutral', ' language',

In [109]:
# Here's an example on how to interpret attribution scores and spans.

# The number of attributed tokens is the shape of "attributions". This is the tensor
# we'll need to slice to pick the correct token spans and contributions.
# It contains attribution scores for all the input tokens, including the chat template special ones.
# Note: if you are checking the attributions of "translation", then the last positions will also
# contain the scores of the initial output tokens that come before the translation, for example
# <de> **GENDERED**
iid = 28
print(data.rows[iid].translation.metadata["attributions"].shape)

# Full_prompt refers to the entire input context. It's gonna be slightly shorter than the scores above
print(data.rows[iid].full_prompt.span)
print(len(data.rows[iid].full_prompt.tokens))

# The translation label span refers to the output tokens, i.e., it's the span in the generation that
# contains all the tokens referring to the translation label.
print(data.rows[iid].translation_label.span)


torch.Size([674])
(0, 666)
666
(3, 8)


In [110]:
def slice_array(x, l_idx, r_idx):
    return x[l_idx:r_idx]

def find_max(tensors: list):
    """Find the maximum value in a list of tensors."""
    max_value = float("-inf")
    for tensor in tensors:
        max_value = max(max_value, tensor.max().item())
    return max_value

def find_mean_std(tensors: list):
    t = torch.cat(tensors)
    return t.mean(), t.std()

def find_relevant_tokens(scores: torch.tensor, tokens: list[str], tr: float):
    relevant_tokens = list()
    for i, (s, t) in enumerate(zip(scores, tokens)):
        if s > tr:
            relevant_tokens.append((i, t))
    return relevant_tokens

# all_stats = list()
# all_tensors = list()
# row_stats = dict()
full_json = dict()

for row_attr in tqdm(data.rows):
    scores = getattr(row_attr, ATTR_TARGET).metadata["attributions"]
    
    row_dict = dict()
    tokens_cat = list()
    part_cat = list()
    scores_cat = list()

    row_dict["system_prompt"] = slice_array(scores, *data.system_prompt.span)
    tokens_cat.extend(data.system_prompt.tokens)
    part_cat.extend(["SYS"] * len(data.system_prompt.tokens))
    scores_cat.append(row_dict["system_prompt"])

    row_dict["preamble"] = slice_array(scores, *data.preamble.span)
    tokens_cat.extend(data.preamble.tokens)
    part_cat.extend(["PRE"] * len(data.preamble.tokens))
    scores_cat.append(row_dict["preamble"])

    row_dict["guidelines"] = slice_array(scores, *data.guidelines.span)
    tokens_cat.extend(data.guidelines.tokens)
    part_cat.extend(["GUI"] * len(data.guidelines.tokens))
    scores_cat.append(row_dict["guidelines"])
    
    for i, (u, a) in enumerate(data.demonstrations):
        scores_u = slice_array(scores, *u.span)
        tokens_cat.extend(u.tokens)
        scores_cat.append(scores_u)
        part_cat.extend([f"USR_{i}"] * len(u.tokens))
        
        scores_a = slice_array(scores, *a.span)
        tokens_cat.extend(a.tokens)
        scores_cat.append(scores_a)
        part_cat.extend([f"ASS_{i}"] * len(a.tokens))

        row_dict[f"shot_{i}_u"] = scores_u
        row_dict[f"shot_{i}_a"] = scores_a
        # row_dict[f"full_shot_{i}"] = torch.cat((scores_u, scores_a))
    
    if not EXCLUDE_SRC_SCORES:
        row_dict["source"] = slice_array(scores, *row_attr.source.span)
        tokens_cat.extend(row_attr.source.tokens)
        scores_cat.append(row_dict["source"])
        part_cat.extend(["SRC"] * len(row_attr.source.tokens))

    if ADD_TRANSLATION_LABEL:
        tl_span = (
            row_attr.full_prompt.span[1] + row_attr.translation_label.span[0],
            row_attr.full_prompt.span[1] + row_attr.translation_label.span[1]
        )
        row_dict["translation_label"] = slice_array(scores, *tl_span)
        tokens_cat.extend(row_attr.translation_label.tokens)
        scores_cat.append(row_dict["translation_label"])
        part_cat.extend(["TL"] * len(row_attr.translation_label.tokens))
        
    # row_dict at this point contains all the parts of the context
    # that are relevant for the max / normalization computation
    
    # 1. turn everything into abs values and compute the max, mean, and std
    row_dict = {k: v.abs() for k, v in row_dict.items()}
    max_value = find_max([v for k, v in row_dict.items() if "full_shot" not in k])
    mean_value, std_value = find_mean_std([v for k, v in row_dict.items() if "full_shot" not in k])

    # 2. set the threshold for significance to twice the std 
    threshold = mean_value + 2 * std_value

    # 3. count for each item of the dict, how many items are above the threshold
    relevant_counts = dict()
    relevant_tokens = dict()
    for cp in ["system_prompt", "guidelines", "preamble"]:
        rt = find_relevant_tokens(
            row_dict[cp], getattr(data, cp).tokens, threshold
        )
        relevant_tokens[cp] = rt
        relevant_counts[cp] = len(rt)

    for i, shot in enumerate(data.demonstrations):
        cp = f"shot_{i}_u"
        rt = find_relevant_tokens(
            row_dict[cp], shot[0].tokens, threshold
        )
        relevant_tokens[cp] = rt
        relevant_counts[cp] = len(rt)
        cp = f"shot_{i}_a"
        rt = find_relevant_tokens(
            row_dict[cp], shot[1].tokens, threshold
        )
        relevant_tokens[cp] = rt
        relevant_counts[cp] = len(rt)

    if not EXCLUDE_SRC_SCORES:
        rt = find_relevant_tokens(
            row_dict["source"], row_attr.source.tokens, threshold
        )
        relevant_tokens["source"] = rt
        relevant_counts["source"] = len(rt)

    if ADD_TRANSLATION_LABEL:
        rt = find_relevant_tokens(
            row_dict["translation_label"], row_attr.translation_label.tokens, threshold
        )
        relevant_tokens["translation_label"] = rt
        relevant_counts["translation_label"] = len(rt)

    # print(relevant_counts)
    # print(relevant_tokens)
    
    # 4. rescale in [0,1] and compute the mean score
    mean_dict = {
        k: (v / max_value).mean().item() for k, v in row_dict.items()
    }
    # print(mean_dict)
    # normalize everything by the max value
    # normalized_tensors = {k: v / max_value for k, v in row_dict.items()}

    # 5. get stats about the top N tokens
    scores_cat = torch.cat(scores_cat)
    N = 20
    # Get the indices of the top N scores in descending order
    top_indices = torch.argsort(scores_cat, descending=True)[:N]

    # Extract the top N values, parts, and tokens based on the indices
    top_scores = scores_cat[top_indices].tolist()
    top_parts = [part_cat[i] for i in top_indices]
    top_tokens = [tokens_cat[i] for i in top_indices]

    full_json[row_attr.rid] = dict()
    full_json[row_attr.rid]["translation_label"] = translation_df.loc[row_attr.rid]["translation_label"]
    full_json[row_attr.rid]["gold_neutrality_label"] = translation_df.loc[row_attr.rid].get("gold_neutrality_label", None)
    full_json[row_attr.rid]["SRC"] = translation_df.loc[row_attr.rid]["SRC"]
    full_json[row_attr.rid]["relevant_counts"] = relevant_counts
    full_json[row_attr.rid]["relevant_tokens"] = relevant_tokens
    full_json[row_attr.rid]["max_value"] = max_value
    full_json[row_attr.rid]["mean_value"] = mean_value.item()
    full_json[row_attr.rid]["std_value"] = std_value.item()
    full_json[row_attr.rid]["top_scores"] = top_scores
    full_json[row_attr.rid]["top_parts"] = top_parts
    full_json[row_attr.rid]["top_tokens"] = top_tokens
    full_json[row_attr.rid]["mean_values"] = mean_dict

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 392/392 [00:01<00:00, 233.00it/s]


In [111]:
row_dict.keys()

dict_keys(['system_prompt', 'preamble', 'guidelines', 'shot_0_u', 'shot_0_a', 'shot_1_u', 'shot_1_a', 'shot_2_u', 'shot_2_a', 'shot_3_u', 'shot_3_a', 'translation_label'])

In [112]:
out_df = pd.DataFrame(full_json).T

In [113]:
out_df.head()

Unnamed: 0,translation_label,gold_neutrality_label,SRC,relevant_counts,relevant_tokens,max_value,mean_value,std_value,top_scores,top_parts,top_tokens,mean_values
26,**NEUTRAL**,,May I also warn against granting migrant women...,"{'system_prompt': 0, 'guidelines': 1, 'preambl...","{'system_prompt': [], 'guidelines': [(55, 'ien...",0.25248,0.008324,0.017153,"[0.14895403385162354, 0.10928550362586975, 0.0...","[TL, GUI, USR_2, SYS, ASS_2, ASS_2, USR_2, USR...","[ **, ien, legitimate, Spanish, veces, .], ...","{'system_prompt': 0.04472806677222252, 'preamb..."
35,**NEUTRAL**,,This time I will give the Commissioner a chanc...,"{'system_prompt': 1, 'guidelines': 2, 'preambl...","{'system_prompt': [(4, ' Spanish')], 'guidelin...",0.163181,0.014325,0.01686,"[0.1631811410188675, 0.13878124952316284, 0.10...","[PRE, TL, SYS, ASS_2, TL, USR_2, USR_2, USR_2,...","[ Spanish, **, Spanish, veces, **, legitim...","{'system_prompt': 0.13752689957618713, 'preamb..."
47,**NEUTRAL**,,I would like to address our colleague and than...,"{'system_prompt': 1, 'guidelines': 1, 'preambl...","{'system_prompt': [(4, ' Spanish')], 'guidelin...",0.474094,0.014127,0.025933,"[0.10701923817396164, 0.08874690532684326, 0.0...","[PRE, GUI, SYS, ASS_2, ASS_2, GUI, ASS_1, ASS_...","[ Spanish, coleg, Spanish, .], veces, as, i...","{'system_prompt': 0.025611044839024544, 'pream..."
49,**NEUTRAL**,,Since the Commissioner doubts whether sufficie...,"{'system_prompt': 2, 'guidelines': 3, 'preambl...","{'system_prompt': [(4, ' Spanish'), (11, '.')]...",0.265968,0.022166,0.024026,"[0.2659681439399719, 0.19200143218040466, 0.13...","[TL, GUI, SYS, PRE, USR_2, USR_2, USR_2, TL, U...","[**, ien, Spanish, Spanish, visitors, legi...","{'system_prompt': 0.10917984694242477, 'preamb..."
82,**NEUTRAL**,,"Many of them have lost their jobs and, with th...","{'system_prompt': 2, 'guidelines': 2, 'preambl...","{'system_prompt': [(4, ' Spanish'), (8, ' gend...",0.18237,0.018524,0.020964,"[0.18237033486366272, 0.14360371232032776, 0.1...","[TL, PRE, ASS_3, SYS, GUI, ASS_2, ASS_3, GUI, ...","[ **, Spanish, las, Spanish, personas, ve...","{'system_prompt': 0.16190151870250702, 'preamb..."


In [114]:
with open(output_file, "w", encoding="utf-8") as f:
    json.dump(full_json, f, ensure_ascii=False, indent=2)