# Local Model-Agnostic Explanations (LIME) for Text

In this notebook, we will:

1. Load a fine-tuned GPT-2 sentiment classifier.
2. Define your example review and tokenization
3. Generate perturbed samples of a review in the **interpretable domain** (binary token presence).
4. Obtain GPT-2 predictions for each perturbed sample
5. Define cosine-distance kernel in binary space
6. Build and train the linear surrogate model
7. Extract and visualize feature importances
8. Compare with the official LIME implementation
9. Open question

> **Key point**: For text, LIME works in two domains:
>
> - **Original domain**: token IDs → embeddings → GPT-2 outputs.
> - **Interpretable domain**: binary mask vectors indicating which tokens are present.
>
> We compute **distance** in the interpretable (binary) space, not on embeddings. This is standard practice in text LIME, since the surrogate works on the binary vectors.

<img src="./static/text_sent.png" alt="LO1 Image" style="width: 90%; height: auto;">  

### 0. Set up environment and imports

In [None]:
import torch # For machine learning
import numpy as np # Array manipulation 
import math # For generating perturbed data
import matplotlib.pyplot as plt # For plotting
import torch.nn as nn # machine learning layers
from IPython.display import HTML, display  # For display purposes
from lime.lime_text import LimeTextExplainer # Inbuilt LIME library for results comparison
from torch.utils.data import DataLoader, TensorDataset # for data handling and batch generation
from transformers import GPT2Tokenizer, GPT2ForSequenceClassification # For mapping words to tokens and then to embeddings (model dependent)
import warnings
warnings.filterwarnings("ignore")  # suppress deprecation and trivial warnings

### 1. Load the fine-tuned GPT-2 model

In [None]:
# Load tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("hipnologo/gpt2-imdb-finetune") # A tokenizer converts text to tokens and is model-dependent
model = GPT2ForSequenceClassification.from_pretrained("hipnologo/gpt2-imdb-finetune") # A model for text classification fine-tunned on sentiment analysis
model.eval() # Set the model to evaluation mode, no dropout or batch normalization

# Use GPU for model if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) # Display model details

GPT2ForSequenceClassification(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (score): Linear(in_features=768, out_features=2, bias=False)
)

### 2. Define your example review and tokenization

In [None]:
# YOUR CODE HERE: set a short IMDB-style review string preferably <= 10 words

token_ids = tokenizer.encode(review, add_special_tokens=False)
print("Token IDs:", token_ids)

We will use the token **"the"** as a neutral token for masking, since it carries minimal sentiment.

In [None]:
neutral_id = tokenizer.encode("the", add_special_tokens=False)[0]
print(f"Neutral token ID: {neutral_id}")

### 3. Generate perturbed samples in the interpretable domain

In [None]:
def perturb_tokens(token_ids, neutral_id, max_mask_ratio=0.6):
    """
    Generate (sampled) perturbations by randomly masking subsets of tokens.
    Returns perturbed token sequences and binary masks (1 = keep, 0 = mask).
    """
    n = len(token_ids)
    max_masks = int(n * max_mask_ratio)
    masks, sequences = [], []
    rng = np.random.default_rng(42)
    for mask_count in range(1, max_masks + 1):
        total_combos = math.comb(n, mask_count)
        samples = min(50, total_combos)
        for _ in range(samples):
            idxs = rng.choice(n, size=mask_count, replace=False)
            mask = np.ones(n, dtype=int)
            seq = token_ids.copy()
            for i in idxs:
                mask[i] = 0
                seq[i] = neutral_id
            masks.append(mask)
            sequences.append(seq)
    sequences = torch.tensor(sequences)
    masks = torch.tensor(masks)
    return sequences, masks

Z_ids, Z_masks = perturb_tokens(token_ids, neutral_id)
print("Perturbations:", Z_ids.shape, Z_masks.shape)
print(Z_ids[100])
print(Z_masks[100])

### 4. Obtain GPT-2 predictions for each perturbed sample

In [None]:
def get_prediction_probabilities(input_token_ids):
    """
    Takes a batch of token IDs and returns the model's prediction probabilities.

    input_token_ids: A tensor of shape (batch_size, sequence_length) containing the token IDs.

    Returns: A tensor of shape (batch_size, num_classes) with the probabilities for each class, located on the CPU.
    """
    # We use torch.no_grad() to tell PyTorch we are only doing inference (predicting),
    # not training. This makes the code run faster and use less memory.
    with torch.no_grad():
        
        # --- Step 1: Move input data to the correct device (e.g., a GPU) ---
        # The model and its data must be on the same device to work together.
        input_token_ids = input_token_ids.to(device)
        
        # --- Step 2: Get the raw model outputs (called "logits") ---
        # We pass the token IDs to the model. The model's output is often an object,
        # so we access the raw prediction scores with `.logits`.
     
        raw_logits =  # YOUR CODE HERE # Shape will be (batch_size, num_classes), e.g., (1, 2)
        
        # --- Step 3: Convert logits into probabilities ---
        # The softmax function turns the raw scores into probabilities that sum to 1.
        # `dim=-1` tells softmax to operate on the last dimension (the class scores).
        probabilities = # YOUR CODE HERE
        
        # --- Step 4: Move the result back to the CPU ---
        # It's good practice to move data back to the CPU so you can easily
        # use it with other libraries like NumPy or Matplotlib.
        probabilities_on_cpu = probabilities.cpu()

    return probabilities_on_cpu

