# Task 1 - In silico perturbation workflow

## Imports

In [1]:
import scanpy as sc
import numpy as np

from collections.abc import Sequence
from typing import Literal
from anndata import AnnData

## Functions

In [2]:
def perturb_genes(
    adata: AnnData,
    target: tuple[str, Sequence[str]],
    factor: tuple[float, Sequence[float]],
    copy: bool = False,
) -> AnnData | None:
    """Perturb selected genes.

    Change the expression of provided genes by a factor and round back to integers.
    Expects raw count data.

    Parameters
    ----------
    adata
        The annotated data matrix.
    target
        Which gene(s) to perturb.
    factor
        Factor for simulated perturbation per provided gene. Must be a positive number.
    copy
        Whether to copy `adata` or modify it inplace.

    Returns
    -------
    Returns `None` if `copy=False`, else returns an `AnnData` object.

    """
    # Sanity checks on inputs
    if isinstance(target, str):
        target = [target]
    if isinstance(factor, float):
        factor = [factor]
        
    if len(target) != len(factor):
        msg = (
            f"the number of provided targets must match the number of provided factors, but {len(target)} and {len(factor)} values were passed, respectively"
        )
        raise ValueError(msg)

    missing_genes = [t for t in target if t not in adata.var_names]
    if missing_genes:
        msg = (
            f"the following provided targets were not found in the index of the provided anndata object: {missing_genes!r}. Example names are {adata.var_names[:3].tolist()}"
        )
        raise ValueError(msg)

    if [f for f in factor if f < 0.]:
        msg = (
            f"some provided factors were smaller than 0. Please only provide positive numbers."
        )
        raise ValueError(msg)

    if len(target) != len(set(target)):
        msg = (
            f"target contains duplicate values."
        )
        raise ValueError(msg)

    # Create copy of adata if requested
    adata = adata.copy() if copy else adata

    # Perturb genes
    perturbed_counts = adata[:, target].X.multiply(factor)
    perturbed_counts = np.round(perturbed_counts)
    adata[:, target].X = perturbed_counts.tocsr()

    return adata if copy else None

## Testing

In [4]:
# Load data and check expression of two example genes
adata = sc.read_h5ad("../data/processed_data.h5ad")
adata[:, ["ENSG00000171532", "ENSG00000078018"]].X.A

array([[ 1., 13.],
       [ 0., 26.],
       [ 0., 68.],
       ...,
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  1.]], dtype=float32)

In [5]:
# Apply perturbation function in place
perturb_genes(adata, ["ENSG00000171532", "ENSG00000078018"], [2., 0.5])

In [6]:
# Check counts of perturbed genes
adata[:, ["ENSG00000171532", "ENSG00000078018"]].X.A

array([[ 2.,  6.],
       [ 0., 13.],
       [ 0., 34.],
       ...,
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.]], dtype=float32)