In [None]:
import numpy as np
import pandas as pd
import emdatabase as emdb
import networkx as nx
import numpy as np
import pickle
from scipy.io import mmread
from scipy.spatial import cKDTree
from scipy.stats import pearsonr, spearmanr
from scipy.optimize import linear_sum_assignment

import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (6, 4)

from sklearn.metrics import confusion_matrix, f1_score, adjusted_rand_score, ConfusionMatrixDisplay
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import MinMaxScaler
from sklearn.neighbors import NearestNeighbors

import maxfuse.match_utils as match_utils
from maxfuse import utils
from maxfuse.model import Fusor

import anndata as ad
import scanpy as sc
import maxfuse as mf
import tangram as tg
import pickle

import seaborn as sns
import json
import plotly.express as px
import os
import requests
import gseapy as gp
import blitzgsea as blitz

## Data Preprocessing

The file format for MaxFuse to read in is adata.

In [None]:
target_acq = 'FinalLiv-27_c002_v001_r001_reg012' # main sample tissue

seg_df = emdb.get_cell_segmentation_output_for_acquisition_id(target_acq, 1)
bio_df = emdb.get_cell_biomarker_expression_for_acquisition_id(target_acq, 1)
class_df = pd.read_csv(f'scratch/SP_SC_data/s4769_cell_types/{target_acq}.cell_types.csv')

seg_bio_df = pd.merge(seg_df, bio_df, on='CELL_ID')
protein = pd.merge(seg_bio_df, class_df, on='CELL_ID')

biomarker_columns = ['BCL6', 'CCR7', 'CD107a', 'CD11c', 'CD138',
                     'CD14', 'CD141', 'CD15', 'CD163', 'CD1c',
                     'CD20', 'CD204', 'CD206', 'CD209', 'CD21',
                     'CD3', 'CD31', 'CD34', 'CD4', 'CD45',
                     'CD45RO', 'CD56', 'CD68', 'CD79a', 'CD8',
                     'CXCL13', 'CXCR5', 'DAPI', 'DCLAMP', 'FAP',
                     'FOXP3', 'GRB', 'HLADR', 'ICOS', 'IDO1',
                     'IFNg', 'INOS', 'KERATIN 8_18', 'Ki67',
                     'LAG3', 'PD1', 'PDL1', 'PNAD', 'Podoplanin',
                     'TCF1', 'TOX', 'Tbet', 'VISTA', 'XCR1',
                     'aSMA', 'eomes']

scaler = MinMaxScaler()
protein[biomarker_columns] = scaler.fit_transform(protein[biomarker_columns])

print(protein[biomarker_columns].head())

In [None]:
# check how CODEX data looks like
sns.scatterplot(data=protein, x="X", y="Y", hue = "CELL_TYPE", s = 3)

In [None]:
columns = bio_df.columns
protein_features = [col for col in columns if col != 'CELL_ID']

# convert to AnnData
protein_adata = ad.AnnData(
    protein[protein_features].to_numpy(), dtype=np.float32
)
protein_adata.var_names = protein[protein_features].columns

In [None]:
# read in RNA data
rna = mmread("scratch/SP_SC_data/liver_rna_counts.txt") # rna count as sparse matrix
rna_names = pd.read_csv('scratch/SP_SC_data/liver_rna_names.csv')['names'].to_numpy()

In [None]:
# convert to AnnData
rna_adata = ad.AnnData(
    rna.tocsr(), dtype=np.float32
)
rna_adata.var_names = rna_names

In [None]:
# process all RNA features
sc.pp.normalize_total(rna_adata)
sc.pp.log1p(rna_adata)
sc.pp.highly_variable_genes(rna_adata, n_top_genes=5000)
# only retain highly variable genes
rna_adata = rna_adata[:, rna_adata.var.highly_variable].copy()
sc.pp.scale(rna_adata)

In [None]:
# scRNA clustering for further refining
# PCA
sc.tl.pca(rna_adata, svd_solver='arpack')

# compute neighborhood graph
sc.pp.neighbors(rna_adata, n_neighbors=10, n_pcs=40)

# run Leiden clustering
sc.tl.leiden(rna_adata, resolution=0.5)

# visualize the clustering
sc.tl.umap(rna_adata)
sc.pl.umap(rna_adata, color=['leiden'])

In [None]:
# read in celltyle labels

metadata_rna = pd.read_csv('scratch/SP_SC_data/liver_rna_meta.csv')
labels_rna = metadata_rna['Type'].to_numpy()
labels_codex = protein['CELL_TYPE'].to_numpy()

protein_adata.obs['celltype'] = labels_codex
rna_adata.obs['celltype'] = labels_rna

In [None]:
# define marker genes for hierarchical cell type refinement
marker_genes = {
    'CD4 T cells': ['CD4', 'FOXP3', 'CD40LG', 'CCR7', 'IL2RA'],
    'CD8 T cells': ['CD8A', 'CD8B', 'GZMB', 'PRF1', 'EOMES',
                    'KLRG1'],
    'Macrophages M2-like': ['CD163', 'MRC1', 'CD206'],
    'Macrophages': ['CD68', 'ITGAM', 'CD14']
}

def annotate_clusters(sub_adata, marker_genes):
    # annotate clusters based on marker genes
    cluster_annotations = {}
    for cluster in sub_adata.obs['leiden'].unique():
        cluster_cells = sub_adata[sub_adata.obs['leiden'] == cluster]
        scores = {}
        for cell_type, markers in marker_genes.items():
            scores[cell_type] = np.mean([np.mean(cluster_cells[:, marker].X) for marker in markers if marker in cluster_cells.var_names])
        # annotate cluster with the cell type with the highest score
        best_match = max(scores, key=scores.get)
        cluster_annotations[cluster] = best_match

    # map cluster annotations to the data
    refined_cell_types = [cluster_annotations[cl] for cl in sub_adata.obs['leiden']]
    return refined_cell_types

# ensure new categories are included in the 'celltype' column
new_categories = ['CD4 T cells', 'CD8 T cells', 'Macrophages M2-like', 'Macrophages']
rna_adata.obs['celltype'] = rna_adata.obs['celltype'].astype('category')
rna_adata.obs['celltype'] = rna_adata.obs['celltype'].cat.add_categories(new_categories)

# separate and annotate T cells and TAMs
t_cells = rna_adata[rna_adata.obs['celltype'] == 'T cells']
t_cell_annotations = annotate_clusters(t_cells, {'CD4 T cells': marker_genes['CD4 T cells'], 'CD8 T cells': marker_genes['CD8 T cells']})
rna_adata.obs.loc[t_cells.obs.index, 'celltype'] = t_cell_annotations

tams = rna_adata[rna_adata.obs['celltype'] == 'TAMs']
tams_annotations = annotate_clusters(tams, {'Macrophages M2-like': marker_genes['Macrophages M2-like'], 'Macrophages': marker_genes['Macrophages']})
rna_adata.obs.loc[tams.obs.index, 'celltype'] = tams_annotations

# remove original 'T cells' and 'TAMs' categories and rename 'CAFs' to 'Fibroblasts'
rna_adata.obs['celltype'] = rna_adata.obs['celltype'].cat.remove_categories(['T cells', 'TAMs'])
rna_adata.obs['celltype'] = rna_adata.obs['celltype'].replace({'CAFs': 'Fibroblasts'})

sc.pl.umap(rna_adata, color=['celltype'])

In [None]:
labels_rna = rna_adata.obs['celltype'].to_numpy()

In [None]:
def build_sp_graph(protein_adata):
    data_matrix = protein_adata.X
    print("Shape of data_matrix:", data_matrix.shape)

    if len(data_matrix.shape) != 2:
        raise ValueError("The extracted data matrix should be a 2D array.")

    G = nx.Graph()
    for i in range(data_matrix.shape[0]):
        G.add_node(i)

    # calculate distances and add edges with weights
    nn = NearestNeighbors(n_neighbors=10, algorithm='ball_tree').fit(data_matrix)
    distances, indices = nn.kneighbors(data_matrix)

    for i, neighbors in enumerate(indices):
        for j, neighbor in enumerate(neighbors):
            if i != neighbor:
                G.add_edge(i, neighbor, weight=distances[i][j])

    return G

def precompute_nearest_neighbors(G):
    # precompute 10 nearest neighbors for each node
    nearest_neighbors = {}
    for node in G.nodes:
        neighbors = sorted(G[node], key=lambda x: G[node][x]['weight'])[:10]
        nearest_neighbors[node] = neighbors
    return nearest_neighbors

def save_nearest_neighbors(nearest_neighbors, filename):
    with open(filename, 'wb') as file:
        pickle.dump(nearest_neighbors, file)

G = build_sp_graph(protein_adata)
nearest_neighbors = precompute_nearest_neighbors(G)
save_nearest_neighbors(nearest_neighbors, 'scratch/nearest_neighbors.pkl')

print("Nearest neighbors have been precomputed and saved.")

In [None]:
# proteins to map
protein_names = [
    'BCL6', 'CXCL13', 'CXCR5', 'DAPI', 'DCLAMP', 'FAP', 'GRB',
    'ICOS', 'IDO1', 'IFNg', 'INOS', 'KERATIN 8_18', 'LAG3',
    'PNAD', 'TCF1', 'TOX', 'VISTA', 'eomes'
]

# function to query MyGene.info for protein-to-gene mappings
def get_protein_to_gene_mapping(protein_names):
    url = "http://mygene.info/v3/query"
    mapping = {}

    for protein in protein_names:
        params = {
            'q': protein,
            'fields': 'symbol,name,uniprot',
            'species': 'human',
        }
        response = requests.get(url, params=params)
        data = response.json()

        if 'hits' in data and len(data['hits']) > 0:
            for item in data['hits']:
                uniprot = item.get('uniprot', {})
                if 'Swiss-Prot' in uniprot:
                    mapping[protein] = item['symbol']
                    break
        else:
            mapping[protein] = None

    return mapping

