In [None]:
import pandas as pd
import numpy as np
import torch
import os

from adapters import (
    AutoAdapterModel
)
from transformers import AutoTokenizer
from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients

In [None]:
DATA_PATH = "../../data/processed"
MODEL_PATH = "../../models"

tunings = ["finetuned", "fairtuned"]
dimension = "merged"
tasks = ["propaganda", "webis", "pheme", "basil", "shadesoftruth", "fingerprints", "clickbait"]
model_name = "bert-base-cased"

In [None]:
task_models_and_data = {}
for task in tasks:
    models_and_data = {}
    for tuning in tunings:
        if tuning == "finetuned":
            model_folder_path = f"../..{os.sep}models{os.sep}{model_name}{os.sep}vanilla{os.sep}{task}"
        else:
            model_folder_path = f"../..{os.sep}models{os.sep}{model_name}{os.sep}{dimension}{os.sep}{task}"

        CONFIG = {
            "task_name": task,
            "model_name": model_name,
            "model_path": f"{model_folder_path}{os.sep}{os.listdir(model_folder_path)[0]}{os.sep}{task}",
            "max_length": 128,
        }

        dataset_path = f"{DATA_PATH}{os.sep}{dimension}"

        test_df = pd.read_csv(f"{dataset_path}{os.sep}{CONFIG['task_name']}_test.csv")
        test_df = test_df.dropna()

        tokenizer = AutoTokenizer.from_pretrained(CONFIG['model_name'])

        model = AutoAdapterModel.from_pretrained(
            CONFIG['model_name'],
            output_attentions=True
        )
        model.load_adapter(CONFIG['model_path'])
        model.set_active_adapters(task)
        model.set_active_embeddings("default")
        models_and_data[tuning] = {
            "model": model,
            "tokenizer": tokenizer,
            "dataset": test_df
        }
    task_models_and_data[task] = models_and_data

In [5]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [7]:
def construct_input_ref_pair(text, tokenizer, ref_token_id, sep_token_id, cls_token_id):
    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]
    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [None]:
def save_act(module, inp, out):
  #global saved_act
  #saved_act = out
  return saved_act

In [9]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [None]:
entity_path = "../../heterogeneity/lists_for_perturbations/"

In [11]:
dfs = []
for file in os.listdir(entity_path):
    df_ent = pd.read_csv(f"{entity_path}{file}")
    df_ent.columns = ["original", "swap"]
    dfs.append(df_ent)

In [12]:
target_entities = pd.concat(dfs, axis=0, ignore_index=True)

In [None]:
tot_final_target_avg_fine = []
tot_final_target_avg_fair = []
for task, models_and_data in task_models_and_data.items():
    for tuning, data in models_and_data.items():
        
        model = data["model"]
        tokenizer = data["tokenizer"]
        df = data["dataset"]

        model.to(device)
        model.eval()
        model.zero_grad()

        def predict(inputs):
            return model(inputs)[0]

        def custom_forward(inputs):
            preds = predict(inputs)
            return torch.softmax(preds, dim = 1)[:, 0] # for negative attribution, torch.softmax(preds, dim = 1)[:, 1] <- for positive attribution
        
        label2id = model.config.prediction_heads[task]["label2id"]
        id2label = {v: k for k, v in label2id.items()}

        lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

        sents_len = []
        all_target_att = []
        entity_to_sum = 0
        for sample in df.to_dict("records"):
            if sample["entity"]:
                entity_to_sum +=1
                text = sample["perturbed_text"]
                label = sample["labels"]
                if sample["entity"] not in target_entities['original'].tolist():
                    continue
                if sample["entity"] == "US":
                    sample["entity"] = "America"
                entity = target_entities.loc[target_entities['original'] == sample["entity"], 'swap'].iloc[0]

                input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, tokenizer, ref_token_id, sep_token_id, cls_token_id)
                token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
                position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
                attention_mask = construct_attention_mask(input_ids)

                indices = input_ids[0].detach().tolist()
                all_tokens = tokenizer.convert_ids_to_tokens(indices)
                sents_len.append(len(all_tokens))

                pred = predict(input_ids)

                attributions, delta = lig.attribute(
                    inputs=input_ids,
                    baselines=ref_input_ids,
                    n_steps=10,
                    internal_batch_size=3,
                    return_convergence_delta=True
                )
                
                score = predict(input_ids)

                attributions_sum = summarize_attributions(attributions)

                tot_target_att = []
                for idx in range(len(input_ids[0])):
                    if tokenizer.decode(input_ids[0][idx]) in entity.split():
                        tot_target_att.append(attributions_sum[idx])
                target_att = float(abs(sum(tot_target_att)))
                all_target_att.append(target_att)

        final_target_avg = np.mean(all_target_att)
        std = np.std(all_target_att)
        print(f"Task: {task} | Tuning: {tuning} | Final target attribution avg: {final_target_avg} | Std: {std} | Total data samples: {entity_to_sum}")
        if tuning == "finetuned":
            tot_final_target_avg_fine.append(final_target_avg)
        else:
            tot_final_target_avg_fair.append(final_target_avg)

print(f"Final fine target attr: {np.mean(tot_final_target_avg_fine)} | Final std: {np.std(tot_final_target_avg_fine)}")
print(f"Final fair target attr: {np.mean(tot_final_target_avg_fair)} | Final std: {np.std(tot_final_target_avg_fair)}")

Task: propaganda | Tuning: finetuned | Final target attribution avg: 0.13045565961742342 | Std: 0.1498988096778749 | Total data samples: 87
Task: propaganda | Tuning: fairtuned | Final target attribution avg: 0.1740882033130756 | Std: 0.1890709638707351 | Total data samples: 87
Task: webis | Tuning: finetuned | Final target attribution avg: 0.10821403801119417 | Std: 0.13558178200534354 | Total data samples: 244
Task: webis | Tuning: fairtuned | Final target attribution avg: 0.10533005783055233 | Std: 0.14331763658236657 | Total data samples: 244
Task: pheme | Tuning: finetuned | Final target attribution avg: 0.13319829309470294 | Std: 0.22310857851974136 | Total data samples: 57
Task: pheme | Tuning: fairtuned | Final target attribution avg: 0.12209910530023847 | Std: 0.23202480415235183 | Total data samples: 57
Task: basil | Tuning: finetuned | Final target attribution avg: 0.109724760017304 | Std: 0.12706068889636934 | Total data samples: 430
Task: basil | Tuning: fairtuned | Final 