# VAE+ANS SDK Tutorial – Installation & Setup

## Installation

To use this notebook, you must install the local `vaerans_ecs` package in editable mode:

```bash
# From the repository root directory:
pip install -e ".[dev]"
```

This installs the package with development dependencies (pytest, mypy, black, ruff, hypothesis).

If you only need core runtime dependencies (without dev tools):
```bash
pip install -e .
```

## Configuration

Create `vaerans_ecs.toml` in the repository root (copy from `vaerans_ecs.toml.example`):
```bash
cp vaerans_ecs.toml.example vaerans_ecs.toml
```

Update the paths in `vaerans_ecs.toml` to point to your ONNX model files:
- `models.sdxl-vae.encoder` → path to VAE encoder ONNX
- `models.sdxl-vae.decoder` → path to VAE decoder ONNX

Alternatively, set the `VAERANS_CONFIG` environment variable to point to your config file.

## Optional Dependencies

For image display and visualization:
```bash
pip install pillow matplotlib
```


# VAE+ANS SDK Tutorial

This tutorial walks through the high-level API and the ECS pipeline using the real SDXL VAE ONNX models shipped in `models/`.

**Prereqs**
- Ensure `vaerans_ecs.toml` points to `models/vae_encoder_sdxl.onnx` and `models/vae_decoder_sdxl.onnx`.
- Install dependencies (from repo root): `pip install -e .[dev]`

Optional: install Pillow and matplotlib for image display: `pip install pillow matplotlib`


In [None]:
from pathlib import Path

import numpy as np

from vaerans_ecs.api import (
    compress,
    decompress,
    get_compression_info,
    get_compression_ratio,
)
from vaerans_ecs.components.image import ReconRGB
from vaerans_ecs.components.latent import Latent4, YUVW4
from vaerans_ecs.core.world import World
from vaerans_ecs.systems.hadamard import Hadamard4
from vaerans_ecs.systems.vae import OnnxVAEDecode, OnnxVAEEncode

repo_root = Path.cwd()
config_path = repo_root / 'vaerans_ecs.toml'
print('Config:', config_path)
assert config_path.exists(), 'vaerans_ecs.toml not found in repo root'


## 1) Load or generate an image
If `examples/23.png` exists, load it; otherwise create a random image.


In [None]:
input_path = repo_root / 'examples' / '23.png'
image = None

try:
    from PIL import Image
    if input_path.exists():
        image = np.array(Image.open(input_path).convert('RGB'))
        print(f'Loaded image: {input_path}')
except Exception:
    image = None

if image is None:
    image = np.random.randint(0, 256, (256, 256, 3), dtype=np.uint8)
    print('Generated random image:', image.shape)

image.shape, image.dtype


In [None]:
# Display image if matplotlib is available
try:
    import matplotlib.pyplot as plt
    plt.imshow(image)
    plt.axis('off')
    plt.show()
except Exception as exc:
    print('matplotlib not available:', exc)


## 2) High-level API: compress + decompress
This uses the SDXL VAE from your config. The compressed format is a placeholder for now, but the VAE path is real.


In [None]:
compressed = compress(
    image,
    model='sdxl-vae',
    quality=50,
    use_hadamard=True,
    config_path=str(config_path),
)
print('Compressed bytes:', len(compressed))
info = get_compression_info(compressed)
ratio = get_compression_ratio(image, compressed)
print('Metadata:', info)
print('Compression ratio:', f'{ratio:.2f}x')

recon = decompress(compressed, config_path=str(config_path))
print('Reconstructed:', recon.shape, recon.dtype)


In [None]:
# Simple metrics
orig = image.astype(np.float32) / 255.0
mse = float(np.mean((orig - recon) ** 2))
psnr = 10 * np.log10(1.0 / mse) if mse > 0 else float('inf')
print('MSE:', f'{mse:.6f}')
print('PSNR:', f'{psnr:.2f} dB')


## 3) ECS Pipeline: encode -> Hadamard -> decode
Use the lower-level ECS APIs to build the same pipeline explicitly.


In [None]:
world = World()
eid = world.spawn_image(image)

encoder = OnnxVAEEncode(model='sdxl-vae', config_path=str(config_path))
decoder = OnnxVAEDecode(model='sdxl-vae', config_path=str(config_path))

# Encode
encoder.run(world, [eid])
latent = world.get_component(eid, Latent4)
latent_view = world.arena.view(latent.z)
print('Latent shape:', latent_view.shape)

# Hadamard forward + inverse
Hadamard4(mode='forward').run(world, [eid])
yuvw = world.get_component(eid, YUVW4)
print('YUVW shape:', world.arena.view(yuvw.t).shape)
Hadamard4(mode='inverse').run(world, [eid])

# Decode
decoder.run(world, [eid])
recon_component = world.get_component(eid, ReconRGB)
recon_view = world.arena.view(recon_component.pix)
print('Recon shape:', recon_view.shape)


## 4) Batch processing
Encode and decode multiple images in a single run.


In [None]:
world = World()
images = [
    np.random.randint(0, 256, (256, 256, 3), dtype=np.uint8)
    for _ in range(3)
]
eids = world.spawn_batch_images(images)

encoder = OnnxVAEEncode(model='sdxl-vae', config_path=str(config_path))
decoder = OnnxVAEDecode(model='sdxl-vae', config_path=str(config_path))

encoder.run(world, eids)
decoder.run(world, eids)

for i, eid in enumerate(eids):
    recon = world.get_component(eid, ReconRGB)
    recon_view = world.arena.view(recon.pix)
    print(f'Entity {eid}: recon {recon_view.shape}')
