In [1]:
from contextlib import ExitStack
import itertools
from dms_stan.utils import from_csv_noload, get_chunk_shape
from typing import Literal, Collection
import numpy as np
import dask
import dask.array as da
import h5py
from tqdm import tqdm

from dms_stan.datasets.trpb import TrpBExponentialGrowth
from dms_stan.model.components import DiscreteDistribution, Parameter, TransformedParameter

import re
IND_RE = re.compile(r"\[([0-9,]+)\]")


In [None]:
for i in range(10):
    var_inds = (i, 10)
    np.array(
                sorted(range(len(var_inds)), key=lambda x: var_inds[x])
            )

In [2]:
model = TrpBExponentialGrowth.from_data_file("~/GitRepos/DMSStan/raw_data/trpb/3-site_merged_replicates/LibI/20230926/LibI_merged_AAs.csv")

In [3]:
fit = from_csv_noload(
    "/home/bwittmann//GitRepos/DMSStan/flip3/trpB/exponential/libI-20250429183759_*.csv"
)

In [4]:
def get_c_order_inds():
    """Identifies the order of columns that would be used in C-style indexing."""

    # Get the indices of each column
    column_indices = [
        (
            tuple(int(ind) - 1 for ind in match_obj.group(1).split(","))
            if (match_obj := IND_RE.search(col))
            else ()
        )
        for col in fit.column_names
    ]

    # Assign indices to the parameters
    var_to_c_order_inds = {}
    for varname, var in itertools.chain(
        fit.metadata.method_vars.items(), fit.metadata.stan_vars.items()
    ):

        # Column indices must be unique by variable
        sliced_inds = column_indices[var.start_idx : var.end_idx]
        assert (
            len(set(sliced_inds)) == var.end_idx - var.start_idx
        ), f"Column indices for {varname} are not unique"
        assert all(
            len(sliced_ind) == len(var.dimensions) for sliced_ind in sliced_inds
        ), f"Column indices for {varname} do not match the number of dimensions"

        # Argsort such that the last dimension increases fastest (c-major order)
        var_to_c_order_inds[varname] = np.array(
            sorted(range(len(sliced_inds)), key=sliced_inds.__getitem__)
        )

    return var_to_c_order_inds


def get_dtypes():
    """Assigns numpy datatypes to the parameters in the fit object."""
    np_type_map = {
        "double": {"float": np.float64, "int": np.int64, "bool": np.bool_},
        "single": {"float": np.float32, "int": np.int32, "bool": np.bool_},
        "half": {"float": np.float16, "int": np.int16, "bool": np.bool_},
    }

    # Datatypes for the method variables
    method_var_dtypes = {
        "lp__": np_type_map[precision]["float"],
        "accept_stat__": np_type_map[precision]["float"],
        "stepsize__": np_type_map[precision]["float"],
        "treedepth__": np_type_map[precision]["int"],
        "n_leapfrog__": np_type_map[precision]["int"],
        "divergent__": np_type_map[precision]["bool"],
        "energy__": np_type_map[precision]["float"],
    }

    # Datatypes for the stan variables
    stan_var_dtypes = {}
    for varname, component in model.named_model_components_dict.items():

        # Only record parameters and transformed parameters
        if not isinstance(component, (Parameter, TransformedParameter)):
            continue

        # Get the datatype
        dtype = np_type_map[precision][
            "int" if isinstance(component, DiscreteDistribution) else "float"
        ]

        # Update the varname if needed
        if isinstance(component, Parameter) and component.observable:
            varname = f"{varname}_ppc"

        # Record the datatype
        stan_var_dtypes[varname] = dtype

    # No overlap between method and stan variables
    assert not set(method_var_dtypes).intersection(set(stan_var_dtypes))

    return method_var_dtypes, stan_var_dtypes


