In [2]:
import typing
import re

import pandas as pd
import numpy as np
import numpy.typing as npt

import dms_stan.model.components.parameters
# import dms_stan.operations as dms_ops

# from dms_stan.flip_dsets import load_nuclease_data, MultiNucDatasetType
# from dms_stan.model import Model


In [7]:
dms_components.Normal.CDF.mro()

[abc.NormalCDF,
 dms_stan.model.components.cdf.CDF,
 dms_stan.model.components.cdf.CDFLike,
 dms_stan.model.components.transformed_parameters.TransformedParameter,
 dms_stan.model.components.transformed_parameters.Transformation,
 dms_stan.model.components.abstract_model_component.AbstractModelComponent,
 abc.ABC,
 dms_stan.model.components.transformed_parameters.TransformableParameter,
 object]

In [32]:
mu_type = typing.get_type_hints(dms_components.Normal.__init__)["mu"]
typing.get_origin(mu_type)
set(typing.get_args(mu_type)).issubset(typing.get_args(dms.custom_types.CombinableParameterType))

True

In [36]:
mu_type is dms.custom_types.ContinuousParameterType

True

In [35]:
dms.custom_types.CombinableParameterType

typing.Union[dms_stan.model.components.transformed_parameters.TransformedParameter, dms_stan.model.components.constants.Constant, dms_stan.model.components.parameters.ContinuousDistribution, float, numpy.ndarray[tuple[int, ...], numpy.dtype[numpy.floating]], dms_stan.model.components.parameters.DiscreteDistribution, int, numpy.ndarray[tuple[int, ...], numpy.dtype[numpy.integer]]]

In [None]:
# def check_param_type(annotation: str | type, target: type) -> bool:
#     """Check if the annotation is a subclass of the target type."""
#     # If the annotation is a string, it might be a forward reference
#     if isinstance(annotation, str):

import dms_stan as dms
from
from typing import get_origin, get_args

# Get the type annotation for 'mu'
mu_type = typing.get_type_hints(dms_components.Normal.__init__)["mu"]

print(f"mu type annotation: {mu_type}")
print(f"CombinableParameterType: {dms.custom_types.CombinableParameterType}")

# Check if the annotation matches the expected type exactly
is_mu_exact_match = mu_type is dms.custom_types.CombinableParameterType
print(f"Exact match: {is_mu_exact_match}")

# Check if mu_type is ContinuousParameterType (which should be a subset of CombinableParameterType)
is_mu_continuous = mu_type is dms.custom_types.ContinuousParameterType
print(f"mu is ContinuousParameterType: {is_mu_continuous}")

# For Union types, we need to check if one is a subset of the other
# CombinableParameterType = Union[ContinuousParameterType, DiscreteParameterType]
# So ContinuousParameterType should be a subset of CombinableParameterType

# Get the args of each Union type
mu_args = set(get_args(mu_type)) if get_origin(mu_type) else {mu_type}
combinable_args = set(get_args(dms.custom_types.CombinableParameterType))

print(f"mu_type args: {mu_args}")
print(f"CombinableParameterType args: {combinable_args}")

# Check if mu_type is a subset of CombinableParameterType
is_subset = mu_args.issubset(combinable_args)
print(f"mu_type is subset of CombinableParameterType: {is_subset}")

SyntaxError: invalid syntax (861022897.py, line 7)

In [14]:
import sys

In [None]:
sys.

In [2]:
nuclease_data = load_nuclease_data(
    "/home/bwittmann/GitRepos/DMSStan/raw_data/nuclease/processed_data",
    "/home/bwittmann/GitRepos/DMSStan/raw_data/nuclease/processed_fiducial_data",
)

