In [None]:
import anndata
import itertools
import matplotlib.pyplot as plt
import numpy as np

from rp2 import create_folder, load_biomart_gene_symbols_df
from rp2.paths import get_data_path, get_output_path

## Extract UMI counts from Hagai *et al.* (2018) dataset

Load full UMI counts in preparation to extract subsets.

In [None]:
def load_umi_count(species):
    adata = anndata.read_h5ad(get_data_path("ArrayExpress", f"E-MTAB-6754.processed.2.{species}.h5ad"))
    gene_symbols_df = load_biomart_gene_symbols_df(species)
    adata.var["gene_symbol"] = adata.var_names.map(gene_symbols_df.symbol).str.lower()
    return adata


available_species = ["mouse"]

umi_ad_map = {species: load_umi_count(species) for species in available_species}

To extract UMI counts for specific genes and conditions:
1. Ensure all previous cells have been executed
2. Edit the values of extract_conditions (and show_histograms) accordingly in the cell below
3. Run the cell below

Output files will be created in *Output/Misc/UmiCounts*.

Note that it is possible to combine conditions (but not genes) by creating a list of lists. For example
```python
replicate = [["1", "2"], "3"]
```
will create results with replicates ! and 2 combined and 3 as a separate condition.

In [None]:
output_path = get_output_path("Misc", "UmiCounts")
create_folder(output_path)

extract_conditions = dict(
    species=["mouse"],
    gene=["tnf"],
    replicate=["1", "2", "3"],
    treatment=[["unst", "lps"]],
    time_point=["0", "2", "4", "6"],
)

show_histograms = True

for values in itertools.product(*extract_conditions.values()):
    extract_dict = dict(zip(extract_conditions.keys(), values))

    species = extract_dict.pop("species")
    gene_symbol = extract_dict.pop("gene").lower()

    umi_ad = umi_ad_map[species]
    
    umi_ad = umi_ad[:, umi_ad.var.gene_symbol == gene_symbol]
    if umi_ad.n_vars != 1:
        if umi_ad.n_vars == 0:
            print(f"  No genes with symbol '{gene_symbol}'")
        else:
            print(f"  Multiple genes with symbol '{gene_symbol}':")
            for name in umi_ad:
                print(f"  {name}")
        continue

    file_prefix = f"species={species}-gene={gene_symbol}"

    for k, v in extract_dict.items():
        if not isinstance(v, list):
            v = [v]
        umi_ad = umi_ad[umi_ad.obs[k].isin(v), :]

        file_prefix += f"-{k}={'+'.join(v)}"

    filename = file_prefix + ".txt"
    print(filename)
    print(f"  {len(umi_ad):,} samples")

    counts = umi_ad.X.A.squeeze().astype(np.int)
    np.savetxt(output_path.joinpath(filename), counts, fmt="%d")

    if show_histograms:
        plt.hist(counts, log=True)
        plt.xlabel("UMI count")
        plt.ylabel("Frequency")
        plt.show()