In [None]:
import dask
dask.config.set({"dataframe.query-planning": False})

import itertools
import scanpy as sc
import gseapy
import liana
import scipy
import numpy as np
import pandas as pd
import sys
import argparse
import json
from sklearn.preprocessing import StandardScaler
from pathlib import Path

sys.path.append("../../../workflow/scripts/")
import _utils
import readwrite

segmentation = 'proseg_expected'
condition = 'NSCLC'
panel = 'lung'
donor = '1GA2'
sample = '1GA2'
k = (segmentation,condition,panel,donor,sample)
if 'proseg' in segmentation:
    k_dir = ('proseg',condition,panel,donor,sample)
else:
    k_dir = k
name = '/'.join(k)
name_dir = '/'.join(k_dir)
sample_dir = Path(f'/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/xenium/processed/segmentation/{name_dir}') / 'raw_results'
sample_counts = Path(f'/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/xenium/processed/std_seurat_analysis/{name}/lognorm/normalised_counts/data.parquet')
sample_idx = Path(f'/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/xenium/processed/std_seurat_analysis/{name}/lognorm/normalised_counts/cells.parquet')
cell_type_labels = Path(f'/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/xenium/processed/_cell_type_annotation_1/{name}/lognorm/reference_based/matched_reference_combo/rctd_class_aware/Level2/single_cell/labels.parquet')

out_file_df_permutations = sample_dir / 'permutation_summary.parquet'
out_file_df_importances = sample_dir / 'importances.parquet'
out_file_df_diffexpr = sample_dir / 'diffexpr.parquet'
out_file_df_markers_rank_significance_logreg = sample_dir / 'markers_rank_significance_logreg.json'
out_file_df_markers_rank_significance_diffexpr = sample_dir / 'markers_rank_significance_diffexpr.json'
# out_dir_liana_lrdata = sample_dir / 'liana_lrdata_folder'

n_neighbors = 10
n_permutations = 30
n_repeats = 5
top_n = 20
top_n_lr = 10
cti = "macrophage"
ctj = "malignant cell"
scoring = 'f1'
# markers = '/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/markers/cellmarker_cell_types_markers.json'
markers = 'diffexpr'

####
#### READ DATA
####
# read raw data to get spatial coordinates
adata = readwrite.read_xenium_sample(
    sample_dir,
    cells_as_circles=False,
    cells_boundaries=False,
    cells_boundaries_layers=False,
    nucleus_boundaries=False,
    cells_labels=False,
    nucleus_labels=False,
    transcripts=False,
    morphology_mip=False,
    morphology_focus=False,
    aligned_images=False,
    anndata=True,
)
if 'proseg_expected' in sample_counts.as_posix():
    adata.obs_names = 'proseg-'+adata.obs_names.astype(str)


# read normalised data, filter cells
X_normalised = pd.read_parquet(sample_counts)
X_normalised.index = pd.read_parquet(sample_idx).iloc[:, 0]
X_normalised.columns = X_normalised.columns.str.replace('.','-')
adata = adata[X_normalised.index,X_normalised.columns]
adata.layers['X_normalised'] = X_normalised

# log-normalize before DE
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

# read labels
label_key = "label_key"
adata.obs[label_key] = pd.read_parquet(cell_type_labels).set_index("cell_id").iloc[:, 0]
adata.obs[label_key] = adata.obs[label_key].replace(r' of .+', '', regex=True)
adata = adata[adata.obs[label_key].notna()]


# read markers if needed
if markers != "diffexpr":
    df_markers = pd.read_json(markers)["canonical"].explode().reset_index()
    df_markers.columns = ["cell_type", "gene"]

# get kNN graph
obsm = 'spatial'
knnlabels, knndis, knnidx, knn_graph = _utils.get_knn_labels(
    adata,n_neighbors=n_neighbors,
    label_key=label_key,obsm=obsm,
    return_sparse_neighbors=True)
adata.obsp[f'{obsm}_connectivities'] = knn_graph

# iterate over targets permutations (cell type i with cell type j presence in kNN)
df_permutations_logreg = {}
df_importances_logreg = {}
df_diffexpr = {}
df_markers_rank_significance_logreg = {}
df_markers_rank_significance_diffexpr = {}
df_markers_rank_significance_lrdata = {}
u_cell_types = adata.obs[label_key].unique()


