# H-JEPA Interactive Model Explorer

This notebook provides an interactive environment for exploring trained H-JEPA models.

**Features:**
- Load and inspect model checkpoints
- Visualize hierarchical representations
- Analyze attention patterns
- Explore feature activations
- Run quick evaluations

**Usage:**
1. Set the checkpoint path below
2. Run cells sequentially
3. Experiment with different samples and visualizations

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from pathlib import Path
from PIL import Image
import torchvision.transforms as transforms
from IPython.display import display
import ipywidgets as widgets
from tqdm.notebook import tqdm

# Add project root to path
import sys
sys.path.append('..')

from src.models.hjepa import create_hjepa
from src.data.datasets import get_dataset

%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

## 1. Configuration

In [None]:
# ============ CONFIGURE THESE ============
CHECKPOINT_PATH = "../results/validation_test/checkpoints/checkpoint_epoch_15.pt"
DEVICE = "mps"  # or "cuda" or "cpu"
DATASET = "cifar10"  # or "cifar100" or "stl10"
# =========================================

## 2. Load Model and Data

In [None]:
def load_model(checkpoint_path, device="mps"):
    """Load model from checkpoint"""
    print(f"Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    config = checkpoint.get('config', {})
    model_state = checkpoint.get('model_state_dict', checkpoint.get('target_encoder', {}))
    
    model = create_hjepa(
        encoder_type=config.get('model', {}).get('encoder_type', 'vit_base_patch16_224'),
        img_size=config.get('data', {}).get('image_size', 224),
        num_hierarchies=config.get('model', {}).get('num_hierarchies', 3),
        predictor_depth=config.get('model', {}).get('predictor', {}).get('depth', 6),
        predictor_heads=config.get('model', {}).get('predictor', {}).get('num_heads', 6),
        use_rope=config.get('model', {}).get('use_rope', True),
        use_flash_attention=config.get('model', {}).get('use_flash_attention', True),
    )
    
    model.load_state_dict(model_state, strict=False)
    model = model.to(device)
    model.eval()
    
    print(f"✓ Model loaded")
    print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"  Hierarchies: {config.get('model', {}).get('num_hierarchies', 3)}")
    
    return model, config

# Load model
model, config = load_model(CHECKPOINT_PATH, DEVICE)

In [None]:
# Load dataset
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = get_dataset(DATASET, root='../data', train=False, transform=transform)
print(f"✓ Loaded {len(dataset)} test samples from {DATASET}")

## 3. Utility Functions

In [None]:
def denormalize_image(image):
    """Denormalize image for visualization"""
    image = image.clone()
    image = image * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    image = image + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    return image.clamp(0, 1)

def show_image(image, title="", ax=None):
    """Display an image"""
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
    
    image_vis = denormalize_image(image).permute(1, 2, 0).cpu().numpy()
    ax.imshow(image_vis)
    ax.set_title(title)
    ax.axis('off')
    
    return ax

## 4. Interactive Sample Explorer

Use the slider to browse through different samples from the dataset.

In [None]:
@widgets.interact(sample_idx=widgets.IntSlider(min=0, max=len(dataset)-1, step=1, value=0))
def explore_sample(sample_idx=0):
    image, label = dataset[sample_idx]
    
    with torch.no_grad():
        image_input = image.unsqueeze(0).to(DEVICE)
        embeddings = model.target_encoder.forward_hierarchical(image_input)
    
    # Display
    num_hierarchies = len(embeddings)
    fig, axes = plt.subplots(1, num_hierarchies + 1, figsize=(4 * (num_hierarchies + 1), 4))
    
    # Show original image
    show_image(image, f"Sample {sample_idx}\nClass: {label}", axes[0])
    
    # Show hierarchical representations (PCA projection)
    for i, emb in enumerate(embeddings):
        B, C, H, W = emb.shape
        emb_flat = emb.view(B, C, -1).squeeze(0)
        
        # PCA to 3D
        U, S, V = torch.svd(emb_flat)
        proj_3d = (U[:, :3].T @ emb_flat).cpu().numpy()
        proj_grid = proj_3d.reshape(3, H, W).transpose(1, 2, 0)
        proj_grid = (proj_grid - proj_grid.min()) / (proj_grid.max() - proj_grid.min() + 1e-8)
        
        axes[i + 1].imshow(proj_grid)
        axes[i + 1].set_title(f'Hierarchy {i+1}\n{H}x{W}')
        axes[i + 1].axis('off')
    
    plt.tight_layout()
    plt.show()

