In [None]:
from spf.scripts.train_single_point import (
    load_checkpoint,
    load_config_from_fn,
    load_model,
)


def load_model_and_config_from_config_fn_and_checkpoint(config_fn, checkpoint_fn):
    config = load_config_from_fn(config_fn)
    config["optim"]["checkpoint"] = checkpoint_fn
    m = load_model(config["model"], config["global"]).to(config["optim"]["device"])
    m, _, _, _, _ = load_checkpoint(
        checkpoint_fn=config["optim"]["checkpoint"],
        config=config,
        model=m,
        optimizer=None,
        scheduler=None,
        force_load=True,
    )
    return m, config


def convert_datasets_config_to_inference(datasets_config, ds_fn):
    datasets_config = datasets_config.copy()
    datasets_config.update(
        {
            "batch_size": 1,
            "flip": False,
            "double_flip": False,
            "precompute_cache": "/home/mouse9911/precompute_cache_chunk16_sept",
            "shuffle": False,
            "skip_qc": True,
            "snapshots_adjacent_stride": 1,
            "train_snapshots_per_session": 1,
            "val_snapshots_per_session": 1,
            "random_snapshot_size": False,
            "snapshots_stride": 1,
            "train_paths": [ds_fn],
            "train_on_val": True,
            "workers": 1,
        }
    )
    return datasets_config

In [None]:
import torch
from spf.scripts.train_single_point import load_dataloaders

from tqdm import tqdm

config_fn = "/home/mouse9911/gits/spf/nov2_checkpoints/nov2_small_paired_checkpoints_inputdo0p3/config.yml"
checkpoint_fn = "/home/mouse9911/gits/spf/nov2_checkpoints/nov2_small_paired_checkpoints_inputdo0p3/best.pth"
ds_fn = "/mnt/4tb_ssd/nosig_data/wallarrayv3_2024_08_21_10_30_58_nRX2_bounce_spacing0p05075.zarr"

# load model
model, config = load_model_and_config_from_config_fn_and_checkpoint(
    config_fn=config_fn, checkpoint_fn=checkpoint_fn
)

# load datasets config
datasets_config = convert_datasets_config_to_inference(
    config["datasets"],
    ds_fn=ds_fn,
)

# load dataloader
optim_config = {"device": "cuda", "dtype": torch.float32}
global_config = {"nthetas": 65, "n_radios": 2, "seed": 0, "beamformer_input": True}
train_dataloader, val_dataloader = load_dataloaders(
    datasets_config, optim_config, config["global"], step=0, epoch=0
)

# run inference
model.eval()
for _, val_batch_data in enumerate(tqdm(val_dataloader, leave=False)):
    val_batch_data = val_batch_data.to(config["optim"]["device"])
    output = model(val_batch_data)