# Notebook 4: Vision-LLM XAI & Evaluation

**Goal**: Interpret the Vision-LLM with attention visualization and evaluate text generation quality.

**Key Practices**:
- Extract attention maps from the Vision Encoder
- BLEU/ROUGE scores for text quality
- Compare ResNet Grad-CAM vs VLM Attention

In [None]:
# Install dependencies
!pip install -q evaluate nltk matplotlib seaborn

In [None]:
import os
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from pathlib import Path
from tqdm.auto import tqdm
import evaluate

In [None]:
# Configuration
DATA_DIR = "/kaggle/input/processed-rafce/processed_dataset"
VLM_MODEL_DIR = "/kaggle/input/vlm-lora/lora_model"  # From Notebook 3
OUTPUT_DIR = "/kaggle/working/xai_output"
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
# Load the fine-tuned model
from unsloth import FastVisionModel

model, tokenizer = FastVisionModel.from_pretrained(
    model_name=VLM_MODEL_DIR,
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

FastVisionModel.for_inference(model)
print("Model loaded for inference")

## Part 1: Text Evaluation (BLEU/ROUGE)

In [None]:
# Load evaluation metrics
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")

In [None]:
# Create test dataset with ground truth
def create_test_samples(data_dir, au_labels_file, num_samples=50):
    """Create test samples with ground truth explanations."""
    # Load AU labels
    au_labels = {}
    with open(au_labels_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 2:
                au_labels[parts[0]] = parts[1] if parts[1] != 'null' else ''
    
    AU_DESCRIPTIONS = {
        '1': 'Inner Brow Raiser', '2': 'Outer Brow Raiser', '4': 'Brow Lowerer',
        '6': 'Cheek Raiser', '12': 'Lip Corner Puller (Smile)', '25': 'Lips Part'
    }
    
    samples = []
    test_dir = Path(data_dir) / 'test'
    
    for emotion_dir in test_dir.iterdir():
        if not emotion_dir.is_dir():
            continue
        
        emotion_name = emotion_dir.name.replace('_', ' ')
        
        for img_path in list(emotion_dir.glob('*.jpg'))[:5]:  # 5 per class
            base_name = img_path.stem.replace('_aligned', '') + '.jpg'
            au_string = au_labels.get(base_name, '')
            
            # Create ground truth
            if au_string:
                aus = au_string.replace('+', ' ').split()
                au_text = ', '.join([AU_DESCRIPTIONS.get(au, f'AU{au}') for au in aus])
                reference = f"The emotion is {emotion_name}. Visible facial cues include: {au_text}."
            else:
                reference = f"The emotion is {emotion_name}."
            
            samples.append({
                'image_path': str(img_path),
                'emotion': emotion_name,
                'reference': reference
            })
    
    return samples[:num_samples]

AU_LABELS_FILE = "/kaggle/input/raf-au/RAFCE_AUlabel.txt"
test_samples = create_test_samples(DATA_DIR, AU_LABELS_FILE)
print(f"Created {len(test_samples)} test samples")

In [None]:
# Run inference and collect predictions
def run_inference(model, tokenizer, image_path, prompt):
    """Run inference on a single image."""
    image = Image.open(image_path).convert('RGB')
    
    messages = [
        {'role': 'user', 'content': [
            {'type': 'image'},
            {'type': 'text', 'text': prompt}
        ]}
    ]
    
    input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
    inputs = tokenizer(
        image,
        input_text,
        add_special_tokens=False,
        return_tensors='pt'
    ).to('cuda')
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            use_cache=True,
            temperature=0.7,
        )
    
    # Extract only the generated part
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Try to extract assistant response
    if 'assistant' in response.lower():
        response = response.split('assistant')[-1].strip()
    return response

prompt = "Analyze the facial expression. Classify the compound emotion and explain which facial cues led to this conclusion."
predictions = []
references = []

for sample in tqdm(test_samples, desc="Generating predictions"):
    try:
        pred = run_inference(model, tokenizer, sample['image_path'], prompt)
        predictions.append(pred)
        references.append(sample['reference'])
    except Exception as e:
        print(f"Error on {sample['image_path']}: {e}")
        predictions.append("")
        references.append(sample['reference'])

print(f"\nGenerated {len(predictions)} predictions")

In [None]:
# Calculate BLEU and ROUGE
# Filter out empty predictions
valid_preds = [(p, r) for p, r in zip(predictions, references) if p]
pred_list = [p for p, r in valid_preds]
ref_list = [[r] for p, r in valid_preds]  # BLEU expects list of references

# BLEU Score
bleu_result = bleu.compute(predictions=pred_list, references=ref_list)
print(f"BLEU Score: {bleu_result['bleu']:.4f}")

# ROUGE Scores
rouge_result = rouge.compute(predictions=pred_list, references=[r[0] for r in ref_list])
print(f"ROUGE-1: {rouge_result['rouge1']:.4f}")
print(f"ROUGE-2: {rouge_result['rouge2']:.4f}")
print(f"ROUGE-L: {rouge_result['rougeL']:.4f}")

In [None]:
# Show sample predictions
print("\n=== Sample Predictions ===")
for i in range(min(5, len(valid_preds))):
    print(f"\n--- Sample {i+1} ---")
    print(f"Reference: {ref_list[i][0]}")
    print(f"Prediction: {pred_list[i]}")

## Part 2: Vision Encoder Attention Visualization

In [None]:
# Extract attention from Vision Encoder
# Note: This is model-specific. For Qwen2-VL, we access the vision encoder attention.

def get_vision_attention(model, tokenizer, image_path):
    """Extract attention maps from the vision encoder."""
    image = Image.open(image_path).convert('RGB')
    
    messages = [
        {'role': 'user', 'content': [
            {'type': 'image'},
            {'type': 'text', 'text': 'Describe this image.'}
        ]}
    ]
    
    input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
    inputs = tokenizer(
        image,
        input_text,
        add_special_tokens=False,
        return_tensors='pt'
    ).to('cuda')
    
    # Forward pass with attention output
    with torch.no_grad():
        outputs = model(
            **inputs,
            output_attentions=True,
            return_dict=True
        )
    
    # Get attention from the last layer
    # Note: Structure depends on model architecture
    attentions = outputs.attentions if hasattr(outputs, 'attentions') else None
    
    return attentions, image

In [None]:
def visualize_attention_simple(image_path, attention_map, output_path):
    """Visualize attention map overlay on image."""
    # Load and resize image
    img = Image.open(image_path).convert('RGB')
    img = img.resize((336, 336))
    img_array = np.array(img)
    
    # Create figure
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    # Original image
    axes[0].imshow(img_array)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    # Attention heatmap
    if attention_map is not None:
        # Reshape attention to 2D grid (depends on patch size)
        h = w = int(np.sqrt(attention_map.shape[-1]))
        attn_resized = attention_map.reshape(h, w)
        attn_resized = np.array(Image.fromarray(attn_resized).resize((336, 336)))
        
        axes[1].imshow(attn_resized, cmap='jet')
        axes[1].set_title('Attention Map')
        axes[1].axis('off')
        
        # Overlay
        axes[2].imshow(img_array)
        axes[2].imshow(attn_resized, cmap='jet', alpha=0.5)
        axes[2].set_title('Attention Overlay')
        axes[2].axis('off')
    else:
        axes[1].text(0.5, 0.5, 'Attention not available', ha='center', va='center')
        axes[2].text(0.5, 0.5, 'N/A', ha='center', va='center')
    
    plt.tight_layout()
    plt.savefig(output_path)
    plt.show()

In [None]:
# Note: Full attention extraction may require model-specific hooks.
# This is a simplified example that may need adjustment based on the actual model.

print("\n=== Attention Visualization ===")
print("Note: Full Vision Encoder attention extraction requires model-specific hooks.")
print("For Qwen2-VL, you may need to access model.visual.blocks[-1].attn")
print("\nFor a complete implementation, consider using:")
print("- BertViz for transformer attention visualization")
print("- Custom forward hooks to capture intermediate activations")

## Part 3: Summary Report

In [None]:
# Save evaluation results
results = {
    'bleu': bleu_result['bleu'],
    'rouge1': rouge_result['rouge1'],
    'rouge2': rouge_result['rouge2'],
    'rougeL': rouge_result['rougeL'],
    'num_samples': len(valid_preds)
}

with open(os.path.join(OUTPUT_DIR, 'evaluation_results.json'), 'w') as f:
    json.dump(results, f, indent=2)

print("\n=== Evaluation Summary ===")
print(f"BLEU Score: {results['bleu']:.4f}")
print(f"ROUGE-1: {results['rouge1']:.4f}")
print(f"ROUGE-2: {results['rouge2']:.4f}")
print(f"ROUGE-L: {results['rougeL']:.4f}")
print(f"Samples evaluated: {results['num_samples']}")
print(f"\nResults saved to {OUTPUT_DIR}/evaluation_results.json")