In [None]:
import os
from pathlib import Path

# This snippet ensures consistent import paths across environments.
# When running notebooks via JupyterLab's web UI, the current working
# directory is often different (e.g., /notebooks) compared to VS Code,
# which typically starts at the project root. This handles that by 
# retrying the import after changing to the parent directory.
# 
# Include this at the top of every notebook to standardize imports
# across development environments.

try:
    from utils.os import chdir_to_git_root
except ModuleNotFoundError:
    os.chdir(Path.cwd().parent)
    print(f"Retrying import from: {os.getcwd()}")
    from utils.os import chdir_to_git_root

chdir_to_git_root("python")

print(os.getcwd())

In [None]:
import logging
# from models.pytorch.narrative_stack.stage1.preprocessing import generate_concept_unit_embeddings, generate_concepts_report
from db import DbUsGaap

db_us_gaap = DbUsGaap()
data_dir = "../data/june-us-gaap" # Where CSV data is read from (once CSV file per symbol)


In [None]:
# NOTE: For debugging / monitoring purposes only
# Determine "category stack" depths

from collections import defaultdict
import numpy as np
from utils.csv import walk_us_gaap_csvs, get_filtered_us_gaap_form_rows_for_symbol

class RunningStats:
    def __init__(self):
        self.count = 0
        self.total = 0
        self.max_val = 0
        self.values = []

    def update(self, val: int):
        self.count += 1
        self.total += val
        self.max_val = max(self.max_val, val)
        self.values.append(val)  # Optional: remove this if median not needed

    def finalize(self):
        result = {
            "avg": self.total / self.count if self.count else 0,
            "max": self.max_val,
        }
        if self.values:
            result["median"] = float(np.median(self.values))
        return result

# Initialize running stats per key
stats = defaultdict(RunningStats)

gen = walk_us_gaap_csvs(data_dir, db_us_gaap, "row")

try:
    while True:
        row = next(gen)
        counter = defaultdict(int)
        for entry in row.entries:
            key = (entry.balance_type or "none", entry.period_type or "none")
            counter[key] += 1
        for key, val in counter.items():
            stats[key].update(val)
except StopIteration:
    pass

# Final summary
summary = {key: stat.finalize() for key, stat in stats.items()}

from pprint import pprint
pprint(summary)


In [None]:
# NOTE: For debugging / monitoring purposes only

import numpy as np
from utils.csv import walk_us_gaap_csvs, UsGaapRowRecord
from collections import defaultdict


def generate_concepts_report_from_walker(
    data_dir: Path,
    db_us_gaap: DbUsGaap,
    filtered_symbols: set[str] | None = None,
):
    gen = walk_us_gaap_csvs(
        data_dir=data_dir,
        db_us_gaap=db_us_gaap,
        walk_type="row",
        filtered_symbols=filtered_symbols,
    )

    unit_stats = defaultdict(list)
    concept_by_unit = defaultdict(set)

    try:
        while True:
            row = next(gen)
            if isinstance(row, UsGaapRowRecord):
                for entry in row.entries:
                    unit_stats[entry.uom].append(entry.value)
                    concept_by_unit[entry.uom].add(entry.concept)
    except StopIteration as stop:
        summary = stop.value

    print(f"\n✅ Scanned {len(summary.csv_files)} files.")
    print(
        f"📦 Found {len(unit_stats)} numeric units and "
        f"{len(summary.non_numeric_units)} non-numeric units."
    )

    for unit, values in sorted(unit_stats.items()):
        arr = np.array(values)
        print(f"🔹 {unit}")
        print(f"   Count: {len(arr)}")
        print(f"   Min:   {arr.min():,.4f}")
        print(f"   Max:   {arr.max():,.4f}")
        print(f"   Mean:  {arr.mean():,.4f}")
        print(f"   Std:   {arr.std():,.4f}")
        print(f"   Concepts: {', '.join(sorted(concept_by_unit[unit]))}")

    if summary.non_numeric_units:
        print("\n⚠️ Non-numeric units encountered:")
        for unit in sorted(summary.non_numeric_units):
            print(f"  - {unit}")

    total_values = sum(len(v) for v in unit_stats.values())
    print(f"\n🧮 Total values extracted: {total_values:,}")


