In [None]:
cd ..

## Fit polynomial curves to gene expression

In [None]:
from tqdm.auto import tqdm
import anndata
import numpy as np
import gc
import seaborn as sns
import matplotlib.pyplot as plt

### HLCA & PBMC data

Loads the pseudotime ordered AnnData's.

In [None]:
# Pick cell type if running hlca from 
# Endothelial, AT2, Fibroblasts

CELL_TYPE = 'Fibroblasts'

meta = {
    'hlca': {
        'n_time_points': 6,
        'filenames': f'adata_{CELL_TYPE}_',
    },
    'pbmc': {
        'n_time_points': 4,
        'filenames': 'adata_PBMC_',
    }
}

Pick dataset from 

- pbmc
- hlca

In [None]:
# dataset = 'pbmc'
dataset = 'hlca'
name = dataset + (f"_{CELL_TYPE}" if dataset == 'hlca' else '')

#### Read AnnDatas

In [None]:
n_time_points = meta[dataset]['n_time_points']

adata_list = []

for i in tqdm(range(n_time_points)):
    adata = anndata.read(f"data/pseudotime_adatas/{meta[dataset]['filenames']}{i}.h5ad")
    adata_list.append(adata)

In [None]:
n_cell_list = [adata.shape[0] for adata in adata_list]
n_genes = adata_list[0].shape[1]
n_cells = sum(n_cell_list)

Sort each data matrix based on pseudotime ordering and concatenate

In [None]:
x = np.zeros((n_cells, n_genes))

start_idx = 0
for adata in tqdm(adata_list):
    idx_order = adata.obs['t'].argsort()
    x[start_idx: start_idx + adata.shape[0]] = adata.X.toarray()[idx_order]
    start_idx += adata.shape[0]

print(f"{x.shape=}")

In [None]:
del adata_list  # free memory
gc.collect()

### Fit polynomials

In [None]:
DEG = 4

In [None]:
memory_friendly = True

if memory_friendly:
    """slower but seems to use less memory"""
    coefs = np.zeros((n_genes, DEG + 1))
    ticks = np.arange(n_cells)
    
    for gene_idx in tqdm(range(n_genes)):  # iterate over genes
        x_gene = x[:, gene_idx]
        coefs[gene_idx] = np.polyfit(ticks, x_gene, deg=DEG)
else:
    coefs = np.polyfit(np.arange(len(x)), x, deg=DEG)

In [None]:
np.save(open(f"data/coefs/coefs_{name}_deg{DEG}.npz", "wb"), coefs)

In [None]:
# Plot example gene
idx = 1235
fig, ax = plt.subplots()
y_smooth = np.poly1d(coefs[idx])(np.arange(n_cells))
sns.lineplot(y_smooth, ax=ax)
# sns.lineplot(x[:, idx], ax=ax, alpha=0.1)

### Pick nodes

In [None]:
coefs = np.load(open(f"data/coefs/coefs_{name}_deg{DEG}.npz", "rb"))

In [None]:
STEPS = 10  # to save space, will evaluate poly every STEPS nodes

In [None]:
ticks = np.arange(0, n_cells, STEPS)
ys = np.zeros((n_genes, ticks.shape[0]))

for i, coef in enumerate(tqdm(coefs)):
    ys[i] = np.poly1d(coef)(ticks)

In [None]:
start_idx = 0
ys_split = []
for n_cell in n_cell_list:
    y = ys[:, start_idx: start_idx + n_cell // STEPS]
    start_idx += int(np.ceil(n_cell / STEPS))
    ys_split.append(y)

In [None]:
[y.shape for y in ys_split]

In [None]:
for i, y in enumerate(tqdm(ys_split)):
    np.save(open(f"data/polys/ys_{name}_deg{DEG}_{i}.npz", "wb"), y)

In [None]:
# Plot example gene
fig, ax = plt.subplots()
y_smooth = np.concatenate([y[idx] for y in ys_split])
sns.lineplot(y_smooth, ax=ax)
# sns.lineplot(x[:, idx], ax=ax, alpha=0.1)