# 3. Knowledge Distillation

This notebook demonstrates how to distill knowledge from a large Teacher model into a smaller Student model.

**Use Case:** Improving the performance of a small, efficient model (e.g., CLIP ViT-Base) by learning from a larger, more powerful model (e.g., CLIP ViT-Large).

Ensure you have run `01_setup_and_data.ipynb` first.

In [None]:
import os

# Ensure we are in the project root
if os.path.exists("vembed-factory"):
    os.chdir("vembed-factory")
elif not os.path.exists("run.py"):
    # Assume we are in notebooks dir
    os.chdir("..")

print(f"Working Directory: {os.getcwd()}")

In [None]:
# Config Data Paths
if os.path.exists("data/flickr30k/train.jsonl"):
    DATA_PATH = "data/flickr30k/train.jsonl"
    IMAGE_ROOT = "data/flickr30k"
    VAL_DATA_PATH = "data/flickr30k/val.jsonl"
else:
    DATA_PATH = "data/dummy/train.jsonl"
    IMAGE_ROOT = "data/dummy"
    VAL_DATA_PATH = ""

## Configuration

- **Teacher**: `openai/clip-vit-large-patch14` (Frozen)
- **Student**: `openai/clip-vit-base-patch32` (Trainable)
- **Method**: Relation Distillation (KL Divergence on Similarity Matrices)
- **Alpha**: 0.5 (50% Task Loss, 50% Distillation Loss)

In [None]:
TEACHER_MODEL = "openai/clip-vit-large-patch14"
STUDENT_MODEL = "openai/clip-vit-base-patch32"

print(f"Teacher: {TEACHER_MODEL}")
print(f"Student: {STUDENT_MODEL}")

## Start Distillation Training

In [None]:
!python run.py examples/clip_train.yaml \
    --data_path $DATA_PATH \
    --val_data_path "$VAL_DATA_PATH" \
    --image_root "$IMAGE_ROOT" \
    --config_override \
        output_dir=output_distill \
        epochs=3 \
        batch_size=32 \
        use_mrl=true \
        model_name="$STUDENT_MODEL" \
        teacher_model_name="$TEACHER_MODEL" \
        distillation_alpha=0.5 \
        distillation_temperature=2.0 \
        distillation_loss_type="kl"

## Evaluation Comparison
Compare the distilled model against the standard fine-tuned model (if available).

In [None]:
if os.path.exists(VAL_DATA_PATH):
    print("Evaluating Distilled Model...")
    !python scripts/evaluate_simple.py \
        --model_path output_distill/checkpoint-epoch-3 \
        --data_path $VAL_DATA_PATH \
        --image_root $IMAGE_ROOT \
        --output_dir eval_results_distill

    !cat eval_results_distill/evaluation_report.md