# Interpretability and Evaluation Demo
This notebook demonstrates how to:
- Load a trained model checkpoint.
- Evaluate sample predictions.
- Visualize attention weights.
- Explore basic metrics like perplexity.

In [None]:
# If running in colab or a fresh environment, you might need to install packages
# !pip install transformers accelerate torch datasets tqdm plotly
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "notebook"

## Load Model and Tokenizer
Assuming a checkpoint was saved under `output/epoch_2/` or similar.
Replace the path if necessary.

In [None]:
model_path = "../output/epoch_2"  # Adjust this if needed
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForMaskedLM.from_pretrained(model_path)
model.eval()
model.to("cuda" if torch.cuda.is_available() else "cpu")

## Sample Predictions
Let's test the model on a few masked sentences to see if it predicts reasonable tokens.

In [None]:
examples = [
    "The capital of France is [MASK].",
    "Machine learning is a field of [MASK] intelligence.",
    "The [MASK] brown fox jumps over the lazy dog."
]

for text in examples:
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
    predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
    predicted_token = tokenizer.decode(predicted_token_id)
    print(f"Input: {text}\nPrediction: {predicted_token}\n")

## Compute Perplexity on a Small Sample
For masked language modeling, perplexity is not always directly computed. If you use a causal LM, you can measure perplexity by evaluating the log-likelihood of a dataset.

Here, we’ll do a rough approximation: we’ll mask random tokens and see if the model predicts them correctly, treating that as an indication of model confidence.
For a true perplexity calculation with a masked LM, you'd implement the standard MLM loss calculation on a validation set.


In [None]:
# Example approximate perplexity computation:
# We'll pretend each prediction probability for the masked token can give us a sense.

test_text = "Deep learning models rely heavily on large amounts of training data."  # Adjust as needed
inputs = tokenizer(test_text, return_tensors="pt").to(model.device)
input_ids = inputs.input_ids.clone()

# Randomly mask a few tokens (except special tokens)
rnd = np.random.RandomState(42)
maskable_positions = (input_ids[0] != tokenizer.cls_token_id) & (input_ids[0] != tokenizer.sep_token_id)
mask_positions = rnd.choice(maskable_positions.nonzero(as_tuple=True)[0].cpu().numpy(), size=2, replace=False)
for pos in mask_positions:
    input_ids[0, pos] = tokenizer.mask_token_id

with torch.no_grad():
    outputs = model(input_ids)
logits = outputs.logits

loss_values = []
for pos in mask_positions:
    true_id = inputs.input_ids[0, pos]
    pred_dist = logits[0, pos]
    # Negative log-likelihood for the correct token
    nll = -torch.log_softmax(pred_dist, dim=-1)[true_id]
    loss_values.append(nll.item())

approx_loss = np.mean(loss_values)
approx_ppl = np.exp(approx_loss)
print(f"Approximate masked perplexity on sample: {approx_ppl:.2f}")

## Visualizing Attention
For interpretability, let's visualize the attention weights of the model’s last layer for one input.
Note: Not all models output attention by default. If attention isn't returned, you might need to load the model with `output_attentions=True` or use a method like `model(**inputs, output_attentions=True)`.

In [None]:
# Re-load the model with attention outputs (if supported by the model architecture)
# If this doesn't work, check if your chosen model supports attention outputs.
model_att = AutoModelForMaskedLM.from_pretrained(model_path, output_attentions=True).to(model.device)
model_att.eval()

test_sentence = "The quick brown fox jumps over the lazy dog."
att_inputs = tokenizer(test_sentence, return_tensors="pt").to(model.device)

with torch.no_grad():
    att_outputs = model_att(**att_inputs)
attentions = att_outputs.attentions  # a tuple (layer_count, batch, heads, seq_len, seq_len)

# We'll visualize attentions from the last layer
last_layer_attn = attentions[-1][0]  # (heads, seq_len, seq_len), taking the first (and only) batch
head_to_visualize = 0
att_matrix = last_layer_attn[head_to_visualize].cpu().numpy()

# Create a token-to-token attention heatmap
tokens = tokenizer.convert_ids_to_tokens(att_inputs.input_ids[0])
fig = px.imshow(att_matrix, x=tokens, y=tokens, 
                color_continuous_scale="RdBu", title="Attention Head Visualization (Last Layer, Head 0)")
fig.update_xaxes(side="top", tickangle=45)
fig.show()

### Interpretation
We can see which tokens each token is attending to. For example, the subject token "The" might attend strongly to the verb or the object tokens in the sentence.