generate_concepts_report_from_walker(data_dir, db_us_gaap, None)


In [None]:
from simd_r_drive import DataStore, NamespaceHasher
from collections import defaultdict
from pydantic import BaseModel
import msgpack
from tqdm import tqdm
import numpy as np
from sklearn.preprocessing import QuantileTransformer
from utils.csv import walk_us_gaap_csvs

# Open data store
store = DataStore("proto.bin")
# store = DataStore("/Volumes/Expansion/proto.bin")

# Define immutable concept/unit pair model
class ConceptUnitPair(BaseModel):
    concept: str
    uom: str

    class Config:
        # Enables hashing
        frozen = True

# Namespaces for storing structured data
TRIPLET_REVERSE_INDEX_NAMESPACE = NamespaceHasher(b"triplet-reverse-index")
UNSCALED_SEQUENTIAL_CELL_NAMESPACE = NamespaceHasher(b"unscaled-sequential-cell")
SCALED_SEQUENTIAL_CELL_NAMESPACE = NamespaceHasher(b"scaled-sequential-cell")
CELL_META_NAMESPACE = NamespaceHasher(b"cell-meta")
CONCEPT_UNIT_PAIR_NAMESPACE = NamespaceHasher(b"concept-unit-pair")

# Initialize CSV stream generator
gen = walk_us_gaap_csvs(data_dir, db_us_gaap, "row")

# Track per (concept, uom) the list of i_cell indices that use it
concept_unit_pairs_i_cells: dict[ConceptUnitPair, list[int]] = defaultdict(list)
pair_to_id: dict[ConceptUnitPair, int] = {}
concept_unit_entries: list[tuple[bytes, bytes]] = []

# Global sequential index for each cell value
i_cell = -1
next_pair_id = 0

# Stream and store data
try:
    while True:
        row = next(gen)
        batch = []

        for cell in row.entries:
            i_cell += 1

            pair = ConceptUnitPair(concept=cell.concept, uom=cell.uom)
            i_bytes = i_cell.to_bytes(4, "little", signed=False)

            # Assign ID to concept/unit pair if not already done
            if pair not in pair_to_id:
                pair_to_id[pair] = next_pair_id
                pair_id_bytes = next_pair_id.to_bytes(4, "little", signed=False)
                pair_key = CONCEPT_UNIT_PAIR_NAMESPACE.namespace(pair_id_bytes)
                pair_val = msgpack.packb((pair.concept, pair.uom))
                concept_unit_entries.append((pair_key, pair_val))
                next_pair_id += 1

            pair_id = pair_to_id[pair]
            pair_id_bytes = pair_id.to_bytes(4, "little", signed=False)

            # Track cell indices per (concept, uom)
            concept_unit_pairs_i_cells[pair].append(i_cell)

            # Store raw unscaled value
            value_bytes = msgpack.packb(cell.value)
            unscaled_key = UNSCALED_SEQUENTIAL_CELL_NAMESPACE.namespace(i_bytes)
            batch.append((unscaled_key, value_bytes))

            # Store reverse triplet → i_cell mapping
            triplet_bytes = msgpack.packb((cell.concept, cell.uom, cell.value))
            triplet_key = TRIPLET_REVERSE_INDEX_NAMESPACE.namespace(triplet_bytes)
            batch.append((triplet_key, i_bytes))

            # Store cell meta (i_cell → concept_unit_id)
            cell_meta_key = CELL_META_NAMESPACE.namespace(i_bytes)
            batch.append((cell_meta_key, pair_id_bytes))

        # Write current batch of entries to store
        store.batch_write(batch)

        # TODO: Comment-out
        # Optional cutoff for debugging
        # if i_cell > 1000:
        #     break

except StopIteration as stop:
    summary = stop.value
    display(summary)

total_triplets = i_cell + 1

store.write(
    b"__triplet_count__",
    total_triplets.to_bytes(4, byteorder="little", signed=False),
)

print(f"Total triplets: {total_triplets}")

# Persist concept_unit_id → (concept, uom) mapping
store.batch_write(concept_unit_entries)

total_pairs = len(concept_unit_pairs_i_cells)

store.write(b"__pair_count__", total_pairs.to_bytes(4, byteorder="little", signed=False))

# Show number of unique concept/unit pairs
print(f"Total concept/unit pairs: {total_pairs}")

