# Responsible AI: XAI GenAI project

## 0. Background



Based on the previous lessons on explainability, post-hoc methods are used to explain the model, such as saliency map, SmoothGrad, LRP, LIME, and SHAP. Take LRP (Layer Wise Relevance Propagation) as an example; it highlights the most relevant pixels to obtain a prediction of the class "cat" by backpropagating the relevance. (image source: [Montavon et. al (2016)](https://giorgiomorales.github.io/Layer-wise-Relevance-Propagation-in-Pytorch/))

<!-- %%[markdown] -->
![LRP example](images/catLRP.jpg)

Another example is about text sentiment classification, here we show a case of visualizing the importance of words given the prediction of 'positive':

![text example](images/textGradL2.png)

where the words highlight with darker colours indicate to be more critical in predicting the sentence to be 'positive' in sentiment.
More examples could be found [here](http://34.160.227.66/?models=sst2-tiny&dataset=sst_dev&hidden_modules=Explanations_Attention&layout=default).

Both cases above require the class or the prediction of the model. But:

***How do you explain a model that does not predict but generates?***

In this project, we will work on explaining the generative model based on the dependency between words. We will first look at a simple example, and using Point-wise Mutual Information (PMI) to compute the saliency map of the sentence. After that we will contruct the expereiment step by step, followed by exercises and questions.


## 1. A simple example to start with
Given a sample sentence: 
> *Tokyo is the capital city of Japan.* 

We are going to explain this sentence by finding the dependency using a saliency map between words.
The dependency of two words in the sentence could be measured by [Point-wise mutual information (PMI)](https://en.wikipedia.org/wiki/Pointwise_mutual_information): 


Mask two words out, e.g. 
> \[MASK-1\] is the captial city of \[MASK-2\].


Ask the generative model to fill in the sentence 10 times, and we have:

| MASK-1      | MASK-2 |
| ----------- | ----------- |
|    tokyo   |     japan   |
|  paris  |     france    |
|  london  |     england    |
|  paris  |     france    |
|  beijing |  china |
|    tokyo   |     japan   |
|  paris  |     france    |
|  paris  |     france    |
|  london  |     england    |
|  beijing |  china |

PMI is calculated by: 

$PMI(x,y)=log_2⁡ \frac{p(\{x,y\}| s-\{x,y\})}{P(\{x\}|s-\{x,y\})P(\{y\}|s-\{x,y\})}$

where $x$, $y$ represents the words that we masked out, $s$ represents the setence, and $s-\{x,y\}$ represents the sentences tokens after removing the words $x$ and $y$.

In this example we have $PMI(Tokyo, capital) = log_2 \frac{0.2}{0.2 * 0.2} = 2.32$

Select an interesting word in the sentences; we can now compute the PMI between all other words and the chosen word using the generative model:
(Here, we use a longer sentence and run 20 responses per word.)
![](images/resPMI.png)


## 2. Preparation
### 2.1 Conda enviroment

```
conda env create -f environment.yml
conda activate xai_llm
```


### 2.2 Download the offline LLM

We use the offline LLM model from hugging face. It's approximately 5 GB.
Download it using the comman below, and save it under `./models/`.
```
huggingface-cli download TheBloke/openchat-3.5-0106-GGUF openchat-3.5-0106.Q4_K_M.gguf --local-dir . --local-dir-use-symlinks False
# credit to https://huggingface.co/TheBloke/openchat-3.5-0106-GGUF
```

## 3. Mask the sentence and get the responses from LLM
### 3.1 Get the input sentence

**Remember to change the anchor word index when changing the input sentence.**

In [None]:
# Removed for consistency, so that every run uses the same sentence
# def get_input():
    # ideally this reads inputs from a file, now it just takes an input
    #return input("Enter a sentence: ")
    
# Cell 23 - Reset the sentence
sentence = "doctors assess symptoms to diagnose diseases"

anchor_word_idx = 0 # the index of the interested word
prompts_per_word = 20 # number of generated responses  

#sentence = get_input()
print("Sentence: ", sentence)

Sentence:  doctors assess symptoms to diagnose diseases


### 3.2 Load the model

In [59]:
from models.ChatModel import ChatModel
model_name = "openchat"
model = ChatModel(model_name)
print(f"Model: {model_name}")

Model: openchat


### 3.3 Run the prompts and get all the responses


In [60]:
from tools.command_generator import generate_prompts, prefix_prompt
from tools.evaluate_response import get_replacements
from tqdm import tqdm

def run_prompts(model, sentence, anchor_idx, prompts_per_word=20):
    prompts = generate_prompts(sentence, anchor_idx)
    all_replacements = []
    for prompt in prompts:
        replacements = []
        for _ in tqdm(
            range(prompts_per_word),
            desc=f"Input: {prompt}",
        ):
            response = model.get_response(
                prefix_prompt(prompt),
            ).strip()
            if response:
                replacement = get_replacements(prompt, response)
                if replacement:
                    replacements.append(replacement)
        if len(replacements) > 0:
            all_replacements.append(replacements)
    return all_replacements

all_responses = run_prompts(model, sentence, anchor_word_idx, prompts_per_word)

Input: [MASK] [MASK] symptoms to diagnose diseases:   5%|▌         | 1/20 [00:14<04:35, 14.52s/it]

 Response is not valid. ['[mask]', '[mask]', 'symptoms', 'to', 'diagnose', 'diseases'] ['', 'doctors', 'use', 'various', 'tests']


Input: [MASK] [MASK] symptoms to diagnose diseases:  70%|███████   | 14/20 [00:26<00:06,  1.02s/it]

 Response is not valid. ['[mask]', '[mask]', 'symptoms', 'to', 'diagnose', 'diseases'] ['fever', 'and', 'headache', 'symptoms', 'help', 'diagnose', 'diseases']


Input: [MASK] [MASK] symptoms to diagnose diseases:  80%|████████  | 16/20 [00:28<00:03,  1.02it/s]

 Response is not valid. ['[mask]', '[mask]', 'symptoms', 'to', 'diagnose', 'diseases'] ['mri', 'scans', 'show', 'clear', '[mask]', '[mask]', 'for', 'accurate', 'disease', 'diagnosis']


Input: [MASK] [MASK] symptoms to diagnose diseases: 100%|██████████| 20/20 [00:30<00:00,  1.53s/it]
Input: [MASK] assess [MASK] to diagnose diseases:   5%|▌         | 1/20 [00:00<00:17,  1.08it/s]

 Response is not valid. ['[mask]', 'assess', '[mask]', 'to', 'diagnose', 'diseases'] ['doctors', 'evaluate', 'patients', 'to', 'diagnose', 'diseases']


Input: [MASK] assess [MASK] to diagnose diseases:  15%|█▌        | 3/20 [00:02<00:14,  1.20it/s]

 Response is not valid. ['[mask]', 'assess', '[mask]', 'to', 'diagnose', 'diseases'] ['the', 'doctor', 'assesses', 'patients', 'to', 'diagnose', 'diseases']


Input: [MASK] assess [MASK] to diagnose diseases:  25%|██▌       | 5/20 [00:04<00:13,  1.12it/s]

 Response is not valid. ['[mask]', 'assess', '[mask]', 'to', 'diagnose', 'diseases'] ['the', 'doctor', 'uses', 'a', 'stethoscope', 'to', 'assess', 'lung', 'sounds', 'and', 'diagnoses', 'diseases']


Input: [MASK] assess [MASK] to diagnose diseases:  35%|███▌      | 7/20 [00:05<00:11,  1.18it/s]

 Response is not valid. ['[mask]', 'assess', '[mask]', 'to', 'diagnose', 'diseases'] ['doctors', 'examine', '[patients]', 'to', 'diagnose', 'diseases']


Input: [MASK] assess [MASK] to diagnose diseases:  40%|████      | 8/20 [00:06<00:09,  1.23it/s]

 Response is not valid. ['[mask]', 'assess', '[mask]', 'to', 'diagnose', 'diseases'] ['the', 'doctor', 'assesses', 'the', 'patient', 'to', 'diagnose', 'diseases']


Input: [MASK] assess [MASK] to diagnose diseases:  50%|█████     | 10/20 [00:08<00:07,  1.28it/s]

 Response is not valid. ['[mask]', 'assess', '[mask]', 'to', 'diagnose', 'diseases'] ['the', 'doctor', 'assesses', 'patient', 'symptoms', 'to', 'diagnose', 'diseases']


Input: [MASK] assess [MASK] to diagnose diseases:  60%|██████    | 12/20 [00:09<00:05,  1.39it/s]

 Response is not valid. ['[mask]', 'assess', '[mask]', 'to', 'diagnose', 'diseases'] ['the', 'doctor', 'assesses', 'patients', 'to', 'diagnose', 'diseases']


Input: [MASK] assess [MASK] to diagnose diseases:  70%|███████   | 14/20 [00:11<00:04,  1.21it/s]


IndexError: pop from empty list

In [61]:
# visualize responses
all_responses[:1]

[[['', ''],
  ['flu', '[mask]'],
  ['', ''],
  ['', ''],
  ['', ''],
  ['flulike', '[mask] [mask]'],
  ['', ''],
  ['common', '[mask]'],
  ['patients', 'frequently exhibit classic [mask]'],
  ['flulike', '[mask]'],
  ['mumps', 'mumps'],
  ['flu', '[mask]'],
  ['fever', 'and chills'],
  ['flulike', '[mask]'],
  ['the', 'flu'],
  ['', ''],
  ['flulike', '[mask]'],
  ['', ''],
  ['flulike', ''],
  ['', '']]]

In [62]:
# Load responses
import json
input_file = "responses.json"
with open(input_file, "r") as f:
    all_responses = json.load(f)

### 3.4 EXERCISE: compute the PMI for each word

$PMI(x,y)=log_2⁡ \frac{p(\{x,y\}| s-\{x,y\})}{P(\{x\}|s-\{x,y\})P(\{y\}|s-\{x,y\})}$

* Compute the $P(x)$, $P(y)$ and $P(x,y)$ first and print it out.
* Compute the PMI for each word.
* Visualize the result by coloring. Tips: you might need to normalize the result first. 


In [None]:
import math
import numpy as np
from termcolor import colored

def compute_pmi(sentence, all_responses, anchor_idx):
    """Compute PMI between anchor word and each other word."""
    words = sentence.lower().split()
    anchor_word = words[anchor_idx]
    pmi_scores = {}
    
    for other_idx in range(len(words)):
        if other_idx == anchor_idx:
            continue
        
        # Get pattern index (skips anchor position)
        pattern_idx = other_idx if other_idx < anchor_idx else other_idx - 1
        if pattern_idx >= len(all_responses):
            continue
            
        responses = all_responses[pattern_idx]
        if not responses:
            continue
        
        # Extract anchor and other word replacements
        anchor_replacements = [r[0].lower() for r in responses if len(r) == 2]
        other_replacements = [r[1].lower() for r in responses if len(r) == 2]
        total = len(anchor_replacements)
        
        # Calculate probabilities
        count_x = sum(w == anchor_word for w in anchor_replacements)
        count_y = sum(w == words[other_idx] for w in other_replacements)
        count_xy = sum(anchor_replacements[i] == anchor_word and 
                      other_replacements[i] == words[other_idx] 
                      for i in range(total))
        
        P_x = count_x / total
        P_y = count_y / total
        P_xy = count_xy / total
        
        # Calculate PMI
        if P_x > 0 and P_y > 0 and P_xy > 0:
            pmi = math.log2(P_xy / (P_x * P_y))
        else:
            pmi = float('-inf')
        
        pmi_scores[other_idx] = {'word': words[other_idx], 'pmi': pmi, 
                                  'P_x': P_x, 'P_y': P_y, 'P_xy': P_xy}
    
    return pmi_scores

def visualize_pmi(sentence, pmi_scores, anchor_idx):
    """Visualize PMI with colored words."""
    words = sentence.split()
    
    # Normalize PMI values
    valid_pmis = [s['pmi'] for s in pmi_scores.values() if s['pmi'] != float('-inf')]
    if not valid_pmis:
        print("No valid PMI scores")
        return
    
    min_pmi, max_pmi = min(valid_pmis), max(valid_pmis)
    pmi_range = max_pmi - min_pmi if max_pmi != min_pmi else 1
    
    # Color each word
    colored_words = []
    for i, word in enumerate(words):
        if i == anchor_idx:
            colored_words.append(colored(word, 'cyan', attrs=['bold']))
        elif i in pmi_scores:
            pmi = pmi_scores[i]['pmi']
            if pmi != float('-inf'):
                norm = (pmi - min_pmi) / pmi_range
                color = 'green' if norm > 0.66 else 'yellow' if norm > 0.33 else 'red'
                colored_words.append(colored(f"{word}({pmi:.2f})", color))
            else:
                colored_words.append(word)
        else:
            colored_words.append(word)
    
    print("\n" + " ".join(colored_words) + "\n")

In [64]:
# Compute PMI scores
pmi_scores = compute_pmi(sentence, all_responses, anchor_word_idx)

# Print results
words = sentence.lower().split()
print(f"Anchor word: '{words[anchor_word_idx]}'\n")
for idx in sorted(pmi_scores.keys()):
    data = pmi_scores[idx]
    print(f"{data['word']:<15} PMI={data['pmi']:7.3f}  "
          f"P(x)={data['P_x']:.3f} P(y)={data['P_y']:.3f} P(xy)={data['P_xy']:.3f}")

# Visualize with colors
visualize_pmi(sentence, pmi_scores, anchor_word_idx)

Anchor word: 'doctors'

assess          PMI=   -inf  P(x)=0.000 P(y)=0.000 P(xy)=0.000
symptoms        PMI=  1.737  P(x)=0.300 P(y)=0.200 P(xy)=0.200
to              PMI=  1.152  P(x)=0.100 P(y)=0.450 P(xy)=0.100
diagnose        PMI=  0.322  P(x)=0.100 P(y)=0.800 P(xy)=0.100
diseases        PMI=  0.737  P(x)=0.200 P(y)=0.300 P(xy)=0.100

[1m[36mdoctors[0m assess [32msymptoms(1.74)[0m [33mto(1.15)[0m [31mdiagnose(0.32)[0m [31mdiseases(0.74)[0m



### PMI Results Interpretation (Higher PMI = stronger association)

#### Results for "doctors assess symptoms to diagnose diseases"
- The visualization shows that "symptoms" has the strongest semantic bond with "doctors" in this sentence

**High PMI: Strong Dependency**
- **symptoms (1.74)**: had the highest association with "doctors", this means that when both are amsked the model frequently generates them to fill the masked words.

**Medium PMI: Moderate Dependency**  
- **to (1.15)**: Moderate association.

**Low PMI: Weak Dependency**
- **diseases (0.74)**: Predictable from context but not uniquely tied to "doctors"
- **diagnose (0.32)**: Despite being less associated with the word doctor (pmi=0.322), it is very predictable (P(y)=0.80). This means that the word itself ("diagnose") is very frequent but paired with alternatives to "doctors" (physicians, clinicians)

**Negative PMI: No Dependency**
- **assess (-inf)**: this means the model never generated "doctors" when both were masked.


## 4. EXERCISE: Try more examples; maybe come up with your own. Report the results.

* Try to come up with more examples and, change the anchor word/number of responses, and observe the results. What does the explanation mean? Do you think it's a nice explanation? Why and why not? 
* What's the limitation of the current method? When does the method fail to explain? 

## 5. Bonus Exercises
### 5.1 Language pre-processing. 
In this exercise, we only lower the letters and split sentences into words; there's much more to do to pre-process the language. For example, contractions (*I'll*, *She's*, *world's*), suffix and prefix, compound words (*hard-working*). It's called word tokenization in NLP, and there are some Python packages that can do such work for us, e.g. [*TextBlob*](https://textblob.readthedocs.io/en/dev/). 


### 5.2 Better word matching
In the above example of
> Tokyo is the capital of Japan and a popular metropolis in the world.

GenAI never gives the specific word 'metropolis' when masking it out; instead, sometimes it provides words like 'city', which is not the same word but has a similar meaning. Instead of measuring the exact matching of certain words (i.e. 0 or 1), we can also measure the similarity of two words, e.g. the cosine similarity in word embedding, which ranges from 0 to 1. 