In [1]:
# Cell 1: Imports & project-root setup
from pathlib import Path
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt

def find_project_root(marker_dir: str = "artifacts") -> Path:
    start = Path(__file__).resolve() if "__file__" in globals() else Path.cwd().resolve()
    for parent in (start, *start.parents):
        if (parent / marker_dir).is_dir():
            return parent
    return Path.cwd().resolve()

PROJECT_ROOT = find_project_root("artifacts")
sys.path.insert(0, str(PROJECT_ROOT))

from common.config import DEVICE, dims
from models.autoencoder.dataset_autoencoder import DatasetAutoencoder
from models.autoencoder.utils import split_dataset
from models.autoencoder.architectures.flexible_autoencoder import ConvAutoencoder
from models.autoencoder.architectures.flexible_autoencoder_norm_end import ConvAutoencoderNormEnd

from models.pixel_nn.dataset_pixel import PixelDataset

config.py: DEVICE is set as cuda


In [2]:
# Cell 2: Key constants (EDIT THESE AS NEEDED)
DATA_PATH       = Path("data/waveforms")
REDUCTION, N    = "resample", 200
TRAIN_FRAC      = 0.8
SEED            = 42
BATCH_SIZE      = 50
LATENT_DIM      = 32
DROPOUT         = 0.1
USE_BATCHNORM   = True
LR              = 1e-3

CKPT_DIR        = PROJECT_ROOT / "artifacts" / "autoencoder" / "checkpoints"
CKPT_FILENAME   = "ConvAENormEnd_resample200_lat32_do10_bn_2025-06-26_11-30-01.pt"
CKPT_PATH       = CKPT_DIR / CKPT_FILENAME

CKPT_PATH_OLD = CKPT_DIR / "ConvAE_resample200_lat32_bn_2025-06-25_11-23-01.pt"


In [3]:
# Cell 3: Load the model

# load dataset and make dataloader for validation data
dataset = DatasetAutoencoder(
    path=Path("data/waveforms"),
    reduction=REDUCTION,
    n=N,
    save=False,        # won’t try to write a new cache
    force_reload=False # will load the cache if present
)

_, val_set = split_dataset(dataset, train_frac=TRAIN_FRAC, seed=SEED)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False)

# load model
loaded_model = ConvAutoencoderNormEnd.load(
    path = CKPT_PATH,
    dataset = dataset,
    device=DEVICE
).to(DEVICE)

In [4]:
print(dataset[0:100].flatten().std())
print(dataset[0:100].flatten().mean())

tensor(0.7851, device='cuda:0')
tensor(7.0572e-09, device='cuda:0')


In [5]:
print(loaded_model.encode(dataset[0:100].to(DEVICE)).flatten().std())
print(loaded_model.encode(dataset[0:100].to(DEVICE)).flatten().mean())

tensor(1.5590, device='cuda:0', grad_fn=<StdBackward0>)
tensor(0.0355, device='cuda:0', grad_fn=<MeanBackward0>)


In [6]:
x_min = dims[0][0]
x_max = dims[0][1]
y_min = dims[1][0]
y_max = dims[1][1]

In [7]:
D_Pixel = PixelDataset(loaded_model, Xmin=x_min, Xmax=x_max, Ymin=y_min, Ymax=y_max, save=False, force_reload=False, )

Encoded waveforms with shape torch.Size([6600, 32]) (num_waveforms, encoding_dims)
Sorted files in rays
Saved ray tensors shape: torch.Size([100, 66, 30, 20]) (num_sections=100, num_rays=66, nX, nY)
Saved emitter/receptor positions tensor shape: torch.Size([66, 4]) (66, 4)
Loaded 100 section files with labels shape: torch.Size([100, 20, 30]) (num_sections, nY, nX)


In [8]:
torch.set_printoptions(threshold=10_000_000)
print(D_Pixel[0][0])

tensor([ 0.0000e+00,  0.0000e+00,  2.0000e-02,  2.0000e-02,  5.0000e-02,
         5.0000e-02,  4.0000e-01,  0.0000e+00,  0.0000e+00,  1.6421e+00,
        -6.8138e-02,  7.7963e-01, -2.8068e+00, -1.2730e+00, -1.5994e+00,
        -1.5082e+00,  1.5768e+00, -1.2296e+00,  1.2503e+00, -4.4155e-01,
         2.6688e+00,  3.5915e-02,  3.7219e-01,  1.7427e+00, -1.2869e+00,
        -4.0467e+00, -3.4922e+00, -1.9275e+00,  9.8453e-01,  4.8939e-01,
        -1.2041e+00,  8.7284e-01, -3.7598e-01,  3.9246e+00, -2.8631e-01,
         8.2217e-01,  2.0961e+00,  1.6561e-02, -1.3865e+00, -8.5945e-02,
         1.4704e+00,  5.0000e-02,  1.0000e-01,  4.0000e-01,  0.0000e+00,
         0.0000e+00,  5.7281e-01,  1.2706e+00, -9.5875e-01, -1.5668e+00,
        -1.1907e-01, -1.5027e+00, -2.0984e-01,  3.0559e+00, -1.1942e+00,
         1.9862e+00, -1.1800e+00,  1.8827e+00,  2.5103e+00,  1.2917e+00,
         1.8375e+00, -4.8695e-01, -3.1873e+00, -2.1730e+00,  2.1604e-01,
        -7.2640e-01,  1.1432e+00,  9.3294e-01,  2.3

  label = torch.tensor(self.labels[p, y, x])