# Show binary keys for each concept/unit pair
for pair in tqdm(concept_unit_pairs_i_cells, desc="Tracking concept/unit pairs"):
    print(msgpack.packb((pair.concept, pair.uom)))


In [None]:

# Scale all values per concept/unit group
for pair, i_cells in tqdm(concept_unit_pairs_i_cells.items(), desc="Scaling per concept/unit"):
    i_bytes_list = [i.to_bytes(4, "little", signed=False) for i in i_cells]
    keys = [UNSCALED_SEQUENTIAL_CELL_NAMESPACE.namespace(i_bytes) for i_bytes in i_bytes_list]

    values = [
        msgpack.unpackb(store.read(key), raw=True)
        for key in keys
    ]

    vals_np = np.array(values).reshape(-1, 1)

    # Clamp quantiles based on sample size
    n_q = min(len(values), 1000)
    if n_q < 2 and len(values) >= 2:
        n_q = 2

    if len(values) < 2:
        # Skip scaling for singleton values — nothing to transform
        continue

    scaler = QuantileTransformer(
        output_distribution="normal",
        n_quantiles=n_q,
        subsample=len(values),
        random_state=42,
    )

    scaled_vals = scaler.fit_transform(vals_np).flatten()

    assert len(scaled_vals) == len(i_cells)

    store.batch_write([
        (
            SCALED_SEQUENTIAL_CELL_NAMESPACE.namespace(i.to_bytes(4, "little", signed=False)),
            msgpack.packb(val)
        )
        for i, val in zip(i_cells, scaled_vals)
    ])


In [None]:
from typing import Iterator, NamedTuple, Tuple

# TODO: Migrate to common type
class ConceptUnitPair(NamedTuple):
    concept: str
    uom: str


def iterate_concept_unit_pairs(
    store: DataStore
) -> Iterator[Tuple[int, ConceptUnitPair]]:
    """
    Yields (pair_id, ConceptUnitPair) from the concept/unit pair namespace.
    """
    raw = store.read(b"__pair_count__")
    if raw is None:
        raise ValueError("Missing __pair_count__ key in store")

    total_pairs = int.from_bytes(raw, "little", signed=False)

    for pair_id in range(total_pairs):
        key = CONCEPT_UNIT_PAIR_NAMESPACE.namespace(
            pair_id.to_bytes(4, "little", signed=False)
        )
        val = store.read(key)
        if val is None:
            raise KeyError(f"Missing concept/unit for pair_id={pair_id}")
        concept, uom = msgpack.unpackb(val, raw=True)
        yield pair_id, ConceptUnitPair(
            concept=concept.decode("utf-8"),
            uom=uom.decode("utf-8")
        )


In [None]:
from typing import Iterator, Tuple
import torch
import numpy as np
import logging
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from utils.pytorch import seed_everything, model_hash
from utils import generate_us_gaap_description
from simd_r_drive import DataStore
# from your_model_defs import ConceptUnitPair  # Replace with actual import
# from your_store_iterator import iterate_concept_unit_pairs  # Replace with actual import


def generate_concept_unit_embeddings(
    store: DataStore,
    device: torch.device,
    batch_size: int = 64,
) -> Iterator[Tuple[int, ConceptUnitPair, np.ndarray]]:
    """
    Yields (pair_id, concept_unit_pair, embedding) for each input concept/unit pair.
    """
    pairs_iter = iterate_concept_unit_pairs(store)

    model = SentenceTransformer("BAAI/bge-large-en-v1.5")
    model.eval()
    model.to(device)

    logging.info(f"Embedding model hash: {model_hash(model)}")

    buffer_ids = []
    buffer_pairs = []
    buffer_texts = []

    for pair_id, pair in pairs_iter:
        text = f"{generate_us_gaap_description(pair.concept)} measured in {pair.uom}"
        buffer_ids.append(pair_id)
        buffer_pairs.append(pair)
        buffer_texts.append(text)

        if len(buffer_pairs) == batch_size:
            yield from _embed_batch(buffer_ids, buffer_pairs, buffer_texts, model, device)
            buffer_ids.clear()
            buffer_pairs.clear()
            buffer_texts.clear()

    if buffer_pairs:
        yield from _embed_batch(buffer_ids, buffer_pairs, buffer_texts, model, device)