# get protein-to-gene mapping
protein_to_gene_mapping = get_protein_to_gene_mapping(protein_names)

mapping_df = pd.DataFrame(list(protein_to_gene_mapping.items()), columns=['Protein name', 'Gene name'])
print(mapping_df)


In [None]:
# construct the feature correspondence
correspondence = pd.read_csv('data/complete_protein_gene_conversion.csv')
correspondence.head()

In [None]:
rna_protein_correspondence = []

for i in range(correspondence.shape[0]):
    curr_protein_name, curr_rna_names = correspondence.iloc[i]
    if curr_protein_name not in protein_adata.var_names:
        continue

    if isinstance(curr_rna_names, float): # check if curr_rna_names is a float (e.g. NaN)
        continue

    if curr_rna_names.find('Ignore') != -1: # some correspondence ignored (e.g. protein isoform to one gene)
        continue
    curr_rna_names = curr_rna_names.split('/') # e.g. one protein to multiple genes
    for r in curr_rna_names:
        if r in rna_adata.var_names:
            rna_protein_correspondence.append([r, curr_protein_name])

rna_protein_correspondence = np.array(rna_protein_correspondence)

In [None]:
# Columns rna_shared and protein_shared are matched.
# One may encounter "Variable names are not unique" warning,
# this is fine and is because one RNA may encode multiple proteins and vice versa.
# ensure unique column names
rna_adata.var_names_make_unique()
protein_adata.var_names_make_unique()
# verify that all RNA names exist in rna_adata and all protein names exist in protein_adata
rna_valid_names = [name for name in rna_protein_correspondence[:, 0] if name in rna_adata.var_names]
protein_valid_names = [name for name in rna_protein_correspondence[:, 1] if name in protein_adata.var_names]

# ensure there is a valid correspondence
rna_protein_valid_correspondence = [(rna, protein) for rna, protein in zip(rna_protein_correspondence[:, 0], rna_protein_correspondence[:, 1]) if rna in rna_adata.var_names and protein in protein_adata.var_names]

# unzip the valid correspondences
rna_valid_names, protein_valid_names = zip(*rna_protein_valid_correspondence)

# perform the indexing
rna_shared = rna_adata[:, rna_valid_names].copy()
protein_shared = protein_adata[:, protein_valid_names].copy()

In [None]:
# Make sure no column is static
mask = (
    (rna_shared.X
     # .toarray()
     .std(axis=0) > 0.5)
    & (protein_shared.X.std(axis=0) > 0.05)
)
rna_shared = rna_shared[:, mask].copy()
protein_shared = protein_shared[:, mask].copy()
print([rna_shared.shape,protein_shared.shape])

In [None]:
# plot UMAP of rna cells based only on rna markers with protein correspondence
sc.pp.neighbors(rna_shared, n_neighbors=15)
sc.tl.umap(rna_shared)
sc.pl.umap(rna_shared, color='celltype')

In [None]:
# plot UMAPs of codex cells based only on protein markers with rna correspondence

sc.pp.neighbors(protein_shared, n_neighbors=15)
sc.tl.umap(protein_shared)
sc.pl.umap(protein_shared, color='celltype')

In [None]:
rna_shared = rna_shared.X.copy()
protein_shared = protein_shared.X.copy()

In [None]:
# plot UMAPs of rna cells based on all active rna markers
rna_adata.raw = rna_adata
sc.pp.neighbors(rna_adata, n_neighbors=15)
sc.tl.umap(rna_adata)
sc.pl.umap(rna_adata, color='celltype')

In [None]:
# plot UMAPs of protein cells based on all active protein markers

sc.pp.neighbors(protein_adata, n_neighbors=15)
sc.tl.umap(protein_adata)
sc.pl.umap(protein_adata, color='celltype')

In [None]:
# make sure no feature is static
rna_active = rna_adata.X
protein_active = protein_adata.X
rna_active = rna_active[:, rna_active.std(axis=0) > 1e-5] # these are fine since already using variable features
protein_active = protein_active[:, protein_active.std(axis=0) > 1e-5] # protein are generally variable

In [None]:
# inspect shape of the four matrices
print(rna_active.shape)
print(protein_active.shape)
print(rna_shared.shape)
print(protein_shared.shape)

In [None]:
pd.DataFrame(rna_active).to_csv('scratch/results/rna_active.csv', index=False)
pd.DataFrame(protein_active).to_csv('scratch/results/protein_active.csv', index=False)
pd.DataFrame(rna_shared).to_csv('scratch/results/rna_shared.csv', index=False)
pd.DataFrame(protein_shared).to_csv('scratch/results/protein_shared.csv', index=False)

## Fitting MaxFuse

In [None]:
rna_active = pd.read_csv('scratch/results/rna_active.csv').values
protein_active = pd.read_csv('scratch/results/protein_active.csv').values
rna_shared = pd.read_csv('scratch/results/rna_shared.csv').values
protein_shared = pd.read_csv('scratch/results/protein_shared.csv').values

### Step I: preparations

In [None]:
# call constructor for Fusor object
# which is the main object for running MaxFuse pipeline

# flipping arr1 and arr2 for downstream r-l match_cells
fusor = mf.model.Fusor(
    shared_arr1=rna_shared,
    shared_arr2=protein_shared,
    active_arr1=rna_active,
    active_arr2=protein_active,
    labels1=None,
    labels2=None
)

In [None]:
fusor.split_into_batches(
    max_outward_size=5000,
    matching_ratio=3,
    metacell_size=2,
    verbose=True
)

In [None]:
# plot top singular values of active_arr1 on a random batch
fusor.plot_singular_values(
    target='active_arr1',
    n_components=None # can also explicitly specify the number of components
)

In [None]:
# plot top singular values of active_arr2 on a random batch
fusor.plot_singular_values(
    target='active_arr2',
    n_components=None
)

In [None]:
fusor.construct_graphs(
    n_neighbors1=15,
    n_neighbors2=15,
    svd_components1=40,
    svd_components2=20,
    resolution1=2,
    resolution2=2,
    # if two resolutions differ less than resolution_tol
    # then we do not distinguish between then
    resolution_tol=0.1,
    verbose=True
)

### Step II: finding initial pivots

In [None]:
# plot top singular values of shared_arr1 on a random batch
fusor.plot_singular_values(
    target='shared_arr1',
    n_components=None,
)

In [None]:
# plot top singular values of shared_arr2 on a random batch
fusor.plot_singular_values(
    target='shared_arr2',
    n_components=None
)

In [None]:
# load var_names and cell type labels here
# retained_var_names_df = pd.read_csv('scratch/results/retained_rna_names.csv')
# retained_rna_var_names = var_names_df['Gene Names'].values

var_names_df = pd.read_csv('scratch/results/var_names.csv')
rna_var_names = var_names_df['Gene Names'].values

labels_rna = pd.read_csv('scratch/results/labels_rna.csv', index_col=0)
labels_rna = labels_rna.values

# load the nearest_neighbors from the pickle file
with open('scratch/nearest_neighbors.pkl', 'rb') as file:
    nearest_neighbors = pickle.load(file)

# load interaction data
file_path = 'scratch/rl_interactions/filtered_interaction_scores.csv'
interaction_df = pd.read_csv(file_path)

In [None]:
retained_names_set = set(rna_var_names)

gene_a_set = set(interaction_df['gene_b'])
gene_b_set = set(interaction_df['gene_a'])

gene_pairs_set = set(zip(interaction_df['gene_b'], interaction_df['gene_a']))
common_pairs = {pair for pair in gene_pairs_set if pair[0] in retained_names_set and pair[1] in retained_names_set}
count_common_pairs = len(common_pairs)

print(f"Number of (gene_a, gene_b) pairs where both genes are in rna_var_names: {count_common_pairs}")
print("Common pairs:", common_pairs)


In [None]:
def find_init_pivots(
        self,
        wt1=0.3, wt2=0.3,
        svd_components1=None, svd_components2=None,
        randomized_svd=False, svd_runs=1,
        verbose=True
):
    """
    Perform initial matching.

    Parameters
    ----------
    wt1: float, default=0.3
        The shrinkage weight to put on the raw data for arr1.
    wt2: float, default=0.3
        The shrinkage weight to put on the raw data for arr2.
    svd_components1: None or int, default=None
        If not None, perform SVD to reduce the dimension of self.shared_arr1.
    svd_components2: None or int, default=None
        If not None, perform SVD to reduce the dimension of self.shared_arr2.
    randomized_svd: bool, default=False
        Whether to use randomized SVD.
    svd_runs: int, default=1
        Perform multiple runs of SVD and the one with lowest Frobenious reconstruction error is selected.
    verbose: bool, default=True
        Whether to print the progress.

    Returns
    -------
    None
    """
    self._init_matching = []
    self.distance_matrices = {}  # dictionary to store distance matrices
    for b1, b2 in self._batch1_to_batch2:
        if verbose:
            print(
                'Now at batch {}<->{}...'.format(b1, b2),
                flush=True
            )
        if self.metacell_size > 1:
            arr1 = utils.get_centroids(
                arr=self.shared_arr1[self._batch_to_indices1[b1], :],
                labels=self._metacell_labels1[b1]
            )
        else:
            arr1 = self.shared_arr1[self._batch_to_indices1[b1], :]

        arr2 = self.shared_arr2[self._batch_to_indices2[b2], :]

        edges1, edges2, clust_labels1, clust_labels2 = None, None, None, None
        if self.method == 'centroid_shrinkage':
            clust_labels1 = self._labels1[b1]
            clust_labels2 = self._labels2[b2]
        elif self.method == 'graph_smoothing':
            edges1 = self._edges1[b1]
            edges2 = self._edges2[b2]
        else:
            raise ValueError('self.method must be one of \'centroid_shrinkage\' or \'graph_smoothing\'.')

        dist_matrix, res = get_init_matching(
                arr1=arr1,
                arr2=arr2,
                clust_labels1=clust_labels1,
                clust_labels2=clust_labels2,
                edges1=edges1,
                edges2=edges2,
                wt1=wt1,
                wt2=wt2,
                randomized_svd=randomized_svd,
                svd_runs=svd_runs,
                svd_components1=svd_components1,
                svd_components2=svd_components2,
                verbose=False
            )
        self.distance_matrices[(b1, b2)] = dist_matrix  # save the distance matrix
        self._init_matching.append(res)

    if verbose:
        print('Done!', flush=True)

