
# Tutorial: Fine-tuning MAE on SOP for Image-to-Image Retrieval

In this tutorial, we will fine-tune a **Masked Autoencoder (MAE)** on the **Stanford Online Products (SOP)** dataset for the Image-to-Image (I2I) retrieval task.

**Goal**: Train a model that takes a query image and retrieves the same product from a gallery of images (different views/conditions).

**Key Concepts**:
- **Task**: Image-to-Image (I2I) Retrieval
- **Model**: `facebook/vit-mae-base` (Vision Transformer trained with Masked Autoencoding)
- **Loss**: InfoNCE (Contrastive Loss)
- **Dataset**: Stanford Online Products (SOP)


In [None]:
import os
import sys

# Ensure we are in the project root
if os.path.exists("vembed-factory"):
    os.chdir("vembed-factory")
elif os.getcwd().endswith("notebooks"):
    os.chdir("..")

print(f"Current working directory: {os.getcwd()}")

# Install dependencies if needed
# !pip install -e ".[all]"


## 1. Data Preparation

We use the Stanford Online Products (SOP) dataset.
For this tutorial, we assume you have converted the SOP dataset into JSONL format suitable for `vembed-factory`.

**Expected Format (JSONL)**:
```json
{"query_image": "path/to/img1.jpg", "positive": "path/to/img2.jpg", "class_id": 123}
```
- `query_image`: The anchor image.
- `positive`: A positive sample (same product).
- `class_id`: Product ID (used for evaluation).

If you don't have the data, please refer to `benchmark/bench_datasets/sop.py` for processing logic.


In [None]:
import os

DATA_PATH = "data/sop_i2i/train.jsonl"
VAL_DATA_PATH = "data/sop_i2i/val.jsonl"

if not os.path.exists(DATA_PATH):
    print(f"Data not found at {DATA_PATH}")
    print("Please ensure you have prepared the SOP dataset.")
else:
    print(f"Found training data: {DATA_PATH}")
    # Preview data
    !head -n 2 {DATA_PATH}


## 2. Training

We will use `vembed.Trainer` to fine-tune the MAE model.
We specify `retrieval_mode="i2i"` to indicate Image-to-Image training (Siamese Image Encoders).


In [None]:
from vembed import Trainer

# Initialize Trainer with MAE model
# mode="custom" tells the factory to use the AutoModel backend
trainer = Trainer(
    model_name="facebook/vit-mae-base",
    mode="custom",
    output_dir="experiments/output_mae_i2i_notebook"
)

# Start training
# We use a small number of epochs for demonstration
if os.path.exists(DATA_PATH):
    trainer.train(
        data_path=DATA_PATH,
        val_data_path=VAL_DATA_PATH if os.path.exists(VAL_DATA_PATH) else None,
        epochs=3,
        batch_size=64,
        learning_rate=5e-5,
        retrieval_mode="i2i",
        use_gradient_cache=True,  # Enable gradient cache for memory efficiency
        save_steps=500
    )
else:
    print("Skipping training (no data).")


## 3. Evaluation

After training, we can evaluate the model on the SOP test set.
We provide a benchmark script `benchmark/compare_sop_before_after.sh` that compares the pre-trained MAE vs. our fine-tuned MAE.

We can also test Matryoshka Representation Learning (MRL) performance by passing `mrl_dims`.

In [None]:
# Run the benchmark comparison script
# This script runs the 'sop' benchmark pipeline and calculates Recall@K

# Make sure the script is executable
!chmod +x benchmark/compare_sop_before_after.sh

# Run comparison (assuming SOP raw data is in data/stanford_online_products)
# ./benchmark/compare_sop_before_after.sh <BEFORE_MODEL> <AFTER_MODEL> [SIMILARITY_MODE] [BATCH_SIZE] [TOPK] [MRL_DIMS]

if os.path.exists("data/stanford_online_products"):
    !./benchmark/compare_sop_before_after.sh \
        facebook/vit-mae-base \
        experiments/output_mae_i2i_notebook/checkpoint-epoch-3 \
        cosine \
        64 \
        100 \
        "768 512 256"


