In [1]:
import os
import scanpy as sc
import numpy as np
import pandas as pd
from tasks.scfoundation import load
import anndata
import scipy.sparse as sp

In [2]:
def prepare_data(adata):
    gene_list_df = pd.read_csv(f'{tokenizer_dir}/OS_scRNA_gene_index.19264.tsv', header=0, delimiter='\t')
    gene_list = list(gene_list_df['gene_name'])
    gene_list_df.set_index('gene_name', inplace=True)

    gexpr_feature = adata.X.A
    idx = adata.obs_names.tolist()
    col = adata.var['gene_name'].tolist()
    gexpr_feature = pd.DataFrame(gexpr_feature, index=idx, columns=col)
    gexpr_feature, _ = load.main_gene_selection(gexpr_feature, gene_list)

    return(gexpr_feature)

def preprocess_data(adata):
    scfoundation_gene_df = pd.read_csv(f'{tokenizer_dir}/scfoundation_gene_df.csv')
    scfoundation_gene_df.set_index('gene_ids', inplace=True)
    adata = adata[:, adata.var_names.isin(scfoundation_gene_df.index)]
    adata.var['gene_name'] = scfoundation_gene_df.loc[adata.var_names, 'gene_symbols'].values

    duplicated_var_names = adata.var.loc[adata.var['gene_name'].duplicated(), 'gene_name'].unique()
    removed_indices = []
    for var_name in duplicated_var_names:
        indices = np.where(adata.var['gene_name'] == var_name)[0]
        merged_column = np.max(adata.X[:, indices], axis=1)
        adata.X[:, indices[0]] = merged_column
        removed_indices += list(indices[1:])
    adata = adata[:, ~np.isin(np.arange(adata.shape[1]), removed_indices)]

    celltype_proportion = adata.obsm['q05_cell_abundance_w_sf']
    celltype_proportion.rename(columns=lambda x: x[23:], inplace=True)
    celltype_proportion = celltype_proportion.div(celltype_proportion.sum(axis=1), axis=0)
    celltype_proportion[celltype_proportion < 0.05] = 0
    celltype_proportion = celltype_proportion.div(celltype_proportion.sum(axis=1), axis=0)
    
    gexpr_feature = prepare_data(adata)
    barcodes = gexpr_feature.index.values
    varnames = gexpr_feature.columns.values
    matrix = sp.csr_matrix(gexpr_feature.values)
    adata_pp = anndata.AnnData(X=matrix, obs=pd.DataFrame({}, index=barcodes), var=pd.DataFrame({}, index=varnames))
    adata_pp.obsm['celltype_proportion'] = celltype_proportion
    adata_pp.obs.index = (batch_label+'_'+adata_pp.obs.index).values

    for celltype in adata.layers.keys():
        adata.X = adata.layers[celltype]
        sc.pp.normalize_total(adata, target_sum=1e4)
        sc.pp.log1p(adata)
        adata.uns.pop('log1p')
        gexpr_feature = prepare_data(adata)
        adata_pp.layers[celltype] = sp.csr_matrix(gexpr_feature.values)
    
    return(adata_pp)

In [3]:
def get_niche_samples(adata):
    celltype_proportion = adata.obsm['celltype_proportion']
    cell_type_list = celltype_proportion.columns.values.tolist()

    samples_barcodes = []
    samples_celltype = []
    samples_expression = []
    samples_ligands_expression = []
    samples_ctprop = []

    adata_ligands = adata[:, adata.var_names.isin(ligand_symbol)]
    for i in range(adata.shape[0]):
        ct_prop = celltype_proportion.iloc[i][celltype_proportion.iloc[i]>0]
        ct_prop_index = [cell_type_list.index(ct) for ct in ct_prop.index]

        ligands_expression = np.zeros([max_cell_type_num, len(ligand_symbol)])
        niche_ctprop = np.zeros([1, max_cell_type_num])
        for ct_index in ct_prop_index:
            ligands_expression[ct_index] = adata_ligands.layers[cell_type_list[ct_index]][i].A[0]
            niche_ctprop[0, ct_index] = ct_prop[cell_type_list[ct_index]]
        ligands_expression = np.concatenate(ligands_expression).reshape(1,-1)

        cell_index = 0
        for ct_index in ct_prop_index:
            cell_index += 1
            samples_barcodes.append(f'{adata.obs_names.values[i]}_{cell_index}')
            samples_celltype.append(ct_index)
            samples_expression.append(adata.layers[cell_type_list[ct_index]][i].A)
            samples_ligands_expression.append(ligands_expression)
            samples_ctprop.append(niche_ctprop)

    matrix = sp.csr_matrix(np.concatenate(samples_expression))
    adata_samples = anndata.AnnData(X=matrix, obs=pd.DataFrame({'batch':[batch_label]*len(samples_barcodes), 'cell_type':samples_celltype}, index=samples_barcodes), var=pd.DataFrame({}, index=adata.var_names))
    adata_samples.obsm['niche_ligands_expression'] = sp.csr_matrix(np.concatenate(samples_ligands_expression))
    adata_samples.obsm['niche_composition'] = sp.csr_matrix(np.concatenate(samples_ctprop))
    adata_samples.uns['ligand_gene_symbols'] = adata_ligands.var_names.values
    adata_samples.uns['max_cell_type_num'] = max_cell_type_num

    return(adata_samples)

In [4]:
tokenizer_dir = '../stformer/tokenizer/'

ligand_database = pd.read_csv(tokenizer_dir+'ligand_database.csv', header=0, index_col=0)
ligand_symbol = ligand_database[ligand_database.sum(1)>1].index.values

max_cell_type_num = 25

In [None]:
dataset = 'human_myocardial_infarction'
slide = 'ACH0011'
directory = f'{dataset}/{slide}'

batch_label = f'{dataset}-{slide}'
adata = sc.read_h5ad(f'{directory}/deconv.h5ad')
adata_pp = preprocess_data(adata)
adata_samples = get_niche_samples(adata_pp)
adata_samples.write(f'{directory}/samples.h5ad')
# adata_samples = sc.read_h5ad(f'{directory}/samples.h5ad')

celltype_proportion = adata.obsm['q05_cell_abundance_w_sf']
celltype_proportion.rename(columns=lambda x: x[23:], inplace=True)
cell_types_list = celltype_proportion.columns.values.tolist()

adata_samples.obs['cell_type'] = adata_samples.obs['cell_type']+1
adata_samples.uns['cell_types_list'] = cell_types_list

niche_celltypes = np.array([list(range(1, len(cell_types_list)+1))]*adata_samples.shape[0])
complement = np.zeros([adata_samples.shape[0], adata_samples.uns['max_cell_type_num']-len(cell_types_list)])
niche_celltypes = np.concatenate([niche_celltypes, complement], axis=1)
adata_samples.obsm['niche_celltypes'] = sp.csr_matrix(niche_celltypes)

adata_samples.write(f'{slide}_niche.h5ad')