In [None]:
from tasks.scfoundation import load
import scipy.sparse as sp
import anndata as ad
import hdf5plugin
import pandas as pd
import numpy as np
import scanpy as sc 
from scipy.spatial.distance import cdist

In [None]:
def prepare_data(adata, need_normalization=True):
    gene_list_df = pd.read_csv(f'{tokenizer_dir}/OS_scRNA_gene_index.19264.tsv', header=0, delimiter='\t')
    gene_list = list(gene_list_df['gene_name'])
    adata = adata[:,adata.var['gene_name'].isin(gene_list)]
    
    gexpr_feature = adata.X.A
    idx = adata.obs_names.tolist()
    col = adata.var['gene_name'].tolist()
    gexpr_feature = pd.DataFrame(gexpr_feature, index=idx, columns=col)
    gexpr_feature, _ = load.main_gene_selection(gexpr_feature, gene_list)

    varnames = gexpr_feature.columns.values
    matrix = sp.csr_matrix(gexpr_feature.values)
    adata_pp = ad.AnnData(X=matrix, obs=adata.obs, var=pd.DataFrame({}, index=varnames), obsm=adata.obsm)

    if need_normalization:
        sc.pp.normalize_total(adata_pp, target_sum=1e4)
        sc.pp.log1p(adata_pp)

    return(adata_pp)

def get_niche_samples_sc(adata, ligand_index):
    adata.uns['cell_types_list'] = adata.obs['cell_type'].cat.categories.tolist()
    cell_types_index = adata.obs['cell_type'].cat.codes.values+1
    adata.obs['cell_type'] = cell_types_index

    all_counts = adata.X.A
    niche_composition = np.zeros([adata.shape[0], 50])
    niche_celltypes = np.zeros([adata.shape[0], 50])
    niche_ligands_expression = []

    for i in range(adata.shape[0]):
        niche_composition[i, 0] = 1

        niche_celltypes[i, 0] = cell_types_index[i]

        ligands_expression = np.zeros([50, len(ligand_index)])
        ligands_expression[0] = all_counts[i, ligand_index]
        ligands_expression = np.concatenate(ligands_expression).reshape(1,-1)
        niche_ligands_expression.append(ligands_expression)

    niche_composition = sp.csr_matrix(niche_composition)
    niche_ligands_expression = sp.csr_matrix(np.concatenate(niche_ligands_expression))
    niche_celltypes = sp.csr_matrix(niche_celltypes)

    adata.X = adata.X.astype(np.float32)
    adata.obsm['niche_composition'] = niche_composition.astype(np.float32)
    adata.obsm['niche_ligands_expression'] = niche_ligands_expression.astype(np.float32)
    adata.obsm['niche_celltypes'] = niche_celltypes.astype(np.float32)
    
    return(adata)

In [None]:
def get_niche_radius(spatial_coords, max_niche_cell_num):
    distances = cdist(spatial_coords, spatial_coords)
    radius = 0
    counts = np.sum(distances < radius, axis=1)
    while max(counts) <= max_niche_cell_num:
        radius += 0.1
        counts = np.sum(distances < radius, axis=1)
    return radius-0.1