def _embed_batch(pair_ids, pairs, texts, model, device):
    tokens = model.tokenize(texts)
    tokens = {k: v.to(device) for k, v in tokens.items()}
    with torch.no_grad():
        output = model.forward(tokens)
        embeddings = output["sentence_embedding"].cpu().numpy()
    for pair_id, pair, embedding in zip(pair_ids, pairs, embeddings):
        yield pair_id, pair, embedding


In [None]:
from utils.pytorch import get_device
from models.pytorch.narrative_stack.stage1.preprocessing import pca_compress_concept_unit_embeddings
from tqdm import tqdm

# Cache embeddings in RAM and apply PCA

# TODO: Move these declarations
PCA_MODEL_NAMESPACE = NamespaceHasher(b"pca-model")
PCA_REDUCED_EMBEDDING_NAMESPACE = NamespaceHasher(b"pca-reduced-embedding")


pairs = []
embeddings = []

for pair_id, pair, embedding in tqdm(generate_concept_unit_embeddings(store, get_device()), desc="Generating Semantic Embeddings"):
    pairs.append((pair_id, pair))
    embeddings.append(embedding)


# Convert to NumPy array of shape (N, D)
embedding_matrix = np.stack(embeddings, axis=0)

# Now ready to pass `embedding_matrix` to PCA
print(f"Embedding matrix shape: {embedding_matrix.shape}")

# TODO: Reuse PCA if already existing (provide ability to pull from another store, etc.)

pca_compressed_concept_unit_embeddings, pca = pca_compress_concept_unit_embeddings(embedding_matrix, n_components=234, pca=None, stable=True)

assert len(pairs) == len(pca_compressed_concept_unit_embeddings)


# TODO: Save PCA-reduced embeddings in store
pca_embedding_entries = [
    (
        PCA_REDUCED_EMBEDDING_NAMESPACE.namespace(
            pair_id.to_bytes(4, "little", signed=False)
        ),
        msgpack.packb(vec.astype(np.float32).tolist())  # Convert numpy array to list
    )
    for (pair_id, _), vec in zip(pairs, pca_compressed_concept_unit_embeddings)
]

store.batch_write(pca_embedding_entries)
print(f"Wrote {len(pca_embedding_entries)} PCA-compressed embeddings to store.")



In [None]:
import joblib
from io import BytesIO

# Serialize PCA model into a byte stream
pca_model_stream = BytesIO()
joblib.dump(pca, pca_model_stream)
pca_model_stream.seek(0)  # Move cursor to the beginning of the stream

# Store PCA model in the DataStore
store.write(PCA_MODEL_NAMESPACE.namespace(b"model"), pca_model_stream.read())

print("Stored PCA model in store.")

In [None]:
from models.pytorch.narrative_stack.stage1.preprocessing.plots import plot_pca_explanation

dim = plot_pca_explanation(embedding_matrix, variance_threshold=0.95)

display(dim)

In [None]:
def get_triplet_count(store: DataStore) -> int:
    raw = store.read(b"__triplet_count__")
    if raw is None:
        raise KeyError("Triplet count key not found")
    return int.from_bytes(raw, "little", signed=False)

def get_pair_count(store: DataStore) -> int:
    raw = store.read(b"__pair_count__")
    if raw is None:
        raise KeyError("Pair count key not found")
    return int.from_bytes(raw, "little", signed=False)

In [None]:
display(get_triplet_count(store))

display(get_pair_count(store))