## 5. Attention Visualization

In [None]:
def visualize_attention(sample_idx=0, layer_idx=5, head_idx=0):
    """Visualize attention patterns for a specific layer and head"""
    image, label = dataset[sample_idx]
    image_input = image.unsqueeze(0).to(DEVICE)
    
    # Forward through encoder
    encoder = model.target_encoder
    x = encoder.vit.patch_embed(image_input)
    
    if hasattr(encoder.vit, 'pos_embed') and encoder.vit.pos_embed is not None:
        x = x + encoder.vit.pos_embed
    
    # Forward to specified layer
    with torch.no_grad():
        for i, block in enumerate(encoder.vit.blocks):
            if i == layer_idx:
                # Extract attention
                B, N, C = x.shape
                qkv = block.attn.qkv(block.norm1(x))
                qkv = qkv.reshape(B, N, 3, block.attn.num_heads, C // block.attn.num_heads)
                qkv = qkv.permute(2, 0, 3, 1, 4)
                q, k, v = qkv[0], qkv[1], qkv[2]
                
                attn = (q @ k.transpose(-2, -1)) * block.attn.scale
                attn = attn.softmax(dim=-1)
                break
            
            x = block(x)
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original image
    show_image(image, f"Sample {sample_idx}", axes[0])
    
    # CLS token attention
    attn_cls = attn[0, head_idx, 0, 1:].cpu().numpy()
    grid_size = int(np.sqrt(len(attn_cls)))
    attn_grid = attn_cls.reshape(grid_size, grid_size)
    
    axes[1].imshow(attn_grid, cmap='viridis')
    axes[1].set_title(f'Attention Map\nLayer {layer_idx}, Head {head_idx}')
    axes[1].axis('off')
    
    # Overlay
    image_vis = denormalize_image(image).permute(1, 2, 0).cpu().numpy()
    attn_resized = np.array(Image.fromarray(attn_grid).resize((224, 224), Image.BILINEAR))
    
    axes[2].imshow(image_vis)
    axes[2].imshow(attn_resized, cmap='jet', alpha=0.5)
    axes[2].set_title('Overlay')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Interactive controls
widgets.interact(
    visualize_attention,
    sample_idx=widgets.IntSlider(min=0, max=len(dataset)-1, step=1, value=0),
    layer_idx=widgets.IntSlider(min=0, max=11, step=1, value=5),
    head_idx=widgets.IntSlider(min=0, max=11, step=1, value=0)
)

## 6. Feature Similarity Search

Find similar images based on learned representations.

In [None]:
def extract_features(num_samples=1000):
    """Extract features from dataset"""
    features_list = []
    
    with torch.no_grad():
        for i in tqdm(range(min(num_samples, len(dataset))), desc="Extracting features"):
            image, _ = dataset[i]
            image_input = image.unsqueeze(0).to(DEVICE)
            
            embeddings = model.target_encoder.forward_hierarchical(image_input)
            emb = embeddings[-1]  # Highest hierarchy
            feat = F.adaptive_avg_pool2d(emb, 1).flatten()
            feat = F.normalize(feat, p=2, dim=0)
            
            features_list.append(feat.cpu())
    
    return torch.stack(features_list)

print("Extracting features... (this may take a minute)")
features = extract_features(1000)
print(f"✓ Extracted {len(features)} feature vectors")

In [None]:
def find_similar_images(query_idx=0, top_k=9):
    """Find most similar images to query"""
    query_feat = features[query_idx]
    
    # Compute similarities
    similarities = features @ query_feat
    
    # Get top-k (excluding query itself)
    top_k_indices = similarities.topk(top_k + 1).indices[1:].numpy()
    top_k_sims = similarities.topk(top_k + 1).values[1:].numpy()
    
    # Visualize
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    axes = axes.flatten()
    
    # Show query
    query_image, query_label = dataset[query_idx]
    show_image(query_image, f"Query (idx={query_idx})", axes[0])
    
    # Show similar images
    for i, (idx, sim) in enumerate(zip(top_k_indices, top_k_sims)):
        image, label = dataset[idx]
        show_image(image, f"#{i+1} (sim={sim:.3f})", axes[i + 1])
    
    plt.tight_layout()
    plt.show()

# Interactive controls
widgets.interact(
    find_similar_images,
    query_idx=widgets.IntSlider(min=0, max=len(features)-1, step=1, value=0),
    top_k=widgets.IntSlider(min=4, max=9, step=1, value=9)
)

## 7. Quick k-NN Evaluation

In [None]:
# Extract features from train set
train_dataset = get_dataset(DATASET, root='../data', train=True, transform=transform)
train_features = []
train_labels = []

with torch.no_grad():
    for i in tqdm(range(min(5000, len(train_dataset))), desc="Extracting train features"):
        image, label = train_dataset[i]
        image_input = image.unsqueeze(0).to(DEVICE)
        
        embeddings = model.target_encoder.forward_hierarchical(image_input)
        emb = embeddings[-1]
        feat = F.adaptive_avg_pool2d(emb, 1).flatten()
        feat = F.normalize(feat, p=2, dim=0)
        
        train_features.append(feat.cpu())
        train_labels.append(label)

train_features = torch.stack(train_features)
train_labels = torch.tensor(train_labels)

print(f"✓ Extracted {len(train_features)} train features")

In [None]:
# Run k-NN evaluation
def knn_accuracy(k=20):
    """Compute k-NN accuracy"""
    test_features = features
    test_labels = torch.tensor([dataset[i][1] for i in range(len(features))])
    
    correct = 0
    
    for i in tqdm(range(len(test_features)), desc=f"k-NN (k={k})"):
        # Compute similarities
        sims = train_features @ test_features[i]
        
        # Get top-k
        topk_indices = sims.topk(k).indices
        topk_labels = train_labels[topk_indices]
        
        # Majority vote
        pred = topk_labels.mode().values.item()
        
        if pred == test_labels[i].item():
            correct += 1
    
    accuracy = 100.0 * correct / len(test_features)
    return accuracy

# Test different k values
k_values = [1, 5, 10, 20]
accuracies = []

for k in k_values:
    acc = knn_accuracy(k)
    accuracies.append(acc)
    print(f"k={k:2d}: {acc:.2f}%")

# Plot
plt.figure(figsize=(8, 5))
plt.plot(k_values, accuracies, marker='o', linewidth=2, markersize=8)
plt.xlabel('k')
plt.ylabel('Accuracy (%)')
plt.title('k-NN Accuracy vs k')
plt.grid(alpha=0.3)
plt.show()

## 8. Export Features

Export features for further analysis or downstream tasks.

In [None]:
# Save features
output_dir = Path('../results/exported_features')
output_dir.mkdir(parents=True, exist_ok=True)

torch.save({
    'train_features': train_features,
    'train_labels': train_labels,
    'test_features': features,
    'test_labels': torch.tensor([dataset[i][1] for i in range(len(features))]),
    'config': config,
}, output_dir / f'{DATASET}_features.pt')

print(f"✓ Features saved to {output_dir}/{DATASET}_features.pt")

## 9. Summary

This notebook demonstrated:
1. ✓ Loading and inspecting H-JEPA models
2. ✓ Visualizing hierarchical representations
3. ✓ Analyzing attention patterns
4. ✓ Similarity search in feature space
5. ✓ Quick k-NN evaluation
6. ✓ Exporting features for downstream tasks

**Next steps:**
- Run full linear probing evaluation (see `scripts/eval_linear_probe.py`)
- Test transfer learning on other datasets (see `scripts/eval_transfer.py`)
- Explore feature visualizations (see `scripts/visualize_features.py`)
- Analyze attention rollout (see `scripts/visualize_attention_rollout.py`)