In [None]:
import itertools

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import rp2
from rp2 import hagai_2018, create_folder
from rp2.paths import get_output_path

rp2.check_environment()

In [None]:
def load_gene_list(csv_path, symbol_column="symbol"):
    df = pd.read_csv(csv_path)
    return df[symbol_column].values

## Extract expression counts from Hagai *et al.* (2018) dataset

To extract RNA counts for specific genes and conditions:,
1. Ensure all previous cells have been executed
2. Edit the values of extract_conditions (and show_histograms) accordingly in the cell below
3. Run the cell below

Output files will be created in *Output/Misc/Counts*.

Note that it is possible to combine conditions (i.e. replicate, treatment and time_point) by creating a list of lists. For example
```python
replicate = [["1", "2"], "3"]
```
will create results with replicates ! and 2 combined and 3 as a separate condition.

For bulk exporting, a list of genes can be loaded from a .csv file as follows:

```python
genes = load_gene_list("Output/burst_trends_gene_list.csv")
```

In [None]:
extract_conditions = dict(
    species=["mouse"],
    count_type=["median"],  # Choose from umi, cpt, cpm or median
    gene=["tnf"],
    replicate=["1", "2", "3"],
    treatment=[["unst", "lps", "pic"]],
    time_point=["0", "2", "4", "6"],
)

show_histograms = True


output_path = get_output_path("Misc", "Counts")
create_folder(output_path)

count_adata_map = {}

for values in itertools.product(*extract_conditions.values()):
    extract_dict = dict(zip(extract_conditions.keys(), values))

    species = extract_dict.pop("species")
    count_type = extract_dict.pop("count_type")
    gene_symbol = extract_dict.pop("gene").lower()

    count_key = f"{species}-{count_type}"
    if count_key in count_adata_map:
        count_adata = count_adata_map[count_key]
    else:
        print(f'Loading {species} "{count_type}" counts')
        count_adata = hagai_2018.load_counts(species, scaling=count_type)
        count_adata_map[count_key] = count_adata

    count_adata = count_adata[:, count_adata.var.symbol.str.lower() == gene_symbol]
    if count_adata.n_vars != 1:
        if count_adata.n_vars == 0:
            print(f"  No genes with symbol '{gene_symbol}'")
        else:
            print(f"  Multiple genes with symbol '{gene_symbol}':")
            for name in count_adata:
                print(f"  {name}")
        continue

    filename = f"{species}_{count_type}_{gene_symbol}"

    for k, v in extract_dict.items():
        if not isinstance(v, list):
            v = [v]
        count_adata = count_adata[count_adata.obs[k].isin(v), :]

        filename += f"_{k}_" + "_".join(v)

    filename += ".txt"

    print(f"File: {filename}")
    print(f"  {len(count_adata):,} samples")

    counts = np.asarray(count_adata.X.A).squeeze().round().astype(np.int)
    np.savetxt(output_path.joinpath(filename), counts, fmt="%d")

    if show_histograms:
        plt.hist(counts, log=True)
        plt.xlabel("Count")
        plt.ylabel("Frequency")
        plt.show()

## UMI count to transcript number

From Grün *et al.* (2014):

**Conversion of UMI count to transcript number.**

For each gene $i$, $k_{o,i}$ denotes the number of observed UMIs and $k_{n,i}$ the number of non-observed UMIs. The total number $K$ of UMIs is given by

$K=k_{o,i}+k_{n,i}$

the number of sequenced transcripts $m_i$

$m_i=\frac{\ln{\left(1-\frac{k_{o,i}}{K}\right)}}{\ln{\left(1-\frac{1}{K}\right)}}\cong-K\ln{\left(1-\frac{k_{o,i}}{K}\right)}$

Define a function implementing this formula:

In [None]:
def umi_to_transcript_count(ko, K, approx=False):
    ko = np.asarray(ko, dtype=np.float)
    num = np.log(1 - (ko / K))
    if approx:
        return -K * num
    den = np.log(1 - (1 / K))
    return num / den

Plot the mapping from UMI count to transcript number:

In [None]:
tenx_umi_bases = 10
tenx_umi_max = 4**tenx_umi_bases
print(f"{tenx_umi_bases} bases in 10X experiments permits up to {tenx_umi_max:,} UMIs")

umi_counts = np.arange(tenx_umi_max)
transcript_numbers = umi_to_transcript_count(umi_counts, tenx_umi_max).astype(np.int)

plt.plot(umi_counts, transcript_numbers)
plt.xlabel("UMI count")
plt.ylabel("Transcript number")
plt.show()

In [None]:
identical_until = np.argwhere(umi_counts == transcript_numbers).max()
print(f"The UMI-transcript transform is an identity mapping for UMI counts up to {identical_until:,}")

plt.plot(umi_counts, transcript_numbers)
plt.xlabel("UMI count")
plt.ylabel("Transcript number")
plt.axvline(x=identical_until, ls="--")
plt.axhline(y=identical_until, ls=":")
plt.xlim(0, identical_until * 4)
plt.ylim(0, identical_until * 4)
plt.show()

In [None]:
umi_count_ad = count_adata_map["mouse-umi"]
umi_count_ad.var["max"] = umi_count_ad.X.max(axis=0).A.squeeze().astype(np.int)

high_count_var = umi_count_ad.var.loc[umi_count_ad.var["max"] >= identical_until].sort_values(by="max", ascending=False)
print(f"{len(high_count_var):,} genes out of {umi_count_ad.n_vars:,} have UMI count >= {identical_until:,}")

for i, row in enumerate(high_count_var.itertuples(), start=1):
    print(f"  {i}. {row.symbol}: max UMI of {row.max:,} maps to {transcript_numbers[row.max]:,} transcripts")