def parse_csv(csv_file):

    with open(csv_file, "r", encoding="utf-8") as f:

        # Parse line-by-line
        for line in f:

            # Skip the header
            if line.startswith("#") or line.startswith("lp__"):
                continue

            # Process the line
            vals = line.strip().split(",")

            # Get data.
            processed_vals = {}
            for varname, dtype in itertools.chain(
                method_var_dtypes.items(), stan_var_dtypes.items()
            ):

                # Get the base python type
                if issubclass(dtype, np.floating):
                    base_type = float
                elif issubclass(dtype, np.integer):
                    base_type = int
                elif issubclass(dtype, np.bool_):
                    base_type = bool
                else:
                    raise ValueError(f"Unsupported dtype: {dtype}")

                # Slice out the appropriate columns and convert to an appropriately
                # shaped numpy array
                var = getattr(
                    fit.metadata,
                    "stan_vars" if varname in stan_var_dtypes else "method_vars",
                )[varname]
                processed_val = np.array(
                    list(map(base_type, vals[var.start_idx : var.end_idx])), order="C"
                )[c_order_inds[varname]].reshape(var.dimensions)

                # Must be c-contiguous
                assert processed_val.flags["C_CONTIGUOUS"]
                processed_vals[varname] = processed_val

            yield processed_vals

In [7]:
# TODO: Writing to hdf5 files should be the standard way of processing outputs of
# the cmdstan sampling run. This will make it very easy for us to toggle the dask
# backend off and on

precision: Literal["double", "single", "half"] = "single"


# Get a copy of the config object (the parameter embeds a copy operation)
config = fit.metadata.cmdstan_config

# How many samples?
num_draws = config["num_samples"] + (
    config["num_warmup"] if config["save_warmup"] else 0
)

# Get the indices by column and variable. Get the dtypes for each variable.
method_var_dtypes, stan_var_dtypes = get_dtypes()
c_order_inds = get_c_order_inds()

# Build the hdf5 file
with h5py.File("/home/bwittmann/GitRepos/DMSStan/tmp/fit.hdf5", "w") as f:

    # We need a group for metadata, samples, and posterior predictive checks
    metadata_group = f.create_group("metadata")
    sample_group = f.create_group("samples")
    ppc_group = f.create_group("ppc")

    # Create datasets for each variable of the metadata group
    varname_to_dset = {
        varname: metadata_group.create_dataset(
            name=varname,
            shape=(config["num_chains"], num_draws),
            dtype=method_var_dtypes[varname],
            chunks=(config["num_chains"], num_draws),  # Always 1 chunk
        )
        for varname in fit.metadata.method_vars.keys()
    }

    # Create datasets for each variable of the sample group, including posterior
    # predictive checks
    expected_n_elements = {}
    for varname, stan_dtype in stan_var_dtypes.items():

        # Determine the dictionary to use
        target = ppc_group if varname.endswith("_ppc") else sample_group

        # Build the dataset
        assert varname not in varname_to_dset, f"Duplicate variable name: {varname}"
        shape = (
            config["num_chains"],
            num_draws,
            *model[varname.removesuffix("_ppc")].shape,
        )
        varname_to_dset[varname] = target.create_dataset(
            name=varname,
            shape=shape,
            dtype=stan_dtype,
            chunks=get_chunk_shape(
                array_shape=shape, array_precision=precision, frozen_dims=(0, 1)
            ),
        )

        # For error-checking, how many elements are in this dataset?
        expected_n_elements[varname] = np.prod(shape[2:])

    # Parse each of the csv files
    assert config["num_chains"] == len(fit.runset.csv_files)
    for chain_ind, csv_file in tqdm(
        enumerate(sorted(fit.runset.csv_files)),
        total=config["num_chains"],
        desc="Converting csvs to hdf5",
    ):
        n_rows = 0
        for draw_ind, draw_data in tqdm(
            enumerate(parse_csv(csv_file)),
            total=num_draws,
            desc=f"Parsing csv {chain_ind}",
            position=1,
            leave=False,
        ):

            # Assign the data to the appropriate dataset
            for varname, vals in draw_data.items():

                # Write to the dataset for this variable
                varname_to_dset[varname][chain_ind, draw_ind] = vals

            # Increment the number of rows
            n_rows += 1

        # Check that the number of rows is correct
        assert n_rows == num_draws
        break

Converting csvs to hdf5:   0%|          | 0/4 [00:48<?, ?it/s]


In [8]:
n_rows

1000