In [None]:
import anndata
import numpy as np
import logging
import pandas as pd

In [None]:
# data = anndata.read_h5ad("/msc/home/q56ppene/cellwhisperer/cellwhisperer/resources/cellxgene_census/00476f9f-ebc1-4b72-b541-32f912ce36ea.h5ad")
# data

In [None]:
proc = anndata.read_h5ad("/msc/home/q56ppene/cellwhisperer/cellwhisperer/resources/cellxgene_census/cellxgene_census_00476f9f-ebc1-4b72-b541-32f912ce36ea_processed.h5ad")
# proc

In [None]:
ensembl_symbol_conversion = pd.read_csv(snakemake.input.ensembl_symbol_conversion).set_index("ensembl_gene_id")["external_gene_name"]

In [None]:
def layerize_dataset(proc):
    """
    Generate `num_replicates` adata layers, each one containing one of the sampled cells. 
    If a given group has less than `num_replicates` cells, repeatedly draw the existing cells (via modulo)
    """

    # drop 
    symboled_ensids = proc.var.index.intersection(ensembl_symbol_conversion.index)
    
    # workaround for missing cell_id in pseudocells. works because it is ordered
    adata = proc[proc.obs["is_pseudobulk"] == "True", symboled_ensids].copy()
    adata.obs = adata.obs.drop(columns=["is_pseudobulk", "replicate"])
    adata.obs.index = [x.replace("_pseudobulk",f"_{i}") for i,x in enumerate(adata.obs.index)]
    adata.obs["cell_id"] = adata.obs.index
    
    nonpseudo = proc[proc.obs["is_pseudobulk"] == "False"]
    nonpseudo.obs["cell_id"] = nonpseudo.obs.index.map(lambda v: v.rsplit("_", maxsplit=1)[0])  # TODO test whether it works and yields the cell_id (i.e. {dataset_id}_{i})
    for i in range(snakemake.params.num_replicates):
        def get_replicate_modulo(cell_id):
            candidates = nonpseudo.obs.index[nonpseudo.obs.cell_id == cell_id]
            if len(candidates) == 0:
                raise RuntimeError(f"{cell_id} has 0 replicates")
            return candidates[i % len(candidates)]
        indices = adata.obs.cell_id.apply(get_replicate_modulo).values
        
        adata.layers[f"replicate_{i+1}"] = nonpseudo[indices, symboled_ensids].X.astype(np.int32)
    
    # set cell_id
    # adata.obs.index = adata.obs.apply(lambda row: row.name.replace("_pseudobulk", f"_{row.cell_id}"))
    adata.obs.set_index("cell_id", inplace=True)
    # adata.obs.drop(columns=["is_pseudobulk", "replicate"], inplace=True)
    adata.obs["abstract"] = pd.Categorical([adata.uns["abstract"]] * len(adata.obs), categories=[adata.uns["abstract"]])
    adata.obs["dataset_title"] = pd.Categorical([adata.uns["dataset_title"]] * len(adata.obs), categories=[adata.uns["dataset_title"]])
    adata.uns = {}
    
    # convert int64 to int32
    int_conv = dict.fromkeys(adata.obs.select_dtypes(np.int64).columns, np.int32)
    adata.obs = adata.obs.astype(int_conv)
    
    return adata

layerize_dataset(proc).obs.dtypes

In [None]:
# takes ~30 minutes

datasets = []
for fn in snakemake.input.datasets:
    proc = anndata.read_h5ad(fn)
    dataset = layerize_dataset(proc)
    datasets.append(dataset)
del proc

In [None]:
full = anndata.concat(datasets, join="outer")
del datasets  # save memory

In [None]:
# outer fills '0' in sparse matrices, so we will need to compress them
for matrix in full.layers.values():
    matrix.eliminate_zeros()
    matrix = matrix.astype(np.int32)  # the join resulted in np.float64
full.X.eliminate_zeros()
full.X = full.X.astype(np.int32)

In [None]:
full.var["ensembl_id"] = full.var.index
full.obs.drop(columns=["cluster_id", "GSE"], inplace=True)  # fails otherwise
full.var["gene_name"] = full.var.index.map(ensembl_symbol_conversion.get)

In [None]:
for col in full.obs.columns:
    types = full.obs[col].apply(type).unique()
    if len(full.obs[col].apply(type).unique()) > 1:
        full.obs[col] = full.obs[col].astype(str).replace("nan", np.nan).astype("category")

In [None]:
full.write_h5ad(snakemake.output[0])

In [None]:
# orig_obs = full.obs.copy()