<details>
  <summary>⚠️ Click here for the solution (this will use your "solution pass")</summary>
  
  ```python
def get_prediction_probabilities(input_token_ids):
    """
    Takes a batch of token IDs and returns the model's prediction probabilities.
    
    Args:
        input_token_ids (torch.Tensor): A tensor of shape (batch_size, sequence_length)
                                        containing the token IDs.
                                        
    Returns:
        torch.Tensor: A tensor of shape (batch_size, num_classes) with the probabilities
                      for each class, located on the CPU.
    """
    # We use `torch.no_grad()` to tell PyTorch we are only doing inference (predicting),
    # not training. This makes the code run faster and use less memory.
    with torch.no_grad():
        
        # --- Step 1: Move input data to the correct device (e.g., a GPU) ---
        # The model and its data must be on the same device to work together.
        input_token_ids = input_token_ids.to(device)
        
        # --- Step 2: Get the raw model outputs (called "logits") ---
        # We pass the token IDs to the model. The model's output is often an object,
        # so we access the raw prediction scores with `.logits`.
        model_outputs = model(input_token_ids)
        raw_logits = model_outputs.logits  # Shape will be (batch_size, num_classes), e.g., (1, 2)
        
        # --- Step 3: Convert logits into probabilities ---
        # The softmax function turns the raw scores into probabilities that sum to 1.
        # `dim=-1` tells softmax to operate on the last dimension (the class scores).
        probabilities = torch.softmax(raw_logits, dim=-1)
        
        # --- Step 4: Move the result back to the CPU ---
        # It's good practice to move data back to the CPU so you can easily
        # use it with other libraries like NumPy or Matplotlib.
        probabilities_on_cpu = probabilities.cpu()

    return probabilities_on_cpu

In [None]:
# Batch through DataLoader. A DataLoader helps manage large datasets by breaking them into smaller batches. 
# In this case, we are using it to get prediction probabilities for each batch.
dataset = TensorDataset(Z_ids)
loader = DataLoader(dataset, batch_size=32)

probs = [] # List to store probabilities for each batch
for (batch_ids,) in loader:
    probs.append(get_prediction_probabilities(batch_ids))
Y = torch.cat(probs, dim=0)  # shape (num_samples, 2)
print("Prediction probs:", Y.shape)

### 5. Define cosine-distance kernel in binary space

LIME weight for each sample is:

$$w_i = \exp\Big(-\frac{(1 - \cos(x',z'))^2}{\sigma^2}\Big)$$

where $x'$ is the original binary vector (all ones) and $z'$ is a perturbed mask.

> **Note**: We use **distance** = 1 − cosine_similarity, squared in the exponent. Note that cosine means $\frac{x\cdot y}{||x|| ||y||}$

In [None]:
# Define the weight above. This weight is larger for augmentations closer to the original input

original_mask = torch.ones_like(Z_masks[0], dtype=torch.float32)  # all tokens present representation of the sentence
cos = nn.CosineSimilarity(dim=1) # Can handle batched inputs

def kernel_weights(masks, sigma=0.5):
    # YOUR CODE HERE 
    return w

<details>
  <summary>⚠️ Click here for the solution (this will use your "solution pass")</summary>
  
  ```python
def kernel_weights(masks, sigma=0.5):
    # masks: (batch, n_tokens)
    dist = 1 - cos(masks.float(), original_mask)
    w = torch.exp(-(dist ** 2) / (sigma ** 2))
    return w

In [None]:
# test
masks = Z_masks[:10]  # take first 10 masks for testing
weights = kernel_weights(masks)
print("Kernel weights:", weights.shape)
# Output should be a tensor of shape Kernel weights: torch.Size([10])

### 6. Build and train the linear surrogate model

Step 1: Building Our Simple Surrogate Model.  
 
First, we need to define the structure of our simple model. We'll create a standard PyTorch model that just has one single linear layer.

In [None]:
# This is the blueprint for our simple linear model.
# It inherits from nn.Module, the base class for all models in PyTorch.
class LinearSurrogate(nn.Module):
    
    # The __init__ method sets up the model's layers.
    def __init__(self, num_input_features):
        super().__init__() # Always call this first!
        # We define a single linear layer.
        # It takes num_input_features as input and produces 1 single number as output.
        self.linear = nn.Linear(num_input_features, 1)

    # The forward method defines what happens when we pass data through the model.
    def forward(self, input_data):
        # The data just goes through our one linear layer.
        return self.linear(input_data)

Step 2: Preparing the Data for Training.  
 
Now, what data do we train this simple model on?
- Inputs (X): The binary masks we generated (Z_masks). These represent which features were "on" or "off."
- Targets (Y): The predictions from our original, complex model (Y).  