def get_init_matching(
        arr1, arr2,
        clust_labels1=None, clust_labels2=None,
        edges1=None, edges2=None,
        wt1=0.3, wt2=0.3,
        randomized_svd=True,
        svd_runs=1,
        svd_components1=None, svd_components2=None,
        verbose=True
):
    """
    Assume the features of arr1 and arr2 are column-wise directly comparable,
    obtain a matching by minimizing the correlation distance between arr1 and arr2.

    Parameters
    ----------
    arr1: np.array of shape (n_samples1, n_features1)
        The first data matrix.
    arr2: np.array of shape (n_samples2, n_features2)
        The second data matrix.
    clust_labels1: None or np.array of shape (n_samples1, )
        If not None, then it is the clustering label of the first data matrix,
        and the smoothing of this matrix will be done via cluster centroid shrinkage.
    clust_labels2: None or np.array of shape (n_samples2, )
        Same as clust_labels1 but for the second data matrix.
    edges1: None or list of length two or three
        If not None, then each edge in the graph is (edges[0][i], edges[1][i]) with weight edges[2][i] (if exists)
        and the smoothing of this matrix will be done via graph smoothing.
    edges2: None or scipy.sparse.csr_matrix of shape (n_samples2, n_samples2)
        Same as edges1 but for the second data matrix.
    wt1: float, default=0.3
        The smoothing of the first data matrix will be wt1 * arr1 + (1-wt1) * shrinkage_targets,
        where the shrinkage_targets will be either the cluster centroids or the average of graph neighbors.
    wt2: float, default=0.3
        Same as wt1 but for the second data matrix.
    randomized_svd: bool, default=False
        Whether to use randomized svd.
    svd_runs: int, default=1
        Randomized SVD will result in different runs,
        so if randomized_svd=True, perform svd_runs many randomized SVDs, and pick the one with the
        smallest Frobenious reconstruction error.
        If randomized_svd=False, svd_runs is forced to be 1.
    svd_components1: None or int
        If None, then do not do SVD,
        else, number of components to keep when doing SVD de-noising for the first data matrix.
    svd_components2: None or int
        Same as svd_components1 but for the second data matrix.
    verbose: bool, default=True
        Whether to print the progress.

    Returns
    -------
    matching: list of length 3
        rows, cols, vals = matching,
        Each matched pair is rows[i], cols[i], their distance is vals[i].
    """
    assert arr1.shape[1] == arr2.shape[1]
    # labels and edges cannot be specified simultaneously
    assert (clust_labels1 is None) or (edges1 is None)
    assert (clust_labels2 is None) or (edges2 is None)

    arr1, arr2 = utils.drop_zero_variability_columns(arr_lst=[arr1, arr2])

    # smoothing and denoising
    if verbose:
        print("Denoising the data...", flush=True)

    if clust_labels1 is not None:
        arr1 = utils.shrink_towards_centroids(arr=arr1, labels=clust_labels1, wt=wt1)
    elif edges1 is not None:
        arr1 = utils.graph_smoothing(arr=arr1, edges=edges1, wt=wt1)
    arr1 = utils.svd_denoise(
        arr=arr1, n_components=svd_components1, randomized=randomized_svd,
        n_runs=svd_runs
    )

    if clust_labels2 is not None:
        arr2 = utils.shrink_towards_centroids(arr=arr2, labels=clust_labels2, wt=wt2)
    elif edges2 is not None:
        arr2 = utils.graph_smoothing(arr=arr2, edges=edges2, wt=wt2)
    arr2 = utils.svd_denoise(
        arr=arr2, n_components=svd_components2, randomized=randomized_svd,
        n_runs=svd_runs
    )

    dist, res = match_init_cells(arr1=arr1, arr2=arr2, verbose=verbose)
    if verbose:
        print('Initial matching completed!', flush=True)

    return dist, res

def match_init_cells(arr1, arr2, base_dist=None, wt_on_base_dist=0, verbose=True):
    """
    Get matching between arr1 and arr2 using linear assignment, the distance is 1 - Pearson correlation.

    Parameters
    ----------
    arr1: np.array of shape (n_samples1, n_features)
        The first data matrix
    arr2: np.array of shape (n_samples2, n_features)
        The second data matrix
    base_dist: None or np.ndarray of shape (n_samples1, n_samples2)
        Baseline distance matrix
    wt_on_base_dist: float between 0 and 1
        The final distance matrix to use is (1-wt_on_base_dist) * dist[arr1, arr2] + wt_on_base_dist * base_dist
    verbose: bool, default=True
        Whether to print the progress

    Returns
    -------
    rows, cols, vals: list
        Each matched pair of rows[i], cols[i], their distance is vals[i]
    """
    if verbose:
        print('Start the matching process...', flush=True)
        print('Computing the distance matrix...', flush=True)
    dist = utils.cdist_correlation(arr1, arr2)
    if base_dist is not None:
        if verbose:
            print(
                f'Interpolating {1-wt_on_base_dist} * dist[arr1, arr2] + {wt_on_base_dist} * base_dist...',
                flush=True
            )
        dist = (1-wt_on_base_dist) * dist + wt_on_base_dist * base_dist
    if verbose:
        print('Solving linear assignment...', flush=True)
    rows, cols = linear_sum_assignment(dist)
    if verbose:
        print('Linear assignment completed!', flush=True)

    matching = rows, cols, np.array([dist[i, j] for i, j in zip(rows, cols)])

    return dist, matching

Fusor.find_initial_pivots = find_init_pivots

