# Pre-encode

> Pre-encode images using frozen encoder for faster decoder training

## Notes

This script pre-encodes images using a trained encoder checkpoint, saving the embeddings for faster decoder training.

**Usage:**
```bash
python midi_rae/preencode.py encoder_ckpt=checkpoints/best.pt preencode.output_dir=preencoded/
```

**TODO:**
- May need a simpler Dataset that returns single images (not pairs) + their filenames
- Decide on output format: one `.pt` per image, or chunked/batched files?
- Add config entries for `encoder_ckpt` path and `preencode.output_dir`

In [None]:
#| default_exp preencode

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os
import torch
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf
import hydra
from tqdm.auto import tqdm
from pathlib import Path

from midi_rae.vit import ViTEncoder
from midi_rae.data import PRPairDataset  # we'll use use img2 and ignore img1

In [None]:
#| export
@hydra.main(version_base=None, config_path="../configs", config_name="config")
def preencode(cfg: DictConfig):
    device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
    print(f"device = {device}")
    
    # Load encoder from checkpoint
    ckpt_path = cfg.get('encoder_ckpt', 'checkpoints/best.pt')
    print(f"Loading encoder from {ckpt_path}")
    
    model = ViTEncoder(
        cfg.data.in_channels, 
        (cfg.data.image_size, cfg.data.image_size), 
        cfg.model.patch_size,
        cfg.model.dim, 
        cfg.model.depth, 
        cfg.model.heads
    ).to(device)
    
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()
    
    # Output directory
    output_dir = Path(cfg.get('preencode', {}).get('output_dir', 'preencoded/'))
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"Saving embeddings to {output_dir}")
    
    # TODO: May want want a single-image dataset here instead of PRPairDataset
    # For now, using PRPairDataset but only encoding img1
    for split in ['train', 'val']:
        print(f"\nProcessing {split} split...")
        ds = PRPairDataset(split=split, max_shift_x=0, max_shift_y=0)
        dl = DataLoader(ds, batch_size=cfg.training.batch_size, num_workers=4, shuffle=False)
        
        all_embeddings = []
        all_images = []  # optionally save original images too for reconstruction comparison
        
        with torch.no_grad():
            for batch in tqdm(dl, desc=f"Encoding {split}"):
                img = batch['img1'].to(device)
                z = model(img, return_cls_only=False)  # (B, 65, 768)
                all_embeddings.append(z.cpu())
                all_images.append(img.cpu())
        
        # Concatenate and save
        embeddings = torch.cat(all_embeddings, dim=0)
        images = torch.cat(all_images, dim=0)
        
        save_path = output_dir / f"{split}_embeddings.pt"
        torch.save({
            'embeddings': embeddings,
            'images': images,  # for reconstruction loss computation
        }, save_path)
        print(f"Saved {len(embeddings)} embeddings to {save_path}")
        print(f"  embeddings shape: {embeddings.shape}")
        print(f"  images shape: {images.shape}")

In [None]:
#| export
#| eval: false
if __name__ == "__main__" and "ipykernel" not in __import__("sys").modules:
    preencode()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()