# Copy-Number-aware Differential Gene Expression

## Pydeseq2CN test

In [None]:
import multiprocessing
import warnings
from math import floor
from pathlib import Path
from typing import List
from typing import Literal
from typing import Optional
from typing import Tuple
from typing import Union
from typing import cast

In [30]:
import anndata as ad  # type: ignore
import numpy as np
import pandas as pd
import statsmodels.api as sm  # type: ignore
from joblib import Parallel  # type: ignore
from joblib import delayed
from joblib import parallel_backend
from scipy.optimize import minimize
from scipy.special import gammaln 
from scipy.special import polygamma  # type: ignore
from scipy.stats import f  # type: ignore
from scipy.stats import trim_mean  # type: ignore
from scipy.stats import norm
from statsmodels.tools.sm_exceptions import DomainWarning  # type: ignore
from sklearn.linear_model import LinearRegression

In [31]:
import pydeseq2
from pydeseq2.preprocessing import deseq2_norm
from pydeseq2.utils import build_design_matrix
from pydeseq2.utils import dispersion_trend
from pydeseq2.utils import fit_alpha_mle
from pydeseq2.utils import fit_lin_mu
from pydeseq2.utils import fit_moments_dispersions
from pydeseq2.utils import fit_rough_dispersions
from pydeseq2.utils import irls_solver
from pydeseq2.utils import get_num_processes
from pydeseq2.utils import make_scatter
from pydeseq2.utils import mean_absolute_deviation
from pydeseq2.utils import nb_nll
from pydeseq2.utils import replace_underscores
from pydeseq2.utils import robust_method_of_moments_disp
from pydeseq2.utils import test_valid_counts
from pydeseq2.utils import trimmed_mean

from pydeseq2.grid_search import grid_fit_alpha
from pydeseq2.grid_search import grid_fit_beta

from pydeseq2.ds import DeseqStats

In [32]:
# Ignore DomainWarning raised by statsmodels when fitting a Gamma GLM with identity link.
warnings.simplefilter("ignore", DomainWarning)
# Ignore AnnData's FutureWarning about implicit data conversion.
warnings.simplefilter("ignore", FutureWarning)

### Model fit function