In [None]:
# based off normal dist values
fusor.find_initial_pivots(
    wt1=0.3, wt2=0.3,
    svd_components1=15, svd_components2=20

### Step III: finding refined pivots

In [None]:
# plot top canonical correlations in a random batch
fusor.plot_canonical_correlations(
    svd_components1=50,
    svd_components2=None,
    cca_components=45
)

In [None]:
# precompute scores for all interaction types in a vectorized manner
def precompute_interaction_scores(interaction_df):
    # interaction scores start from the 14th column
    interaction_columns = interaction_df.columns[13:]  # adjust index as needed

    # dictionary to store total scores for each interaction type
    interaction_scores = {}

    for column in interaction_columns:
        # sum all positive scores directly
        total_score = interaction_df[column][interaction_df[column] > 0].sum()
        interaction_scores[column] = total_score
    scores_df = pd.DataFrame(list(interaction_scores.items()), columns=['Interaction_Type', 'Total_Score'])

    return scores_df

interaction_scores_df = precompute_interaction_scores(interaction_df)

In [None]:
def compute_interaction_score(index1, index2, interaction_scores_df, labels_rna):
    sc_cell_type = labels_rna[index1][0]
    matched_sc_cell_type = labels_rna[index2][0]

    cell_type_pair = f"{sc_cell_type}|{matched_sc_cell_type}"
    reversed_cell_type_pair = f"{matched_sc_cell_type}|{sc_cell_type}"

    # look up the scores from the precomputed DataFrame
    score_normal = interaction_scores_df.loc[interaction_scores_df['Interaction_Type'] == cell_type_pair, 'Total_Score'].iloc[0]
    score_reversed = interaction_scores_df.loc[interaction_scores_df['Interaction_Type'] == reversed_cell_type_pair, 'Total_Score'].iloc[0]

    return max(score_normal, score_reversed)

def precompute_indices(arr, rna_active):
    # compute cosine similarity for the entire array at once
    similarities = cosine_similarity(arr, rna_active)
    indices = np.argmax(similarities, axis=1)  # get indices of max similarity for each cell in arr
    return indices

def adjust_distance_matrix(arr1, dist, initial_cols, interaction_scores_df, labels_rna, nearest_neighbors, rl_factor=0.1):
    # precompute indices for arr1 against rna_active
    indices_arr1 = precompute_indices(arr1, rna_active)

    # initialize a copy of the distance matrix for adjustments
    adjusted_dist = dist.copy()

    # iterate through each cell in arr1
    for i in range(arr1.shape[0]):
        closest_sp_sc_indices = np.argsort(dist[i])[:100]  # get indices of the closest 50 cells for each cell in arr1

        interaction_scores = []
        for j in closest_sp_sc_indices:
            closest_sp_indices = nearest_neighbors[j]
            for sp_index in closest_sp_indices:
                matched_indices = np.where(initial_cols == sp_index)[0]
                if matched_indices.size > 0:
                    matched_sc_index = matched_indices[0]  # just use the first match.
                    interaction_score = compute_interaction_score(indices_arr1[i], indices_arr1[matched_sc_index], interaction_scores_df, labels_rna)
                    interaction_scores.append(interaction_score)

            # calculate mean interaction score if any scores were calculated
            mean_interaction_score = np.mean(interaction_scores) if interaction_scores else 0
            if mean_interaction_score > 0:
                adjusted_dist[i, j] -= rl_factor * mean_interaction_score
                # print(f'Adjusted dist score: {adjusted_dist[i, j]}')
        # print(f'cell {i} done!')

    return adjusted_dist

def match_rl_cells(arr1, arr2, init_matching, batch_index1, batch_index2, base_dist=None, wt_on_base_dist=0, verbose=True):
    """
    Get matching between arr1 and arr2 using linear assignment, the distance is 1 - Pearson correlation.

    Parameters
    ----------
    arr1: np.array of shape (n_samples1, n_features)
        The first data matrix
    arr2: np.array of shape (n_samples2, n_features)
        The second data matrix
    base_dist: None or np.ndarray of shape (n_samples1, n_samples2)
        Baseline distance matrix
    wt_on_base_dist: float between 0 and 1
        The final distance matrix to use is (1-wt_on_base_dist) * dist[arr1, arr2] + wt_on_base_dist * base_dist
    verbose: bool, default=True
        Whether to print the progress

    Returns
    -------
    rows, cols, vals: list
        Each matched pair of rows[i], cols[i], their distance is vals[i]
    """
    print(f'arr1 shape: {arr1.shape}')
    print(f'arr2 shape: {arr2.shape}')

    if verbose:
        print('Start the matching process...', flush=True)
        print('Computing the distance matrix...', flush=True)

    dist = fusor.distance_matrices[(batch_index1, batch_index2)]

    # dist = utils.cdist_correlation(arr1, arr2)

    # np.save('scratch/results/correlation_distance_matrix.npy', dist)

    # dist = np.load('scratch/results/correlation_distance_matrix.npy')

    if base_dist is not None:
        if verbose:
            print(f'Interpolating {1-wt_on_base_dist} * dist[arr1, arr2] + {wt_on_base_dist} * base_dist...', flush=True)
        dist = (1-wt_on_base_dist) * dist + wt_on_base_dist * base_dist

    if verbose:
        print('Solving initial linear assignment...', flush=True)

    print(f'dist shape: {dist.shape}')
    print(f'dist : {dist}')

    # Perform initial matching
    # initial_rows, initial_cols = linear_sum_assignment(dist)
    initial_rows, initial_cols = init_matching[0], init_matching[1]

    # np.save('scratch/results/initial_rows.npy', initial_rows)
    # np.save('scratch/results/initial_cols.npy', initial_cols)

    # initial_rows = np.load('scratch/results/initial_rows.npy')
    # initial_cols = np.load('scratch/results/initial_cols.npy')

    print(f'initial rows shape: {initial_rows.shape}')
    print(f'initial_rows: {initial_rows}')
    print(f'initial cols shape: {initial_cols.shape}')
    print(f'initial_cols: {initial_cols}')

    if verbose:
        print('Initial linear assignment completed!', flush=True)

    if verbose:
        print('Adjusting distance matrix based on receptor/ligand interactions...', flush=True)

    # Adjust the distance matrix based on receptor/ligand interactions
    dist = adjust_distance_matrix(arr1, dist, initial_cols, interaction_scores_df, labels_rna, nearest_neighbors)

    print(f'new dist shape: {dist.shape}')
    print(f'new dist : {dist}')

    if verbose:
        print('Solving adjusted linear assignment...', flush=True)

    # Perform adjusted matching
    adjusted_rows, adjusted_cols = linear_sum_assignment(dist)

    if verbose:
        print('Adjusted linear assignment completed!', flush=True)

    return adjusted_rows, adjusted_cols, np.array([dist[i, j] for i, j in zip(adjusted_rows, adjusted_cols)])

# match_utils.match_cells = match_cells

In [None]:
def ref_pivots(
        self,
        wt1=0.5, wt2=0.5,
        svd_components1=None, svd_components2=None,
        cca_components=None,
        filter_prop=0,
        n_iters=1,
        randomized_svd=False, svd_runs=1,
        cca_max_iter=2000,
        verbose=True
):
    """
    Perform refined matching.

    Parameters
    ----------
    wt1: float, default=0.3
        The shrinkage weight to put on the raw data for arr1.
    wt2: float, default=0.3
        The shrinkage weight to put on the raw data for arr2.
    svd_components1: None or int, default=None
        If not None, perform SVD to reduce the dimension of self.active_arr1 before feeding it to CCA.
    svd_components2: None or int, default=None
        If not None, perform SVD to reduce the dimension of self.active_arr2 before feeding it to CCA.
    cca_components: None or int, default=None
        Number of CCA components.
        If None, it is set to 100 or self.active_arr1.shape[1] or self.active_arr2.shape[1], whichever is smaller.
    filter_prop: float, default=0.
        CCA is performed on top 1-filter_prop slice of the data on which the matched distances are smallest.
    n_iters: int, default=1
        Number of refinement iterations.
    randomized_svd: bool, default=False
        Whether to perform randomized SVD.
    svd_runs: int, default=1
        Perform multiple runs of SVD and the one with lowest Frobenious reconstruction error is selected.
    cca_max_iter: int, default=2000
        Maximum iteration number for CCA.
    verbose: bool, default=True
        Whether to print the progress.

    Returns
    -------
    None
    """
    # save cca parameters for later use
    self._svd_components1_for_cca_embedding = svd_components1
    self._svd_components2_for_cca_embedding = svd_components2
    self._randomized_svd_for_cca_embedding = randomized_svd
    self._svd_runs_for_cca_embedding = svd_runs
    self._cca_components = cca_components
    self._cca_max_iter = cca_max_iter

    self._refined_matching = []
    for batch_idx, (b1, b2) in enumerate(self._batch1_to_batch2):
        if verbose:
            print(
                'Now at batch {}<->{}...'.format(b1, b2),
                flush=True
            )
        arr1_init, arr2_init = None, None
        if self.metacell_size > 1:
            arr1 = utils.get_centroids(
                arr=self.active_arr1[self._batch_to_indices1[b1], :],
                labels=self._metacell_labels1[b1]
            )
        else:
            arr1 = self.active_arr1[self._batch_to_indices1[b1], :]

        arr2 = self.active_arr2[self._batch_to_indices2[b2], :]
        arr2_init = self.shared_arr2[self._batch_to_indices2[b2], :]

        edges1, edges2, clust_labels1, clust_labels2 = None, None, None, None
        if self.method == 'centroid_shrinkage':
            clust_labels1 = self._labels1[b1]
            clust_labels2 = self._labels2[b2]
        elif self.method == 'graph_smoothing':
            edges1 = self._edges1[b1]
            edges2 = self._edges2[b2]
        else:
            raise ValueError('self.method must be one of \'centroid_shrinkage\' or \'graph_smoothing\'.')

        self._refined_matching.append(
            get_ref_matching(
                init_matching=self._init_matching[batch_idx],
                arr1=arr1,
                arr2=arr2,
                batch_index1=b1,  # Pass batch index for arr1
                batch_index2=b2,  # Pass batch index for arr2
                randomized_svd=randomized_svd,
                svd_runs=svd_runs,
                svd_components1=svd_components1,
                svd_components2=svd_components2,
                clust_labels1=clust_labels1,
                clust_labels2=clust_labels2,
                edges1=edges1,
                edges2=edges2,
                wt1=wt1,
                wt2=wt2,
                n_iters=n_iters,
                filter_prop=filter_prop,
                cca_components=cca_components,
                cca_max_iter=cca_max_iter,
                verbose=False
            )
        )

    if verbose:
        print('Done!', flush=True)

def get_ref_matching(
        init_matching, arr1, arr2,
        batch_index1, batch_index2,
        randomized_svd=False, svd_runs=1,
        svd_components1=None, svd_components2=None,
        clust_labels1=None, clust_labels2=None,
        edges1=None, edges2=None,
        wt1=0.5, wt2=0.5,
        n_iters=3, filter_prop=0,
        cca_components=15,
        cca_max_iter=2000,
        verbose=True
):
    """
    Refinement of init_matching.

    Parameters
    ----------
    init_matching: list
        init_matching[0][i], init_matching[1][i] is a matched pair,
        and init_matching[2][i] is the distance for this pair.
    arr1: np.array of shape (n_samples1, n_features1)
        The first data matrix.
    arr2: np.array of shape (n_samples2, n_features2)
        The second data matrix.
    randomized_svd: bool, default=False
        Whether to use randomized SVD
    svd_runs: int, default=1
        Randomized SVD will result in different runs,
        so if randomized_svd=True, perform svd_runs many randomized SVDs, and pick the one with the
        smallest Frobenious reconstruction error.
        If randomized_svd=False, svd_runs is forced to be 1.
    svd_components1: None or int
        If None, then do not do SVD,
        else, number of components to keep when doing SVD de-noising for the first data matrix
        before feeding into CCA.
    svd_components2: None or int
        Same as svd_components1 but for the second data matrix.
    clust_labels1: None or np.array of shape (n_samples1, )
        If not None, then it is the clustering label of the first data matrix,
        and the smoothing of this matrix will be done via cluster centroid shrinkage.
    clust_labels2: None or np.array of shape (n_samples2, )
        Same as clust_labels1 but for the second data matrix.
    edges1: None or list of length two or three
        If not None, then each edge in the graph is (edges[0][i], edges[1][i]) with weight edges[2][i] (if exists)
        and the smoothing of this matrix will be done via graph smoothing.
    edges2: None or scipy.sparse.csr_matrix of shape (n_samples2, n_samples2)
        Same as edges1 but for the second data matrix.
    wt1: float, default=0.5
        The smoothing of the first data matrix will be wt1 * (cca embedding of arr1) + (1-wt1) * shrinkage_targets,
        where the shrinkage_targets will be either the cluster centroids or the average of graph neighbors.
    wt2: float, default=0.5
        Same as wt1 but for the second data matrix.
    n_iters: int, default=3
        Number of refinement iterations.
    filter_prop: float, default=0
        Proportion of matched pairs to discard before feeding into refinement iterations.
    cca_components: int, default=15
        Number of CCA components.
    cca_max_iter: int, default=2000,
        Maximum number of CCA iterations.
    verbose: bool, default=True
        Whether to print the progress.

    Returns
    -------
    matching: list of length 3
        rows, cols, vals = matching,
        Each matched pair is rows[i], cols[i], their distance is vals[i].
    """
    ns = [len(x) for x in init_matching]
    assert ns[0] == ns[1] == ns[2]
    # labels and edges can not be specified simultaneously
    assert (clust_labels1 is None) or (edges1 is None)
    assert (clust_labels2 is None) or (edges2 is None)
    assert isinstance(n_iters, int) and n_iters >= 1
    assert 0 <= int(ns[0] * filter_prop) < ns[0]

    assert 1 <= cca_components <= min(arr1.shape[1], arr2.shape[1])

    # incorporate receptor-ligand interactions first!
    if verbose:
        print('Performing initial matching based on full feature set...', flush=True)

    # Initial matching using all features
    init_matching = match_rl_cells(arr1=arr1, arr2=arr2, init_matching=init_matching, batch_index1=batch_index1, batch_index2=batch_index2, verbose=verbose)

    arr1 = utils.drop_zero_variability_columns(arr_lst=[arr1])[0]
    arr2 = utils.drop_zero_variability_columns(arr_lst=[arr2])[0]

    if verbose:
        print('Normalizing and reducing the dimension of the data...', flush=True)
    arr1_svd = utils.svd_embedding(
        arr=arr1, n_components=svd_components1,
        randomized=randomized_svd, n_runs=svd_runs
    )
    arr2_svd = utils.svd_embedding(
        arr=arr2, n_components=svd_components2,
        randomized=randomized_svd, n_runs=svd_runs
    )

    # prepare the distance matrix used in the initial matching
    cca_matching = init_matching
    # iterative refinement
    for it in range(n_iters):
        if verbose:
            print('Now at iteration {}:'.format(it), flush=True)
        cca_matching = match_utils.get_refined_matching_one_iter(
            init_matching=cca_matching,
            arr1=arr1_svd, arr2=arr2_svd,
            clust_labels1=clust_labels1,
            clust_labels2=clust_labels2, edges1=edges1, edges2=edges2,
            wt1=wt1, wt2=wt2, filter_prop=filter_prop,
            cca_components=cca_components, cca_max_iter=cca_max_iter, verbose=verbose
        )

    arr1_cca, arr2_cca, _ = utils.cca_embedding(
        arr1=arr1_svd, arr2=arr2_svd,
        init_matching=cca_matching,
        filter_prop=filter_prop,
        n_components=cca_components,
        max_iter=cca_max_iter)

    if verbose:
        print('Refined matching completed!', flush=True)
    return cca_matching

Fusor.refine_pivots = ref_pivots

In [None]:
# now will be adjusted r-l dist
fusor.refine_pivots(
    wt1=0.3, wt2=0.3,
    svd_components1=40, svd_components2=None,
    cca_components=30,
    n_iters=1,
    randomized_svd=False,
    svd_runs=1,
    verbose=True
)

In [None]:
fusor.filter_bad_matches(target='pivot', filter_prop=0.5)

In [None]:
pivot_matching = fusor.get_matching(order=(2, 1),target='pivot') # flipping order since flipped arr1 and arr2

lv1_acc = mf.metrics.get_matching_acc(matching=pivot_matching,
    labels1=labels_rna, # here too
    labels2=labels_codex, # here too
    order = (2,1) # here too
)
lv1_acc

In [None]:
# We can inspect the first pivot pair.
[pivot_matching[0][0], pivot_matching[1][0], pivot_matching[2][0]]

### Step IV: propagation

In [None]:
fusor.propagate(
    svd_components1=40,
    svd_components2=None,
    wt1=0.7,
    wt2=0.7,
)

In [None]:
fusor.filter_bad_matches(
    target='propagated',
    filter_prop=0.3
)

In [None]:
full_matching = fusor.get_matching(order=(2, 1), target='full_data') # likewise flipping

In [None]:
# compute the cell type level matching accuracy, for the full (filtered version) dataset
lv1_acc = mf.metrics.get_matching_acc(matching=full_matching,
    labels1=labels_rna, # here too
    labels2=labels_codex # here too
)
lv1_acc

## Step V: downstream analysis

In [None]:
rna_cca, protein_cca_sub = fusor.get_embedding(
    active_arr1=fusor.active_arr1,
    active_arr2=fusor.active_arr2[full_matching[1],:] # cells in codex remained after filtering
)

In [None]:
labels_rna = labels_rna.ravel()  # this flattens `labels_rna` to 1D if it is 2D
print(labels_rna.shape)
print(labels_codex.shape)

In [None]:
np.random.seed(42)
subs = 13000
randix = np.random.choice(protein_cca_sub.shape[0],subs, replace = False)

dim_use = 15 # dimensions of the CCA embedding to be used for UMAP etc.

cca_adata = ad.AnnData(
    np.concatenate((rna_cca[:,:dim_use], protein_cca_sub[randix,:dim_use]), axis=0),
    dtype=np.float32
)
cca_adata.obs['data_type'] = ['rna'] * rna_cca.shape[0] + ['protein'] * subs
cca_adata.obs['cell_type'] = list(np.concatenate((
    labels_rna, labels_codex[full_matching[1]][randix]), axis = 0))

In [None]:
sc.pp.neighbors(cca_adata, n_neighbors=15)
sc.tl.umap(cca_adata)
sc.pl.umap(cca_adata, color='data_type')

In [None]:
sc.pl.umap(cca_adata, color='cell_type')

## Step VI: spatial profiling analysis

In [None]:
# rna_adata.write("scratch/results/rna_adata.h5ad")
# protein_adata.write("scratch/results/protein_adata.h5ad")
# protein.to_csv('scratch/results/protein.csv', index=False)

# custom JSON encoder that converts numpy int64 to int
class JSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(JSONEncoder, self).default(obj)

# writing the dictionary to a JSON file
with open('scratch/results/full_matching_results_rl0.1_100.json', 'w') as jsonfile:
    json.dump(full_matching, jsonfile, cls=JSONEncoder, indent=4)

In [None]:
matching_df = pd.DataFrame(list(zip(full_matching[0], full_matching[1], full_matching[2])),
                           columns=['mod1_indx', 'mod2_indx', 'score'])

merged_df = matching_df.merge(protein, left_on='mod2_indx', right_index=True)


plt.scatter(protein['X'], protein['Y'], c='blue', label='Spatial Proteomics', alpha=0.3, s=1)
plt.scatter(merged_df['X'], merged_df['Y'], c='red', label='RNA (mod1)', alpha=0.3, s=1)

plt.xlabel('Centroid X')
plt.ylabel('Centroid Y')
plt.title('RNA Mapped on Spatial Proteomics Data')
plt.legend()

plt.show()

In [None]:
labels_rna_df = pd.DataFrame(labels_rna, columns=['cell_type'])

merged_df = merged_df.merge(labels_rna_df, left_on='mod1_indx', right_index=True)

plt.scatter(protein['X'], protein['Y'], c='blue', label='Spatial Proteomics', alpha=0.3, s=1)

# define a color map for cell types
cell_types = merged_df['cell_type'].unique()
color_map = {cell_type: color for cell_type, color in zip(cell_types, plt.cm.get_cmap('tab10').colors)}

# plot RNA data points based on cell types
for cell_type in cell_types:
    subset = merged_df[merged_df['cell_type'] == cell_type]
    plt.scatter(subset['X'], subset['Y'], c=[color_map[cell_type]], label=f'RNA ({cell_type})', alpha=0.6, s=5)

plt.xlabel('Centroid X')
plt.ylabel('Centroid Y')
plt.title('RNA Mapped on Spatial Proteomics Data')
plt.legend()

plt.show()

In [None]:
# create a confusion matrix comparing "CELL_TYPE" vs "cell_type"
confusion_matrix = pd.crosstab(merged_df['CELL_TYPE'], merged_df['cell_type'])

plt.figure(figsize=(10, 8))
sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='viridis', cbar=True,
            xticklabels=True, yticklabels=True)
