# DAW-JEPA – Colab Repro Notebook

This notebook shows **step-by-step how to run my DAW-JEPA project code on Google Colab**:

1. Set up the runtime and clone the GitHub repo.
2. Prepare the ImageNet-100 data used for pretraining.
3. Run I-JEPA / DAW-JEPA pretraining on ImageNet-100.
4. Run linear probing on STL10 / CIFAR-10.
5. Run k-NN evaluation on STL10 / CIFAR-10.

> Note: This notebook is designed to be self-contained and readable for grading. > It is not meant to be perfectly resource-efficient – full pretraining on ImageNet-100 > is still compute-heavy, even on Colab GPUs.


## 0. Runtime & environment

Please make sure the Colab runtime has a **GPU** attached:

- Runtime → Change runtime type → Hardware accelerator → GPU.


In [None]:
# Check GPU
!nvidia-smi || echo "No GPU found. Please enable GPU in Colab runtime settings."

import torch
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())


## 1. Clone the GitHub repo

Replace `<YOUR_GITHUB_USERNAME>` with my GitHub username if needed.
The repo corresponds exactly to the code submitted for this project.


In [None]:
# Clone the DAW-JEPA repo (GitHub URL to be updated by me before submission)
%cd /content
!git clone https://github.com/<YOUR_GITHUB_USERNAME>/daw-jepa.git
%cd daw-jepa

# List top-level files for sanity check
!ls


## 2. Install additional Python packages (if needed)

Colab already comes with `torch` and `torchvision`.  
Here we only install **lightweight utilities** that may not be pre-installed.

If the following cell fails because a package is already installed, it is safe to ignore the warning.


In [None]:
# Optional: install utilities used in this project
!pip install -q pyyaml pandas matplotlib kagglehub wandb

## 3. Prepare the ImageNet-100 dataset

The pretraining in this project uses **ImageNet-100**.

Because of licensing, the dataset itself is **not included** in the repo.  
This notebook assumes **one** of the following:

1. You already have `imagenet100.zip` stored in your Google Drive at:
   - `/content/drive/MyDrive/imagenet100.zip`, or
2. You adapt the paths below to your own ImageNet-100 location.

In my experiments, I used Option 1. The following cells reproduce that setup.


In [None]:
# 3.1 Mount Google Drive (required if dataset is stored there)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 3.2 Unzip ImageNet-100 from Google Drive into the Colab runtime
# If your zip file is in a different location, please update the path below.
!unzip -q /content/drive/MyDrive/imagenet100.zip -d /content/imagenet100_local

# After this, the expected directory structure is:
# /content/imagenet100_local/train/<class_name>/*.JPEG
# /content/imagenet100_local/val/<class_name>/*.JPEG

!ls /content/imagenet100_local || echo "Please check that imagenet100.zip exists in your Drive."


### 3.3 Verify config data paths

The training configs used in this project expect:

```yaml
data:
  root_path: /content
  image_folder: imagenet100_local
```

This matches the directory layout created above (`/content/imagenet100_local`).  
If you change the data path, please update the YAML configs under `configs/` accordingly.


## 4. Run pretraining (I-JEPA and DAW-JEPA)

This section shows how to:

- Train a **baseline I-JEPA** model on ImageNet-100.
- Train **DAW-JEPA** variants (EMA and Instant modes).

All commands run from the repo root (`/content/daw-jepa`).

In [None]:
# Ensure we are in the repo root
%cd /content/daw-jepa
!pwd

### 4.1 Baseline I-JEPA on ImageNet-100 (optional but recommended)

This uses the original I-JEPA objective without difficulty-aware weighting.

Config file: `configs/in100_vits16_ep100.yaml`

> Note: Full training for 100 epochs is compute-heavy. For a quick smoke test, > you can reduce `num_epochs` in the config or stop the run early.


In [None]:
# Baseline I-JEPA pretraining on ImageNet-100
!python main.py --fname configs/in100_vits16_ep100.yaml --devices cuda:0

### 4.2 DAW-JEPA (EMA mode)

This enables **difficulty-aware weighting** with an EMA difficulty buffer.

Config file: `configs/in100_vits16_ep100_daw_ema.yaml`

Key DAW settings (inside the YAML):

```yaml
daw:
  enabled: true
  mode: ema
  ema_alpha: 0.7
  gamma: 0.3
  w_min: 0.8
  w_max: 1.2
  warmup_epochs: 20
```

