# UMAP Visualisation of AION Embeddings

This notebook projects the saved CAMELS embeddings into two dimensions with UMAP and colours each point according to the target parameter values. Use it to diagnose how much signal the pretrained AION encoder is capturing.

> **Note:** ensure the `umap-learn` package is available in your environment (`pip install umap-learn`).

In [3]:
import json
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import umap

#from camels_aion.data import CAMELS_FIELDS
PARAMETER_NAMES = ["Omega_m", "sigma8", "A_SN1", "A_SN2", "A_AGN1", "A_AGN2"]

sns.set_context("talk")

## Configure paths

Point the variables below to the directory containing your embedding shards and the corresponding manifest JSON. `MAX_POINTS` caps the number of samples for faster visualisation.

In [4]:
SHARD_DIR = pathlib.Path("../outputs/")
MANIFEST_PATH = SHARD_DIR / "IllustrisTNG_LH_z0p00_manifest.json"
MAX_POINTS = 5000  # adjust as needed

In [5]:
with open(MANIFEST_PATH, "r", encoding="utf-8") as fh:
    manifest = json.load(fh)

shard_paths = [SHARD_DIR / name for name in manifest["shards"]]
len(shard_paths)

469

## Load embeddings

We reuse the same logic as in the training script: load each shard, average embeddings over the token dimension, and collect the labels. To avoid exhausting memory we stop once `MAX_POINTS` samples have been gathered.

In [6]:
embeddings = []
labels = []
total = 0

for shard in shard_paths:
    payload = torch.load(shard, weights_only=False)
    emb = payload["embeddings"].float()
    if emb.ndim == 3:
        emb = emb.mean(dim=1)
    lab = payload["labels"].float()

    embeddings.append(emb)
    labels.append(lab)
    total += emb.shape[0]
    if total >= MAX_POINTS:
        break

embeddings = torch.cat(embeddings, dim=0)[:MAX_POINTS]
labels = torch.cat(labels, dim=0)[:MAX_POINTS]

embeddings.shape, labels.shape

FileNotFoundError: [Errno 2] No such file or directory: '../outputs/IllustrisTNG_LH_z0p00_00000-00031.pt'

## UMAP projection

In [None]:
reducer = umap.UMAP(n_neighbors=30, min_dist=0.2, metric="cosine", random_state=42)
embedding_2d = reducer.fit_transform(embeddings.numpy())
embedding_2d.shape

## Scatter plots coloured by parameter value

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

for ax, name in zip(axes, PARAMETER_NAMES):
    sns.scatterplot(
        x=embedding_2d[:, 0],
        y=embedding_2d[:, 1],
        hue=labels[:, PARAMETER_NAMES.index(name)].numpy(),
        palette="viridis",
        s=10,
        linewidth=0,
        ax=ax,
    )
    ax.set_title(name)
    ax.set_xlabel("UMAP-1")
    ax.set_ylabel("UMAP-2")
    ax.legend(title=name, loc="upper right", bbox_to_anchor=(1.25, 1))

plt.tight_layout()
plt.show()