plt.title('Confusion Matrix: CODEX (Ground-truth) vs scRNA-seq (Predicted)')
plt.xlabel('scRNA-seq (Predicted)')
plt.ylabel('CODEX (Ground)')
plt.xticks(rotation=45)
plt.yticks(rotation=0)

plt.savefig('scratch/figures/rl_factor0.1_500.png', bbox_inches='tight', dpi=300)
plt.show()

In [None]:
df_copy = df.copy()

# custom mappings for cell type categories
cell_type_map = {
    'Macrophages': 'Myeloid cells',
    'Macrophages M2-like': 'Myeloid cells',
    'Dendritic cells': 'Myeloid cells',
    'Neutrophils': 'Myeloid cells',
    'B cells': 'B cells',
    'INFg+': 'B cells',
    'Epithelium (INOS+)': 'Epithelium',
    'Epithelium (INOS-)': 'Epithelium'
}

# apply mappings to the copy
df_copy['CELL_TYPE'] = df_copy['CELL_TYPE'].replace(cell_type_map)
df_copy['cell_type'] = df_copy['cell_type'].replace({
    'Macrophages M2-like': 'Macrophages',
    'Macrophages': 'Macrophages'
})

# remove 'Endothelial cells'
filtered_df = df_copy[df_copy['CELL_TYPE'] != 'Endothelial cells']