## 4. Visualization

Let's visualize the retrieval results!
We will randomly select a few query images from the test set and show their Top-5 retrieved results using the embeddings we just computed.


In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

# Add project root to path
if os.path.basename(os.getcwd()) == "notebooks":
    sys.path.append("..")
else:
    sys.path.append(".")

# Import dataset loader from vembed-factory
try:
    from benchmark.bench_datasets.sop import _load_sop_entries
except ImportError:
    # Fallback if running from a different context
    sys.path.append(os.path.abspath(".."))
    from benchmark.bench_datasets.sop import _load_sop_entries

# Config
SOP_ROOT = "data/stanford_online_products"
EMB_DIR = "experiments/benchmark_output_sop_compare/after"

def visualize_results():
    if not os.path.exists(os.path.join(EMB_DIR, "test_query_embeddings.npy")):
        print(f"Embeddings not found in {EMB_DIR}. Please run the benchmark step above.")
        return

    print("Loading embeddings...")
    q_emb = np.load(os.path.join(EMB_DIR, "test_query_embeddings.npy"))
    d_emb = np.load(os.path.join(EMB_DIR, "test_doc_embeddings.npy"))

    q_norm = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-12)
    d_norm = d_emb / (np.linalg.norm(d_emb, axis=1, keepdims=True) + 1e-12)

    print("Loading SOP dataset index...")
    try:
        entries, class_ids = _load_sop_entries(SOP_ROOT, "test")
    except FileNotFoundError:
        print(f"SOP dataset index not found in {SOP_ROOT}. Cannot visualize images.")
        return

    if len(entries) == 0:
        print("Dataset is empty.")
        return

    # Sanity Check
    if len(entries) != len(q_emb):
        print(f"Warning: Dataset size ({len(entries)}) != Embedding size ({len(q_emb)}).")
        print("Using the minimum size for safety.")
        min_len = min(len(entries), len(q_emb))
        entries = entries[:min_len]
        q_norm = q_norm[:min_len]
        d_norm = d_norm[:min_len]

    def show_retrieval(idx, topk=5):
        query_vec = q_norm[idx]

        # Compute scores: (1, D) @ (N, D).T -> (1, N)
        scores = query_vec @ d_norm.T

        # Get Top-K indices
        top_indices = np.argsort(scores)[::-1][:topk]

        # Plot
        fig, axes = plt.subplots(1, topk + 1, figsize=(15, 3))

        # Query Image
        q_path = entries[idx]["query_image"]
        try:
            img = Image.open(q_path)
            axes[0].imshow(img)
            axes[0].set_title(f"Query\nClass: {entries[idx]['class_id']}")
            axes[0].axis("off")
            # Add a colored border to query
            for spine in axes[0].spines.values():
                spine.set_edgecolor('blue')
                spine.set_linewidth(2)
        except Exception as e:
            axes[0].text(0.5, 0.5, "Img Not Found", ha="center")
            print(f"Error loading {q_path}: {e}")

        # Results
        for i, res_idx in enumerate(top_indices):
            res_path = entries[res_idx]["positive"]
            score = scores[res_idx]
            is_same_class = (entries[idx]["class_id"] == entries[res_idx]["class_id"])

            ax = axes[i+1]
            try:
                img = Image.open(res_path)
                ax.imshow(img)

                # Title color based on correctness
                color = "green" if is_same_class else "red"
                title = f"Rank {i+1}\n{score:.3f}"
                ax.set_title(title, color=color, fontweight="bold")
                ax.axis("off")

            except Exception:
                ax.text(0.5, 0.5, "Img Not Found", ha="center")

        plt.tight_layout()
        plt.show()

    print("Visualizing random samples...")
    indices = np.random.choice(len(entries), 3, replace=False)
    for idx in indices:
        show_retrieval(idx, topk=5)

    # Execute
    visualize_results()