In [33]:
class DeseqDataSet(ad.AnnData):
    def __init__(self,
        *,
        adata: Optional[ad.AnnData] = None,
        counts: Optional[pd.DataFrame] = None,
        metadata: Optional[pd.DataFrame] = None,
        design_factors: Union[str, List[str]] = "condition",
        continuous_factors: Optional[List[str]] = None,
        ref_level: Optional[List[str]] = None,
        min_mu: float = 0.5,
        min_disp: float = 1e-8,
        max_disp: float = 10.0,
        refit_cooks: bool = True,
        min_replicates: int = 7,
        beta_tol: float = 1e-8,
        n_cpus: Optional[int] = None,
        batch_size: int = 128,
        joblib_verbosity: int = 0,
        quiet: bool = False,
    ) -> None:
        # Initialize the AnnData part
        if adata is not None:
            if counts is not None:
                warnings.warn(
                    "adata was provided; ignoring counts.", UserWarning, stacklevel=2
                )
            if metadata is not None:
                warnings.warn(
                    "adata was provided; ignoring metadata.", UserWarning, stacklevel=2
                )
            # Test counts before going further
            #test_valid_counts(adata.X)
            # Copy fields from original AnnData
            self.__dict__.update(adata.__dict__)
        
        elif counts is not None and metadata is not None:
            # Test counts before going further
            #test_valid_counts(counts)
            super().__init__(X=counts, obs=metadata)
            
        else:
            raise ValueError(
                "Either adata or both counts and metadata arguments must be provided."
            )
            
         # Convert design_factors to list if a single string was provided.
        self.design_factors = (
            [design_factors] if isinstance(design_factors, str) else design_factors
        )
        self.continuous_factors = continuous_factors
        
        if self.obs[self.design_factors].isna().any().any():
            raise ValueError("NaNs are not allowed in the design factors.")
        self.obs[self.design_factors] = self.obs[self.design_factors].astype(str)
        
        # Check that design factors don't contain underscores. If so, convert them to
        # hyphens.
        if np.any(["_" in factor for factor in self.design_factors]):
            warnings.warn(
                """Same factor names in the design contain underscores ('_'). They will
                be converted to hyphens ('-').""",
                UserWarning,
                stacklevel=2,
            )

            new_factors = replace_underscores(self.design_factors)

            self.obs.rename(
                columns={
                    old_factor: new_factor
                    for (old_factor, new_factor) in zip(self.design_factors, new_factors)
                },
                inplace=True,
            )

            self.design_factors = new_factors

            # Also check continuous factors
            if self.continuous_factors is not None:
                self.continuous_factors = replace_underscores(self.continuous_factors)

        # If ref_level has underscores, covert them to hyphens
        # Don't raise a warning: it will be raised by build_design_matrix()
        if ref_level is not None:
            ref_level = replace_underscores(ref_level)
        
        # Build the design matrix
        # Stored in the obsm attribute of the dataset
        self.obsm["design_matrix"] = build_design_matrix(
            metadata=self.obs,
            design_factors=self.design_factors,
            continuous_factors=self.continuous_factors,
            ref_level=ref_level,
            expanded=False,
            intercept=True,
        )
        
        # Check that the design matrix has full rank
        self._check_full_rank_design()
        
        self.min_mu = min_mu
        self.min_disp = min_disp
        self.max_disp = np.maximum(max_disp, self.n_obs)
        self.refit_cooks = refit_cooks
        self.ref_level = ref_level
        self.min_replicates = min_replicates
        self.beta_tol = beta_tol
        self.n_processes = get_num_processes(n_cpus)
        self.batch_size = batch_size
        self.joblib_verbosity = joblib_verbosity
        self.quiet = quiet
    
    def vst(
        self,
        use_design: bool = False,
        fit_type: Literal["parametric", "mean"] = "parametric",
    ) -> None:
        # Start by fitting median-of-ratio size factors, if not already present.
        if "size_factors" not in self.obsm:
            self.fit_size_factors()

        if use_design:
            # Check that the dispersion trend curve was fitted. If not, fit it.
            # This will call previous functions in a cascade.
            if "trend_coeffs" not in self.uns:
                self.fit_dispersion_trend()
        else:
            # Reduce the design matrix to an intercept and reconstruct at the end
            self.obsm["design_matrix_buffer"] = self.obsm["design_matrix"].copy()
            self.obsm["design_matrix"] = pd.DataFrame(
                1, index=self.obs_names, columns=[["intercept"]]
            )
            # Fit the trend curve with an intercept design
            self.fit_genewise_dispersions()
            if fit_type == "parametric":
                self.fit_dispersion_trend()

            # Restore the design matrix and free buffer
            self.obsm["design_matrix"] = self.obsm["design_matrix_buffer"].copy()
            del self.obsm["design_matrix_buffer"]

        # Apply VST
        if fit_type == "parametric":
            a0, a1 = self.uns["trend_coeffs"]
            cts = self.layers["normed_counts"]
            self.layers["vst_counts"] = np.log2(
                (1 + a1 + 2 * a0 * cts + 2 * np.sqrt(a0 * cts * (1 + a1 + a0 * cts)))
                / (4 * a0)
            )
        elif fit_type == "mean":
            gene_dispersions = self.varm["genewise_dispersions"]
            use_for_mean = gene_dispersions > 10 * self.min_disp
            mean_disp = trim_mean(gene_dispersions[use_for_mean], proportiontocut=0.001)
            self.layers["vst_counts"] = (
                2 * np.arcsinh(np.sqrt(mean_disp * self.layers["normed_counts"]))
                - np.log(mean_disp)
                - np.log(4)
            ) / np.log(2)
        else:
            raise NotImplementedError(
                f"Found fit_type '{fit_type}'. Expected 'parametric' or 'mean'."
            )
            
    def deseq2(self) -> None:
        
        """Perform dispersion and log fold-change (LFC) estimation.

        """
        # Compute DESeq2 normalization factors using the Median-of-ratios method
        self.fit_size_factors()
        # Fit an independent negative binomial model per gene
        self.fit_genewise_dispersions()
        # Fit a parameterized trend curve for dispersions, of the form
        self.fit_dispersion_trend()
        # Compute prior dispersion variance
        self.fit_dispersion_prior()
        # Refit genewise dispersions a posteriori (shrinks estimates towards trend curve)
        self.fit_MAP_dispersions()
        # Fit log-fold changes (in natural log scale)
        self.fit_LFC()
        self.calculate_cooks()
        
        if self.refit_cooks:
            # Replace outlier counts, and refit dispersions and LFCs
            # for genes that had outliers replaced
            self.refit()
        
    def fit_size_factors(
        self, fit_type: Literal["ratio", "iterative"] = "ratio"
    ) -> None:
        if not self.quiet:
            print("Fitting size factors...", file=sys.stderr)
        start = time.time()
        if fit_type == "iterative":
            self._fit_iterate_size_factors()
        # Test whether it is possible to use median-of-ratios.
        elif (self.X == 0).any(0).all():
            # There is at least a zero for each gene
            warnings.warn(
                "Every gene contains at least one zero, "
                "cannot compute log geometric means. Switching to iterative mode.",
                RuntimeWarning,
                stacklevel=2,
            )
            self._fit_iterate_size_factors()
        else:
            self.layers["normed_counts"], self.obsm["size_factors"] = deseq2_norm(self.X)
        end = time.time()

        if not self.quiet:
            print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr)
            
    
    def fit_genewise_dispersions(self) -> None:
        """Fit gene-wise dispersion estimates.

        Fits a negative binomial per gene, independently.
        """
        # Check that size factors are available. If not, compute them.
        if "size_factors" not in self.obsm:
            self.fit_size_factors()

        # Exclude genes with all zeroes
        self.varm["non_zero"] = ~(self.X == 0).all(axis=0)
        self.non_zero_idx = np.arange(self.n_vars)[self.varm["non_zero"]]
        self.non_zero_genes = self.var_names[self.varm["non_zero"]]

        if isinstance(self.non_zero_genes, pd.MultiIndex):
            raise ValueError("non_zero_genes should not be a MultiIndex")

        # Fit "method of moments" dispersion estimates
        self._fit_MoM_dispersions()

        # Convert design_matrix to numpy for speed
        design_matrix = self.obsm["design_matrix"].values
        
        if (
            len(self.obsm["design_matrix"].value_counts())
            == self.obsm["design_matrix"].shape[-1]
        ):
            with parallel_backend("loky", inner_max_num_threads=1):
                mu_hat_ = np.array(
                    Parallel(
                        n_jobs=self.n_processes,
                        verbose=self.joblib_verbosity,
                        batch_size=self.batch_size,
                    )(
                        delayed(fit_lin_mu)(
                            counts=self.X[:, i],
                            size_factors=self.obsm["size_factors"],
                            design_matrix=design_matrix,
                            min_mu=self.min_mu,
                        )
                        for i in self.non_zero_idx
                    )
                )
        else:
            with parallel_backend("loky", inner_max_num_threads=1):
                res = Parallel(
                    n_jobs=self.n_processes,
                    verbose=self.joblib_verbosity,
                    batch_size=self.batch_size,
                )(
                    delayed(irls_solver)(
                        counts=self.X[:, i],
                        size_factors=self.obsm["size_factors"],
                        design_matrix=design_matrix,
                        disp=self.varm["_MoM_dispersions"][i],
                        min_mu=self.min_mu,
                        beta_tol=self.beta_tol,
                    )
                    for i in self.non_zero_idx
                )

                _, mu_hat_, _, _ = zip(*res)
                mu_hat_ = np.array(mu_hat_)

        self.layers["_mu_hat"] = np.full((self.n_obs, self.n_vars), np.NaN)
        self.layers["_mu_hat"][:, self.varm["non_zero"]] = mu_hat_.T

        if not self.quiet:
            print("Fitting dispersions...", file=sys.stderr)
        start = time.time()
        with parallel_backend("loky", inner_max_num_threads=1):
            res = Parallel(
                n_jobs=self.n_processes,
                verbose=self.joblib_verbosity,
                batch_size=self.batch_size,
            )(
                delayed(fit_alpha_mle)(
                    counts=self.X[:, i],
                    design_matrix=design_matrix,
                    mu=self.layers["_mu_hat"][:, i],
                    alpha_hat=self.varm["_MoM_dispersions"][i],
                    min_disp=self.min_disp,
                    max_disp=self.max_disp,
                )
                # for i in range(num_genes)
                for i in self.non_zero_idx
            )
        end = time.time()

        if not self.quiet:
            print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr)

        dispersions_, l_bfgs_b_converged_ = zip(*res)

        self.varm["genewise_dispersions"] = np.full(self.n_vars, np.NaN)
        self.varm["genewise_dispersions"][self.varm["non_zero"]] = np.clip(
            dispersions_, self.min_disp, self.max_disp
        )

        self.varm["_genewise_converged"] = np.full(self.n_vars, np.NaN)
        self.varm["_genewise_converged"][self.varm["non_zero"]] = l_bfgs_b_converged_
        
    
    def fit_dispersion_trend(self) -> None:
        r"""Fit the dispersion trend coefficients.

        """

        # Check that genewise dispersions are available. If not, compute them.
        if "genewise_dispersions" not in self.varm:
            self.fit_genewise_dispersions()

        if not self.quiet:
            print("Fitting dispersion trend curve...", file=sys.stderr)
        start = time.time()
        self.varm["_normed_means"] = self.layers["normed_counts"].mean(0)

        # Exclude all-zero counts
        targets = pd.Series(
            self[:, self.non_zero_genes].varm["genewise_dispersions"].copy(),
            index=self.non_zero_genes,
        )
        covariates = sm.add_constant(
            pd.Series(
                1 / self[:, self.non_zero_genes].varm["_normed_means"],
                index=self.non_zero_genes,
            )
        )

        for gene in self.non_zero_genes:
            if (
                np.isinf(covariates.loc[gene]).any()
                or np.isnan(covariates.loc[gene]).any()
            ):
                targets.drop(labels=[gene], inplace=True)
                covariates.drop(labels=[gene], inplace=True)

        # Initialize coefficients
        old_coeffs = pd.Series([0.1, 0.1])
        coeffs = pd.Series([1.0, 1.0])

        while (np.log(np.abs(coeffs / old_coeffs)) ** 2).sum() >= 1e-6:
            glm_gamma = sm.GLM(
                targets.values,
                covariates.values,
                family=sm.families.Gamma(link=sm.families.links.identity()),
            )

            res = glm_gamma.fit()
            old_coeffs = coeffs.copy()
            coeffs = res.params

            # Filter out genes that are too far away from the curve before refitting
            predictions = covariates.values @ coeffs
            pred_ratios = (
                self[:, covariates.index].varm["genewise_dispersions"] / predictions
            )

            targets.drop(
                targets[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
                inplace=True,
            )
            covariates.drop(
                covariates[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
                inplace=True,
            )

        end = time.time()

        if not self.quiet:
            print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr)

        self.uns["trend_coeffs"] = pd.Series(coeffs, index=["a0", "a1"])

        self.varm["fitted_dispersions"] = np.full(self.n_vars, np.NaN)
        self.varm["fitted_dispersions"][self.varm["non_zero"]] = dispersion_trend(
            self.varm["_normed_means"][self.varm["non_zero"]],
            self.uns["trend_coeffs"],
        )
        
    
    def fit_dispersion_prior(self) -> None:
        """Fit dispersion variance priors and standard deviation of log-residuals.
        """

        # Check that the dispersion trend curve was fitted. If not, fit it.
        if "fitted_dispersions" not in self.varm:
            self.fit_dispersion_trend()

        # Exclude genes with all zeroes
        num_samples = self.n_obs
        num_vars = self.obsm["design_matrix"].shape[-1]

        # Check the degrees of freedom
        if (num_samples - num_vars) <= 3:
            warnings.warn(
                "As the residual degrees of freedom is less than 3, the distribution "
                "of log dispersions is especially asymmetric and likely to be poorly "
                "estimated by the MAD.",
                UserWarning,
                stacklevel=2,
            )

        # Fit dispersions to the curve, and compute log residuals
        disp_residuals = np.log(
            self[:, self.non_zero_genes].varm["genewise_dispersions"]
        ) - np.log(self[:, self.non_zero_genes].varm["fitted_dispersions"])

        # Compute squared log-residuals and prior variance based on genes whose
        # dispersions are above 100 * min_disp. This is to reproduce DESeq2's behaviour.
        above_min_disp = self[:, self.non_zero_genes].varm["genewise_dispersions"] >= (
            100 * self.min_disp
        )

        self.uns["_squared_logres"] = (
            mean_absolute_deviation(disp_residuals[above_min_disp]) ** 2
        )

        self.uns["prior_disp_var"] = np.maximum(
            self.uns["_squared_logres"] - polygamma(1, (num_samples - num_vars) / 2),
            0.25,
        )
        
    def fit_MAP_dispersions(self) -> None:
        """Fit Maximum a Posteriori dispersion estimates.

        After MAP dispersions are fit, filter genes for which we don't apply shrinkage.
        """

        # Check that the dispersion prior variance is available. If not, compute it.
        if "prior_disp_var" not in self.uns:
            self.fit_dispersion_prior()

        # Convert design matrix to numpy for speed
        design_matrix = self.obsm["design_matrix"].values

        if not self.quiet:
            print("Fitting MAP dispersions...", file=sys.stderr)
        start = time.time()
        with parallel_backend("loky", inner_max_num_threads=1):
            res = Parallel(
                n_jobs=self.n_processes,
                verbose=self.joblib_verbosity,
                batch_size=self.batch_size,
            )(
                delayed(fit_alpha_mle)(
                    counts=self.X[:, i],
                    design_matrix=design_matrix,
                    mu=self.layers["_mu_hat"][:, i],
                    alpha_hat=self.varm["fitted_dispersions"][i],
                    min_disp=self.min_disp,
                    max_disp=self.max_disp,
                    prior_disp_var=self.uns["prior_disp_var"].item(),
                    cr_reg=True,
                    prior_reg=True,
                )
                for i in self.non_zero_idx
            )
        end = time.time()

        if not self.quiet:
            print(f"... done in {end-start:.2f} seconds.\n", file=sys.stderr)

        dispersions_, l_bfgs_b_converged_ = zip(*res)

        self.varm["MAP_dispersions"] = np.full(self.n_vars, np.NaN)
        self.varm["MAP_dispersions"][self.varm["non_zero"]] = np.clip(
            dispersions_, self.min_disp, self.max_disp
        )

        self.varm["_MAP_converged"] = np.full(self.n_vars, np.NaN)
        self.varm["_MAP_converged"][self.varm["non_zero"]] = l_bfgs_b_converged_

        # Filter outlier genes for which we won't apply shrinkage
        self.varm["dispersions"] = self.varm["MAP_dispersions"].copy()
        self.varm["_outlier_genes"] = np.log(self.varm["genewise_dispersions"]) > np.log(
            self.varm["fitted_dispersions"]
        ) + 2 * np.sqrt(self.uns["_squared_logres"])
        self.varm["dispersions"][self.varm["_outlier_genes"]] = self.varm[
            "genewise_dispersions"
        ][self.varm["_outlier_genes"]]
        
    def fit_LFC(self) -> None:
        """Fit log fold change (LFC) coefficients.

        In the 2-level setting, the intercept corresponds to the base mean,
        while the second is the actual LFC coefficient, in natural log scale.
        """

        # Check that MAP dispersions are available. If not, compute them.
        if "dispersions" not in self.varm:
            self.fit_MAP_dispersions()

        # Convert design matrix to numpy for speed
        design_matrix = self.obsm["design_matrix"].values

        if not self.quiet:
            print("Fitting LFCs...", file=sys.stderr)
        start = time.time()
        with parallel_backend("loky", inner_max_num_threads=1):
            res = Parallel(
                n_jobs=self.n_processes,
                verbose=self.joblib_verbosity,
                batch_size=self.batch_size,
            )(
                delayed(irls_solver)(
                    counts=self.X[:, i],
                    size_factors=self.obsm["size_factors"],
                    design_matrix=design_matrix,
                    disp=self.varm["dispersions"][i],
                    min_mu=self.min_mu,
                    beta_tol=self.beta_tol,
                )
                for i in self.non_zero_idx
            )
        end = time.time()

        if not self.quiet:
            print(f"... done in {end-start:.2f} seconds.\n", file=sys.stderr)

        MLE_lfcs_, mu_, hat_diagonals_, converged_ = zip(*res)
        mu_ = np.array(mu_).T
        hat_diagonals_ = np.array(hat_diagonals_).T

        self.varm["LFC"] = pd.DataFrame(
            np.NaN,
            index=self.var_names,
            columns=self.obsm["design_matrix"].columns,
        )

        self.varm["LFC"].update(
            pd.DataFrame(
                MLE_lfcs_,
                index=self.non_zero_genes,
                columns=self.obsm["design_matrix"].columns,
            )
        )

        self.layers["_mu_LFC"] = np.full((self.n_obs, self.n_vars), np.NaN)
        self.layers["_mu_LFC"][:, self.varm["non_zero"]] = mu_

        self.layers["_hat_diagonals"] = np.full((self.n_obs, self.n_vars), np.NaN)
        self.layers["_hat_diagonals"][:, self.varm["non_zero"]] = hat_diagonals_

        self.varm["_LFC_converged"] = np.full(self.n_vars, np.NaN)
        self.varm["_LFC_converged"][self.varm["non_zero"]] = converged_
        
    
    def calculate_cooks(self) -> None:
        if "dispersions" not in self.varm:
            self.fit_MAP_dispersions()
        num_vars = self.obsm["design_matrix"].shape[-1]

        # Keep only non-zero genes
        nonzero_data = self[:, self.non_zero_genes]
        normed_counts = pd.DataFrame(
            nonzero_data.X / self.obsm["size_factors"][:, None],
            index=self.obs_names,
            columns=self.non_zero_genes,
        )
        
        dispersions = robust_method_of_moments_disp(
            normed_counts, self.obsm["design_matrix"]
        )
        V = (
            nonzero_data.layers["_mu_LFC"]
            + dispersions.values[None, :] * nonzero_data.layers["_mu_LFC"] ** 2
        )
        squared_pearson_res = (nonzero_data.X - nonzero_data.layers["_mu_LFC"]) ** 2 / V
        diag_mul = (
            nonzero_data.layers["_hat_diagonals"]
            / (1 - nonzero_data.layers["_hat_diagonals"]) ** 2
        )

        self.layers["cooks"] = np.full((self.n_obs, self.n_vars), np.NaN)
        self.layers["cooks"][:, self.varm["non_zero"]] = (
            squared_pearson_res / num_vars * diag_mul
        )
        
    def refit(self) -> None:
        # Replace outlier counts
        self._replace_outliers()
        if not self.quiet:
            print(
                f"Refitting {sum(self.varm['replaced']) } outliers.\n", file=sys.stderr
            )

        if sum(self.varm["replaced"]) > 0:
            # Refit dispersions and LFCs for genes that had outliers replaced
            self._refit_without_outliers()
            
    
    def _fit_MoM_dispersions(self) -> None:
        # Check that size_factors are available. If not, compute them.
        if "normed_counts" not in self.layers:
            self.fit_size_factors()

        rde = fit_rough_dispersions(
            self.layers["normed_counts"],
            self.obsm["design_matrix"],
        )
        mde = fit_moments_dispersions(
            self.layers["normed_counts"], self.obsm["size_factors"]
        )
        alpha_hat = np.minimum(rde, mde)

        self.varm["_MoM_dispersions"] = np.full(self.n_vars, np.NaN)
        self.varm["_MoM_dispersions"][self.varm["non_zero"]] = np.clip(
            alpha_hat, self.min_disp, self.max_disp
        )
        

    def plot_dispersions(
        self, log: bool = True, save_path: Optional[str] = None, **kwargs
    ) -> None:
        disps = [
            self.varm["genewise_dispersions"],
            self.varm["dispersions"],
            self.varm["fitted_dispersions"],
        ]
        legend_labels = ["Estimated", "Final", "Fitted"]
        make_scatter(
            disps,
            legend_labels=legend_labels,
            x_val=self.varm["_normed_means"],
            log=log,
            save_path=save_path,
            **kwargs,
        )
    
    def _replace_outliers(self) -> None:
        # Check that cooks distances are available. If not, compute them.
        if "cooks" not in self.layers:
            self.calculate_cooks()
        
        num_samples = self.n_obs
        num_vars = self.obsm["design_matrix"].shape[1]
        # Check whether cohorts have enough samples to allow refitting
        n_or_more = (
            self.obsm["design_matrix"][
                self.obsm["design_matrix"].columns[-1]
            ].value_counts()
            >= self.min_replicates
        )
        if n_or_more.sum() == 0:
            # No sample can be replaced. Set self.replaced to False and exit.
            self.varm["replaced"] = pd.Series(False, index=self.var_names)
            return

        replaceable = n_or_more[
            self.obsm["design_matrix"][self.obsm["design_matrix"].columns[-1]]
        ]

        self.obsm["replaceable"] = replaceable.values

        # Get positions of counts with cooks above threshold
        cooks_cutoff = f.ppf(0.99, num_vars, num_samples - num_vars)
        idx = self.layers["cooks"] > cooks_cutoff
        self.varm["replaced"] = idx.any(axis=0)

        if sum(self.varm["replaced"] > 0):
            # Compute replacement counts: trimmed means * size_factors
            self.counts_to_refit = self[:, self.varm["replaced"]].copy()

            trim_base_mean = pd.DataFrame(
                cast(
                    np.ndarray,
                    trimmed_mean(
                        self.counts_to_refit.X / self.obsm["size_factors"][:, None],
                        trim=0.2,
                        axis=0,
                    ),
                ),
                index=self.counts_to_refit.var_names,
            )

            replacement_counts = (
                pd.DataFrame(
                    trim_base_mean.values * self.obsm["size_factors"],
                    index=self.counts_to_refit.var_names,
                    columns=self.counts_to_refit.obs_names,
                )
                .astype(int)
                .T
            )

            self.counts_to_refit.X[
                self.obsm["replaceable"][:, None] & idx[:, self.varm["replaced"]]
            ] = replacement_counts.values[
                self.obsm["replaceable"][:, None] & idx[:, self.varm["replaced"]]
            ]


      
    def _refit_without_outliers(
        self,
    ) -> None:
        """Re-run the whole DESeq2 pipeline with replaced outliers."""
        assert (
            self.refit_cooks
        ), "Trying to refit Cooks outliers but the 'refit_cooks' flag is set to False"

        # Check that _replace_outliers() was previously run.
        if "replaced" not in self.varm:
            self._replace_outliers()

        # Only refit genes for which replacing outliers hasn't resulted in all zeroes
        new_all_zeroes = (self.counts_to_refit.X == 0).all(axis=0)
        self.new_all_zeroes_genes = self.counts_to_refit.var_names[new_all_zeroes]
        if (~new_all_zeroes).sum() == 0:  # if no gene can be refit, we can skip
            return

        self.counts_to_refit = self.counts_to_refit[:, ~new_all_zeroes].copy()
        if isinstance(self.new_all_zeroes_genes, pd.MultiIndex):
            raise ValueError

        sub_dds = DeseqDataSet(
            counts=pd.DataFrame(
                self.counts_to_refit.X,
                index=self.counts_to_refit.obs_names,
                columns=self.counts_to_refit.var_names,
            ),
            metadata=self.obs,
            design_factors=self.design_factors,
            ref_level=self.ref_level,
            min_mu=self.min_mu,
            min_disp=self.min_disp,
            max_disp=self.max_disp,
            refit_cooks=self.refit_cooks,
            min_replicates=self.min_replicates,
            beta_tol=self.beta_tol,
            n_cpus=self.n_processes,
            batch_size=self.batch_size,
        )

        # Use the same size factors
        sub_dds.obsm["size_factors"] = self.counts_to_refit.obsm["size_factors"]

        # Estimate gene-wise dispersions.
        sub_dds.fit_genewise_dispersions()

        # Compute trend dispersions.
        # Note: the trend curve is not refitted.
        sub_dds.uns["trend_coeffs"] = self.uns["trend_coeffs"]
        sub_dds.varm["_normed_means"] = (
            self.counts_to_refit.X / self.counts_to_refit.obsm["size_factors"][:, None]
        ).mean(0)
        sub_dds.varm["fitted_dispersions"] = dispersion_trend(
            sub_dds.varm["_normed_means"],
            sub_dds.uns["trend_coeffs"],
        )

        # Estimate MAP dispersions.
        # Note: the prior variance is not recomputed.
        sub_dds.uns["_squared_logres"] = self.uns["_squared_logres"]
        sub_dds.uns["prior_disp_var"] = self.uns["prior_disp_var"]

        sub_dds.fit_MAP_dispersions()

        # Estimate log-fold changes (in natural log scale)
        sub_dds.fit_LFC()

        # Replace values in main object
        to_replace = self.varm["replaced"].copy()
        # Only replace if genes are not all zeroes after outlier replacement
        to_replace[to_replace] = ~new_all_zeroes

        self.varm["_normed_means"][to_replace] = sub_dds.varm["_normed_means"]
        self.varm["LFC"][to_replace] = sub_dds.varm["LFC"]
        self.varm["dispersions"][to_replace] = sub_dds.varm["dispersions"]

        replace_cooks = pd.DataFrame(self.layers["cooks"].copy())
        replace_cooks.loc[self.obsm["replaceable"], to_replace] = 0.0

        self.layers["replace_cooks"] = replace_cooks
        # Take into account new all-zero genes
        if (new_all_zeroes).sum() > 0:
            self[:, self.new_all_zeroes_genes].varm["_normed_means"] = np.zeros(
                new_all_zeroes.sum()
            )
            self[:, self.new_all_zeroes_genes].varm["LFC"] = np.zeros(
                new_all_zeroes.sum()
            )
    
    def _fit_iterate_size_factors(self, niter: int = 10, quant: float = 0.95) -> None:
        """
        Fit size factors using the ``iterative`` method.

        Used when each gene has at least one zero.

        Parameters
        ----------
        niter : int
            Maximum number of iterations to perform (default: ``10``).

        quant : float
            Quantile value at which negative likelihood is cut in the optimization
            (default: ``0.95``).

        """

        # Initialize size factors and normed counts fields
        self.obsm["size_factors"] = np.ones(self.n_obs)
        self.layers["normed_counts"] = self.X

        # Reduce the design matrix to an intercept and reconstruct at the end
        self.obsm["design_matrix_buffer"] = self.obsm["design_matrix"].copy()
        self.obsm["design_matrix"] = pd.DataFrame(
            1, index=self.obs_names, columns=[["intercept"]]
        )

        # Fit size factors using MLE
        def objective(p):
            sf = np.exp(p - np.mean(p))
            nll = nb_nll(
                counts=self[:, self.non_zero_genes].X,
                mu=self[:, self.non_zero_genes].layers["_mu_hat"]
                / self.obsm["size_factors"][:, None]
                * sf[:, None],
                alpha=self[:, self.non_zero_genes].varm["dispersions"],
            )
            # Take out the lowest likelihoods (highest neg) from the sum
            return np.sum(nll[nll < np.quantile(nll, quant)])

        for i in range(niter):
            # Estimate dispersions based on current size factors
            self.fit_genewise_dispersions()

            # Use a mean trend curve
            use_for_mean_genes = self.var_names[
                (self.varm["genewise_dispersions"] > 10 * self.min_disp)
                & self.varm["non_zero"]
            ]

            mean_disp = trimmed_mean(
                self[:, use_for_mean_genes].varm["genewise_dispersions"], trim=0.001
            )
            self.varm["fitted_dispersions"] = np.ones(self.n_vars) * mean_disp
            self.fit_dispersion_prior()
            self.fit_MAP_dispersions()
            old_sf = self.obsm["size_factors"].copy()

            # Fit size factors using MLE
            res = minimize(objective, np.log(old_sf), method="Powell")

            self.obsm["size_factors"] = np.exp(res.x - np.mean(res.x))

            if not res.success:
                print("A size factor fitting iteration failed.", file=sys.stderr)
                break

            if (i > 1) and np.sum(
                (np.log(old_sf) - np.log(self.obsm["size_factors"])) ** 2
            ) < 1e-4:
                break
            elif i == niter - 1:
                print("Iterative size factor fitting did not converge.", file=sys.stderr)

        # Restore the design matrix and free buffer
        self.obsm["design_matrix"] = self.obsm["design_matrix_buffer"].copy()
        del self.obsm["design_matrix_buffer"]

        # Store normalized counts
        self.layers["normed_counts"] = self.X / self.obsm["size_factors"][:, None]
        
    def _check_full_rank_design(self):
        """Check that the design matrix has full column rank."""
        rank = np.linalg.matrix_rank(self.obsm["design_matrix"])
        num_vars = self.obsm["design_matrix"].shape[1]

        if rank < num_vars:
            warnings.warn(
                "The design matrix is not full rank, so the model cannot be "
                "fitted, but some operations like design-free VST remain possible. "
                "To perform differential expression analysis, please remove the design "
                "variables that are linear combinations of others.",
                UserWarning,
                stacklevel=2,
            )

In [57]:
# Load data
rna_counts = pd.read_csv('data_simulation/sim4_omics/rna_cnv.csv', index_col=0)
rna_counts = rna_counts.T
metadata = pd.read_csv('data_simulation/sim4_omics/metadata.csv', index_col=0)
#cnv_tumor = pd.read_csv('data_simulation/cnv_tumor.csv', index_col=0)

In [58]:
rna_counts

Unnamed: 0,G 1,G 2,G 3,G 4,G 5,G 6,G 7,G 8,G 9,G 10,...,G 14991,G 14992,G 14993,G 14994,G 14995,G 14996,G 14997,G 14998,G 14999,G 15000
S1,45.000000,179.000002,677.000007,16.000000,1428.000014,94.000001,115.000001,13.000000,12.000000,2804.000028,...,407.000004,1032.000005,1174.000012,221.000002,405.000004,349.000003,558.000006,7.0,229.500002,362.000004
S2,78.000000,438.000002,846.000004,30.000000,3748.000019,178.000001,40.000000,22.000000,66.000000,5246.000026,...,452.000005,504.000005,1132.500008,249.000002,490.000005,772.000008,1333.500009,7.0,111.000001,359.000004
S3,18.000000,146.000001,611.000006,6.000000,2511.000025,737.000007,77.000001,28.000000,29.000000,3103.000031,...,1143.000008,737.000007,1189.000012,358.000004,314.000003,753.000008,1248.000008,14.0,247.000002,651.000004
S4,4.000000,115.000001,488.000005,9.000000,2263.000023,165.000002,105.000001,24.000000,32.000000,2428.000024,...,902.000009,636.000006,705.000014,401.000004,613.000006,461.000005,809.000008,11.0,219.000002,486.000005
S5,116.000001,352.000002,1080.000005,54.000000,4292.000021,96.000000,458.000002,70.000000,50.000000,7088.000035,...,713.000007,548.000005,2007.000013,774.000005,748.500005,789.000008,912.000009,11.0,466.000002,390.000004
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
S68,160.000002,1280.000013,4200.000042,128.000001,18688.000187,752.000008,776.000008,80.000001,208.000002,20496.000205,...,474.000005,596.000006,1027.000010,223.000002,1246.500012,677.000007,1309.000013,2.0,212.000002,447.000004
S69,24.000000,256.500003,699.000007,7.500000,2944.500029,139.500001,81.000001,15.000000,69.000001,4017.000040,...,651.000007,758.000008,1186.500012,313.000003,1148.000011,321.000003,701.000007,9.0,291.000003,561.000006
S70,27.000000,324.000003,514.500005,10.500000,3411.000034,175.500002,153.000002,30.000000,39.000000,4837.500048,...,342.000003,814.000008,1066.000011,235.000002,766.000008,40.000000,825.000008,6.0,184.000002,655.500007
S71,70.000001,124.000001,556.000006,27.000000,2732.000027,10.000000,32.000000,24.000000,18.000000,2314.000023,...,638.000006,944.000009,1119.000011,440.000004,750.000008,556.000006,654.000007,14.0,335.000003,435.000004


#### Generate CN corrected RNA counts matrix

In [138]:
import random

def create_matrix(rows, cols, min_val, max_val):
    return [[random.randint(min_val, max_val) for _ in range(cols)] for _ in range(rows)]

In [139]:
def rna_cnv_count_matrix(rna_counts, cnv_tumor):
    cnv_normal_mat = create_matrix(19979, 45, 2, 2)
    cnv_normal_mat= np.array(cnv_normal)
    cnv_tumor_mat = np.array(cnv_tumor)
    cnv = np.concatenate((cnv_tumor_mat, cnv_normal_mat), axis=1) 
    cnv = cnv/2

    counts_mat = np.array(rna_counts)
    rna_counts_cnv = np.multiply(counts_mat, cnv)
    rna_counts_cnv = pd.DataFrame(rna_counts_cnv)
    
    # Reassign rownames and column names
    gene_id = rna_counts.index
    gene_id = pd.DataFrame(gene_id)
    gene_id.rename(columns = {0:'geneID'}, inplace = True) 
    rna_counts_cnv = pd.concat([gene_id, rna_counts_cnv], axis=1)
    rna_counts_cnv.set_index('geneID', inplace = True)
    sample_id = rna_counts.columns
    rna_counts_cnv.columns = sample_id
    rna_counts_cnv = rna_counts_cnv.T

    return rna_counts_cnv

In [140]:
rna_counts_cnv = rna_cnv_count_matrix(rna_counts, cnv_tumor)


#### Model fit

In [85]:
def test_pydeseqCN(rna_counts, metadata):
    
    # Create dds object
    dds = DeseqDataSet(
        counts=rna_counts,
        metadata=metadata,
        design_factors="condition",
        refit_cooks=False,
        n_cpus=8
    )
    dds.deseq2()
    # Statistical test
    stat_res = DeseqStats(dds, 
                      contrast=['condition', 'B', 'A'], 
                      alpha=0.05, 
                      cooks_filter=False, 
                      independent_filter=True, 
                      prior_LFC_var=None, 
                      lfc_null=0, 
                      alt_hypothesis=None, 
                      inference=None, quiet=False
                         )
    stat_res.summary()
    # LFC shrinkage (apeGLM) 
    stat_res.lfc_shrink(coeff="condition_B_vs_A")
    res_df = stat_res.results_df
    return res_df

In [86]:
res = test_pydeseqCN(rna_counts, metadata)

Fitting size factors...
... done in 0.04 seconds.

Fitting dispersions...
... done in 1.10 seconds.

Fitting dispersion trend curve...
... done in 0.74 seconds.

Fitting MAP dispersions...
... done in 1.31 seconds.

Fitting LFCs...
... done in 0.56 seconds.

Running Wald tests...
... done in 0.42 seconds.

Fitting MAP LFCs...


Log2 fold change & Wald test p-value: condition B vs A
            baseMean  log2FoldChange     lfcSE      stat    pvalue      padj
G 1        47.724375       -0.146760  0.293935 -0.499296  0.617571  0.999868
G 2       363.551597       -0.067415  0.211498 -0.318749  0.749917  0.999868
G 3       879.641536       -0.087083  0.171480 -0.507833  0.611570  0.999868
G 4        24.362647        0.233736  0.268687  0.869921  0.384343  0.999868
G 5      3707.907244       -0.062938  0.192512 -0.326931  0.743720  0.999868
...              ...             ...       ...       ...       ...       ...
G 14996   594.658205       -0.197746  0.184121 -1.074002  0.282822  0.999868
G 14997  1132.731293        0.076383  0.139005  0.549498  0.582664  0.999868
G 14998    14.009224        0.401684  0.229216  1.752428  0.079700  0.999868
G 14999   265.627439        0.098208  0.148001  0.663565  0.506969  0.999868
G 15000   438.623532       -0.003224  0.089394 -0.036066  0.971230  0.999868

[15000 rows x 6 colu

... done in 2.19 seconds.



#### Save results

In [84]:
# Replace this with the path to directory where you would like results to be saved
OUTPUT_PATH = "data_simulation/results/sim_4/"
os.makedirs(OUTPUT_PATH, exist_ok=True)  # Create path if it doesn't exist
res.to_csv(os.path.join(OUTPUT_PATH, "res_sim_cnv.csv"))

In [72]:
res

Unnamed: 0,baseMean,log2FoldChange,lfcSE,stat,pvalue,padj
G 1,47.724375,-1.759724e-06,0.001408,-0.499296,0.617571,0.999868
G 2,363.551597,-1.657652e-06,0.001426,-0.318749,0.749917,0.999868
G 3,879.641536,-3.458112e-06,0.001421,-0.507833,0.611570,0.999868
G 4,24.362647,3.328379e-06,0.001500,0.869921,0.384343,0.999868
G 5,3707.907244,-1.724637e-06,0.001427,-0.326931,0.743720,0.999868
...,...,...,...,...,...,...
G 14996,594.658205,-6.047613e-06,0.001396,-1.074002,0.282822,0.999868
G 14997,1132.731293,4.188237e-06,0.001462,0.549498,0.582664,0.999868
G 14998,14.009224,7.927858e-06,0.001532,1.752428,0.079700,0.999868
G 14999,265.627439,3.617875e-06,0.001467,0.663565,0.506969,0.999868