confusion_matrix = pd.crosstab(filtered_df['CELL_TYPE'], filtered_df['cell_type'])

# order for plotting
y_order = ['B cells', 'CD4 T cells', 'CD8 T cells', 'Fibroblasts', 'Myeloid cells', 'Epithelium']
x_order = ['B cells', 'CD4 T cells', 'CD8 T cells', 'Fibroblasts', 'Macrophages', 'Malignant cells']

# reorder the DataFrame for plotting
confusion_matrix = confusion_matrix.reindex(index=y_order, columns=x_order).fillna(0)

# Plot the confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_matrix, annot=True, fmt='g', cmap='viridis', cbar=True,
            xticklabels=True, yticklabels=True)
plt.title('Confusion Matrix: CODEX (Ground-truth) vs scRNA-seq (Predicted)')
plt.xlabel('scRNA-seq (Predicted)')
plt.ylabel('CODEX (Ground)')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.show()


In [None]:
df_copy = filtered_df.copy()

# custom mappings for cell type categories
cell_type_map_y = {
    'Epithelium (INOS+)': 'Malignant cells',
    'Epithelium (INOS-)': 'Malignant cells'
}

cell_type_map_x = {
    'Macrophages': 'Myeloid cells',
    'Macrophages M2-like': 'Myeloid cells'
}

# apply mappings to the copy for y-axis (true labels) and x-axis (predicted labels)
df_copy['CELL_TYPE'] = df_copy['CELL_TYPE'].replace(cell_type_map_y)
df_copy['cell_type'] = df_copy['cell_type'].replace(cell_type_map_x)

# extract labels for F1 and ARI calculation
true_labels = df_copy['CELL_TYPE'].values  # adjusted true cell types
predicted_labels = df_copy['cell_type'].values  # adjusted predicted cell types

# calculate F1-score. Note: We use 'weighted' to account for label imbalance.
f1 = f1_score(true_labels, predicted_labels, average='weighted')
print("F1-score: {:.2f}".format(f1))

# calculate Adjusted Rand Index
ari = adjusted_rand_score(true_labels, predicted_labels)
print("Adjusted Rand Index (ARI): {:.2f}".format(ari))


## Step VII: Pseudo-spot creation

In [None]:
### Visium Regions ###
A = ad.read_h5ad('scratch/SP_SC_data/s4769_visium_data/visium_counts.h5ad')
label_df = pd.read_csv('scratch/SP_SC_data/s4769_metadata.csv')

def get_matching_codex_visium_dfs(target_acq):
    # target_region = region_id_to_label[target_acq]
    target_region = emdb.get_region_label_for_acquisition_id(target_acq)

    # CODEX cell coordinates and cell types
    # cell_type_df = pd.read_csv(os.path.join(root_path, 'raw_data', f'scratch/SP_SC_data/s4769_cell_types/{target_acq}.cell_types.csv'))
    # cell_type_df.set_index('CELL_ID', inplace=True)
    # cell_coord_df = pd.read_csv(os.path.join(root_path, 'raw_data', f'/{target_acq}.cell_data.csv'))
    # cell_coord_df.set_index('CELL_ID', inplace=True)
    # cell_data_df = cell_coord_df.join(cell_type_df)

    # Visium spot coordinates and expression
    xdim, ydim = emdb.get_image_dim_for_acquisition_id(target_acq)
    spot_coord_df = pd.read_csv(f'scratch/SP_SC_data/s4769_visium_data/0{target_region}_tissue_positions_list.csv', index_col=0)
    spot_coord_df = spot_coord_df[spot_coord_df['in_tissue'] == 1]
    spot_coord_df = spot_coord_df[spot_coord_df['transformed coords x'] >= 0]
    spot_coord_df = spot_coord_df[spot_coord_df['transformed coords x'] <= xdim]
    spot_coord_df = spot_coord_df[spot_coord_df['transformed coords y'] >= 0]
    spot_coord_df = spot_coord_df[spot_coord_df['transformed coords y'] <= ydim]
    assert len(spot_coord_df) == len(set(spot_coord_df['barcode']))

    patient_id = target_region.split('_')[0]
    state = label_df[label_df['acq_id'] == target_acq]['diagnosis'].item().lower()
    inds = []
    valid_bcs = []

    identifier_ar = list(A.obs.index)
    for bc in spot_coord_df['barcode']:
        assert len(bc) == 18
        assert bc.endswith('-1')
        identifier = f'cytassist_{patient_id}_{state}@{bc}'
        if identifier in identifier_ar:
            inds.append(identifier_ar.index(identifier))
            valid_bcs.append(bc)
        else:
            print("Missing %s" % identifier)

    spot_coord_df = spot_coord_df.set_index('barcode').loc[valid_bcs].reset_index()
    sub_A = A[inds]
    assert sub_A.shape[0] == spot_coord_df.shape[0]

    spot_ids = np.array(spot_coord_df['barcode'])
    spot_coords = np.array(spot_coord_df[['transformed coords x', 'transformed coords y']])
    spot_to_cells = {spot_id: [] for spot_id in spot_ids}
    for cid, row in protein.iterrows():
        x, y, ct = row['X'], row['Y'], row['CELL_TYPE']
        dists_to_spots = np.linalg.norm(np.array([x, y]).reshape((1, 2)) - spot_coords, ord=2, axis=1)
        if np.min(dists_to_spots) < 140:
            matched_spot = spot_ids[np.argmin(dists_to_spots)]
            spot_to_cells[matched_spot].append((cid, ct))

    return sub_A, spot_coord_df, protein, spot_to_cells

In [None]:
# # Get the matching dataframes
sub_A, spot_coord_df, protein, spot_to_cells = get_matching_codex_visium_dfs(target_acq)

In [None]:
# convert the protein DataFrame to include only relevant rows
protein_subset = protein[protein.index.isin([cell[0] for cells in spot_to_cells.values() for cell in cells])]

fig, ax = plt.subplots(figsize=(12, 12))

# plot Visium spots
ax.scatter(spot_coord_df['transformed coords x'],
           spot_coord_df['transformed coords y'],
           c='blue', label='Visium Spots', s=50, alpha=0.6)

# plot CODEX cells with colors indicating cell types
cell_types = protein_subset['CELL_TYPE'].unique()
colors = plt.cm.get_cmap('tab20', len(cell_types))

for i, cell_type in enumerate(cell_types):
    cells_of_type = protein_subset[protein_subset['CELL_TYPE'] == cell_type]
    ax.scatter(cells_of_type['X'], cells_of_type['Y'], c=[colors(i)], label=cell_type, s=20, alpha=0.6)

ax.set_xlabel('X Coordinate')
ax.set_ylabel('Y Coordinate')
ax.set_title(f'CODEX Ground Truth Cells')

ax.legend(loc='best')

plt.show()

In [None]:
scaler = MinMaxScaler()

# fit and transform the coordinates for normalization
spot_coord_df[['transformed coords x', 'transformed coords y']] = scaler.fit_transform(spot_coord_df[['transformed coords x', 'transformed coords y']])
merged_df[['X', 'Y']] = scaler.fit_transform(merged_df[['X', 'Y']])

fig, ax = plt.subplots(figsize=(12, 12))

# plot Visium spots
ax.scatter(spot_coord_df['transformed coords x'],
           spot_coord_df['transformed coords y'],
           c='blue', label='Visium Spots', s=50, alpha=0.6)

# define a color map for cell types
cell_types = merged_df['cell_type'].unique()
colors = plt.cm.get_cmap('tab20', len(cell_types))

# plot RNA data points based on cell types
for i, cell_type in enumerate(cell_types):
    subset = merged_df[merged_df['cell_type'] == cell_type]
    ax.scatter(subset['X'], subset['Y'], c=[colors(i)], label=f'RNA ({cell_type})', alpha=0.6, s=30)

ax.set_xlabel('Normalized X Coordinate')
ax.set_ylabel('Normalized Y Coordinate')
ax.set_title('Matched scRNA-seq Cells')

ax.legend(loc='best')

plt.show()

## Step VIII: Evaluating spatial accuracy

In [None]:
# condensing the RNA to CODEX matched data onto its nearest visium spot

# build a KD-tree for the Visium spots
visium_coords = spot_coord_df[['transformed coords x', 'transformed coords y']].values
kdtree = cKDTree(visium_coords)

# find the nearest Visium spot for each RNA cell
rna_coords = merged_df[['X', 'Y']].values
distances, indices = kdtree.query(rna_coords)

# add the nearest Visium spot index to the merged_df
merged_df['nearest_spot_index'] = indices

merged_df

In [None]:
A = ad.read_h5ad('scratch/SP_SC_data/s4769_visium_data/visium_counts.h5ad')
# label_df = pd.read_csv('scratch/SP_SC_data/s4769_metadata.csv')

In [None]:
indices_to_retain = merged_df['mod1_indx'].values

# filter the scRNA_adata object to include only those cells
filtered_rna_adata = rna_adata[indices_to_retain]

# identify common genes
common_genes = np.intersect1d(filtered_rna_adata.var_names, sub_A.var_names)

len(common_genes)

In [None]:
# identify unique genes in scRNA and spatial transcriptomics datasets
unique_genes_scRNA = np.setdiff1d(filtered_rna_adata.var_names, sub_A.var_names)
unique_genes_spatial = np.setdiff1d(sub_A.var_names, filtered_rna_adata.var_names)

# print the number of unique genes
print(f"Number of unique genes in scRNA: {len(unique_genes_scRNA)}")
print(f"Number of unique genes in spatial transcriptomics: {len(unique_genes_spatial)}")

