# Pre-encode

> Pre-encode images using frozen encoder for faster decoder training

## Notes

This script pre-encodes images using a trained encoder checkpoint, saving the embeddings for faster decoder training.

**Usage:**
```bash
python midi_rae/preencode.py encoder_ckpt=checkpoints/best.pt preencode.output_dir=preencoded/
```

**TODO:**
- May need a simpler Dataset that returns single images (not pairs) + their filenames
- Decide on output format: one `.pt` per image, or chunked/batched files?
- Add config entries for `encoder_ckpt` path and `preencode.output_dir`

In [None]:
#| default_exp preencode

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os
import torch
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf
import hydra
from tqdm.auto import tqdm
from pathlib import Path

from midi_rae.vit import ViTEncoder
from midi_rae.data import PRPairDataset  # we'll use use img2 and ignore img1

In [None]:
#| export
@hydra.main(version_base=None, config_path="../configs", config_name="config")
def preencode(cfg: DictConfig):
    device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
    print(f"device = {device}")
    
    # Load encoder from checkpoint
    ckpt_path = cfg.get('encoder_ckpt', 'checkpoints/enc_best.pt')
    print(f"Loading encoder from {ckpt_path}")
    
    model = ViTEncoder(
        cfg.data.in_channels, 
        (cfg.data.image_size, cfg.data.image_size), 
        cfg.model.patch_size,
        cfg.model.dim, 
        cfg.model.depth, 
        cfg.model.heads
    ).to(device)
    
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
    state_dict = {k.replace('_orig_mod.', ''): v for k, v in ckpt['model_state_dict'].items()}
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    
    # Output directory
    output_dir = Path(cfg.get('preencode', {}).get('output_dir', 'preencoded/'))
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"Saving embeddings to {output_dir}")
    
    for split in ['train', 'val']:
        print(f"\nProcessing {split} split...")
        ds = PRPairDataset(split=split, max_shift_x=cfg.training.max_shift_x, max_shift_y=cfg.training.max_shift_y)
        dl = DataLoader(ds, batch_size=cfg.training.batch_size, num_workers=4, shuffle=False)
        
        num_chunks = cfg.preencode.num_passes # chunk = 1 pass thru ds
        for chunk in range(1,num_chunks+1):
            chunk_embeddings = []
            chunk_images = []  # optionally save original images too for reconstruction comparison
            with torch.no_grad():
                for batch in tqdm(dl, desc=f"Encoding {split}, Chunk {chunk}/{num_chunks}"):
                    img = batch['img2'].to(device)  # img2 come from wider distribution than img1, ignore img1
                    z = model(img, return_cls_only=False)  # (B, 65, 768)
                    chunk_embeddings.append(z.cpu())
                    chunk_images.append(img.cpu())
            
            # Concatenate and save
            embeddings = torch.cat(chunk_embeddings, dim=0)
            images = torch.cat(chunk_images, dim=0)
            
            save_path = output_dir / f"{split}_embeddings_{chunk}.pt"
            torch.save({
                'embeddings': embeddings,
                'images': images,  # for reconstruction loss computation
            }, save_path)
            print(f"Saved {len(embeddings)} embeddings to {save_path}")
            print(f"  embeddings shape: {embeddings.shape}")
            print(f"  images shape: {images.shape}")


In [None]:
#| export
#| eval: false
if __name__ == "__main__" and "ipykernel" not in __import__("sys").modules:
    preencode()

help me understand this error. I never saw anything like this when I was just training the encoder.
```
Encoding train, Chunk 1/1000:   0%|                                                                                                                                 | 0/186 [00:00<?, ?it/s]
Error executing job with overrides: []
Traceback (most recent call last):
  File "/home/shawley/github/midi-rae/midi_rae/preencode.py", line 59, in preencode
    for batch in tqdm(dl, desc=f"Encoding {split}, Chunk {chunk}/{num_chunks}"):
  File "/home/shawley/envs/midi-rae/lib/python3.10/site-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/home/shawley/envs/midi-rae/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 741, in __next__
    data = self._next_data()
  File "/home/shawley/envs/midi-rae/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1548, in _next_data
    return self._process_data(data, worker_id)
  File "/home/shawley/envs/midi-rae/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1586, in _process_data
    data.reraise()
  File "/home/shawley/envs/midi-rae/lib/python3.10/site-packages/torch/_utils.py", line 775, in reraise
    raise exception
IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/shawley/envs/midi-rae/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 358, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/shawley/envs/midi-rae/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/shawley/envs/midi-rae/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/shawley/github/midi-rae/midi_rae/data.py", line 63, in __getitem__
    shift_x = random.choice([i for i in range(-self.max_shift_x, self.max_shift_x + 1) if i != 0])
  File "/usr/lib/python3.10/random.py", line 378, in choice
    return seq[self._randbelow(len(seq))]
IndexError: list index out of range



##### ğŸ¤–ReplyğŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ 

The issue is in this line from your `PRPairDataset`:

```python
shift_x = random.choice([i for i in range(-self.max_shift_x, self.max_shift_x + 1) if i != 0])
```

When you pass `max_shift_x=0`, the list comprehension produces an **empty list** (since the only value in the range is 0, which gets filtered out by `if i != 0`). Then `random.choice([])` fails.

You're creating the dataset with:
```python
ds = PRPairDataset(split=split, max_shift_x=0, max_shift_y=0)
```

The fix would be in `PRPairDataset.__getitem__` to handle the zero-shift case, something like:

```python
shift_x = 0 if self.max_shift_x == 0 else random.choice([i for i in range(-self.max_shift_x, self.max_shift_x + 1) if i != 0])
```

(And same for `shift_y`.)

Does that make sense?

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

Oh, we should read the max shift from the configuration file. That was where the encoder was trained. e.g. `cfg.training.max_shift_x` etc

##### ğŸ¤–ReplyğŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ ğŸ§ 

Good idea! Change this line:

```python
ds = PRPairDataset(split=split, max_shift_x=0, max_shift_y=0)
```

to:

```python
ds = PRPairDataset(split=split, max_shift_x=cfg.training.max_shift_x, max_shift_y=cfg.training.max_shift_y)
```

That way it stays consistent with how the encoder was trained.