In [None]:
def lookup_by_triplet(
    store: DataStore,
    concept: str,
    uom: str,
    unscaled_value: float
) -> dict:
    """
    Given a (concept, uom, value) triplet, return its i_cell, unscaled value,
    and scaled value if available.

    Returns a dict with keys: i_cell, unscaled_value, scaled_value
    """
    # Encode the triplet as used in reverse index
    triplet_key_bytes = msgpack.packb((concept, uom, unscaled_value))
    triplet_key = TRIPLET_REVERSE_INDEX_NAMESPACE.namespace(triplet_key_bytes)

    # Lookup i_cell
    i_cell_bytes = store.read(triplet_key)
    if i_cell_bytes is None:
        raise KeyError(f"Triplet ({concept}, {uom}, {unscaled_value}) not found in reverse index")

    i_cell = int.from_bytes(i_cell_bytes, "little", signed=False)

    # Construct unscaled key
    # unscaled_key = UNSCALED_SEQUENTIAL_CELL_NAMESPACE.namespace(
    #     i_cell.to_bytes(4, "little", signed=False)
    # )
    # unscaled_value_bytes = store.read(unscaled_key)
    # unscaled_value = msgpack.unpackb(unscaled_value_bytes, raw=True)

    # Construct scaled key
    scaled_key = SCALED_SEQUENTIAL_CELL_NAMESPACE.namespace(
        i_cell.to_bytes(4, "little", signed=False)
    )
    scaled_value_bytes = store.read(scaled_key)
    scaled_value = (
        msgpack.unpackb(scaled_value_bytes, raw=True)
        if scaled_value_bytes is not None
        else None
    )

    return {
        "i_cell": i_cell,
        "unscaled_value": unscaled_value,
        "scaled_value": scaled_value
    }

In [None]:
lookup_by_triplet(store, "AccountsReceivableNetCurrent", "USD", 1324000000.0)

In [None]:
def lookup_by_index(store: DataStore, i_cell: int) -> dict:
    """
    Look up the concept, uom, unscaled value, and scaled value for a given i_cell.

    Returns:
        dict with keys: i_cell, concept, uom, unscaled_value, scaled_value

    Raises:
        KeyError if any required value is missing.
    """
    i_bytes = i_cell.to_bytes(4, "little", signed=False)

    # Load concept_unit_id from cell meta
    meta_key = CELL_META_NAMESPACE.namespace(i_bytes)
    concept_unit_id_bytes = store.read(meta_key)
    if concept_unit_id_bytes is None:
        raise KeyError(f"Missing concept_unit_id for i_cell {i_cell}")

    concept_unit_id = int.from_bytes(concept_unit_id_bytes, "little", signed=False)

    # Load (concept, uom) from concept_unit_id
    pair_key = CONCEPT_UNIT_PAIR_NAMESPACE.namespace(
        concept_unit_id.to_bytes(4, "little", signed=False)
    )
    pair_bytes = store.read(pair_key)
    if pair_bytes is None:
        raise KeyError(f"Missing (concept, uom) for concept_unit_id {concept_unit_id}")

    concept, uom = msgpack.unpackb(pair_bytes, raw=False)

    # Load unscaled value
    unscaled_key = UNSCALED_SEQUENTIAL_CELL_NAMESPACE.namespace(i_bytes)
    unscaled_bytes = store.read(unscaled_key)
    if unscaled_bytes is None:
        raise KeyError(f"Missing unscaled value for i_cell {i_cell}")
    unscaled_value = msgpack.unpackb(unscaled_bytes, raw=True)

    # Load scaled value (optional)
    scaled_key = SCALED_SEQUENTIAL_CELL_NAMESPACE.namespace(i_bytes)
    scaled_bytes = store.read(scaled_key)
    scaled_value = (
        msgpack.unpackb(scaled_bytes, raw=True) if scaled_bytes is not None else None
    )

    return {
        "i_cell": i_cell,
        "concept": concept,
        "uom": uom,
        "unscaled_value": unscaled_value,
        "scaled_value": scaled_value
    }


In [None]:
# for i in range(0, 10):
#     display(lookup_by_index(store, i * 20))

# TODO: Continue prototyping from here...

In [None]:
# # For deterministic hashing (TODO: Move to tests)

# import hashlib
# import pickle

# # def hash_extracted_data(data: ExtractedConceptUnitValueData) -> str:
# def hash_extracted_data(data) -> str:
#     """
#     Computes a SHA-256 hash of the full extracted concept/unit/value data structure.
#     This includes tuples, unit stats, and file list — all serialized deterministically.
#     """
#     # Serialize using protocol=5 (highest and deterministic in modern Python)
#     serialized = pickle.dumps(data.dict(), protocol=5)
#     return hashlib.sha256(serialized).hexdigest()

# hash_extracted_data(extracted_concept_unit_value_data)

In [None]:
# print("Fetching...")
# extracted_concept_unit_value_data_2 = extract_concept_unit_value_tuples(data_dir, valid_concepts)

