In [None]:
from config import simd_r_drive_server_config
from us_gaap_store import UsGaapStore

from simd_r_drive_ws_client import DataStoreWsClient

data_store = DataStoreWsClient(simd_r_drive_server_config.host, simd_r_drive_server_config.port)
us_gaap_store = UsGaapStore(data_store)

In [None]:
data_store.file_size()

In [None]:
from models.pytorch.narrative_stack.stage1.dataset import IterableConceptValueDataset, collate_with_scaler
from models.pytorch.narrative_stack.stage1 import Stage1Autoencoder

In [None]:
import logging
from pathlib import Path
from typing import Iterator, List

import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm  # purely for a progress bar
from utils.pytorch import get_device

# ── project imports ────────────────────────────────────────────────────────────
from config import project_paths, simd_r_drive_server_config
from models.pytorch.narrative_stack.stage1.dataset import IterableConceptValueDataset, collate_with_scaler
from models.pytorch.narrative_stack.stage1 import Stage1Autoencoder

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(message)s",
)

# ── constants ─────────────────────────────────────────────────────────────────
CKPT = (
    Path(project_paths.python_data)
    / "stage1_23_(no_pre_dedupe)"
    / "stage1_resume-v10.ckpt"
)
BATCH_SIZE_OVERRIDE = None  # e.g. 128 to force smaller GPU batches
NUM_WORKERS = 2
DEVICE = get_device()


# ── helpers ───────────────────────────────────────────────────────────────────
def inverse_batch(
    scaled_vals: torch.Tensor, scalers: List  # list[sklearn Scaler]
) -> np.ndarray:
    """
    Per-sample inverse_transform; vectorized enough to stay fast.
    """
    scaled_np = scaled_vals.detach().cpu().numpy()
    return np.stack(
        [s.inverse_transform(v.reshape(-1, 1)).ravel() for s, v in zip(scalers, scaled_np)]
    )


def build_loader() -> DataLoader:
    ds = IterableConceptValueDataset(
        simd_r_drive_server_config,
        internal_batch_size=64,
        return_scaler=True,
        shuffle=False,
    )
    bs = BATCH_SIZE_OVERRIDE or Stage1Autoencoder.load_from_checkpoint(CKPT).hparams.batch_size
    return DataLoader(
        ds,
        batch_size=bs,
        collate_fn=collate_with_scaler,
        pin_memory=True,
        num_workers=NUM_WORKERS,
        persistent_workers=True,
    )


def inference_stream() -> Iterator[List[dict]]:
    """
    Yields one *batch* (list of dicts) at a time.
    """
    model = Stage1Autoencoder.load_from_checkpoint(CKPT).to(DEVICE).eval()
    loader = build_loader()

    with torch.no_grad():
        for x, y, scalers, concept_units in tqdm(loader, desc="infer"):
            x = x.to(DEVICE)
            y = y.to(DEVICE)

            # Forward pass
            recon_emb, recon_val, z = model(x)

            # Inverse‑scale predicted & true values
            pred_orig = inverse_batch(recon_val, scalers)
            true_orig = inverse_batch(y[:, -1:], scalers)

            # Move to CPU once, keep contiguous
            z_cpu       = z.cpu().numpy()
            pred_scaled = recon_val.cpu().numpy()
            true_scaled = y[:, -1:].cpu().numpy()

            batch_records = [
                {
                    "concept": concept,
                    "uom": uom,
                    "pred_val_scaled": float(pred_scaled[i, 0]),
                    "true_val_scaled": float(true_scaled[i, 0]),
                    "pred_val_orig": float(pred_orig[i, 0]),
                    "true_val_orig": float(true_orig[i, 0]),
                    "latent": z_cpu[i],
                }
                for i, (concept, uom) in enumerate(concept_units)
            ]

            yield batch_records


# ── main ──────────────────────────────────────────────────────────────────────
# if __name__ == "__main__":
for batch in inference_stream():
    print(batch)

    # TODO: Remove
    break
