# MSc Thesis - SHAP Methods
- This notebook containes the top 3 SHAP methods I have tested. So far method 1 seems the most promising, followed by method 2 then 3.
- Method 1 is the newest method I have experimented with. I didn't mention it in Monday's meeting since I didn't finish it yet.

In [1]:
# Top 3 topics & countries by largest absolute difference between WVS data and LLM moral score
top3_topics = ['for a man to beat his wife',
               'someone accepting a bribe in the course of their duties',
               'terrorism as a political, ideological or religious mean']

top3_countries_in = ['Netherlands', 'New Zealand', 'Germany'] # "in" prompting style
top3_countries_people = ['Libya', 'Maldives', 'Armenia'] # "people" prompting style

In [2]:
!pip install --quiet shap

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m540.5/540.5 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import numpy as np

In [4]:
import shap

In [26]:
name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name)

## Method 1: Language Modeling

- Aim: get text generation explanations for moral tokens
- I followed this [example](https://shap.readthedocs.io/en/latest/example_notebooks/text_examples/language_modelling/Language%20Modeling%20Explanation%20Demo.html) and modified the Top-K model form [here](https://github.com/shap/shap/blob/master/shap/models/_topk_lm.py)

In [6]:
from shap.models import Model
import scipy.special

In [27]:
class CustomModel(Model):
    """Generates scores (log odds) for the specified tokens for Causal/Masked LM."""

    def __init__(self, model, tokenizer, token_ids, batch_size=128, device=None):
        super().__init__(model)
        self.tokenizer = tokenizer
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.token_ids = [np.array(ids) for ids in token_ids]
        self.batch_size = batch_size
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else torch.device(device)
        self.inner_model = model.to(self.device)
        self.X = None
        self.output_names = None

    def __call__(self, texts):
        masked_X = np.array([text for text in texts])
        X = masked_X.copy()
        output_batch = None
        self.update_cache_X(X[:1])
        start_batch_idx, end_batch_idx = 0, len(masked_X)
        while start_batch_idx < end_batch_idx:
            logits = self.get_lm_logits(masked_X[start_batch_idx:start_batch_idx+self.batch_size])
            logodds = self.get_logodds(logits)
            if output_batch is None:
                output_batch = logodds
            else:
                output_batch = np.concatenate((output_batch, logodds))
            start_batch_idx += self.batch_size
        return output_batch

    def update_cache_X(self, X):
        if (self.X is None) or (not np.array_equal(self.X, X)):
            self.X = X
            self.output_names = self.get_output_names(X)

    def get_output_names(self, X):
        output_names = [self.tokenizer.decode(ids) for ids in self.token_ids]
        return output_names

    def get_logodds(self, logits):
        def calc_logodds(arr):
            probs = np.exp(arr) / np.exp(arr).sum(-1)
            logodds = scipy.special.logit(probs)
            return logodds

        logodds = np.apply_along_axis(calc_logodds, -1, logits)
        logodds_for_token_ids = [np.take(logodds, ids, axis=-1).sum(axis=-1) for ids in self.token_ids]
        return np.stack(logodds_for_token_ids, axis=-1)

    def get_inputs(self, X, padding_side='right'):
        self.tokenizer.padding_side = padding_side
        inputs = self.tokenizer(X.tolist(), return_tensors='pt', padding=True)
        self.tokenizer.padding_side = 'right'
        return inputs

    def get_lm_logits(self, X):
        inputs = self.get_inputs(X, padding_side="left")
        inputs["position_ids"] = (inputs["attention_mask"].long().cumsum(-1) - 1)
        inputs["position_ids"].masked_fill_(inputs["attention_mask"] == 0, 0)
        inputs = inputs.to(self.device)
        with torch.no_grad():
            outputs = self.inner_model(**inputs, return_dict=True)
        logits = outputs.logits.detach().cpu().numpy().astype('float64')[:, -1, :]
        return logits

In [8]:
moral_id = tokenizer.encode("moral", add_special_tokens=False)
immoral_id = tokenizer.encode("immoral", add_special_tokens=False)

In [9]:
wrapped_model = CustomModel(model, tokenizer, token_ids=[moral_id, immoral_id])
masker_1 = shap.maskers.Text(tokenizer, mask_token="...", collapse_mask_token=True)

In [10]:
test1 = ['People in the Netherlands believe terrorism as a political, ideological or religious mean is',
         'People in New Zealand believe terrorism as a political, ideological or religious mean is',
         'People in Germany believe terrorism as a political, ideological or religious mean is']

In [11]:
explainer = shap.Explainer(wrapped_model, masker_1)

In [12]:
shap_values_1 = explainer(test1)

PartitionExplainer explainer: 4it [00:12,  4.17s/it]


In [13]:
shap.plots.text(shap_values_1)

## Method 2: Bias Prompting
- I followed this [example](https://shap.readthedocs.io/en/latest/example_notebooks/text_examples/text_generation/Open%20Ended%20GPT2%20Text%20Generation%20Explanations.html). In particular look at last example.

In [14]:
def compute_shap_values(x, y, model, tokenizer):
    """
    Compute SHAP values for given X and Y datasets using a specified model.

    Parameters:
        x_data (list): List of input texts (X dataset).
        y_data (list): List of target texts (Y dataset).

    Returns:
        shap_values: Computed SHAP values.
    """
    teacher_forcing_model = shap.models.TeacherForcing(model, tokenizer)
    masker = shap.maskers.Text(tokenizer, mask_token="...", collapse_mask_token=True)
    explainer_2 = shap.Explainer(teacher_forcing_model, masker)

    y_data = generate_y_prompts(y, len(x))
    shap_values = explainer_2(x, y_data)

    return shap_values

def generate_y_prompts(token, count):
    y_prompts = [token for _ in range(count)]
    return y_prompts

In [15]:
test2 = ['People in the Netherlands believe terrorism as a political, ideological or religious mean is',
         'People in New Zealand believe terrorism as a political, ideological or religious mean is',
         'People in Germany believe terrorism as a political, ideological or religious mean is']

In [16]:
shap_values_2_moral = compute_shap_values(test2, "moral", model, tokenizer)
shap_values_2_immoral = compute_shap_values(test2, "immoral", model, tokenizer)

  0%|          | 0/210 [00:00<?, ?it/s]

PartitionExplainer explainer: 4it [00:10, 10.03s/it]               


In [17]:
shap.plots.text(shap_values_2_moral)

In [18]:
shap.plots.text(shap_values_2_immoral)

## Method 3: Perplexity

In [33]:
def perplexity_metric(lines, model=model, tokenizer=tokenizer, end_with_period=True):
    eos_token = tokenizer.eos_token or tokenizer.sep_token
    if eos_token is None:
        raise ValueError("Neither eos_token nor sep_token is set in the tokenizer.")

    perplexities = []

    for line in lines:
        if end_with_period and not line.endswith('.'):
            line += '.'

        line = eos_token + line

        tokenizer.pad_token = tokenizer.eos_token

        tok_moral = tokenizer(line, return_tensors='pt', padding='longest', add_special_tokens=False)
        input_ids = tok_moral['input_ids']
        attention_mask = tok_moral['attention_mask']

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss
            ppl = torch.exp(loss).item()

        perplexities.append(ppl)

    return perplexities

In [20]:
test3_moral = ['People in the Netherlands believe terrorism as a political, ideological or religious mean is moral',
         'People in New Zealand believe terrorism as a political, ideological or religious mean is moral',
         'People in Germany believe terrorism as a political, ideological or religious mean is moral']

test3_immoral = ['People in the Netherlands believe terrorism as a political, ideological or religious mean is immoral',
         'People in New Zealand believe terrorism as a political, ideological or religious mean is immoral',
         'People in Germany believe terrorism as a political, ideological or religious mean is immoral']

In [36]:
test3_moral_perplexity = perplexity_metric(test3_moral)
print(test3_moral_perplexity)

[97.76705932617188, 99.75059509277344, 112.29773712158203]


In [35]:
test3_immoral_perplexity = perplexity_metric(test3_immoral)
print(test3_immoral_perplexity)

[86.59269714355469, 88.64825439453125, 99.73043060302734]


In [37]:
explainer_3 = shap.Explainer(perplexity_metric, tokenizer)

In [38]:
shap_values_3_moral = explainer_3(test3_moral)
shap_values_3_immoral = explainer_3(test3_immoral)

  0%|          | 0/240 [00:00<?, ?it/s]

PartitionExplainer explainer:  33%|███▎      | 1/3 [00:00<?, ?it/s]

  0%|          | 0/240 [00:00<?, ?it/s]

PartitionExplainer explainer: 100%|██████████| 3/3 [01:05<00:00, 16.16s/it]

  0%|          | 0/210 [00:00<?, ?it/s]

PartitionExplainer explainer: 4it [01:34, 31.57s/it]


  0%|          | 0/240 [00:00<?, ?it/s]

PartitionExplainer explainer:  33%|███▎      | 1/3 [00:00<?, ?it/s]

  0%|          | 0/240 [00:00<?, ?it/s]

PartitionExplainer explainer: 100%|██████████| 3/3 [01:14<00:00, 15.16s/it]

  0%|          | 0/210 [00:00<?, ?it/s]

PartitionExplainer explainer: 4it [01:36, 32.30s/it]


In [39]:
shap.plots.text(shap_values_3_moral)

In [40]:
shap.plots.text(shap_values_3_immoral)