# print("Hashing...")
# hash_extracted_data(extracted_concept_unit_value_data_2)

In [None]:
# View concepts report (not needed for preprocessing but contains useful information)

generate_concepts_report(extracted_concept_unit_value_data)

In [None]:
from utils.pytorch import get_device, seed_everything

device = get_device()

logging.info("Collecting concept unit pairs...")
concept_unit_pairs = collect_concept_unit_pairs(extracted_concept_unit_value_data)

logging.info(f"Total concept unit pairs: {len(concept_unit_pairs)}")

logging.info("Generating concept unit embeddings...")
concept_unit_embeddings = generate_concept_unit_embeddings(concept_unit_pairs, device=device)
# concept_unit_embeddings_2 = generate_concept_unit_embeddings(concept_unit_pairs, device=device)



In [None]:
# import numpy as np
# import torch
# from tqdm import tqdm

# # Normalize for cosine similarity
# A = concept_unit_embeddings_1
# B = concept_unit_embeddings_2

# A_norm = A / np.linalg.norm(A, axis=1, keepdims=True)
# B_norm = B / np.linalg.norm(B, axis=1, keepdims=True)

# # Cosine similarity per row
# cos_sim = np.sum(A_norm * B_norm, axis=1)

# # Report
# print(f"Cosine similarity:")
# print(f"  Mean: {cos_sim.mean():.8f}")
# print(f"  Min:  {cos_sim.min():.8f}")
# print(f"  Std:  {cos_sim.std():.8f}")

# # Optional: show rows below threshold
# threshold = 0.999
# bad_indices = np.where(cos_sim < threshold)[0]
# print(f"\n🔻 Below {threshold}: {len(bad_indices)} / {len(cos_sim)} rows")
# if len(bad_indices):
#     for idx in bad_indices[:10]:
#         print(f"  Row {idx}: cosine = {cos_sim[idx]:.8f}")


In [None]:
# concept_unit_embeddings

In [None]:
from models.pytorch.narrative_stack.stage1.preprocessing.plots import plot_pca_explanation

dim = plot_pca_explanation(concept_unit_embeddings, variance_threshold=0.95)

display(dim)

In [None]:
from models.pytorch.narrative_stack.stage1.preprocessing import pca_compress_concept_unit_embeddings


pca_compressed_concept_unit_embeddings, pca = pca_compress_concept_unit_embeddings(concept_unit_embeddings, n_components=243)


In [None]:
# # TODO: Prototype this

# import joblib
# import io

# # `pca` is the fitted PCA object
# buffer = io.BytesIO()
# joblib.dump(pca, buffer)
# pca_bytes = buffer.getvalue()

In [None]:
from models.pytorch.narrative_stack.stage1.preprocessing.plots import plot_semantic_embeddings

plot_semantic_embeddings(pca_compressed_concept_unit_embeddings, title="PCA Semantic Embedding Scatterplot")
plot_semantic_embeddings(concept_unit_embeddings, title="Raw Semantic Embedding Scatterplot")

In [None]:
pca_compressed_concept_unit_embeddings

In [None]:
import logging
import numpy as np

# TODO: Add types
def save_concept_unit_value_tuples(pca_compressed_concept_unit_embeddings, concept_unit_pairs, concept_unit_value_tuples, file_path):
    assert len(pca_compressed_concept_unit_embeddings) == len(concept_unit_pairs), \
        f"Mismatch: {len(pca_compressed_concept_unit_embeddings)} embeddings vs {len(concept_unit_pairs)} keys"

    # Save both embeddings and tuples
    np.savez_compressed(
        file_path,
        keys=np.array([f"{c}::{u}" for c, u in concept_unit_pairs]),
        embeddings=pca_compressed_concept_unit_embeddings,
        concept_unit_value_tuples=np.array(concept_unit_value_tuples, dtype=object)
    )

    logging.info(f"Saved {len(concept_unit_value_tuples):,} tuples and {len(pca_compressed_concept_unit_embeddings):,} embeddings to '{file_path}'")


save_concept_unit_value_tuples(
    pca_compressed_concept_unit_embeddings,
    concept_unit_pairs,
    extracted_concept_unit_value_data.concept_unit_value_tuples,
    "data/stage1_latents.npz" # TODO: Rename! These are not latent vectors!
)