In [None]:
class NucleaseModel(Model):
    """Defines the DMS Stan model for the Nuclease dataset."""

    def __init__(
        self,
        data: MultiNucDatasetType,
        fluorescence_beta: float = 1.0,
        codon_noise_sigma: float = 0.1,
        experimental_noise_sigma: float = 0.1,
        g1_alpha: float = 1.0,
        g2_alpha: float = 1.0,
        g3_alpha: float = 1.0,
        g4_alpha: float = 1.0,
    ):

        # Store the raw data
        self.data = data

        # We have one true fluorescence value per variant. We define it on the
        # log scale
        self.mean_log_fluorescence = dms_components.ExpExponential(
            beta=fluorescence_beta, shape=(len(data["variants"]),)
        )

        # Define the generative model for the generation-specific fluorescence values
        self.generation_info = self._init_gen_fluorescence(
            codon_noise_sigma=codon_noise_sigma,
            experimental_noise_sigma=experimental_noise_sigma,
        )

        # Define input proportions. The proportions are defined as the fraction of
        # the total input population represented by a specific variant. Because
        # we have many variants, we define the input proportions on the log scale
        self._init_input_proportions(
            g1_alpha=g1_alpha,
            g2_alpha=g2_alpha,
            g3_alpha=g3_alpha,
            g4_alpha=g4_alpha,
        )

        # Calculate the proportions of the input populations that we expect to make
        # it through a given filter. We assume that all variants experience the
        # same noise resulting from differing codon usage. Experimental noise is
        # already baked into the system.
        self._calculate_output_proportions()

    def _init_gen_fluorescence(
        self, codon_noise_sigma: float, experimental_noise_sigma: float
    ):
        """
        Initializes codon-level fluorescence values for the fiducial and nonfiducial
        variants.
        """
        # We expect at least two sources of noise. One comes from different expression
        # levels resulting from different codons and the other comes from slight
        # differences in conditions between generations (i.e., other experimental
        # noise).
        self.codon_noise = dms_components.HalfNormal(sigma=codon_noise_sigma)
        self.experimental_noise = dms_components.HalfNormal(
            sigma=experimental_noise_sigma
        )

        # Define the fluorescence values for all generations
        generation_info: dict[str, tuple(int, int, int)] = {}
        for key in ("g1", "g2", "g3", "g4"):

            # Get the variant indices for fiducial and non-fiducial datasets
            fiducial_variant_inds = self.data[key]["fiducial"]["variant_inds"]
            data_variant_inds = self.data[key]["data"]["variant_inds"]

            # Ignore any data variants that are also present in the fiducial dataset.
            # This is to avoid double counting the fluorescence values for these
            # variants.
            data_variant_inds = np.setdiff1d(data_variant_inds, fiducial_variant_inds)

            # Get the mean fluorescence values for the population of droplets corresponding
            # to a specific variant in this generation. This part captures the cumulative
            # noise, which we assume to be normally distributed. To make sure we
            # keep on the log scale, we use the ExpNormal distribution, which models
            # a random variable whose exponential is normally distributed.
            fiducial_mean_log_fluorescence = f"{key}_mean_fid_log_fluorescence"
            data_mean_log_fluorescence = f"{key}_mean_log_fluorescence"
            n_fiducial_variants = len(fiducial_variant_inds)
            n_data_variants = len(data_variant_inds)
            setattr(
                self,
                fiducial_mean_log_fluorescence,
                dms_components.ExpNormal(
                    mu=dms_ops.exp(self.mean_log_fluorescence[fiducial_variant_inds]),
                    sigma=self.experimental_noise,
                    shape=(n_fiducial_variants, 1),
                ),
            )
            setattr(
                self,
                data_mean_log_fluorescence,
                dms_components.ExpNormal(
                    mu=dms_ops.exp(self.mean_log_fluorescence[data_variant_inds]),
                    sigma=self.experimental_noise,
                    shape=(n_data_variants,),
                ),
            )

            # Fiducial datasets have fluorescence values at the codon level. We
            # need to expand their mean fluorescence values to the codon level.
            # We would expect the fluorescence values here to be normally distributed
            # on the log scale (i.e., we expect a fold change in fluorescence for
            # different codons)
            n_fiducial_codons = self.data[key]["fiducial"]["ic1"].shape[-1]
            setattr(
                self,
                f"{key}_fid_codon_log_fluorescence",
                dms_components.Normal(
                    mu=getattr(self, fiducial_mean_log_fluorescence),
                    sigma=self.codon_noise,
                    shape=(n_fiducial_variants, n_fiducial_codons),
                ),
            )

            # Record information on this generation, including the number of non-fiducial
            # variants, the number of fiducial codon variants, and the total number
            # of variants (non-fiducial + fiducial codons variants) for which we
            # have count data.
            generation_info[key] = (
                n_data_variants,
                n_fiducial_codons,
                n_fiducial_variants + n_fiducial_codons,
            )

        # Return the generation information
        return generation_info

    def _init_input_proportions(
        self, g1_alpha: float, g2_alpha: float, g3_alpha: float, g4_alpha: float
    ):
        """
        Initializes the input proportions for each generation. The input proportions
        are defined as the fraction of the total input population represented by
        a specific variant.
        """
        # TODO: Consider a hyperprior for alpha in the cases where we have multiple
        # samples from the same Dirichlet distribution.
        self.g1_input_log_prop = dms_components.ExpDirichlet(
            alpha=g1_alpha,
            shape=(
                3,  # TODO: We are sure that there are three different samples here?
                self.generation_info["g1"][-1],
            ),
        )
        self.g2_input_log_prop = dms_components.ExpDirichlet(
            alpha=g2_alpha,
            shape=(
                2,  # TODO: We are sure that there are two different samples here?
                self.generation_info["g2"][-1],
            ),
        )
        self.g3_input_log_prop = dms_components.ExpDirichlet(
            alpha=g3_alpha,
            shape=self.generation_info["g3"][-1],
        )
        self.g4_input_log_prop = dms_components.ExpDirichlet(
            alpha=g4_alpha,
            shape=self.generation_info["g4"][-1],
        )

    def _calculate_output_proportions(self):
        """
        Calculates the expected output proportions for a population of variants
        whose mean log fluorescence values are defined by `self.{generation}_mean_log_fluorescence`.
        We assume that all variants experience the same noise resulting from differing
        codon usage, so we apply the codon noise inferred from the fiducial sequences
        to the non-fiducial sequences. We assume that all other system noise is
        captured when we define the generational mean fluorescence values from
        the overall mean fluorescence values. In other words, the proportion of
        variants that make it through a given filter is the proportion of variants
        above a certain threshold defined by the distribution of fluorescence
        values (i.e., the complementary CDF of the fluorescence values).
        """

        def single_calculation(
            input_log_prop: dms_components.ExpDirichlet,
            mean_log_fluorescence: dms_components.ExpNormal,
            threshold: npt.NDArray[np.floating],
        ) -> dms_components.TransformedParameter:
            """
            Returns the proportion of variants that are above a given threshold.
            Because our mean fluorescence is defined on the log scale, we need to
            use the CDF of the log-normal distribution to calculate the proportion
            of variants whose fluorescence values are above a given threshold.
            """
            # Get the log complementary CDF of the fluorescence values
            # TODO: Overload this function. It should be callable as an instance
            # or a class method.
            log_ccdf = dms_components.LogNormal(
                mu=mean_log_fluorescence, sigma=self.codon_noise
            ).log_ccdf(threshold)

            # Now update the input log proportions to reflect the decrease.
            raw_output_log_prop = input_log_prop + log_ccdf

            # Finally, renormalize the output log proportions such that they sum
            # to 1.0 across all variants. This is our output proportion.
            return dms_ops.normalize_log(raw_output_log_prop)

        # Calculate the output proportions, making sure to use the correct combinations
        # of input log proportions and fluorescence thresholds for each generation.
        for key in ("g1", "g2", "g3", "g4"):

            # Calculate the output log proportions for this generation
            setattr(
                self,
                f"{key}_output_log_prop",
                single_calculation(
                    input_log_prop=getattr(self, f"{key}_input_log_prop"),
                    mean_log_fluorescence=getattr(self, f"{key}_mean_log_fluorescence"),
                    threshold=self.data[key]["ft"],
                ),
            )

    def _model_counts(self):
        """
        Defines the distributions that model our observations. These are all multinomial
        distributions parametrized by the log proportions of the variants.
        """

SyntaxError: expected ':' (1139892077.py, line 23)

In [4]:
nuclease_data["g1"]["fiducial"]["variant_inds"]

array([0])