for ctj in u_cell_types:
    if  (adata.obs[label_key]==ctj).sum() < 30:
        print(f"Not enough cells from class {ctj}")
        continue

    # get markers
    if markers == "diffexpr":
        sc.tl.rank_genes_groups(adata, groupby=label_key, groups=[ctj], reference='rest', method="wilcoxon")
        ctj_marker_genes = sc.get.rank_genes_groups_df(adata, group=ctj)['names'][: top_n].tolist()
    else:
        ctj_marker_genes = df_markers[df_markers["cell_type"] == ctj]["gene"].tolist()
        ctj_marker_genes = [g for g in ctj_marker_genes if g in adata.var_names]

        assert len(ctj_marker_genes), f"no markers found for {ctj}"

    for cti in u_cell_types:
        if cti == ctj:
            continue
        print(cti, ctj)

        adata.obs[f"has_{ctj}_neighbor"] = knnlabels[ctj]>0

        # Filter for cti
        adata_cti = adata[adata.obs[label_key] == cti]

        if (   (adata_cti.obs[f"has_{ctj}_neighbor"]).sum() < 30
            or (~adata_cti.obs[f"has_{ctj}_neighbor"]).sum() < 30
            ):
            print(f"Not enough cells from each class to test {cti} with {ctj} neighbors")
            continue

        ####
        #### LOGISTIC REGRESSION TEST: predict ctj in kNN based on cti expression
        ####

        # train logreg model
        # df_permutations_logreg[cti,ctj], df_importances_logreg[cti,ctj] = _utils.logreg(
        #     X = adata_cti.layers["X_normalised"],
        #     y = adata_cti.obs[f"has_{ctj}_neighbor"],
        #     feature_names=adata.var_names,
        #     scoring=scoring,
        #     test_size=0.2,
        #     n_permutations=n_permutations,
        #     n_repeats=n_repeats,
        #     max_iter=1000,
        #     random_state=0,
        #     importance_mode='coef'
        # )

        # # get significance from gsea and hypergeometric test
        # df_markers_rank_significance_logreg[cti,ctj] = _utils.get_marker_rank_significance(
        #     rnk=df_importances_logreg[cti,ctj]['importances'],
        #     gene_set=ctj_marker_genes,
        #     top_n = top_n)


        ###
        ### DIFF EXPR TEST: check DE genes between cti with ctj neighbor or not
        ###
        adata_cti.obs[f'has_{ctj}_neighbor_str'] = adata_cti.obs[f'has_{ctj}_neighbor'].astype(str)
        sc.tl.rank_genes_groups(adata_cti, groupby=f"has_{ctj}_neighbor_str", groups=['True'], reference='False', method="wilcoxon")
        df_diffexpr[cti,ctj] = sc.get.rank_genes_groups_df(adata_cti, group='True').sort_values('pvals_adj')

        # get significance from gsea and hypergeometric test
        df_markers_rank_significance_diffexpr[cti,ctj] = _utils.get_marker_rank_significance(
            rnk=df_diffexpr[cti,ctj].set_index('names')['logfoldchanges'],
            gene_set=ctj_marker_genes,
            top_n = top_n)


        ###
        ### CELL-CELL COMMUNICATION TEST: check communication between cti with ctj neighbor
        ###
        # adata_cti_ctj = adata[adata.obs[label_key].isin([cti, ctj])]
        # lrdata = liana.mt.bivariate(
        #     adata_cti_ctj,
        #     connectivity_key = f'{obsm}_connectivities',
        #     resource_name='consensus', # NOTE: uses HUMAN gene symbols!
        #     local_name='cosine', # Name of the function
        #     global_name=None,
        #     n_perms=30, # Number of permutations to calculate a p-value
        #     mask_negatives=True, # Whether to mask LowLow/NegativeNegative interactions
        #     add_categories=True, # Whether to add local categories to the results
        #     nz_prop=0.0, # Minimum expr. proportion for ligands/receptors and their subunits
        #     use_raw=False,
        #     verbose=True
        #     )


        # lrdata_cti_has_ctj_neighbor = lrdata[(lrdata.obs[label_key] == cti) & lrdata.obs[f"has_{ctj}_neighbor"]]
        # lrdata_cti_has_ctj_neighbor.var['mean_cti_has_ctj_neighbor'] = lrdata_cti_has_ctj_neighbor.X.mean(0).A1
        # lrdata_cti_has_ctj_neighbor.var['std_cti'] = StandardScaler(with_mean=False).fit(lrdata_cti_has_ctj_neighbor.X).scale_ # std for sparse matrix

        # # get significance from gsea and hypergeometric test
        # ctj_marker_lr = [lr for lr in lrdata_cti_has_ctj_neighbor.var_names if any([g in lr for g in ctj_marker_genes]) ]
        # df_markers_rank_significance_lrdata[cti,ctj] = _utils.get_marker_rank_significance(
        #     rnk=lrdata_cti_has_ctj_neighbor.var['mean_cti_has_ctj_neighbor'],
        #     gene_set=ctj_marker_lr,
        #     top_n = top_n_lr)

###
### SAVE OUTPUTS
###
df_permutations_logreg = pd.concat(df_permutations_logreg)
df_importances_logreg = pd.concat(df_importances_logreg)
df_diffexpr = pd.concat(df_diffexpr)
df_markers_rank_significance_logreg = pd.concat(df_markers_rank_significance_logreg)
df_markers_rank_significance_diffexpr = pd.concat(df_markers_rank_significance_diffexpr)

#logreg
# df_permutations.to_parquet(out_file_df_permutations)
# df_importances.to_parquet(out_file_df_importances)
# df_markers_rank_significance_logreg.to_parquet(out_file_df_markers_rank_significance_logreg)

# #diffexpr
# df_diffexpr.to_parquet(out_file_df_diffexpr)
# df_markers_rank_significance_diffexpr.to_parquet(out_file_df_markers_rank_significance_diffexpr)

#liana
# readwrite.write_anndata_folder(lrdata, out_dir_liana_lrdata)