# save_concept_unit_value_tuples(
#     pca_compressed_concept_unit_embeddings_2,
#     concept_unit_pairs,
#     extracted_concept_unit_value_data.concept_unit_value_tuples,
#     "data/stage1_latents_new_2.npz"
# )

In [None]:
# TODO: Validate subsequent

# import numpy as np

# # Load from disk
# new_data = np.load("data/stage1_latents_new.npz", allow_pickle=True)
# new_concept_unit_value_tuples = new_data["concept_unit_value_tuples"].tolist()
# new_embeddings = new_data["embeddings"]

# # Check shape match
# assert len(pca_compressed_concept_unit_embeddings) == len(new_embeddings), \
#     "Mismatch in embedding row counts"

# # Cosine similarity check
# def cosine_similarity(a, b):
#     a = a / np.linalg.norm(a)
#     b = b / np.linalg.norm(b)
#     return np.dot(a, b)

# cos_sims = []
# for a_vec, b_vec in zip(pca_compressed_concept_unit_embeddings, new_embeddings):
#     sim = cosine_similarity(a_vec.astype(np.float64), b_vec.astype(np.float64))
#     cos_sims.append(sim)

# # Report
# cos_sims = np.array(cos_sims)
# print(f"✅ Compared {len(cos_sims)} rows")
# print(f"🔹 Mean cosine similarity: {cos_sims.mean():.8f}")
# print(f"🔹 Min cosine similarity:  {cos_sims.min():.8f}")
# print(f"🔹 Std dev:                {cos_sims.std():.8f}")


In [None]:
# import numpy as np

# # Load saved latent data
# old_data = np.load("data/stage1_latents.npz", allow_pickle=True)

# # Build embedding map
# embedding_map = {
#     tuple(key.split("::", 1)): vec
#     for key, vec in zip(old_data["keys"], old_data["embeddings"])
# }

# # Load concept-unit-value tuples
# old_concept_unit_value_tuples = old_data["concept_unit_value_tuples"].tolist()

# # Load saved latent data
# new_data = np.load("data/stage1_latents_new.npz", allow_pickle=True)

# # Build embedding map
# embedding_map = {
#     tuple(key.split("::", 1)): vec
#     for key, vec in zip(new_data["keys"], new_data["embeddings"])
# }

# # Load concept-unit-value tuples
# new_concept_unit_value_tuples = new_data["concept_unit_value_tuples"].tolist()

In [None]:
# import numpy as np
# from hashlib import sha256

# # a = np.load("data/stage1_latents.npz", allow_pickle=True)
# a = np.load("data/stage1_latents_new_1.npz", allow_pickle=True)
# b = np.load("data/stage1_latents_new_2.npz", allow_pickle=True)

# # with open("data/stage1_latents_new_1.pkl", "rb") as f:
# #     a = pickle.load(f)

# # with open("data/stage1_latents_new_2.pkl", "rb") as f:
# #     b = pickle.load(f)

# def hash_array(arr):
#     return sha256(np.ascontiguousarray(arr)).hexdigest()

# print("Hash a[keys]:", hash_array(a["keys"]))
# print("Hash b[keys]:", hash_array(b["keys"]))

# for k in a.files:
# # for k in a.keys():
#     print(f"Checking: {k}")
#     if k == "embeddings":
#         A = a[k].astype(np.float32)
#         B = b[k].astype(np.float32)
#         assert A.shape == B.shape, "Shape mismatch in embeddings"

#         # Normalize to unit vectors
#         A_norm = A / np.linalg.norm(A, axis=1, keepdims=True)
#         B_norm = B / np.linalg.norm(B, axis=1, keepdims=True)

#         # Cosine similarity
#         cos_sim = np.sum(A_norm * B_norm, axis=1)
#         mean_sim = np.mean(cos_sim)
#         min_sim = np.min(cos_sim)

#         print(f"Mean cosine similarity: {mean_sim:.8f}")
#         print(f"Min cosine similarity:  {min_sim:.8f}")
#         print(f"Std cosine similarity:  {cos_sim.std():.8f}")
#         assert min_sim > 0.999, f"Cosine similarity too low in embeddings: {min_sim}"
#     else:
#         assert np.array_equal(a[k], b[k]), f"Mismatch in {k}"