This is the key trick! We're teaching our simple model to predict what our complex model would predict.

In [None]:
# We pair up our masks (inputs) with the original model's predictions (targets).
# This creates our training dataset.
training_data = TensorDataset(Z_masks.float(), Y)

# The DataLoader will feed this data to our model in small, shuffled batches.
# Shuffling helps the model learn better and not get stuck on any ordering in the data.
train_loader = DataLoader(training_data, batch_size=64, shuffle=True)

Step 3: The PyTorch Training Loop Recipe

This is the standard "recipe" for training almost any model in PyTorch. We'll go through it step-by-step.

In [None]:
# Setup for Training
# 1. Create an instance of our simple model.
# The input size is the number of features in our masks.
surrogate_model = LinearSurrogate(num_input_features=Z_masks.shape[1])

# 2. Choose an optimizer.
# The optimizer's job is to adjust the model's weights to reduce the error.
# Adam is a popular and effective choice. lr is the learning rate.
optimizer = torch.optim.Adam(surrogate_model.parameters(), lr=0.05)


# The Training Loop 
for epoch in range(10):
    
    total_loss_for_epoch = 0
    
    # The DataLoader gives us one batch of data at a time.
    for masks_batch, predictions_batch in train_loader:
        
        # --- The 5 Core Steps of a Training Iteration ---
        
        # 1. PREPARE THE DATA
        # Our original model gave probabilities for two classes [prob_class_0, prob_class_1].
        # We only want to explain the probability of the positive class (class 1).
        true_predictions = predictions_batch[:, 1:2]
        
        # Calculate weights to focus the model on "important" samples (if needed).
        # This makes the model pay more attention to masks that are similar to the original input.
        sample_weights = kernel_weights(masks_batch, sigma=0.5)

        # 2. MAKE A PREDICTION
        # Get the surrogate model's prediction for the current batch of masks.
        surrogate_prediction = surrogate_model(masks_batch)

        # 3. CALCULATE THE LOSS (the error)
        # We use a "Weighted Mean Squared Error". It measures how far off the prediction is,
        # but gives more importance to the samples with higher weights.
        # You have everything you need to compose the loss seen in the lecture notes
        
        loss = # YOUR CODE

        # 4. BACKPROPAGATION
        # a) Reset previous calculations.
        optimizer.zero_grad()
        # b) Calculate how to adjust the weights to reduce the loss.
        loss.backward()

        # 5. UPDATE THE WEIGHTS
        # The optimizer takes a small step to improve the model's weights.
        optimizer.step()
        
        total_loss_for_epoch += loss.item()
        
    # Print the average loss for this epoch to see if the model is learning.
    average_loss = total_loss_for_epoch / len(train_loader)
    print(f"Epoch {epoch+1}, Average Loss = {average_loss:.4f}")

<details>
  <summary>⚠️ Click here for the solution (this will use your "solution pass")</summary>
  
  ```python
error = (true_predictions - surrogate_prediction) ** 2
weighted_error = sample_weights * error
loss = weighted_error.mean()

### 7. Extract and visualize feature importances

In [None]:
# 7.1 | Extract raw surrogate weights and corresponding tokens
w = surrogate_model.linear.weight.detach().cpu().squeeze().numpy()
tokens = tokenizer.convert_ids_to_tokens(token_ids)
# 7.2 | Aggregate duplicate tokens by averaging their weights
from collections import OrderedDict
import numpy as np
agg = OrderedDict()
for tok, weight in zip(tokens, w):
    word = tok.lstrip('Ġ')
    agg.setdefault(word, []).append(weight)
words, vals = zip(*[(word, np.mean(weights)) for word, weights in agg.items()])
# 7.3 | Sort by absolute importance (descending)
sorted_items = sorted(zip(words, vals), key=lambda x: abs(x[1]), reverse=True)
words, vals = zip(*sorted_items)
# 7.4 | Plot from most to least important
plt.figure(figsize=(8, len(words) * 0.3))
colors = ['sandybrown' if v > 0 else 'skyblue' for v in vals]
plt.barh(words, vals, color=colors)
plt.xlabel('Surrogate weight')
plt.title('LIME explanation (interpretable domain) — Most to Least IMPORTANT')
plt.gca().invert_yaxis()  # highest importance on top
plt.tight_layout()
plt.show()

### 8. Compare with the official LIME implementation

In [None]:
def predict_for_lime(texts):
    # returns array of shape (n_samples, 2)
    enc = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).to(device)
    logits = model(**enc).logits
    return torch.softmax(logits, -1).detach().cpu().numpy()
# Create the explainer and explanation
explainer = LimeTextExplainer(class_names=['neg','pos'])
exp = explainer.explain_instance(
    review,
    predict_for_lime,
    labels=[1],
    num_features=len(token_ids),
    num_samples=500)
# Render the explanation as HTML and display
html = exp.as_html(labels=[1])
display(HTML(html))

### 9. Open question

**Experiment**: Change `sigma` in the kernel and observe how the explanation weights vary. Can you explain your obsevations?