In [None]:
import os
import json
import h5py
import anndata as ad
import pandas as pd
import numpy as np
from scipy.sparse import spmatrix, hstack, csr_matrix, csc_matrix

In [None]:
out_dir = "./"
out_h5ad_suffix = "_celltypes"
c2l_output_h5ad_file = "/path/to/anndata.h5ad"
spatial_h5ad_file = "/path/to/spatial_anndata.h5ad"
q = "q05_cell_abundance_w_sf"

# obs key and value tuple to get single sample from objects if applicable else None
# e.g. sample = ("sample", "sp1234")
sample = None

# var key to reindex genes if applicable else None
# e.g. var_reindex = "SYMBOL"
var_reindex = None

In [None]:
# Read only obs, var and obsm from the cell2location output h5ad
with h5py.File(c2l_output_h5ad_file) as f:
    c2l_adata = ad.AnnData(
        obs=ad._io.h5ad.read_elem(f["obs"]) if "obs" in f else None,
        var=ad._io.h5ad.read_elem(f["var"]) if "var" in f else None,
        obsm=ad._io.h5ad.read_elem(f["obsm"]) if "obsm" in f else None,
    )

In [None]:
# Create AnnData with cell2location obsm (cell type abundance) as X matrix
cell_types = ad.AnnData(
    c2l_adata.obsm[q].to_numpy(),
    dtype="float32",
    obs=c2l_adata.obs,
    var=c2l_adata.obsm[q].columns.str.replace(q.split("_")[0] + "cell_abundance_w_sf_", "").to_frame(),
)
cell_types.var.drop(columns=0, inplace=True)
cell_types.var.index.rename("CellType", inplace=True)

if sample:
    cell_types = cell_types[cell_types.obs[sample[0]] == sample[1]]

In [None]:
spatial_adata = ad.read(spatial_h5ad_file)

if sample:
    spatial_adata = spatial_adata[spatial_adata.obs[sample[0]] == sample[1]]

In [None]:
# Optionally reindex var
spatial_adata.var.reset_index(inplace=True)
spatial_adata.var.set_index(var_reindex, inplace=True)
spatial_adata.var.index = spatial_adata.var.index.astype(str)
spatial_adata.var_names_make_unique()

In [None]:
cell_types.shape, spatial_adata.shape

In [None]:
# Create AnnData concatenating gene expression matrix and cell type abundance matrix
if isinstance(spatial_adata.X, spmatrix):
    adata_combined = ad.AnnData(
        hstack((
                spatial_adata.X,
                csr_matrix(cell_types.X)
                if isinstance(spatial_adata.X, csr_matrix)
                else csc_matrix(cell_types.X),
        )),
        obs=spatial_adata.obs,
        var=pd.concat([
                spatial_adata.var.assign(is_gene=True),
                cell_types.var.assign(is_celltype=True),
            ],
            axis=1,
        ),
    )
else:
    adata_combined = ad.AnnData(
        np.hstack((spatial_adata.X, cell_types.X)),
        obs=spatial_adata.obs,
        var=pd.concat([
                spatial_adata.var.assign(is_gene=True),
                cell_types.var.assign(is_celltype=True),
            ],
            axis=1,
        ),
    )

adata_combined.var["is_gene"] = adata_combined.var["is_gene"].fillna(False)
adata_combined.var["is_celltype"] = adata_combined.var["is_celltype"].fillna(False)

In [None]:
# Ensure bool columns remain bool, filling nans with False or fill as appropriate
for col in [col for col in spatial_adata.var_keys() if spatial_adata.var[col].dtype == bool]:
    adata_combined.var[col] = adata_combined.var[col].fillna(False)
for col in [col for col in cell_types.var_keys() if cell_types.var[col].dtype == bool]:
    adata_combined.var[col] = adata_combined.var[col].fillna(False)

In [None]:
# Write h5ad
out_file = os.path.join(
        out_dir,
        os.path.splitext(os.path.basename(spatial_h5ad_file))[0] + out_h5ad_suffix + ".h5ad"
    )
    
adata_combined.write_h5ad(out_file)

In [None]:
# Write list of genes to json
out_genes_file = os.path.join(
        out_dir,
        os.path.splitext(os.path.basename(spatial_h5ad_file))[0] + "_genes" + ".json"
    )

with open(out_genes_file, "w") as f:
    json.dump(spatial_adata.var_names.tolist(), f)

In [None]:
# Write list of cell types to json
out_celltypes_file = os.path.join(
        out_dir,
        os.path.splitext(os.path.basename(spatial_h5ad_file))[0] + "_celltypes" + ".json"
    )

with open(out_celltypes_file, "w") as f:
    json.dump(cell_types.var_names.tolist(), f)