# print some of the unique genes to inspect potential naming issues
print("Unique genes in scRNA (first 10):", unique_genes_scRNA[:200])
print("Unique genes in spatial transcriptomics (first 10):", unique_genes_spatial[:200])

In [None]:
# need to filter the rna_adata object include only selected ones
indices_to_retain = merged_df['mod1_indx'].values
filtered_rna_adata = rna_adata[indices_to_retain]

common_genes = np.intersect1d(filtered_rna_adata.var_names, sub_A.var_names)
filtered_rna_adata = filtered_rna_adata[:, common_genes]
spatial_adata = sub_A[:, common_genes]

# binarize the gene expression matrix (presence/absence of genes)
scRNA_gene_presence = (filtered_rna_adata.X > 0).astype(int)
scRNA_gene_presence_df = pd.DataFrame(scRNA_gene_presence, columns=filtered_rna_adata.var_names)

# add nearest spot index to the gene presence DataFrame
scRNA_gene_presence_df['nearest_spot_index'] = indices

# aggregate RNA expressions based on the nearest Visium spot
aggregated_data = scRNA_gene_presence_df.groupby('nearest_spot_index').max().reset_index()

# compute the Jaccard index for each Visium spot
def jaccard_index(set1, set2):
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union

# convert aggregated_data and ground_truth_df to sets of expressed genes
ground_truth_gene_presence = (spatial_adata.X > 0).astype(int)
ground_truth_df = pd.DataFrame(ground_truth_gene_presence.toarray(), columns=spatial_adata.var_names)

aggregated_gene_sets = aggregated_data.drop('nearest_spot_index', axis=1).apply(lambda row: set(row[row > 0].index), axis=1)
ground_truth_gene_sets = ground_truth_df.apply(lambda row: set(row[row > 0].index), axis=1)

# initialize jaccard indices with zeros or NaNs for the length of spot_coord_df
jaccard_indices = np.full(len(spot_coord_df), np.nan)

# calculate Jaccard index for each spot
for i in range(len(aggregated_gene_sets)):
    spot_index = aggregated_data.loc[i, 'nearest_spot_index']
    jaccard_indices[spot_index] = jaccard_index(aggregated_gene_sets[i], ground_truth_gene_sets.iloc[spot_index])

fig, ax = plt.subplots(figsize=(12, 12))]
ax.scatter(spot_coord_df['transformed coords x'], spot_coord_df['transformed coords y'],
           c='blue', label='Visium Spots', s=50, alpha=0.6)
norm = plt.Normalize(0, 1)
cmap = plt.cm.viridis

# plot aggregated RNA data points based on Jaccard index
scatter = ax.scatter(spot_coord_df['transformed coords x'], spot_coord_df['transformed coords y'],
                     c=jaccard_indices, cmap=cmap, norm=norm, s=50, alpha=0.6)
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Jaccard Index')
ax.set_xlabel('Normalized X Coordinate')
ax.set_ylabel('Normalized Y Coordinate')
ax.set_title('Jaccard Index of Aggregated RNA on Visium Spots')

plt.show()

In [None]:
scRNA_gene_presence_df = pd.DataFrame(filtered_rna_adata.X, columns=filtered_rna_adata.var_names)

# add nearest spot index to the gene presence DataFrame
scRNA_gene_presence_df['nearest_spot_index'] = indices

# aggregate RNA expressions based on nearest Visium spot
aggregated_data = scRNA_gene_presence_df.groupby('nearest_spot_index').mean()

# (spatial transcriptomics data)
ground_truth_df = pd.DataFrame(spatial_adata.X.todense(), columns=spatial_adata.var_names)

# initialize correlation indices with NaNs for length of spot_coord_df
correlation_indices = np.full(len(spot_coord_df), np.nan)

# calculate Spearman correlation for each spot
for i in range(len(aggregated_data)):
    spot_index = aggregated_data.index[i]
    if spot_index < len(ground_truth_df):
        correlation, _ = spearmanr(aggregated_data.iloc[i], ground_truth_df.iloc[spot_index])
        correlation_indices[spot_index] = correlation

fig, ax = plt.subplots(figsize=(12, 12))

norm = plt.Normalize(np.nanmin(correlation_indices), np.nanmax(correlation_indices))
cmap = plt.cm.viridis

# plot aggregated RNA data points based on Spearman correlation
scatter = ax.scatter(spot_coord_df['transformed coords x'], spot_coord_df['transformed coords y'],
                     c=correlation_indices, cmap=cmap, norm=norm, s=50, alpha=0.6)

cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Spearman Correlation')
ax.set_xlabel('Normalized X Coordinate')
ax.set_ylabel('Normalized Y Coordinate')
ax.set_title('Spearman Correlation of Aggregated RNA on Visium Spots')

plt.show()


In [None]:
# plot Spearman correlation values as a histogram
plt.figure(figsize=(10, 6))
sns.histplot(correlation_indices, bins=30, kde=True)

plt.title('Distribution of Spearman Correlation Values')
plt.xlabel('Spearman Correlation')
plt.ylabel('Frequency')

plt.show()

In [None]:
# compute Spearman correlation for each gene across all spots
ground_truth_df = pd.DataFrame(spatial_adata.X.toarray(), columns=spatial_adata.var_names)

aggregated_data = aggregated_data.reindex(spot_coord_df.index)
ground_truth_df = ground_truth_df.reindex(spot_coord_df.index)

gene_correlations = []

for gene in common_genes:
    if gene in aggregated_data.columns and gene in ground_truth_df.columns:
        # drop NaN values that may have resulted from the reindexing
        valid_indices = ~np.isnan(aggregated_data[gene]) & ~np.isnan(ground_truth_df[gene])
        gene_corr, _ = spearmanr(aggregated_data[gene][valid_indices], ground_truth_df[gene][valid_indices])
        gene_correlations.append((gene, gene_corr))

# sort genes by their correlation values
gene_correlations_sorted = sorted(gene_correlations, key=lambda x: x[1], reverse=True)

gene_correlations_df = pd.DataFrame(gene_correlations_sorted, columns=['Gene', 'Spearman Correlation'])

# plot the top 30 genes
plt.figure(figsize=(12, 8))
top_n = 30
sns.barplot(x='Spearman Correlation', y='Gene', data=gene_correlations_df.head(top_n))
plt.title(f'Top {top_n} Genes by Spearman Correlation Across Modalities')
plt.xlabel('Spearman Correlation')
plt.ylabel('Gene')

plt.show()


## Step IX: Cell-type Deconvolution Using Tangram

In [None]:
# raw adata for deconvolution
scRNA_adata = ad.AnnData(
    rna.tocsr(), dtype=np.float32
)
scRNA_adata.var_names = rna_names
scRNA_adata.var_names_make_unique()
scRNA_adata.obs['cell_type'] = rna_adata.obs['celltype']

spatial_adata = sub_A

In [None]:
# loading raw Visium counts for Tangram comparison
all_A = ad.read_h5ad('scratch/SP_SC_data/s4769_visium_data/visium_all.h5ad')
label_df = pd.read_csv('scratch/SP_SC_data/s4769_metadata.csv')

In [None]:
def new_get_matching_codex_visium_dfs(target_acq):
    # target_region = region_id_to_label[target_acq] # what is region_id_to_label
    target_region = emdb.get_region_label_for_acquisition_id(target_acq)


    # CODEX cell coordinates and cell types
    # cell_type_df = pd.read_csv(os.path.join(root_path, 'raw_data', f'scratch/SP_SC_data/s4769_cell_types/{target_acq}.cell_types.csv'))
    # cell_type_df.set_index('CELL_ID', inplace=True)
    # cell_coord_df = pd.read_csv(os.path.join(root_path, 'raw_data', f'/{target_acq}.cell_data.csv'))
    # cell_coord_df.set_index('CELL_ID', inplace=True)
    # cell_data_df = cell_coord_df.join(cell_type_df)

    # Visium spot coordinates and expression
    xdim, ydim = emdb.get_image_dim_for_acquisition_id(target_acq)
    spot_coord_df = pd.read_csv(f'scratch/SP_SC_data/s4769_visium_data/0{target_region}_tissue_positions_list.csv', index_col=0)
    spot_coord_df = spot_coord_df[spot_coord_df['in_tissue'] == 1]
    spot_coord_df = spot_coord_df[spot_coord_df['transformed coords x'] >= 0]
    spot_coord_df = spot_coord_df[spot_coord_df['transformed coords x'] <= xdim]
    spot_coord_df = spot_coord_df[spot_coord_df['transformed coords y'] >= 0]
    spot_coord_df = spot_coord_df[spot_coord_df['transformed coords y'] <= ydim]
    assert len(spot_coord_df) == len(set(spot_coord_df['barcode']))

    patient_id = target_region.split('_')[0]
    state = label_df[label_df['acq_id'] == target_acq]['diagnosis'].item().lower()
    inds = []
    valid_bcs = []

    identifier_ar = list(all_A.obs.index)
    for bc in spot_coord_df['barcode']:
        assert len(bc) == 18
        assert bc.endswith('-1')
        identifier = f'cytassist_{patient_id}_{state}@{bc}'
        if identifier in identifier_ar:
            inds.append(identifier_ar.index(identifier))
            valid_bcs.append(bc)
        else:
            print("Missing %s" % identifier)

    spot_coord_df = spot_coord_df.set_index('barcode').loc[valid_bcs].reset_index()
    sub_all_A = all_A[inds]
    assert sub_all_A.shape[0] == spot_coord_df.shape[0]

    spot_ids = np.array(spot_coord_df['barcode'])
    spot_coords = np.array(spot_coord_df[['transformed coords x', 'transformed coords y']])
    spot_to_cells = {spot_id: [] for spot_id in spot_ids}
    for cid, row in protein.iterrows():
        x, y, ct = row['X'], row['Y'], row['CELL_TYPE']
        dists_to_spots = np.linalg.norm(np.array([x, y]).reshape((1, 2)) - spot_coords, ord=2, axis=1)
        if np.min(dists_to_spots) < 140:
            matched_spot = spot_ids[np.argmin(dists_to_spots)]
            spot_to_cells[matched_spot].append((cid, ct))

    return sub_all_A, spot_coord_df, protein, spot_to_cells

