In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import json
import numpy as np

from ae_latent.data.builder import build_dataloaders
from ae_latent.models.builders import build_vect_latent_ae
from ae_latent.analysis.latent.extract import extract_latents_to_npz

All splits

In [3]:
run_name = "celeba_z512"
runs_dir = "../runs/"


data_splits = ["train", "test", "val"]  # ["train", "test", "val"] - full list or just the ones you want

run_dir = runs_dir + run_name
# Load root JSON config
with open(run_dir+"/config.json", "r") as f:
    cfg = json.load(f)

# IMPORTANT: ensure cfg["data"]["return_ids"] == True
cfg["data"]["return_ids"] = True

loaders = build_dataloaders(cfg)

model = build_vect_latent_ae(
    {**cfg["model"], "z_dim": cfg["z_dim"], "dataset": cfg["dataset"]},
    device=cfg["device"],
    ckpt_path=run_dir+"/best.pt",
)

for split in data_splits:
    loader = loaders[split]
    out = extract_latents_to_npz(
        model,
        loader,
        run_dir+f"/latents/{split}_latents.npz",
        overwrite=True,
        use_amp=cfg.get("use_amp", False),
    )
    print("saved:", out)
print("done")




saved: ../runs/celeba_z512/latents/train_latents.npz
saved: ../runs/celeba_z512/latents/test_latents.npz
saved: ../runs/celeba_z512/latents/val_latents.npz
done


Format

In [None]:
# how to use
data = np.load("PATH/TO.npz")

print(data.files)
# ['Z', 'y', 'base_idx', 'dataset', 'split']

# example: celeba z256
#
# Z:        (162770, 256)
# y:        (162770, 40)
# base_idx: (162770,)
# dataset:  (162770,) - just repeats "celeba"
# split:    (162770,) - just repeats "train"

Z = data["Z"]
y = data["y"]