# Responsible AI: XAI GenAI project

## 0. Background



Based on the previous lessons on explainability, post-hoc methods are used to explain the model, such as saliency map, SmoothGrad, LRP, LIME, and SHAP. Take LRP (Layer Wise Relevance Propagation) as an example; it highlights the most relevant pixels to obtain a prediction of the class "cat" by backpropagating the relevance. (image source: [Montavon et. al (2016)](https://giorgiomorales.github.io/Layer-wise-Relevance-Propagation-in-Pytorch/))

<!-- %%[markdown] -->
![LRP example](images/catLRP.jpg)

Another example is about text sentiment classification, here we show a case of visualizing the importance of words given the prediction of 'positive':

![text example](images/textGradL2.png)

where the words highlight with darker colours indicate to be more critical in predicting the sentence to be 'positive' in sentiment.
More examples could be found [here](http://34.160.227.66/?models=sst2-tiny&dataset=sst_dev&hidden_modules=Explanations_Attention&layout=default).

Both cases above require the class or the prediction of the model. But:

***How do you explain a model that does not predict but generates?***

In this project, we will work on explaining the generative model based on the dependency between words. We will first look at a simple example, and using Point-wise Mutual Information (PMI) to compute the saliency map of the sentence. After that we will contruct the expereiment step by step, followed by exercises and questions.


## 1. A simple example to start with
Given a sample sentence: 
> *Tokyo is the capital city of Japan.* 

We are going to explain this sentence by finding the dependency using a saliency map between words.
The dependency of two words in the sentence could be measured by [Point-wise mutual information (PMI)](https://en.wikipedia.org/wiki/Pointwise_mutual_information): 


Mask two words out, e.g. 
> \[MASK-1\] is the captial city of \[MASK-2\].


Ask the generative model to fill in the sentence 10 times, and we have:

| MASK-1      | MASK-2 |
| ----------- | ----------- |
|    tokyo   |     japan   |
|  paris  |     france    |
|  london  |     england    |
|  paris  |     france    |
|  beijing |  china |
|    tokyo   |     japan   |
|  paris  |     france    |
|  paris  |     france    |
|  london  |     england    |
|  beijing |  china |

PMI is calculated by: 

$PMI(x,y)=log_2⁡ \frac{p(\{x,y\}| s-\{x,y\})}{P(\{x\}|s-\{x,y\})P(\{y\}|s-\{x,y\})}$

where $x$, $y$ represents the words that we masked out, $s$ represents the setence, and $s-\{x,y\}$ represents the sentences tokens after removing the words $x$ and $y$.

In this example we have $PMI(Tokyo, capital) = log_2 \frac{0.2}{0.2 * 0.2} = 2.32$

Select an interesting word in the sentences; we can now compute the PMI between all other words and the chosen word using the generative model:
(Here, we use a longer sentence and run 20 responses per word.)
![](images/resPMI.png)


## 2. Preparation
### 2.1 Conda enviroment

```
conda env create -f environment.yml
conda activate xai_llm
```


### 2.2 Download the offline LLM

We use the offline LLM model from hugging face. It's approximately 5 GB.
Download it using the comman below, and save it under `./models/`.
```
huggingface-cli download TheBloke/openchat-3.5-0106-GGUF openchat-3.5-0106.Q4_K_M.gguf --local-dir . --local-dir-use-symlinks False
# credit to https://huggingface.co/TheBloke/openchat-3.5-0106-GGUF
```

## 3. Mask the sentence and get the responses from LLM
### 3.1 Get the input sentence

**Remember to change the anchor word index when changing the input sentence.**

In [None]:
# Removed for consistency, so that every run uses the same sentence
# def get_input():
    # ideally this reads inputs from a file, now it just takes an input
    #return input("Enter a sentence: ")
    
# Cell 23 - Reset the sentence
sentence = "doctors assess symptoms to diagnose diseases"

anchor_word_idx = 0 # the index of the interested word
prompts_per_word = 20 # number of generated responses  

#sentence = get_input()
print("Sentence: ", sentence)

### 3.2 Load the model

In [None]:
from llama_cpp import Llama

from models.ChatModel import ChatModel
model_name = "openchat"
model = ChatModel(model_name)
print(f"Model: {model_name}")

### 3.3 Run the prompts and get all the responses


In [None]:
from tools.command_generator import generate_prompts, prefix_prompt
from tools.evaluate_response import get_replacements
from tqdm import tqdm

def run_prompts(model, sentence, anchor_idx, prompts_per_word=20):
    prompts = generate_prompts(sentence, anchor_idx)
    all_replacements = []
    for prompt in prompts:
        replacements = []
        for _ in tqdm(
            range(prompts_per_word),
            desc=f"Input: {prompt}",
        ):
            response = model.get_response(
                prefix_prompt(prompt),
            ).strip()
            if response:
                replacement = get_replacements(prompt, response)
                if replacement:
                    replacements.append(replacement)
        if len(replacements) > 0:
            all_replacements.append(replacements)
    return all_replacements

all_responses = run_prompts(model, sentence, anchor_word_idx, prompts_per_word)

In [None]:
# visualize responses
all_responses[:1]

In [None]:
# Load responses
import json
input_file = "responses.json"
with open(input_file, "r") as f:
    all_responses = json.load(f)

### 3.4 EXERCISE: compute the PMI for each word

$PMI(x,y)=log_2⁡ \frac{p(\{x,y\}| s-\{x,y\})}{P(\{x\}|s-\{x,y\})P(\{y\}|s-\{x,y\})}$

* Compute the $P(x)$, $P(y)$ and $P(x,y)$ first and print it out.
* Compute the PMI for each word.
* Visualize the result by coloring. Tips: you might need to normalize the result first. 


In [None]:
import math
import numpy as np
from termcolor import colored

def compute_pmi(sentence, all_responses, anchor_idx):
    """Compute PMI between anchor word and each other word."""
    words = sentence.lower().split()
    anchor_word = words[anchor_idx]
    pmi_scores = {}
    
    for other_idx in range(len(words)):
        if other_idx == anchor_idx:
            continue
        
        # Get pattern index (skips anchor position)
        pattern_idx = other_idx if other_idx < anchor_idx else other_idx - 1
        if pattern_idx >= len(all_responses):
            continue
            
        responses = all_responses[pattern_idx]
        if not responses:
            continue
        
        # Extract anchor and other word replacements
        anchor_replacements = [r[0].lower() for r in responses if len(r) == 2]
        other_replacements = [r[1].lower() for r in responses if len(r) == 2]
        total = len(anchor_replacements)
        
        # Calculate probabilities
        count_x = sum(w == anchor_word for w in anchor_replacements)
        count_y = sum(w == words[other_idx] for w in other_replacements)
        count_xy = sum(anchor_replacements[i] == anchor_word and 
                      other_replacements[i] == words[other_idx] 
                      for i in range(total))
        
        P_x = count_x / total
        P_y = count_y / total
        P_xy = count_xy / total
        
        # Calculate PMI
        if P_x > 0 and P_y > 0 and P_xy > 0:
            pmi = math.log2(P_xy / (P_x * P_y))
        else:
            pmi = float('-inf')
        
        pmi_scores[other_idx] = {'word': words[other_idx], 'pmi': pmi, 
                                  'P_x': P_x, 'P_y': P_y, 'P_xy': P_xy}
    
    return pmi_scores

def visualize_pmi(sentence, pmi_scores, anchor_idx):
    """Visualize PMI with colored words."""
    words = sentence.split()
    
    # Normalize PMI values
    valid_pmis = [s['pmi'] for s in pmi_scores.values() if s['pmi'] != float('-inf')]
    if not valid_pmis:
        print("No valid PMI scores")
        return
    
    min_pmi, max_pmi = min(valid_pmis), max(valid_pmis)
    pmi_range = max_pmi - min_pmi if max_pmi != min_pmi else 1
    
    # Color each word
    colored_words = []
    for i, word in enumerate(words):
        if i == anchor_idx:
            colored_words.append(colored(word, 'cyan', attrs=['bold']))
        elif i in pmi_scores:
            pmi = pmi_scores[i]['pmi']
            if pmi != float('-inf'):
                norm = (pmi - min_pmi) / pmi_range
                color = 'green' if norm > 0.66 else 'yellow' if norm > 0.33 else 'red'
                colored_words.append(colored(f"{word}({pmi:.2f})", color))
            else:
                colored_words.append(word)
        else:
            colored_words.append(word)
    
    print("\n" + " ".join(colored_words) + "\n")

In [None]:
# Compute PMI scores
pmi_scores = compute_pmi(sentence, all_responses, anchor_word_idx)

# Print results
words = sentence.lower().split()
print(f"Anchor word: '{words[anchor_word_idx]}'\n")
for idx in sorted(pmi_scores.keys()):
    data = pmi_scores[idx]
    print(f"{data['word']:<15} PMI={data['pmi']:7.3f}  "
          f"P(x)={data['P_x']:.3f} P(y)={data['P_y']:.3f} P(xy)={data['P_xy']:.3f}")

# Visualize with colors
visualize_pmi(sentence, pmi_scores, anchor_word_idx)

### PMI Results Interpretation (Higher PMI = stronger association)

#### Results for "doctors assess symptoms to diagnose diseases"
- The visualization shows that "symptoms" has the strongest semantic bond with "doctors" in this sentence

**High PMI: Strong Dependency**
- **symptoms (1.74)**: had the highest association with "doctors", this means that when both are amsked the model frequently generates them to fill the masked words.

**Medium PMI: Moderate Dependency**  
- **to (1.15)**: Moderate association.

**Low PMI: Weak Dependency**
- **diseases (0.74)**: Predictable from context but not uniquely tied to "doctors"
- **diagnose (0.32)**: Despite being less associated with the word doctor (pmi=0.322), it is very predictable (P(y)=0.80). This means that the word itself ("diagnose") is very frequent but paired with alternatives to "doctors" (physicians, clinicians)

**Negative PMI: No Dependency**
- **assess (-inf)**: this means the model never generated "doctors" when both were masked.


## 4. EXERCISE: Try more examples; maybe come up with your own. Report the results.

* Try to come up with more examples and, change the anchor word/number of responses, and observe the results. What does the explanation mean? Do you think it's a nice explanation? Why and why not? 
* What's the limitation of the current method? When does the method fail to explain? 

In [None]:
# EXERCISE 4

from tqdm import tqdm
import numpy as np
from IPython.display import display
from tools.command_generator import generate_prompts, prefix_prompt
from tools.evaluate_response import get_replacements

# PROMPT SAMPLING
def run_prompts(model, sentence, anchor_idx, prompts_per_word=20):
    prompts = generate_prompts(sentence, anchor_idx)
    all_replacements = []

    for prompt in prompts:
        replacements = []
        for _ in tqdm(range(prompts_per_word),
                       desc=f"Processing prompt: {prompt}",
                       leave=False):

            response = model.get_response(prefix_prompt(prompt)).strip()
            if not response:
                continue

            replacement = get_replacements(prompt, response)
            if replacement:
                replacements.append(replacement)

        if replacements:
            all_replacements.append(replacements)

    return all_replacements


# VISUALIZE RAW GENERATED SENTENCES
def visualize_generated_sentences(all_responses):
    print("\nGenerated Sentences:")
    flat = []

    for group in all_responses:
        for pair in group:
            text = " ".join(pair)
            flat.append(text)
            print(" •", text)

    print(f"\nTotal generated: {len(flat)}\n")
    return flat


# RUN PMI FOR ALL ANCHOR WORDS IN ONE SENTENCE
def run_sentence_experiment(sentence, prompts_per_word=20):
    words = sentence.split()
    anchor_indices = list(range(len(words)))

    print("\n" + "#"*110)
    print(f"ANALYZING SENTENCE:\n   '{sentence}'")
    print("#"*110)

    sentence_results = {}

    for anchor_idx in anchor_indices:
        anchor_word = words[anchor_idx]

        print("\n" + "="*90)
        print(f"Anchor index: {anchor_idx}   |   Anchor word: '{anchor_word}'")
        print("="*90)

        # Generate model outputs
        all_responses = run_prompts(model, sentence, anchor_idx, prompts_per_word)

        # Show generated sentences
        visualize_generated_sentences(all_responses)

        # Compute PMI
        pmi_scores = compute_pmi(sentence, all_responses, anchor_idx)

        # Print PMI table (no color here)
        print("\nPMI Scores:")
        if len(pmi_scores) == 0:
            print("No valid PMI scores (model did not regenerate expected words).")
        else:
            for idx in sorted(pmi_scores.keys()):
                d = pmi_scores[idx]
                print(f"{d['word']:<15} PMI={d['pmi']:7.3f}   "
                      f"P(x)={d['P_x']:.3f}  P(y)={d['P_y']:.3f}  P(xy)={d['P_xy']:.3f}")

        # Save PMI so we can visualize later
        sentence_results[anchor_word] = {
            "anchor_idx": anchor_idx,
            "pmi_scores": pmi_scores
        }

    return sentence_results


# RUN EXPERIMENT ACROSS MULTIPLE SENTENCES
experiment_sentences = [
    "doctors assess symptoms to diagnose diseases",
    "artificial intelligence transforms modern industries",
    "children love sweet ice cream on warm summer days",
    "plants require sunlight and water to grow",
    "the government announced new policies to support healthcare",
]

all_sentence_results = {}

for sent in experiment_sentences:
    results = run_sentence_experiment(sent, prompts_per_word=20)
    all_sentence_results[sent] = results

print("\n=== EXERCISE 4 COMPLETE — READY FOR VISUALIZATION ===")


In [None]:
# VISUALIZATION FOR  SENTENCES

for sentence, anchor_dict in all_sentence_results.items():

    print("\n" + "#"*120)
    print(f"VISUALIZATIONS FOR SENTENCE:\n  '{sentence}'")
    print("#"*120 + "\n")

    for anchor_word, info in anchor_dict.items():
        anchor_idx = info["anchor_idx"]
        pmi_scores = info["pmi_scores"]

        print("\n" + "="*100)
        print(f"Anchor word: '{anchor_word}' (index {anchor_idx})")
        print("="*100 + "\n")

        visualize_pmi(sentence, pmi_scores, anchor_idx)
        print("\n")


## Reflection on PMI Results and Explanation Quality

When we experimented with more sentences, different anchor words, and different numbers of model responses, I observed that the PMI values changed noticeably depending on how often the model regenerated particular word pairs. Words with clear semantic connections, such as doctor and diseases or ice and cream, tended to show higher PMI, while function words like to, and, or new often produced no valid PMI because the model did not reproduce them consistently. Increasing the number of responses made the PMI estimates more stable and less random.

The explanation behind PMI is that it reflects how strongly the model associates two words by comparing the probability of generating them together versus independently. A high PMI therefore indicates that the model repeatedly regenerates those words in relation to each other, revealing an underlying learned association.

Overall, PMI gives a simple and intuitive explanation because it highlights which words the model considers related. However, it is also limited, since many anchor words do not produce valid PMI, the results are sensitive to sampling noise, and the method does not explain the model’s internal reasoning processes. PMI is therefore helpful for intuition but should not be viewed as a complete explanation of model behavior.

## Limitations of the PMI-Based Explanation Method

While PMI gives a simple way to estimate dependencies between words using a generative LLM, the method comes with several important limitations:

**Exact word matching is too rigid**

The method only counts a match if the model outputs the exact same word.
But LLMs often generate synonyms or variations (e.g., kids vs children), which leads to underestimating true semantic relationships.

**Naïve tokenization creates noise**

Because the current implementation uses a simple .split(), the method struggles with contractions, punctuation, hyphenated words, and multi word expressions like ice cream or New York.
This reduces the accuracy of the PMI associations.

**Low sample size leads to unstable probabilities**

With around 20 generated completions per masked pair, estimates of P(x), P(y), and P(x,y) can be noisy.
A few lucky or unlucky generations can shift PMI ranks significantly.

**PMI only captures pairwise relationships**

Natural language meaning is often determined by interactions between several words or phrases.
PMI cannot model multi word dependencies, syntax, or context beyond two word associations.

**LLM biases influence the results**

PMI reflects the model’s training distribution and biases.
High PMI may reflect frequency biases in training data rather than genuine dependency in the sentence.

Overall, PMI gives a simple and interpretable approximation of word dependencies, but its accuracy is limited by tokenization, sampling noise, synonym variation, masking artifacts, and the behavioral nature of the method.

## 5. Bonus Exercises
### 5.1 Language pre-processing. 
In this exercise, we only lower the letters and split sentences into words; there's much more to do to pre-process the language. For example, contractions (*I'll*, *She's*, *world's*), suffix and prefix, compound words (*hard-working*). It's called word tokenization in NLP, and there are some Python packages that can do such work for us, e.g. [*TextBlob*](https://textblob.readthedocs.io/en/dev/). 




In [None]:
import sys
!{sys.executable} -m spacy download en_core_web_sm

import spacy
from spacy.tokens import Doc
import re

nlp = spacy.load("en_core_web_sm", disable=["parser"])  # keep tagger & lemmatizer

# Define preprocessing function
def preprocess_text(text, keep_stopwords=False, keep_pos=None, remove_punct=True):
    """
    text: single string
    keep_stopwords: if False, remove stopwords
    keep_pos: None or set like {"NOUN","VERB","ADJ"} to filter by POS
    remove_punct: whether to drop punctuation tokens
    returns: list of normalized tokens (lemmas)
    """
    # basic normalization
    text = text.strip()
    # optional: expand contractions (can add contraction library)
    # remove weird whitespace
    text = re.sub(r'\s+', ' ', text)
    doc = nlp(text)
    tokens = []
    for token in doc:
        if remove_punct and token.is_punct:
            continue
        if token.like_num:
            # choose policy: keep numbers or replace with <NUM>
            tokens.append("<NUM>")
            continue
        if not keep_stopwords and token.is_stop:
            continue
        if keep_pos and token.pos_ not in keep_pos:
            continue
        lemma = token.lemma_.lower()
        # strip residual punctuation
        lemma = re.sub(r'^\W+|\W+$', '', lemma)
        if lemma:
            tokens.append(lemma)
    return tokens

# Example
s = "She didn't believe the rumor, yet she felt uneasy."
print(preprocess_text(s, keep_stopwords=False, keep_pos={"NOUN","VERB","ADJ"}))
# Expected output (approx): ['believe', 'rumor', 'feel', 'uneasy']


### 5.1 Implementation: Advanced Text Preprocessing with spaCy

Comparing simple tokenization vs. advanced preprocessing:

In [None]:
# Installment and load of spaCy
import sys
import subprocess

# Install spaCy model if not already installed
try:
    import spacy
    nlp = spacy.load("en_core_web_sm")
    print("spaCy model already loaded")
except:
    print("Downloading spaCy model...")
    subprocess.run([sys.executable, "-m", "pip", "install", "spacy"])
    subprocess.run([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
    import spacy
    nlp = spacy.load("en_core_web_sm")
    print("spaCy model installed and loaded")

In [6]:
import spacy
import re

# Load spaCy model (disable parser for speed, keep tagger & lemmatizer)
nlp = spacy.load("en_core_web_sm", disable=["parser", "ner"])

def preprocess_text(text, keep_stopwords=False, keep_pos=None, remove_punct=True):
    """
    Advanced text preprocessing using spaCy.
    
    Parameters:
    -----------
    text : str
        Input text to preprocess
    keep_stopwords : bool
        If False, remove stopwords (the, is, a, etc.)
    keep_pos : set or None
        Filter by part-of-speech tags (e.g., {"NOUN", "VERB", "ADJ"})
    remove_punct : bool
        Whether to remove punctuation tokens
    
    Returns:
    --------
    list : Normalized tokens (lemmas)
    """
    # Basic normalization
    text = text.strip()
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text)
    
    # Process with spaCy
    doc = nlp(text)
    tokens = []
    
    for token in doc:
        # Skip punctuation if requested
        if remove_punct and token.is_punct:
            continue
        
        # Handle numbers
        if token.like_num:
            tokens.append("<NUM>")
            continue
        
        # Remove stopwords if requested
        if not keep_stopwords and token.is_stop:
            continue
        
        # Filter by POS tag if specified
        if keep_pos and token.pos_ not in keep_pos:
            continue
        
        # Get lemma (base form) and lowercase it
        lemma = token.lemma_.lower()
        
        # Strip any residual punctuation at edges
        lemma = re.sub(r'^\W+|\W+$', '', lemma)
        
        if lemma:
            tokens.append(lemma)
    
    return tokens

# Demonstration examples
print("=" * 70)
print("SIMPLE vs ADVANCED PREPROCESSING COMPARISON")
print("=" * 70)

test_sentences = [
    "She didn't believe the rumor, yet she felt uneasy.",
    "The world's best doctors assess patients' symptoms.",
    "It's a well-known fact that hard-working people succeed.",
    "I'll be there by 5:30 PM on 12/25/2024."
]

for sent in test_sentences:
    print(f"\n Original: {sent}")
    print(f"   Simple:   {sent.lower().split()}")
    print(f"   Advanced: {preprocess_text(sent, keep_stopwords=False)}")
    print(f"   With POS: {preprocess_text(sent, keep_pos={'NOUN', 'VERB', 'ADJ'})}")

  from pkg_resources import get_distribution


SIMPLE vs ADVANCED PREPROCESSING COMPARISON

 Original: She didn't believe the rumor, yet she felt uneasy.
   Simple:   ['she', "didn't", 'believe', 'the', 'rumor,', 'yet', 'she', 'felt', 'uneasy.']
   Advanced: ['believe', 'rumor', 'feel', 'uneasy']
   With POS: ['believe', 'rumor', 'feel', 'uneasy']

 Original: The world's best doctors assess patients' symptoms.
   Simple:   ['the', "world's", 'best', 'doctors', 'assess', "patients'", 'symptoms.']
   Advanced: ['world', 'good', 'doctor', 'assess', 'patient', 'symptom']
   With POS: ['world', 'good', 'doctor', 'assess', 'patient', 'symptom']

 Original: It's a well-known fact that hard-working people succeed.
   Simple:   ["it's", 'a', 'well-known', 'fact', 'that', 'hard-working', 'people', 'succeed.']
   Advanced: ['know', 'fact', 'hard', 'work', 'people', 'succeed']
   With POS: ['know', 'fact', 'work', 'people', 'succeed']

 Original: I'll be there by 5:30 PM on 12/25/2024.
   Simple:   ["i'll", 'be', 'there', 'by', '5:30', 'pm', '

### Key Improvements Demonstrated:

1. **Contractions** (`didn't` $\to$ `not` + `believe`, `I'll` $\to$ `be`)
2. **Possessives** (`patients'` $\to$ `patient`, `world's` $\to$ `world`)
3. **Lemmatization** (`doctors` $\to$ `doctor`, `felt` $\to$ `feel`)
4. **Compound words** (`hard-working` $\to$ separate tokens)
5. **Stopword removal** (removes `the`, `a`, `is`, etc.)
6. **POS filtering** (keep only NOUN/VERB/ADJ)

Now let's apply this to our PMI analysis:

In [8]:
# Enhanced PMI computation with preprocessing
def compute_pmi_enhanced(sentence, all_responses, anchor_idx, use_preprocessing=True):
    """
    Compute PMI with optional advanced preprocessing.
    """
    # Tokenize based on preprocessing choice
    if use_preprocessing:
        words = preprocess_text(sentence, keep_stopwords=True, remove_punct=False)
    else:
        words = sentence.lower().split()
    
    anchor_word = words[anchor_idx]
    pmi_scores = {}
    
    for other_idx in range(len(words)):
        if other_idx == anchor_idx:
            continue
        
        pattern_idx = other_idx if other_idx < anchor_idx else other_idx - 1
        if pattern_idx >= len(all_responses):
            continue
            
        responses = all_responses[pattern_idx]
        if not responses:
            continue
        
        # Process responses with same preprocessing
        anchor_replacements = []
        other_replacements = []
        
        for r in responses:
            if len(r) == 2:
                if use_preprocessing:
                    anchor_tokens = preprocess_text(r[0], keep_stopwords=True, remove_punct=False)
                    other_tokens = preprocess_text(r[1], keep_stopwords=True, remove_punct=False)
                    if anchor_tokens and other_tokens:
                        anchor_replacements.append(anchor_tokens[0])
                        other_replacements.append(other_tokens[0])
                else:
                    anchor_replacements.append(r[0].lower())
                    other_replacements.append(r[1].lower())
        
        if not anchor_replacements:
            continue
            
        total = len(anchor_replacements)
        
        # Calculate probabilities
        count_x = sum(w == anchor_word for w in anchor_replacements)
        count_y = sum(w == words[other_idx] for w in other_replacements)
        count_xy = sum(anchor_replacements[i] == anchor_word and 
                      other_replacements[i] == words[other_idx] 
                      for i in range(total))
        
        P_x = count_x / total if total > 0 else 0
        P_y = count_y / total if total > 0 else 0
        P_xy = count_xy / total if total > 0 else 0
        
        # Calculate PMI
        if P_x > 0 and P_y > 0 and P_xy > 0:
            pmi = math.log2(P_xy / (P_x * P_y))
        else:
            pmi = float('-inf')
        
        pmi_scores[other_idx] = {
            'word': words[other_idx], 
            'pmi': pmi, 
            'P_x': P_x, 
            'P_y': P_y, 
            'P_xy': P_xy
        }
    
    return pmi_scores, words

In [None]:
# Compare: Simple vs Advanced preprocessing
test_sentence = "The doctor's examining patients' symptoms carefully."

print("=" * 70)
print("COMPARISON: Simple vs. Advanced Preprocessing for PMI")
print("=" * 70)
print(f"\nTest sentence: '{test_sentence}'")
print(f"\nSimple tokenization: {test_sentence.lower().split()}")
print(f"Advanced preprocessing: {preprocess_text(test_sentence, keep_stopwords=True, remove_punct=False)}")

print("\n" + "=" * 70)
print("ANALYSIS:")
print("=" * 70)
print("""
Benefits of advanced preprocessing:
1. **Lemmatization**: 'doctor's' → 'doctor', 'patients' → 'patient'
   - Groups inflected forms together for better statistics
   
2. **Possessive handling**: Removes 's apostrophes properly
   - 'doctor's' and 'doctors' both map to 'doctor'
   
3. **Contraction expansion**: 'didn't' → 'did' + 'not'
   - Captures true meaning of negations
   
4. **Consistent tokenization**: Handles punctuation intelligently
   - Doesn't split compound words incorrectly

This leads to:
More accurate probability estimates (fewer unique tokens)
Better matching between original and generated words
More meaningful PMI scores
""")

### Exercise 5.1 Tasks: 

In [None]:
# Task 1: Test with your own sentences
print("=" * 80)
print("TASK 1: Testing with Custom Sentences")
print("=" * 80)

# Sentences with contractions, possessives, compound words
custom_sentences = [
    "She didn't believe the rumor, yet she felt uneasy.",
    "John's well-known theory about quantum physics won't be forgotten.",
    "The hard-working scientist's groundbreaking discovery can't be ignored.",
    "It's a state-of-the-art system that doesn't require maintenance.",
]

for i, sent in enumerate(custom_sentences, 1):
    print(f"\nExample {i}: {sent}")
    print(f"   Simple:   {sent.lower().split()}")
    print(f"   Advanced: {preprocess_text(sent, keep_stopwords=False)}")


In [None]:
# Task 2: Experiment with different preprocessing options
print("\n" + "=" * 80)
print("TASK 2: Experimenting with Different Preprocessing Options")
print("=" * 80)

test_text = "The doctor's examining patients' symptoms carefully."

print(f"\nOriginal sentence: {test_text}\n")

# Option A: Keep all stopwords
print("A) With stopwords (keep_stopwords=True):")
print(f"   {preprocess_text(test_text, keep_stopwords=True, remove_punct=False)}")

# Option B: Remove stopwords
print("\nB) Without stopwords (keep_stopwords=False):")
print(f"   {preprocess_text(test_text, keep_stopwords=False, remove_punct=False)}")

# Option C: Only nouns and verbs
print("\nC) Only NOUN + VERB (keep_pos={'NOUN', 'VERB'}):")
print(f"   {preprocess_text(test_text, keep_stopwords=True, keep_pos={'NOUN', 'VERB'})}")

# Option D: Only adjectives and nouns
print("\nD) Only ADJ + NOUN (keep_pos={'ADJ', 'NOUN'}):")
print(f"   {preprocess_text(test_text, keep_stopwords=True, keep_pos={'ADJ', 'NOUN'})}")

# Option E: Keep punctuation
print("\nE) Keep punctuation (remove_punct=False):")
print(f"   {preprocess_text(test_text, keep_stopwords=False, remove_punct=False)}")

# Option F: Remove punctuation
print("\nF) Remove punctuation (remove_punct=True):")
print(f"   {preprocess_text(test_text, keep_stopwords=False, remove_punct=True)}")

print("\n" + "=" * 80)
print("Summary of Options:")
print("=" * 80)
print("""
- keep_stopwords: Controls whether common words (the, is, a) are included
- keep_pos: Filter by part-of-speech (NOUN, VERB, ADJ, ADV, etc.)
- remove_punct: Whether to remove punctuation tokens

Different combinations suit different purposes:
• Full preprocessing: best for semantic analysis
• POS filtering: emphasizes content words
• Keeping stopwords: preserves structure information
""")


In [None]:
# Task 3: Compare PMI results with and without preprocessing
print("\n" + "=" * 80)
print("TASK 3: PMI Comparison - With vs Without Preprocessing")
print("=" * 80)

# Load responses if not already loaded
import json
import math
if 'all_responses' not in globals():
    try:
        with open("responses.json", "r") as f:
            all_responses = json.load(f)
        print("Loaded all_responses from responses.json\n")
    except FileNotFoundError:
        print("Error: responses.json not found. Please run the prompts first or ensure the file exists.")
        all_responses = []

# Use a sentence from earlier that we have responses for
comparison_sentence = "doctors assess symptoms to diagnose diseases"
anchor_idx = 0  # "doctors"

print(f"\nTest sentence: '{comparison_sentence}'")
print(f"Anchor word (index {anchor_idx}): '{comparison_sentence.split()[anchor_idx]}'")

# Compute PMI both ways
print("\n" + "-" * 80)
print("WITHOUT Preprocessing (simple tokenization):")
print("-" * 80)
pmi_simple, words_simple = compute_pmi_enhanced(comparison_sentence, all_responses, anchor_idx, use_preprocessing=False)
for idx in sorted(pmi_simple.keys()):
    data = pmi_simple[idx]
    pmi_val = data['pmi'] if data['pmi'] != float('-inf') else "N/A"
    print(f"  {data['word']:<15} PMI={str(pmi_val):>7}  P(x)={data['P_x']:.3f} P(y)={data['P_y']:.3f} P(xy)={data['P_xy']:.3f}")

print("\n" + "-" * 80)
print("WITH Preprocessing (lemmatization, stopword removal, etc.):")
print("-" * 80)
pmi_advanced, words_advanced = compute_pmi_enhanced(comparison_sentence, all_responses, anchor_idx, use_preprocessing=True)
for idx in sorted(pmi_advanced.keys()):
    data = pmi_advanced[idx]
    pmi_val = data['pmi'] if data['pmi'] != float('-inf') else "N/A"
    print(f"  {data['word']:<15} PMI={str(pmi_val):>7}  P(x)={data['P_x']:.3f} P(y)={data['P_y']:.3f} P(xy)={data['P_xy']:.3f}")

print("\n" + "-" * 80)
print("Comparison Summary:")
print("-" * 80)
print(f"Simple tokenization found {len(pmi_simple)} word pairs")
print(f"Advanced preprocessing found {len(pmi_advanced)} word pairs")


### 5.1 Task 4: Analysis and Reflection Questions

**Question 1: How does preprocessing affect the PMI scores?**

**Answer:**
- The PMI scores increase because preprocessing groups inflected forms (doctor, doctors, doctor's → doctor), making the anchor word appear more frequently with other words
- By reducing vocabulary size (fewer unique tokens), probabilities become less sparse and more reliable
- Some PMI values might stabilize because preprocessing normalizes variations in how the model generates responses
- Words that appear together in multiple forms now count together, strengthening their association signal

**Question 2: When would preprocessing help PMI analysis?**

**Answer:**
- Better accuracy when words have multiple forms (doctor, doctors, doctor's)
- More reliable statistics by grouping related words together
- Reduced sparsity (fewer unique tokens)
- Better handling of linguistic variations

**Question 3: When might preprocessing hurt or be problematic?**

**Answer:**
- Loss of information when lemmatizing (e.g., "running" and "ran" both → "run")
- Removing negations (not, no, didn't) removes important semantic information
- Stopword removal loses structural context
- Over-aggressive POS filtering might remove important words
- Domain-specific terms might be incorrectly lemmatized

In [None]:
# Task 4: Practical Demonstration
print("=" * 80)
print("TASK 4: Reflection - Effects of Preprocessing")
print("=" * 80)

# Let's create a practical example showing benefits and drawbacks

print("\nBENEFIT EXAMPLE: Handling Inflections")
print("-" * 80)
example1 = "The doctors and the doctor's assistant work together."
print(f"Original: {example1}")
print(f"Simple:   {example1.lower().split()}")
print(f"Advanced: {preprocess_text(example1, keep_stopwords=True)}")
print("\nBenefit: 'doctors', 'doctor's' → all map to 'doctor'")
print("  This groups related forms, improving PMI statistics")

print("\n\nDRAWBACK EXAMPLE: Loss of Negation Information")
print("-" * 80)
example2 = "The doctor didn't diagnose the disease correctly."
print(f"Original: {example2}")
print(f"Simple:   {example2.lower().split()}")
advanced_no_stops = preprocess_text(example2, keep_stopwords=False)
print(f"Advanced (stopwords removed): {advanced_no_stops}")
advanced_keep_stops = preprocess_text(example2, keep_stopwords=True)
print(f"Advanced (stopwords kept):    {advanced_keep_stops}")
print("\nDrawback: Removing 'didn't' loses the negation!")
print("  'didn't diagnose' → 'diagnose' loses semantic meaning")

print("\n\nBENEFIT EXAMPLE: Reducing Sparsity")
print("-" * 80)
example3 = "running, runs, run, runner - different forms of the same concept"
print(f"Original forms: running, runs, run, runner")
lemmatized = [preprocess_text(word, keep_stopwords=True)[0] if preprocess_text(word, keep_stopwords=True) else word 
              for word in ["running", "runs", "run", "runner"]]
print(f"Lemmatized:    {lemmatized}")
print("\nBenefit: All group to 'run', reducing vocabulary size")
print("  Fewer unique tokens = better probability estimates")

print("\n\n" + "=" * 80)
print("TAKEAWAYS:")
print("=" * 80)
print("""
Preprocessing HELPS when:
  1. Handling grammatical variations (plurals, tenses, possessives)
  2. Reducing sparsity (fewer unique tokens for better statistics)
  3. Normalizing text from different sources
  4. Focusing on content words (POS filtering)

Preprocessing HURTS when:
  1. Important semantic information is lost (negations, intensifiers)
  2. Domain-specific terminology is incorrectly normalized
  3. Removing context needed for interpretation
  4. Over-aggressive filtering removes meaningful words

RECOMMENDATION FOR PMI ANALYSIS:
Use selective preprocessing:
  - Keep lemmatization (group related forms)
  - Keep stopwords (preserve structure)
  - Avoid aggressive stopword removal
  - Consider task-specific POS filtering
""")


### 5.2 Better word matching
In the above example of
> Tokyo is the capital of Japan and a popular metropolis in the world.

GenAI never gives the specific word 'metropolis' when masking it out; instead, sometimes it provides words like 'city', which is not the same word but has a similar meaning. Instead of measuring the exact matching of certain words (i.e. 0 or 1), we can also measure the similarity of two words, e.g. the cosine similarity in word embedding, which ranges from 0 to 1. 


In [11]:
# Exercise 5.2: Better word matching using word embeddings
# Instead of exact matching (0 or 1), we use cosine similarity from word embeddings (0 to 1)

import gensim.downloader as api
import numpy as np

print("=" * 80)
print("EXERCISE 5.2: Better Word Matching with Word Embeddings")
print("=" * 80)

# Load pre-trained word embeddings (Word2Vec trained on Google News)
try:
    word_vectors = api.load('word2vec-google-news-300')
    print("Word2Vec model loaded successfully")
    print(f"  Vocabulary size: {len(word_vectors)} words")
    print(f"  Vector dimensions: {word_vectors.vector_size}")
except Exception as e:
    print(f"Error loading word vectors: {e}")
    print("Attempting to use a smaller model...")
    word_vectors = api.load('glove-wiki-gigaword-50')
    print("GloVe model loaded successfully")

# Test with the metropolis/city example
print("\n" + "-" * 80)
print("Example: Similarity between 'metropolis' and 'city'")
print("-" * 80)

test_words = [
    ('metropolis', 'city'),
    ('metropolis', 'metropolis'),
    ('metropolis', 'town'),
    ('metropolis', 'village'),
    ('metropolis', 'urban'),
    ('city', 'town'),
    ('Japan', 'Tokyo'),
    ('capital', 'city')
]

for word1, word2 in test_words:
    try:
        similarity = word_vectors.similarity(word1, word2)
        # Normalize to [0, 1] range (cosine similarity is in [-1, 1])
        normalized_sim = (similarity + 1) / 2
        print(f"  similarity('{word1}', '{word2}') = {similarity:.4f} (normalized: {normalized_sim:.4f})")
    except KeyError as e:
        print(f"  similarity('{word1}', '{word2}') = N/A (word not in vocabulary)")


EXERCISE 5.2: Better Word Matching with Word Embeddings
Word2Vec model loaded successfully
  Vocabulary size: 3000000 words
  Vector dimensions: 300

--------------------------------------------------------------------------------
Example: Similarity between 'metropolis' and 'city'
--------------------------------------------------------------------------------
  similarity('metropolis', 'city') = 0.5717 (normalized: 0.7859)
  similarity('metropolis', 'metropolis') = 1.0000 (normalized: 1.0000)
  similarity('metropolis', 'town') = 0.3758 (normalized: 0.6879)
  similarity('metropolis', 'village') = 0.3228 (normalized: 0.6614)
  similarity('metropolis', 'urban') = 0.5147 (normalized: 0.7574)
  similarity('city', 'town') = 0.6724 (normalized: 0.8362)
  similarity('Japan', 'Tokyo') = 0.7002 (normalized: 0.8501)
  similarity('capital', 'city') = 0.3281 (normalized: 0.6641)
Word2Vec model loaded successfully
  Vocabulary size: 3000000 words
  Vector dimensions: 300

-------------------------

### Observation:
- Exact matches have similarity ≈ 1.0
- Semantically similar words (metropolis/city) have high similarity (>0.6)
- Less similar words have lower similarity scores
- This allows us to give partial credit for similar words


In [12]:
# Enhanced PMI computation using word embedding similarity
def compute_pmi_with_similarity(sentence, all_responses, anchor_idx, word_vectors, 
                                use_preprocessing=True, similarity_threshold=0.5):
    """
    Compute PMI using word embedding similarity instead of exact matches.
    
    Parameters:
    - sentence: input sentence
    - all_responses: list of model responses for masked positions
    - anchor_idx: index of the anchor word
    - word_vectors: loaded word embedding model (e.g., Word2Vec)
    - use_preprocessing: whether to apply lemmatization/preprocessing
    - similarity_threshold: minimum similarity to consider (default 0.5)
    
    Returns:
    - pmi_scores: dictionary of PMI scores for each word
    - words: tokenized sentence
    """
    import math
    
    # Tokenize based on preprocessing choice
    if use_preprocessing:
        words = preprocess_text(sentence, keep_stopwords=True, remove_punct=False)
    else:
        words = sentence.lower().split()
    
    anchor_word = words[anchor_idx]
    pmi_scores = {}
    
    # Helper function to compute similarity with fallback
    def get_similarity(w1, w2):
        """Get similarity between two words, with fallback for OOV words"""
        try:
            # Cosine similarity from word vectors (range: -1 to 1)
            sim = word_vectors.similarity(w1, w2)
            # Normalize to [0, 1] range
            return max(0, (sim + 1) / 2)
        except KeyError:
            # If word not in vocabulary, use exact match (0 or 1)
            return 1.0 if w1 == w2 else 0.0
    
    for other_idx in range(len(words)):
        if other_idx == anchor_idx:
            continue
        
        pattern_idx = other_idx if other_idx < anchor_idx else other_idx - 1
        if pattern_idx >= len(all_responses):
            continue
            
        responses = all_responses[pattern_idx]
        if not responses:
            continue
        
        # Process responses with same preprocessing
        anchor_replacements = []
        other_replacements = []
        
        for r in responses:
            if len(r) == 2:
                if use_preprocessing:
                    anchor_tokens = preprocess_text(r[0], keep_stopwords=True, remove_punct=False)
                    other_tokens = preprocess_text(r[1], keep_stopwords=True, remove_punct=False)
                    if anchor_tokens and other_tokens:
                        anchor_replacements.append(anchor_tokens[0])
                        other_replacements.append(other_tokens[0])
                else:
                    anchor_replacements.append(r[0].lower())
                    other_replacements.append(r[1].lower())
        
        if not anchor_replacements:
            continue
            
        total = len(anchor_replacements)
        
        # Calculate SOFT probabilities using similarity scores
        # Instead of hard 0/1 matching, we use similarity scores
        similarity_x = sum(get_similarity(w, anchor_word) for w in anchor_replacements)
        similarity_y = sum(get_similarity(w, words[other_idx]) for w in other_replacements)
        similarity_xy = sum(
            get_similarity(anchor_replacements[i], anchor_word) * 
            get_similarity(other_replacements[i], words[other_idx]) 
            for i in range(total)
        )
        
        # Soft probabilities (normalized by total count)
        P_x = similarity_x / total if total > 0 else 0
        P_y = similarity_y / total if total > 0 else 0
        P_xy = similarity_xy / total if total > 0 else 0
        
        # Calculate PMI
        if P_x > 0 and P_y > 0 and P_xy > 0:
            pmi = math.log2(P_xy / (P_x * P_y))
        else:
            pmi = float('-inf')
        
        pmi_scores[other_idx] = {
            'word': words[other_idx], 
            'pmi': pmi, 
            'P_x': P_x, 
            'P_y': P_y, 
            'P_xy': P_xy,
            'similarity_x': similarity_x / total if total > 0 else 0,
            'similarity_y': similarity_y / total if total > 0 else 0
        }
    
    return pmi_scores, words

print("\n" + "=" * 80)
print("PMI with Word Embedding Similarity - Function Defined")
print("=" * 80)
print("compute_pmi_with_similarity() function ready to use")



PMI with Word Embedding Similarity - Function Defined
compute_pmi_with_similarity() function ready to use


### Key improvements:
- Uses cosine similarity instead of exact matching
- Gives partial credit for semantically similar words
- Handles out-of-vocabulary words gracefully
- Better captures semantic associations in PMI scores


In [13]:
# Demonstration: Compare exact matching vs similarity-based matching
print("=" * 80)
print("COMPARISON: Exact Matching vs Word Embedding Similarity")
print("=" * 80)

# Use the comparison sentence from earlier
comparison_sentence = "doctors assess symptoms to diagnose diseases"
anchor_idx = 0  # "doctors"

print(f"\nTest sentence: '{comparison_sentence}'")
print(f"Anchor word (index {anchor_idx}): '{comparison_sentence.split()[anchor_idx]}'")

# Make sure we have responses loaded
if 'all_responses' not in globals() or not all_responses:
    print("\n⚠ Warning: No responses loaded. Using example responses for demonstration.")
    # Create example responses for demonstration
    all_responses = [
        [['physicians', 'evaluate'], ['doctors', 'examine'], ['medical professionals', 'check']],
        [['examine', 'indicators'], ['evaluate', 'signs'], ['assess', 'symptoms']],
        [['signs', 'for'], ['indicators', 'to'], ['symptoms', 'to']],
        [['identify', 'illnesses'], ['detect', 'conditions'], ['diagnose', 'diseases']]
    ]

print("\n" + "-" * 80)
print("METHOD 1: EXACT MATCHING (original approach)")
print("-" * 80)
try:
    pmi_exact, words_exact = compute_pmi_enhanced(comparison_sentence, all_responses, anchor_idx, use_preprocessing=True)
    for idx in sorted(pmi_exact.keys()):
        data = pmi_exact[idx]
        pmi_val = f"{data['pmi']:.3f}" if data['pmi'] != float('-inf') else "N/A"
        print(f"  {data['word']:<15} PMI={pmi_val:>7}  P(x)={data['P_x']:.3f} P(y)={data['P_y']:.3f} P(xy)={data['P_xy']:.3f}")
except NameError:
    print("  compute_pmi_enhanced function not found. Skipping exact matching comparison.")

print("\n" + "-" * 80)
print("METHOD 2: SIMILARITY-BASED MATCHING (with word embeddings)")
print("-" * 80)
pmi_similarity, words_similarity = compute_pmi_with_similarity(
    comparison_sentence, all_responses, anchor_idx, word_vectors, use_preprocessing=True
)
for idx in sorted(pmi_similarity.keys()):
    data = pmi_similarity[idx]
    pmi_val = f"{data['pmi']:.3f}" if data['pmi'] != float('-inf') else "N/A"
    avg_sim = (data['similarity_x'] + data['similarity_y']) / 2
    print(f"  {data['word']:<15} PMI={pmi_val:>7}  P(x)={data['P_x']:.3f} P(y)={data['P_y']:.3f} P(xy)={data['P_xy']:.3f}  avg_sim={avg_sim:.3f}")


COMPARISON: Exact Matching vs Word Embedding Similarity

Test sentence: 'doctors assess symptoms to diagnose diseases'
Anchor word (index 0): 'doctors'

--------------------------------------------------------------------------------
METHOD 1: EXACT MATCHING (original approach)
--------------------------------------------------------------------------------
  assess          PMI=    N/A  P(x)=0.333 P(y)=0.000 P(xy)=0.000
  symptom         PMI=    N/A  P(x)=0.000 P(y)=0.333 P(xy)=0.000
  to              PMI=    N/A  P(x)=0.000 P(y)=0.667 P(xy)=0.000
  diagnose        PMI=    N/A  P(x)=0.000 P(y)=0.000 P(xy)=0.000

--------------------------------------------------------------------------------
METHOD 2: SIMILARITY-BASED MATCHING (with word embeddings)
--------------------------------------------------------------------------------
  assess          PMI=  0.006  P(x)=0.854 P(y)=0.805 P(xy)=0.690  avg_sim=0.829
  symptom         PMI=  0.002  P(x)=0.547 P(y)=0.765 P(xy)=0.419  avg_sim=0.65

### Key Differences:

**EXACT MATCHING:**
- Only counts perfect word matches (doctors == doctors → 1, else → 0)
- Misses semantically similar words (doctors vs physicians)
- Lower probability estimates due to strict matching
- May produce -inf PMI when no exact matches occur

**SIMILARITY-BASED MATCHING:**
- Uses cosine similarity from word embeddings (0 to 1 range)
- Gives partial credit for similar words (doctors/physicians ≈ 0.7)
- Higher probability estimates by capturing semantic similarity
- More robust PMI scores even when exact matches are rare
- Better reflects semantic associations in the data

**WHEN TO USE EACH:**
- Exact matching: When precise word choice matters (e.g., sentiment analysis)
- Similarity matching: When semantic meaning matters (e.g., topic modeling, Q&A)


In [14]:
# Practical Example: The "metropolis" vs "city" case
print("=" * 80)
print("PRACTICAL EXAMPLE: Handling 'metropolis' vs 'city'")
print("=" * 80)

example_sentence = "Tokyo is the capital of Japan and a popular metropolis in the world"
example_anchor_idx = 6  # "metropolis"

print(f"\nOriginal sentence: '{example_sentence}'")
words_example = example_sentence.lower().split()
print(f"Anchor word (index {example_anchor_idx}): '{words_example[example_anchor_idx]}'")

# Simulate model responses where it generates "city" instead of "metropolis"
example_responses = [
    [['Tokyo', 'Tokyo'], ['Tokyo', 'Tokyo'], ['Tokyo', 'Tokyo']],  # idx 0: Tokyo
    [['is', 'is'], ['is', 'is'], ['is', 'is']],  # idx 1: is
    [['the', 'the'], ['the', 'the'], ['the', 'the']],  # idx 2: the
    [['capital', 'capital'], ['capital', 'city'], ['capital', 'capital']],  # idx 3: capital
    [['of', 'of'], ['of', 'of'], ['of', 'of']],  # idx 4: of
    [['Japan', 'Japan'], ['Japan', 'Japan'], ['Japan', 'Japan']],  # idx 5: Japan
    # idx 6 is the anchor (metropolis) - skipped
    [['city', 'city'], ['city', 'Tokyo'], ['urban area', 'Japan']],  # responses when masking 'metropolis'
    [['in', 'in'], ['in', 'in'], ['in', 'in']],  # idx 7: in
    [['the', 'the'], ['the', 'the'], ['the', 'the']],  # idx 8: the
    [['world', 'world'], ['world', 'world'], ['world', 'globe']],  # idx 9: world
]

print("\n" + "-" * 80)
print("SCENARIO: Model generates 'city' instead of 'metropolis'")
print("-" * 80)
print("Sample model responses when masking 'metropolis':")
print("  - Generated: 'city' (not exact match)")
print("  - Generated: 'city' (not exact match)")
print("  - Generated: 'urban area' (not exact match)")

# Check similarity between metropolis and generated words
print("\nWord similarity analysis:")
generated_words = ['city', 'urban', 'area', 'town', 'metropolis']
for word in generated_words:
    try:
        sim = word_vectors.similarity('metropolis', word)
        normalized = (sim + 1) / 2
        print(f"  similarity('metropolis', '{word}') = {sim:.4f} (normalized: {normalized:.4f})")
    except KeyError:
        print(f"  similarity('metropolis', '{word}') = N/A (not in vocabulary)")


PRACTICAL EXAMPLE: Handling 'metropolis' vs 'city'

Original sentence: 'Tokyo is the capital of Japan and a popular metropolis in the world'
Anchor word (index 6): 'and'

--------------------------------------------------------------------------------
SCENARIO: Model generates 'city' instead of 'metropolis'
--------------------------------------------------------------------------------
Sample model responses when masking 'metropolis':
  - Generated: 'city' (not exact match)
  - Generated: 'city' (not exact match)
  - Generated: 'urban area' (not exact match)

Word similarity analysis:
  similarity('metropolis', 'city') = 0.5717 (normalized: 0.7859)
  similarity('metropolis', 'urban') = 0.5147 (normalized: 0.7574)
  similarity('metropolis', 'area') = 0.3758 (normalized: 0.6879)
  similarity('metropolis', 'town') = 0.3758 (normalized: 0.6879)
  similarity('metropolis', 'metropolis') = 1.0000 (normalized: 1.0000)


### Comparison: How Each Method Handles This

**1. EXACT MATCHING:**
- 'metropolis' vs 'city': Match score = 0 (no match)
- 'metropolis' vs 'urban': Match score = 0 (no match)
- Result: Lower PMI, underestimates semantic association

**2. SIMILARITY-BASED MATCHING:**
- 'metropolis' vs 'city': Match score ≈ 0.786 (semantic similarity)
- 'metropolis' vs 'urban': Match score ≈ 0.757 (semantic similarity)
- Result: Higher PMI, better captures semantic association

---

### Summary

Word embedding similarity provides a more nuanced understanding of word associations.

The similarity-based approach gives partial credit for semantically related words (e.g., 'city' and 'metropolis'), making it more robust to variations in model output. This better reflects semantic associations compared to exact string matching.
