In [None]:
import scanpy as sc
import anndata as ad
import pandas as pd
import numpy as np
import scipy as sp

In [None]:
adata = sc.read_h5ad(snakemake.input.read_count_table)
adata.var.rename(columns={"symbol": "gene_name"}, inplace=True)
adata = adata[:, adata.var["gene_name"].notnull()]
adata.var.set_index("gene_name", inplace=True)

In [None]:
# CPM-normalize each sample using scanpy and get log1p
sc.pp.normalize_total(adata, target_sum=1e6)

In [None]:
sc.pp.log1p(adata)

In [None]:
if snakemake.params.filter_protein_coding:
    adata = adata[:, adata.var["biotype"] == "protein_coding"]

In [None]:
# Filter for highly variable genes
# sc.pp.highly_variable_genes(adata, n_top_genes=snakemake.params.num_hvg)
# adata_filtered = adata[:, adata.var.highly_variable]

In [None]:
weights_dict = np.load(snakemake.input.weights, allow_pickle=True)
assert all(weights_dict["orig_ids"] == adata.obs.index)
weights = weights_dict["weight"]

In [None]:
# sample 
np.random.seed(snakemake.params.seed)
sampled_obs = np.random.choice(adata.obs.index, size=min(snakemake.params.num_samples, len(adata)), p=weights/weights.sum(), replace=False)
adata = adata[sampled_obs]

In [None]:
# adata.to_df().to_csv(snakemake.output["sparse_matrix"])  # CSV is a stupid format for this

In [None]:
sp.io.mmwrite(snakemake.output["sparse_matrix"], adata.X.transpose())
pd.Series(adata.obs.index).to_csv(snakemake.output["colnames"], index=False)
pd.Series(adata.var.index).to_csv(snakemake.output["rownames"], index=False)