sub_all_A, new_spot_coord_df, protein, new_spot_to_cells = new_get_matching_codex_visium_dfs(target_acq)

In [None]:
# prepare data for Tangram
tg.pp_adatas(scRNA_adata, spatial_adata, genes=None) # best not to normalize before this (use raw data).

# map single-cell data to spatial data using Tangram
ad_map = tg.map_cells_to_space(scRNA_adata, spatial_adata, mode='cells')

# project cell types onto spatial data
tg.project_cell_annotations(ad_map, spatial_adata)

# extract deconvolution results
proportions_df = spatial_adata.obsm['tangram_ct_pred']
proportions_df = pd.DataFrame(proportions_df, index=spatial_adata.obs_names, columns=scRNA_adata.obs['cell_type'].unique())

plt.figure(figsize=(12, 8))
sns.heatmap(proportions_df, cmap='viridis', cbar_kws={'label': 'Proportion'})
plt.title('Deconvolution of Spatial Transcriptomics Data')
plt.xlabel('Cell Types')
plt.ylabel('Spatial Spots')
plt.show()


In [None]:
# extract the deconvolution results
proportions_df = spatial_adata.obsm['tangram_ct_pred']
proportions_df = pd.DataFrame(proportions_df, index=spatial_adata.obs_names, columns=scRNA_adata.obs['cell_type'].unique())

# extract the part after the @ sign for matching with spot_coord_df
proportions_df.index = [index.split('@')[1] for index in proportions_df.index]

# ensure spot_coord_df is aligned with the proportions_df
spot_coord_df = spot_coord_df.set_index('barcode')
spot_coord_df = spot_coord_df.loc[proportions_df.index]

# combine coordinates with proportions
plot_df = pd.concat([spot_coord_df[['transformed coords x', 'transformed coords y']], proportions_df], axis=1)

# plot the deconvolution results on a scatter plot for each cell type
for cell_type in scRNA_adata.obs['cell_type'].unique():
    plot_df['proportion'] = plot_df[cell_type]

    fig = px.scatter(
        plot_df, x='transformed coords x', y='transformed coords y',
        color='proportion', color_continuous_scale='viridis',
        title=f'Deconvolution of Spatial Transcriptomics Data - {cell_type}',
        labels={'proportion': 'Proportion', 'transformed coords x': 'X Coordinate', 'transformed coords y': 'Y Coordinate'})
    fig.show()

In [None]:
# plot deconvolution results on a scatter plot for each cell type
for cell_type in scRNA_adata.obs['cell_type'].unique():
    plot_df['proportion'] = plot_df[cell_type]

    fig = px.scatter(
        plot_df, x='transformed coords x', y='transformed coords y',
        color='proportion', color_continuous_scale='viridis',
        title=f'Deconvolution of Spatial Transcriptomics Data - {cell_type}',
        labels={'proportion': 'Proportion', 'transformed coords x': 'X Coordinate', 'transformed coords y': 'Y Coordinate'})

    fig.update_layout(width=1200, height=800)

    fig.show()

In [None]:
# combine coordinates with proportions
plot_df = pd.concat([spot_coord_df[['transformed coords x', 'transformed coords y']], proportions_df], axis=1)

# melt dataframe to long format for Plotly
melted_df = plot_df.melt(id_vars=['transformed coords x', 'transformed coords y'], var_name='cell_type', value_name='proportion')

# plot deconv results on faceted scatter plot
fig = px.scatter(
    melted_df, x='transformed coords x', y='transformed coords y',
    color='proportion', color_continuous_scale='viridis',
    facet_col='cell_type', facet_col_wrap=4,
    title='Deconvolution of Spatial Transcriptomics Data',
    labels={'proportion': 'Proportion', 'transformed coords x': 'X Coordinate', 'transformed coords y': 'Y Coordinate'})

fig.update_layout(width=1400, height=800)

fig.show()

## Step X: SPACE-GM Microenvironment Analysis of Inferred scRNAseq

In [None]:
microe = pd.read_csv(f'scratch/SP_SC_data/microe_annotations_with_names/pre_saved_annotations/pre-treatment-full-20_leiden_0.32-15/{target_acq}.csv', index_col=0)

# merge based on CELL_ID
merged_microe_df = merged_df.merge(microe, on='CELL_ID')

# group by 'MicroE_Name' and 'cell_type' to get counts
microe_cell_type_counts = merged_microe_df.groupby(['MicroE_Name', 'cell_type']).size().reset_index(name='counts')

microe_cell_type_counts

In [None]:
# ensure that MicroE_Name is treated as a numeric category for proper ordering
microe_cell_type_counts['MicroE_Name'] = pd.Categorical(
    microe_cell_type_counts['MicroE_Name'],
    categories=[f'MicroE_{i}' for i in range(20)],
    ordered=True)

# sort the DataFrame by 'MicroE_Name' to ensure correct order
microe_cell_type_counts.sort_values('MicroE_Name', inplace=True)

In [None]:
# pivot the data to create a matrix suitable for heatmap
pivot_table = microe_cell_type_counts.pivot(index='MicroE_Name', columns='cell_type', values='counts').fillna(0)
rows_to_move = pivot_table.iloc[1:3]  # select second and third rows
pivot_table_dropped = pivot_table.drop(pivot_table.index[1:3])  # drop the second and third rows from the original
pivot_table_reordered = pd.concat([pivot_table_dropped, rows_to_move])  # concatenate the remaining DataFrame with the rows to move
pivot_table_reordered

In [None]:
plt.figure(figsize=(12, 8))
plt.title("Microenvironment vs Single-Cell Type Counts")

sns.heatmap(pivot_table_reordered, cmap="viridis", annot=True, fmt="g")

plt.xlabel("Single-Cell Type")
plt.ylabel("Microenvironment Name")

plt.show()

In [None]:
plt.figure(figsize=(12, 8))
sns.barplot(data=microe_cell_type_counts, x='cell_type', y='counts', hue='MicroE_Name')
plt.title("Distribution of Single-Cell Types Across Microenvironments")
plt.xlabel("Single-Cell Type")
plt.ylabel("Counts")
plt.xticks(rotation=90)
plt.legend(title='Microenvironment Name')
plt.show()


In [None]:
# performing GSEA analysis per Leiden cluster of SPACE-GM annot

# ensure mod1_indx is aligned with filtered_rna_adata object
merged_microe_df['mod1_indx'] = merged_microe_df['mod1_indx'].astype(str)
filtered_rna_adata.obs.index = filtered_rna_adata.obs.index.astype(str)

# expression data for relevant indices
expression_data = filtered_rna_adata[filtered_rna_adata.obs.index.isin(merged_microe_df['mod1_indx'])]

# adds MicroE_Cluster as a new column in filtered_rna_adata.obs
cluster_mapping = merged_microe_df.set_index('mod1_indx')['MicroE_Cluster'].to_dict()
expression_data.obs['MicroE_Cluster'] = expression_data.obs.index.map(cluster_mapping)

# adds MicroE_Name for cluster titles
cluster_name_mapping = merged_microe_df.set_index('MicroE_Cluster')['MicroE_Name'].to_dict()

results = []
for cluster in sorted(expression_data.obs['MicroE_Cluster'].unique()):
    # filter for current cluster
    cluster_data = expression_data[expression_data.obs['MicroE_Cluster'] == cluster]

    # prepare rank data using gene names
    mean_expression = cluster_data.X.mean(axis=0)
    gene_names = cluster_data.var_names

    rnk = pd.DataFrame({'gene': gene_names, 'score': mean_expression})
    rnk['gene'] = rnk['gene'].str.upper()
    rnk = rnk.drop_duplicates(subset=['gene']).sort_values(by='score', ascending=False)

    # perform GSEA
    if rnk.shape[0] > 0:
        gsea_results = gp.prerank(rnk=rnk, gene_sets='MSigDB_Hallmark_2020', min_size=10, max_size=1000, permutation_num=1000)

        # Check if results are not empty
        if not gsea_results.res2d.empty:
            results.append((cluster, gsea_results))
    else:
        print(f"No data for cluster: {cluster}")

summary_list = []

for cluster, result in results:
    summary = result.res2d[['Term', 'ES', 'NES', 'NOM p-val', 'FDR q-val']].copy()
    summary['Cluster'] = cluster
    summary_list.append(summary)

final_summary = pd.concat(summary_list).reset_index(drop=True)

# plotting
clusters = sorted(final_summary['Cluster'].unique())
num_clusters = len(clusters)

# set up a 3x3 grid since there are 9 clusters
rows, cols = 3, 3

fig, axes = plt.subplots(rows, cols, figsize=(24, 18), constrained_layout=True)

for i, cluster in enumerate(clusters):
    ax = axes[i // cols, i % cols]
    cluster_data = final_summary[final_summary['Cluster'] == cluster]
    top_terms = cluster_data.sort_values(by='NES', ascending=False).head(10)

    sns.barplot(x='NES', y='Term', data=top_terms, ax=ax, palette='viridis')
    ax.set_xlabel('Normalized Enrichment Score (NES)')
    ax.set_title(cluster_name_mapping.get(cluster, 'Cluster Name Not Found'))
    ax.invert_yaxis()

plt.show()