def get_niche_samples(adata, spatial_coords, niche_radius, ligand_index, max_niche_cell_num=None):
    all_counts = adata.X.A
    cell_types_index = adata.obs['cell_type'].values

    niche_ligands_expression = []
    niche_composition = []
    niche_celltypes = np.zeros([adata.shape[0], 50])

    distances = cdist(spatial_coords, spatial_coords)
    for i in range(len(all_counts)):
        distances_i = distances[i]

        if niche_radius is None:
            niche_cells = np.argsort(distances_i)[:max_niche_cell_num]
        else:
            niche_cells = np.where(distances_i <= niche_radius)[0]
        
        niche_counts = np.concatenate([all_counts[niche_cells][:,ligand_index], np.zeros([50-len(niche_cells), 986])])
        niche_counts = np.concatenate(niche_counts)
        niche_ligands_expression.append(np.expand_dims(niche_counts, axis=0))

        niche_prop = np.concatenate([np.array([1]*len(niche_cells)), np.array([0]*(50-len(niche_cells)))])
        niche_prop = niche_prop/niche_prop.sum()
        niche_composition.append(np.expand_dims(niche_prop, axis=0))

        niche_celltypes[i, :len(niche_cells)] = cell_types_index[niche_cells]

    niche_ligands_expression = sp.csr_matrix(np.concatenate(niche_ligands_expression))
    niche_composition = sp.csr_matrix(np.concatenate(niche_composition))
    niche_celltypes = sp.csr_matrix(niche_celltypes)

    adata.X = adata.X.astype(np.float32)

    if max_niche_cell_num is None:
        adata.obsm['niche_composition'] = niche_composition.astype(np.float32)
        adata.obsm['niche_ligands_expression'] = niche_ligands_expression.astype(np.float32)
        adata.obsm['niche_celltypes'] = niche_celltypes.astype(np.float32)
    else:
        adata.obsm[f'niche_composition_niche{max_niche_cell_num}'] = niche_composition.astype(np.float32)
        adata.obsm[f'niche_ligands_expression_niche{max_niche_cell_num}'] = niche_ligands_expression.astype(np.float32)
        adata.obsm[f'niche_celltypes_niche{max_niche_cell_num}'] = niche_celltypes.astype(np.float32)
    
    return(adata)

In [None]:
tokenizer_dir = '../stformer/tokenizer/'

gene_list_df = pd.read_csv(f'{tokenizer_dir}/OS_scRNA_gene_index.19264.tsv', header=0, delimiter='\t')
ligand_database = pd.read_csv(tokenizer_dir+'ligand_database.csv', header=0, index_col=0)
ligand_symbol = ligand_database[ligand_database.sum(1)>1].index.values
ligand_index = gene_list_df.loc[gene_list_df['gene_name'].isin(ligand_symbol), 'index'].values

In [None]:
dataset = 'pancreas_cosmx_rds'

adata = sc.read_mtx(f'{dataset}/count_matrix.mtx').T

df1 = pd.read_csv(f'{dataset}/genes.tsv', sep='\t', header=None)
df2 = pd.read_csv(f'{dataset}/barcodes.tsv', sep='\t', header=None)
df3 = pd.read_csv(f'{dataset}/meta_data.csv', index_col=0)

adata.var_names = df1[0]
adata.obs_names = df2[0]
adata.var['gene_name'] = df1[0].values
adata.obs = df3.loc[adata.obs_names]

adata.write(f'{dataset}/pancreas_cosmx.h5ad')

In [None]:
adata = adata[adata.obs['cell_type']!='QC_dropped']
adata.obs['cell_type'] = adata.obs['cell_type'].astype('category')

In [None]:
adata = prepare_data(adata, need_normalization=True)

# #--------- Prepare sample data from single-cell RNA-seq data ------------------
# adata = get_niche_samples_sc(adata, ligand_index)


#--------- Prepare sample data from single-cell spatial data ------------------
cell_types_list = adata.obs['cell_type'].cat.categories.tolist()
adata.obs['cell_type'] = adata.obs['cell_type'].cat.codes.values+1

for max_niche_cell_num in [5,10,15,20,25,30,35]:
    adata_list = []

    for fov in set(adata.obs['fov'].value_counts().index.values):
        adata0 = adata[adata.obs['fov']==fov]

        spatial_coords = adata0.obs[['x_FOV_px', 'y_FOV_px']].values # spatial_coords = adata0.obsm['spatial']
        niche_radius = get_niche_radius(spatial_coords, max_niche_cell_num) # niche_radius = None
        print(fov, niche_radius)
        adata0 = get_niche_samples(adata0, spatial_coords, niche_radius, ligand_index, max_niche_cell_num)

        adata_list.append(adata0)

    adata = sc.concat(adata_list)

adata.uns['cell_types_list'] = cell_types_list

In [None]:
adata.write(f'pancreas_cosmx_niche.h5ad')