The command below starts pretraining with these settings.


In [None]:
# DAW-JEPA pretraining (EMA difficulty)
!python main.py --fname configs/in100_vits16_ep100_daw_ema.yaml --devices cuda:0

### 4.3 DAW-JEPA (Instant mode)

This variant updates the difficulty buffer **instantly** with the current loss instead of using EMA.

Config file: `configs/in100_vits16_ep100_daw_instant.yaml`

Example DAW settings:

```yaml
daw:
  enabled: true
  mode: instant
  gamma: 0.3
  w_min: 0.8
  w_max: 1.2
  warmup_epochs: 10   # used in my best CIFAR-10 run
```

Run the command below to launch Instant-mode DAW-JEPA pretraining.


In [None]:
# DAW-JEPA pretraining (Instant difficulty)
!python main.py --fname configs/in100_vits16_ep100_daw_instant.yaml --devices cuda:0

### 4.4 Locate the latest checkpoints

Each run writes checkpoints and logs under the `logs/` directory.

The following helper cell prints all `jepa-latest.pth.tar` checkpoints and lets you pick one.


In [None]:
import glob, os

ckpts = sorted(glob.glob("logs/**/jepa-latest.pth.tar", recursive=True))
print("Found checkpoints:")
for i, p in enumerate(ckpts):
    print(f"[{i}] {p}")

# For convenience, you can select one index here:
SELECTED = -1  # -1 means "last one"; or set to 0, 1, 2, ...

if ckpts:
    CKPT_PATH = ckpts[SELECTED]
    print("\nUsing CKPT_PATH =", CKPT_PATH)
else:
    CKPT_PATH = None
    print("No checkpoints found. Please run a pretraining cell above first.")

## 5. Linear probing on STL10 / CIFAR-10

This section evaluates the **frozen encoder** using a linear classifier.

The main script is `src/linprobe.py`.  
Datasets (STL10 / CIFAR-10) will be downloaded automatically under `./data` by torchvision.

Before running, please make sure:

- `CKPT_PATH` above points to a valid `jepa-latest.pth.tar` checkpoint.


In [None]:
# Sanity-check CKPT_PATH
print("Current CKPT_PATH:", CKPT_PATH)

### 5.1 STL10 linear probe

This trains a linear classifier on top of the frozen JEPA encoder using STL10.


In [None]:
# Linear probe on STL10
!python -m src.linprobe \
  --ckpt_path "$CKPT_PATH" \
  --dataset stl10 \
  --data_root ./data \
  --crop_size 224 \
  --batch_size 256 \
  --epochs 100 \
  --lr 0.1

### 5.2 CIFAR-10 linear probe

This trains a linear classifier on CIFAR-10.


In [None]:
# Linear probe on CIFAR-10
!python -m src.linprobe \
  --ckpt_path "$CKPT_PATH" \
  --dataset cifar10 \
  --data_root ./data \
  --crop_size 224 \
  --batch_size 256 \
  --epochs 100 \
  --lr 0.1

## 6. k-NN evaluation on STL10 / CIFAR-10

This section evaluates the frozen encoder using a **non-parametric k-NN classifier**.

Script: `src.knn_eval`

The following commands reproduce the evaluations I reported in the project write-up.


### 6.1 k-NN on STL10

Uses k = 20 by default (can be changed via `--k`).

In [None]:
# k-NN evaluation on STL10
!python -m src.knn_eval \
  --ckpt_path "$CKPT_PATH" \
  --data_root ./data \
  --dataset stl10 \
  --k 20 \
  --batch_size 512

### 6.2 k-NN on CIFAR-10

Uses k = 50 by default (can be changed via `--k`).

In [None]:
# k-NN evaluation on CIFAR-10
!python -m src.knn_eval \
  --ckpt_path "$CKPT_PATH" \
  --data_root ./data \
  --dataset cifar10 \
  --k 50 \
  --batch_size 512

## 7.  DAW hyper-parameter sweep

The script `sweep.py` generates DAW configs from a base YAML and runs multiple pretraining jobs
with different `daw.mode` / `daw.gamma` settings.

Because this can be computationally expensive, this section is **optional** for reproducing the main results,
but is included here for completeness.


In [None]:
# Launch a DAW sweep (optional; may take a long time)
# This will use the base config inside sweep.py and generate configs under configs/sweep/
!python sweep.py