In [None]:
#@title Install dependencies and setup paths
import os
import sys
  
PATH_TO_MODELS = f'./cycif_models/'

paths = [
  PATH_TO_MODELS,
]

paths_exist = {path: os.path.exists(path) for path in paths}
display(paths_exist)
assert all(paths_exist.values())

In [None]:
#@title Import and set processes and set training params
%load_ext autoreload
%autoreload 2

import functools
import logging
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import pickle
import scipy
import seaborn as sns
from sklearn.model_selection import train_test_split
import time
import tqdm
# Spatial LDA imports
from spatial_lda.featurization import neighborhood_to_cluster
from spatial_lda.featurization import make_nearest_neighbor_graph
from spatial_lda.featurization import make_merged_difference_matrices
from spatial_lda.featurization import featurize_spleens
#from spatial_lda.visualization import plot_samples_in_a_row
#from spatial_lda.visualization import plot_bcell_topic_multicolor
import spatial_lda.model
#import scimap as sm
import scanpy as sc

N_PARALLEL_PROCESSES = 8 #@param
TRAIN_SIZE_FRACTION = 0.99 #@param
N_TOPICS_LIST = [8, 10, 6] #@param

logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
## functions
%matplotlib inline
import palettable.cartocolors.qualitative as qual_palettes
import matplotlib.patches as mpatches
import matplotlib.markers as mmark
from matplotlib.lines import Line2D

def plot_bcell_topic_multicolor(ax, sample_idx, topic_weights, spleen_dfs):
    topic_weights_slide = topic_weights[topic_weights.index.map(lambda x: x[0])==s_slide]
    color_palette = qual_palettes.Bold_10.mpl_colors
    colors = color_palette[:topic_weights_slide.shape[1]]
    d_color = dict(zip(sorted(set(np.argmax(np.array(topic_weights_slide),axis=1))),colors))
    cell_coords = spleen_dfs[sample_idx]
    non_b_coords = cell_coords[~cell_coords.isb]
    ax.scatter(
        non_b_coords['sample.X'],
        non_b_coords['sample.Y'],
        s=1,
        c='k',
        marker='x',
        label='Non-tumor',
        alpha=.2)

    cell_indices = topic_weights_slide.index.map(lambda x: x[1])
    coords = spleen_dfs[sample_idx].loc[spleen_dfs[sample_idx].index.isin(cell_indices)]

    ax.scatter(coords['sample.X'], coords['sample.Y'], s=3,
               c=np.array(colors)[np.argmax(np.array(topic_weights_slide), axis=1), :],
               #label=d_color
              )
    handles = [mpatches.Circle((1,1), radius=5, color=color) for color in d_color.values()] #rectangles
    #handles = [Line2D([], [], c=color, lw=0, marker="o", markersize=8, label=species) for species, color in d_color.items()] #dots
    ax.axes.get_yaxis().set_visible(False)
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.set_title(f'Sample {sample_idx}')
    ax.axis('equal')
    ax.set_ylim(ax.get_ylim()[::-1])
    ax.legend(handles, d_color.keys(),
                    bbox_to_anchor=(1.05,.5), title="Topics")
    #return(fig, ax)


def plot_samples_in_a_row(features_df, plot_fn, patient_dfs, tumor_set=None):
    sns.set_style("white")
    tumor_idx = features_df.index.map(lambda x: x[0])
    if tumor_set is None:
        tumor_set = np.unique(tumor_idx)

    n = len(tumor_set)
    num_rows = 1
    num_cols = (len(tumor_set) // num_rows)
    fig, axes = plt.subplots(
        num_rows, num_cols, figsize=(num_cols * 4, num_rows * 4), dpi=300)

    for i, tumor in enumerate(tumor_set):
        plot_fn(axes[i], tumor, features_df[tumor_idx == tumor], patient_dfs)

    sns.despine(left=True, bottom=True)
    return(fig, axes)
    
def plot_topics_heatmap(topics, features, ax, normalizer=None):
    n_topics = topics.shape[0]
    if normalizer is not None:
        topics = normalizer(topics)
    else:
        topics = _standardize_topics(topics)

    topics = pd.DataFrame(topics, index=features,
                          columns=['Topic %d' % x for x in range(n_topics)])
    sns.heatmap(topics, square=True, cmap='RdBu_r', ax=ax)
    #return(fig, ax)
from spatial_lda.visualization import plot_adjacency_graph

def make_plot_fn(difference_matrices):  
    def plot_fn(ax, tumor_idx, features_df, patient_dfs):
        plot_adjacency_graph(ax, tumor_idx, features_df, patient_dfs, difference_matrices)
    return plot_fn

def plot_topic_multicolor(ax, sample_idx, topic_weights, spleen_dfs):
    color_palette = qual_palettes.Bold_10.mpl_colors
    colors = color_palette[:topic_weights.shape[1]]
    d_color = dict(zip(sorted(set(np.argmax(np.array(topic_weights),axis=1))),colors))
    cell_coords = spleen_dfs[sample_idx]
    non_b_coords = cell_coords[~cell_coords.isb]
    ax.scatter(
        non_b_coords['sample.X'],
        non_b_coords['sample.Y'],
        s=1,
        c='k',
        marker='x',
        label='Non-tumor',
        alpha=.2)

    cell_indices = topic_weights.index#.map(lambda x: x[1])
    coords = spleen_dfs[sample_idx].loc[cell_indices]

    ax.scatter(coords['sample.X'], coords['sample.Y'], s=3,
               c=np.array(colors)[np.argmax(np.array(topic_weights), axis=1), :],
               #label=d_color
              )
    handles = [mpatches.Circle((1,1), radius=5, color=color) for color in d_color.values()] #rectangles
    #handles = [Line2D([], [], c=color, lw=0, marker="o", markersize=8, label=species) for species, color in d_color.items()] #dots
    ax.axes.get_yaxis().set_visible(False)
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.set_title(f'Sample {sample_idx}')
    ax.axis('equal')
    ax.set_ylim(ax.get_ylim()[::-1])
    ax.legend(handles, d_color.keys(),
                    bbox_to_anchor=(1.05,.5), title="Topics")
    #return(fig, ax)
    
def plot_results(adata, s_col, x='X_centroid',y='Y_centroid'):
    fig, ax = plt.subplots()
    df = adata.to_df()
    df[adata.obs.columns] = adata_lda.obs
    color_palette = qual_palettes.Bold_10.mpl_colors
    colors = color_palette[:len(df.loc[:,s_col].unique())]
    d_color = dict(zip(sorted(set(df.loc[:,s_col])),colors))
    if adata_lda.obs._get_numeric_data().columns.isin([s_col]).any():
        sp = ax.scatter(df[x], df[y], s=3,c=df.loc[:,s_col],
           cmap = 'plasma',
              )
        fig.colorbar(sp)
    else: 
        ax.scatter(df[x], df[y], s=3,
           c=df.loc[:,s_col].astype('str').map(d_color),
              )
        handles = [mpatches.Circle((1,1), radius=5, color=color) for color in d_color.values()]
        ax.legend(handles, d_color.keys(),
                    bbox_to_anchor=(1.05,.8), title=s_col)
    ax.axis('equal')
    ax.set_ylim(ax.get_ylim()[::-1])

    return(fig, ax)

def plot_one_tumor_all_topics(ax, tumor_idx, topic_weights, patient_dfs):
    color_palette = qual_palettes.Bold_10.mpl_colors
    colors = np.array(color_palette[:topic_weights.shape[1]])
    d_color = dict(zip(sorted(set(np.argmax(np.array(topic_weights),axis=1))),colors))
    cell_coords = patient_dfs[tumor_idx]
    immune_coords = cell_coords[cell_coords.isimmune]
    cell_indices = topic_weights.index.map(lambda x: x[1])
    coords = patient_dfs[tumor_idx].loc[cell_indices]
    ax.scatter(immune_coords['x'], -immune_coords['y'],
               s=5, c='k', label='__Immune', alpha=0.1)
    ax.scatter(coords['x'], -coords['y'], s=2,
               c=colors[np.argmax(np.array(topic_weights), axis=1), :])
    ax.set_title(f"Tumor {tumor_idx}")
    ax.axes.get_yaxis().set_visible(False)
    ax.axes.get_xaxis().set_visible(False)
    ax.legend()
    handles = [Line2D([], [], c=color, lw=0, marker="o", markersize=8, label=species) for species, color in d_color.items()] #dots
    ax.legend(handles, d_color.keys(),
                    bbox_to_anchor=(1.05,.9), title="Topics")

def plot_one_tumor_topic(ax, tumor_idx, topic_weights, patient_dfs):
    cell_coords = patient_dfs[tumor_idx]
    immune_coords = cell_coords[cell_coords.isimmune]
    cell_indices = topic_weights.index.map(lambda x: x[1])
    coords = patient_dfs[tumor_idx].loc[cell_indices]
    ax.scatter(immune_coords['x'], -immune_coords['y'],
               s=5, c='k', label='__Immune', alpha=0.1)
    ax.scatter(coords['x'], -coords['y'], s=2, c=topic_weights, cmap="plasma",
               label=topic_weights.name)
    ax.set_title(f"Tumor {tumor_idx}")
    ax.axes.get_yaxis().set_visible(False)
    ax.axes.get_xaxis().set_visible(False)
    ax.legend()
    
def plot_sm_topic(ax, s_slide, topic_weights, patient_dfs,topic_idx):
    cell_coords = patient_dfs[patient_dfs.imageid==s_slide]
    cell_indices = topic_weights[topic_weights.index.str.contains(s_slide)].index
    coords = cell_coords.loc[cell_indices]
    #ax.scatter(cell_coords['X_centroid'], -cell_coords['Y_centroid'],
    #           s=5, c='k', label='__Immune', alpha=0.1)
    ax.scatter(coords['X_centroid'], -coords['Y_centroid'], s=2, c=topic_weights.loc[cell_indices,topic_idx], cmap="plasma",
               label=topic_idx)
    ax.set_title(f"{s_slide}")
    ax.axes.get_yaxis().set_visible(False)
    ax.axes.get_xaxis().set_visible(False)
    ax.axis('equal')
    ax.legend()

In [None]:
pwd

## Load data and featurize

In [None]:
datadir = '/home/groups/graylab_share/Chin_Lab/ChinData/engje/Data/20200000/20200406_JP-TMAs'#'/home/groups/graylab_share/data/engje/Data/20200000/20200406_JP-TMAs'
#this data has been stadardized to 1 pixel = 1 um
df_cyc = pd.read_csv(f'{datadir}/data/20220420_JP-TMAs_IMC-TMAs_MIBI_CombinedCelltypes_all.csv',index_col = 0,dtype='object')

df_cyc['x'] = df_cyc.DAPI_X.astype('float64')
df_cyc['y'] = df_cyc.DAPI_Y.astype('float64')

df_cyc['sample.X'] = df_cyc.DAPI_X.astype('float64')
df_cyc['sample.Y'] = df_cyc.DAPI_Y.astype('float64')

df_cyc['isb'] = df_cyc.loc[:,'leidencelltype3']=='epithelial'
df_cyc['isimmune'] = df_cyc.loc[:,'leidencelltype3']!='epithelial'
#MIBI
df_cyc.loc[df_cyc.Platform=='MIBI','isb'] = df_cyc.loc[df_cyc.Platform=='MIBI','leidencelltype2']=='epithelial'
df_cyc.loc[df_cyc.Platform=='MIBI','isimmune'] = df_cyc.loc[df_cyc.Platform=='MIBI','leidencelltype2']!='epithelial'

df_cyc['cluster'] = df_cyc.loc[:,'leiden']
df_cyc['sample.Z'] = 1.0

In [None]:
#load subtype annotation
df_surv = pd.read_csv(f'{datadir}/annotation/BC-TMAs_clinical_data.csv',index_col=0)

df_cyc_sub = pd.read_csv(f'{datadir}/annotation/20210403_JP-TMA1_Annotation+Clinical-Subtype.csv',index_col=0)
df_cyc_sub.rename({'ClinicalSubtype':'ID'},axis=1,inplace=True)
#collapse Subtypes
d_replace = {'0':'other','?':'other', np.nan:'other'}
df_cyc_sub.ID.replace(d_replace,inplace=True)
df_cyc_sub = df_cyc_sub.append(df_surv[df_surv.index.str.contains('JP-TMA2-1')].rename({'subtype':'ID'},axis=1))
df_cyc_sub.loc[df_cyc_sub.index.str.contains('JP-TMA2-1'),'Accession'] = df_cyc_sub.loc[df_cyc_sub.index.str.contains('JP-TMA2-1')].index
#add subtype and patient
#d_cyc_sub = dict(zip(df_cyc_sub.index.tolist(),df_cyc_sub.ID.tolist()))
#d_patient = dict(zip(df_cyc_sub.index,df_cyc_sub.Accession))

In [None]:
df_surv.head()

## specify platform

In [None]:
for s_plat in ['IMC','cycIF']: #'cycIF','MIBI','IMC'
    ls_slide = sorted(df_cyc[df_cyc.Platform==s_plat].slide_scene.unique())
    ls_drop = df_cyc.columns[df_cyc[df_cyc.Platform==s_plat].isna().sum() > 0]
    print(len(ls_slide))
    ## load data
    PATH_TO_SPLEEN_DF_PKL = f'./data/{s_plat}_df.pkl'
    if os.path.exists(PATH_TO_SPLEEN_DF_PKL):
        print('Loading saved features.')
        with open(PATH_TO_SPLEEN_DF_PKL, 'rb') as f:
            spleen_dfs = pickle.load(f)
    else:
        spleen_dfs = {}
        for s_slide in ls_slide:
            df_scene = df_cyc.loc[(df_cyc.slide_scene==s_slide),~df_cyc.columns.isin(ls_drop)]
            if len(df_scene) > 0:
                if s_slide == '50_Ay11x6-202': #no tumor
                    print(s_slide)
                elif s_slide == '59_Cy8x3-472':
                    print(s_slide)
                elif s_slide == '61_X4Y4-215':
                    print(s_slide)
                elif s_slide == '187_X3Y2-99':
                    print(s_slide)
                elif s_slide == '241_X8Y3-337':
                    print(s_slide)
                elif s_slide == '19_Cy8x8-509':
                    print(s_slide)
                elif s_slide == 'JP-TMA1-1_scene103':
                    print(s_slide)
                elif df_scene.isb.sum()<4:
                    print(s_slide)
                else:
                    spleen_dfs.update({s_slide:df_scene})
        print(len(spleen_dfs.keys()))
        # save the pickle!
        with open(PATH_TO_SPLEEN_DF_PKL, 'wb') as f:
            pickle.dump(spleen_dfs, f)
    break

In [None]:
#pretty fast
PATH_TO_SPLEEN_FEATURES_PKL = f'./data/{s_plat}_cells_features.pkl'
print(PATH_TO_SPLEEN_FEATURES_PKL)
if os.path.exists(PATH_TO_SPLEEN_FEATURES_PKL):
    print('Loading saved features.')
    with open(PATH_TO_SPLEEN_FEATURES_PKL, 'rb') as f:
        spleen_cells_features = pickle.load(f)
else:
    print('Featurizing samples ...')

    spleen_cells_features = featurize_spleens(spleen_dfs, neighborhood_to_cluster, radius=100,
                                            n_processes=N_PARALLEL_PROCESSES)
    with open(PATH_TO_SPLEEN_FEATURES_PKL, 'wb') as f:
        pickle.dump(spleen_cells_features, f)

## Split training set and compute difference matrices


In [None]:
s_subtype ='ER+'#'TNBC'# 

if s_plat == 'cycIF':
    ls_index = df_cyc_sub[df_cyc_sub.ID==s_subtype].index
else:
    ls_index = df_cyc[df_cyc.subtype==s_subtype].slide_scene
ls_index_dict = ls_index[ls_index.isin(spleen_dfs.keys())].unique()
spleen_cells_features_sub = spleen_cells_features[spleen_cells_features.index.map(lambda x: x[0]).isin(ls_index_dict)]
if s_plat == 'IMC':
    try:
        spleen_cells_features_sub.drop('20',axis=1,inplace=True)
    except:
        print('no 20')

In [None]:
spleen_dfs_sub = {}
for key in ls_index_dict:
        spleen_dfs_sub.update({key:spleen_dfs[key]})

In [None]:
%%time
#pretty fast

spleen_difference_matrices = make_merged_difference_matrices(#spleen_cells_features, spleen_dfs,
                                                             spleen_cells_features_sub, spleen_dfs_sub,
                                                             'sample.X', 'sample.Y')
#all_sample_idxs = spleen_cells_features.index.map(lambda x: x[0])
all_sample_idxs = spleen_cells_features_sub.index.map(lambda x: x[0])
_sets = train_test_split(spleen_cells_features_sub, #spleen_cells_features, 
                         test_size=1. - TRAIN_SIZE_FRACTION,
                         stratify=all_sample_idxs)
train_spleen_cells_features, test_spleen_cells__features = _sets
train_difference_matrices = make_merged_difference_matrices(
    train_spleen_cells_features, spleen_dfs_sub, #spleen_dfs,
    'sample.X', 'sample.Y')
spleen_idxs = train_spleen_cells_features.index.map(lambda x: x[0])

## Visualize graphs used to tie neighboring environments

In [None]:
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
#fig.savefig('LDA_example.pdf')

In [None]:
## 
s_date = '20241112'
if s_plat == 'cycIF':
    #os.mkdir(s_date)
    ls_sllide =[ 'JP-TMA1-1_scene034','JP-TMA2-1_scene35']#'JP-TMA2-1_scene35',
    if s_subtype == 'ER+':
        ls_sllide =[ 'JP-TMA1-1_scene044','JP-TMA1-1_scene044']
    _plot_fn = make_plot_fn(spleen_difference_matrices)
    
    #plot_samples_in_a_row(spleen_cells_features, _plot_fn, spleen_dfs, tumor_set=ls_slide[0:2])
    plot_samples_in_a_row(spleen_cells_features_sub, _plot_fn, spleen_dfs_sub,tumor_set=ls_sllide)#ls_index_dict[0:5])
    fig=plt.gcf()
    plt.tight_layout()
    fig.savefig(f"{s_date}/graphs_LDA_{s_subtype}_{s_plat}.pdf")

# Spatial LDA results



## Parameter sweep

In [None]:
num_topics = 8
N_TOPICS_LIST = [8]

In [None]:
s_plat

In [None]:
s_subtype

In [None]:
#this is the slow step
from spatial_lda.model import order_topics_consistently
spatial_lda_models = {}  
difference_penalty = 0.25  
for n_topics in N_TOPICS_LIST:
    path_to_train_model = '_'.join((f'{PATH_TO_MODELS}{s_plat}_{s_subtype}_training',
                                  f'penalty={difference_penalty}',
                                  f'topics={n_topics}',
                                  f'trainfrac={TRAIN_SIZE_FRACTION}')) + '.pkl'
    print(path_to_train_model)
    if not os.path.exists(path_to_train_model):
        print(f'Running n_topics={n_topics}, d={difference_penalty}\n')
        spatial_lda_model = spatial_lda.model.train(sample_features=train_spleen_cells_features, 
                                                difference_matrices=train_difference_matrices,
                                                difference_penalty=difference_penalty,
                                                n_topics=n_topics,
                                                n_parallel_processes=N_PARALLEL_PROCESSES,                                                                         
                                                verbosity=1,
                                                admm_rho=0.1,
                                                primal_dual_mu=1e+5)
        spatial_lda_models[n_topics] = spatial_lda_model
        with open(path_to_train_model, 'wb') as f:
            pickle.dump(spatial_lda_model, f)    
    else:
        with open(path_to_train_model, 'rb') as f:
            spatial_lda_models[n_topics] = pickle.load(f)
      
order_topics_consistently(spatial_lda_models.values())     

## Load models with different number of topics and the same difference penalty

In [None]:
#can select more
lda_5 = spatial_lda_models[num_topics]
topic_weights_5 = lda_5.topic_weights

samples = ls_slide[0:5] #[s_slide,s_slide]

### Visualizing topics on regular cells

In [None]:
if s_plat == 'cycIF':
    s_slide= 'JP-TMA1-1_scene034'
    if s_subtype == 'ER+':
        s_slide ='JP-TMA1-1_scene044'
    num_topics= 8 
    fig, ax = plot_samples_in_a_row(spatial_lda_models[num_topics].topic_weights, plot_one_tumor_all_topics, spleen_dfs, tumor_set=[s_slide,s_slide])
    ax[0].legend(frameon=False)
    plt.tight_layout()
    fig.savefig(f"{s_date}/graphs_LDA_topics_{s_subtype}_{s_plat}.pdf")

In [None]:
if s_plat == 'cycIF':
    for t in [0,1,2,3,4,5,6]:
        print(t)
        fig, ax = plot_samples_in_a_row(spatial_lda_models[num_topics].topic_weights.iloc[:, t], plot_one_tumor_topic, spleen_dfs, tumor_set=[s_slide,s_slide])
        ax[0].legend(frameon=False)
        plt.tight_layout()
        fig.savefig(f"{s_date}/graphs_LDA_topic_{t}_{s_subtype}_{s_plat}.pdf")

## Clinical annotation

In [None]:
#Load Cyclic patients and subtypes
df_cyc_sub = pd.read_csv(f'{datadir}/annotation/20210403_JP-TMA1_Annotation+Clinical-Subtype.csv',index_col=0)
df_cyc_sub.rename({'ClinicalSubtype':'ID'},axis=1,inplace=True)
#collapse Subtypes
d_replace = {'0':'other','?':'other', np.nan:'other'}
df_cyc_sub.ID.replace(d_replace,inplace=True)
df_cyc_sub.loc[df_cyc_sub.index.str.contains('JP-TMA2-1'),'Accession'] = df_cyc_sub.loc[df_cyc_sub.index.str.contains('JP-TMA2-1'),'TMA_scene.1']
#add subtype and patient
d_cyc_sub = dict(zip(df_cyc_sub.index.tolist(),df_cyc_sub.ID.tolist()))
d_patient = dict(zip(df_cyc_sub.index,df_cyc_sub.Accession))
# JP-TMA2 subtypes

#survival
df_surv = pd.read_csv(f'{datadir}/data/cycIF_clinical_outcome.csv',index_col=0)
df_surv['Platform'] = 'cycIF'
df_surv.rename({'Recurence_time':'Recurrence_time'},axis=1,inplace=True)

df_cyc_a = pd.read_csv(f'{datadir}/annotation/JP-TMAs_Clinical_Variables.csv',index_col=0)

d_sub_both = dict(zip(df_cyc_sub.Accession,df_cyc_sub.ID))
#d_sub_both.update(d_cyc_sub2)
df_surv['subtype'] = df_surv.index.map(d_sub_both)

In [None]:
df_surv.head()
#df_cyc_a.head()

In [None]:
topic_weights_5

In [None]:
# save tissue means
if s_plat == 'cycIF':
    s_sample = 'JP-TMA1'
else:
    s_sample = 'IMC'
    if s_subtype == 'TNBC':
        spleen_cells_features.drop('20',axis=1,inplace=True)
s_grouper='Patient'
s_cell= 'all'
for num_topics in [8]:
    lda_5 = spatial_lda_models[num_topics]
    topic_weights_5 = lda_5.topic_weights
    topic_weights_5['Patient'] = topic_weights_5.index.map(lambda x: x[0]).map(d_patient)

    df_out = topic_weights_5.groupby(s_grouper).mean()
    s_out = f'results_{s_sample}_SpatialLDA_by{s_grouper}_by{s_cell}_k{num_topics}.csv'
    print(s_out)
    df_out.to_csv(f'{datadir}/{s_out}')
    break
topic_weights_5.drop('Patient',axis=1,inplace=True)

In [None]:
# plot heatmap
normalizer = lambda ar: (ar/np.sum(ar, axis=1, keepdims=True)).T

fig, ax = plt.subplots(dpi=300)
plot_topics_heatmap(lda_5.components_, spleen_cells_features.columns, ax, normalizer)
plt.tight_layout()
fig.savefig(f'{datadir}/{s_date}/heatmap_spatial_lda_k{num_topics}_{s_plat}_{s_subtype}.pdf')

# Scimap

not used

In [None]:
# for s_plat in ['cycIF','IMC','MIBI',]:
#     results_file = f'{datadir}/{s_plat}_spatial_lda.h5ad'
#     if not os.path.exists(results_file):
#         print(f'generating {s_plat}')
#         ls_slide = sorted(df_cyc[df_cyc.Platform==s_plat].slide_scene.unique())
#         ls_drop = df_cyc.columns[df_cyc[df_cyc.Platform==s_plat].isna().sum() > 0]
#         df_plat = df_cyc[df_cyc.Platform==s_plat]
#         print(len(ls_slide))
#         #create adata
#         ls_drop_splda = ['leiden', 'leidencelltype3', 'leidencelltype4','leidencelltype2',
#                'gatedcelltype3', 'gatedcelltype5', 'leidencelltype5', 'slide','core','corec',
#                'celltype', 'DAPI_X', 'DAPI_Y', 'slide_scene', 'subtype', 'Patient',
#                'Platform','sample.Y', 'sample.X','x', 'y', 'isb', 'isimmune', 'cluster', 'sample.Z']
#         adata = sc.AnnData(df_plat.drop(set(ls_drop).union(set(ls_drop_splda)),axis=1).astype('float64')) #
#         adata.obs['X_centroid'] = df_plat.loc[adata.obs.index,'DAPI_X'].astype('float64')#*.325 already corrected
#         adata.obs['Y_centroid'] = df_plat.loc[adata.obs.index,'DAPI_Y'].astype('float64')#*.325
#         adata.obs['phenotype'] = df_plat.loc[adata.obs.index,'leiden'].replace({'CD4 T cell': 'CD3 T cell','CD8 T cell': 'CD3 T cell'})
#         adata.obs['imageid'] = df_plat.loc[adata.obs.index,'slide_scene']
#         adata.write_h5ad(results_file)
#         sm.tl.spatial_aggregate(adata, x_coordinate='X_centroid', y_coordinate='Y_centroid', purity=50,
#                               phenotype='phenotype', method='radius', radius=30, knn=10, imageid='imageid', subset=None, label='spatial_aggregate')
#         pd.DataFrame(adata.obs['spatial_aggregate']).to_csv(f'{datadir}/{s_plat}_spatial_aggregate_scimap.csv')
        
#         break
#         adata_lda = sm.tl.spatial_lda (adata[adata.obs.imageid.isin(ls_slide)], x_coordinate='X_centroid', y_coordinate='Y_centroid',
#                                phenotype='phenotype', method='radius', radius=30, knn=10, 
#                                imageid='imageid', num_motifs=10, random_state=0, subset=None, label='spatial_lda')
#         adata_lda.write_h5ad(results_file)
#         fig, ax = plt.subplots(figsize=(8,8),dpi=200)
#         sns.heatmap(adata_lda.uns['spatial_lda_probability'],cmap='RdBu_r',ax=ax)
#         plt.tight_layout()
#         fig.savefig(f'{datadir}/{s_date}/heatmap_{s_plat}_sm_spatial_lda_k{len(adata_lda.uns["spatial_lda_probability"])}.png')
#     else:
#         print(f'loading {s_plat}')
#         adata = sc.read_h5ad(results_file)
#         sm.tl.spatial_aggregate(adata, x_coordinate='X_centroid', y_coordinate='Y_centroid', purity=50,
#                               phenotype='phenotype', method='radius', radius=30, knn=10, imageid='imageid', subset=None, label='spatial_aggregate')
#         pd.DataFrame(adata.obs['spatial_aggregate']).to_csv(f'{datadir}/{s_plat}_spatial_aggregate_scimap.csv')
      
        

In [None]:
#df_plat.drop(set(ls_drop).union(set(ls_drop_splda)),axis=1)

In [None]:
# for idx in [1,3,5,6,7,8]:
#     fig, ax = plt.subplots(figsize=(4,4),dpi=200)
#     plot_sm_topic(ax, s_slide, topic_weights=adata_lda.uns['spatial_lda'],patient_dfs=adata_lda.obs,topic_idx=f'Topic-{idx}')
#     #break

In [None]:
# adata_lda = sm.tl.spatial_cluster(adata_lda, df_name='spatial_lda', method='kmeans', k=10,
#                    n_pcs=None, resolution=1, phenograph_clustering_metric='euclidean',
#                    nearest_neighbors=30, random_state=0, label='spatial_kmeans_lda', output_dir=None)

# adata_lda.uns['spatial_lda'].columns = [item.replace('Motif_','Topic-') for item in adata_lda.uns['spatial_lda'].columns]

In [None]:
# df = adata_lda.to_df()
# df[adata_lda.obs.columns] = adata_lda.obs
# df_plot = df.loc[:,['spatial_kmeans_lda','phenotype']].groupby('spatial_kmeans_lda').value_counts(normalize=True).unstack().fillna(0)
# fig, ax = plt.subplots(figsize=(8,8),dpi=200)
# plot_topics_heatmap(np.array(df_plot), df_plot.columns,ax,normalizer)
# ax.set_xticklabels(df_plot.index)
# plt.tight_layout()
# fig.savefig(f'{datadir}/{s_date}/heatmap_spatial_kmeans_lda_k{len(df_plot)}.png')

In [None]:

# #mean topics per tissue
# s_sample = 'JP-TMA1'
# s_grouper='Patient'
# s_cell= 'all'
# for num_topics in [10]:
#     topic_weights_5 = adata_lda.uns['spatial_lda']
#     topic_weights_5['Patient'] = topic_weights_5.index.map(lambda x: x.split('_cell')[0]).map(d_patient)
#     df_out = topic_weights_5.groupby(s_grouper).mean()
#     s_out = f'results_{s_sample}_smSpatialLDA_by{s_grouper}_by{s_cell}_k{num_topics}.csv'
#     print(s_out)
#     df_out.to_csv(f'{datadir}/{s_out}')

In [None]:
# # mean topic clusters per tissue

# adata_lda.obs['Patient'] = adata_lda.obs.imageid.map(d_patient)
# k = adata_lda.obs['spatial_kmeans_lda'].nunique()
# s_sample = 'JP-TMA1'
# s_grouper='Patient'
# s_cell= 'all'

# df_out = (adata_lda.obs.groupby(['spatial_kmeans_lda','Patient']).count().X_centroid/adata_lda.obs.groupby(['Patient']).count().X_centroid).unstack().fillna(0).T
# s_out = f'results_{s_sample}_KmeansLDA_by{s_grouper}_by{s_cell}_k{k}.csv'
# print(s_out)
# df_out.to_csv(f'{datadir}/{s_out}')

## the plot

In [None]:
# for s_slide in ls_slide[0:10]:
#     s_slide = 'JP-TMA1-1_scene034'
#     plot_results(adata_lda[adata_lda.obs.imageid==s_slide], s_col='spatial_kmeans_lda', x='X_centroid',y='Y_centroid')
#     break
# fig, ax = plt.subplots()
# plot_topics_heatmap(np.array(df_plot), df_plot.columns, ax=ax, normalizer=normalizer)

In [None]:
# num_topics = 8
# print(num_topics)
# lda_5 = spatial_lda_models[num_topics]
# topic_weights_5 = lda_5.topic_weights

# for s_slide in spleen_dfs.keys():
#     fig, ax = plt.subplots(1,2,figsize=(8,4))
#     ax=ax.ravel()
#     plot_bcell_topic_multicolor(ax[0], sample_idx=s_slide, topic_weights=topic_weights_5, spleen_dfs=spleen_dfs)
#     plot_topics_heatmap(lda_5.components_, spleen_cells_features.columns, ax[1], normalizer)
#     plt.tight_layout()
    

# survival analysis

In [None]:
## find best cutpoint 
import lifelines
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.statistics import multivariate_logrank_test
from lifelines import exceptions
import warnings
warnings.filterwarnings("ignore",category = exceptions.ApproximationWarning)

def single_km(df_all,s_cell,s_subtype,s_plat,s_col,savedir,alpha=0.05,cutp=0.5,s_time='Survival_time',s_censor='Survival'):
    df_all.index = df_all.index.astype('str')
    df = df_all[(df_all.Platform==s_plat) & (df_all.subtype==s_subtype)].copy()
    df = df.loc[:,[s_col,s_time,s_censor]].dropna()
    if len(df) > 0:
        #KM
        i_cut = np.quantile(df.loc[:,s_col],cutp)
        b_low = df.loc[:,s_col] <= i_cut
        s_title1 = f'{s_subtype} {s_plat}'
        s_title2 = f'{s_col} in {s_cell} cells n={len(df)}'
        if i_cut == 0:
            b_low = df.loc[:,s_col] <= 0
        df.loc[b_low,'abundance'] = 'low'
        df.loc[~b_low,'abundance'] = 'high'
        #log rank
        results = multivariate_logrank_test(event_durations=df.loc[:,s_time],
                                            groups=df.abundance, event_observed=df.loc[:,s_censor])
        #kaplan meier plotting
        if results.summary.p[0] < alpha:
            kmf = KaplanMeierFitter()
            fig, ax = plt.subplots(figsize=(3,3),dpi=300)
            for s_group in ['high','low']:
                df_abun = df[df.abundance==s_group]
                durations = df_abun.loc[:,s_time]
                event_observed = df_abun.loc[:,s_censor]
                try:
                    kmf.fit(durations, event_observed,label=s_group)
                    kmf.plot(ax=ax,ci_show=False,show_censors=True)
                except:
                    results.summary.p[0] = 1
            ax.set_title(f'{s_title1}\n{s_title2}\np={results.summary.p[0]:.2}',fontsize=10)
            ax.set_xlabel(s_censor)
            ax.legend(loc='upper right',title=f'{cutp}({i_cut:.2})')
            plt.tight_layout()
            fig.savefig(f"{savedir}/KM_{s_title1.replace(' ','_')}_{s_title2.replace(' ','_')}_{cutp}_{s_censor}.png",dpi=300)
        return(df)

In [None]:
#example file names 
s_sample = 'JP-TMA1'#'20220413_JP-TMAs_IMC-TMAs_MIBI'#'20220411_JP-TMAs_IMC-TMAs_MIBI' # '20220409_JP-TMAs_IMC-TMAs'
df_file = pd.DataFrame()
for s_file in os.listdir(datadir):
    if s_file.find(f'results_{s_sample}') > -1:
        s_type = s_file.split('.csv')[0].split('_')[2]
        s_subtype = s_file.split('.csv')[0].split('_')[-1]
        s_partition = 'spatial'
        s_cell = s_file.split('.csv')[0].split('_')[-2].split('by')[1]
    else:
        continue
    df_file.loc[s_file,'subtype'] = s_subtype
    df_file.loc[s_file,'type'] = s_type
    df_file.loc[s_file,'partition'] = s_partition
    df_file.loc[s_file,'cell'] = s_cell
    #break
#df_file.to_csv(f'{s_sample}_results_files.csv')

In [None]:
%matplotlib inline
ls_file = [#'results_JP-TMA1_KmeansLDA_byPatient_byall_k10.csv',
       #'results_JP-TMA1_smSpatialLDA_byPatient_byall_k10.csv',
       'results_JP-TMA1_SpatialLDA_byPatient_byall_k6.csv',
       #'results_JP-TMA1_SpatialLDA_byPatient_byall_k8.csv'
          ]

In [None]:
df_file

In [None]:
alpha = 0.05
s_date = '20220808'
df_surv.index = df_surv.index.astype('str')
savedir = f'{datadir}/{s_date}'
s_plat = 'cycIF'


for s_index in ls_file: #df_file.index:
    s_type = df_file.loc[s_index,'type']
    s_cell = df_file.loc[s_index,'type']
    df_all=pd.read_csv(f'{datadir}/{s_index}',index_col=0)
    df_all.index = df_all.index.astype('str')
    ls_marker = df_all.columns
    df_all = df_all.merge(df_surv,left_index=True,right_index=True)  
    for s_subtype in ['ER+','TNBC']: #,'HER2+'
        #for s_plat in ['cycIF']:
        for s_time, s_censor in [('Survival_time','Survival'),('Recurrence_time','Recurrence')]:
            for s_col in ls_marker:
                for cutp in [.5,.66,0.33]:
                    single_km(df_all,s_cell,s_subtype,s_plat,s_col,savedir,alpha,cutp,s_time,s_censor)
                    #break
                #break
            #break
        #break
    break

## Scimap testing (skipped output results)

USE:
- spatial_count: radius from center cell followed by spatial cluster

- spatial aggregate; very straight forward. how does getis-ord compare? 

FUTURE:

- spatial_interaction: permutation test for significant proximity: slow

- spatial_expression: product  of  the  expression  matrix  and  a  weighted  proximity  matrix NOT working well

- spatial_pscore: cell cell proximity between cells types of interest




- spatial_similarity_search: could be useful in the future



In [None]:
adata_lda = sm.tl.spatial_aggregate(adata_lda, x_coordinate='X_centroid', y_coordinate='Y_centroid', purity=50,
                              phenotype='phenotype', method='radius', radius=30, knn=10, imageid='imageid', subset=None, label='spatial_aggregate')


In [None]:
adata_lda.obs['spatial_aggregate']= adata_lda.obs['spatial_aggregate'].fillna('non-significant')

In [None]:
for s_slide in ls_slide[0:10]:
    fig, ax =plot_results(adata_lda[adata_lda.obs.imageid==s_slide], s_col='spatial_aggregate', x='X_centroid',y='Y_centroid')

In [None]:
sm.tl.spatial_count(adata_lda, x_coordinate='X_centroid', y_coordinate='Y_centroid', 
              phenotype='phenotype', method='radius', radius=30, knn=10, imageid='imageid', subset=None, label='spatial_count')

In [None]:
sm.tl.spatial_cluster(adata_lda, df_name='spatial_count', method='kmeans', k=10, n_pcs=None, 
                resolution=1, phenograph_clustering_metric='euclidean', nearest_neighbors=30, random_state=0, label='spatial_count_kmeans', output_dir=None)

In [None]:
for s_slide in ls_slide[0:10]:
    fig, ax =plot_results(adata_lda[adata_lda.obs.imageid==s_slide], s_col='spatial_count_kmeans', x='X_centroid',y='Y_centroid')

In [None]:
sm.tl.spatial_distance(adata_lda, x_coordinate='X_centroid', y_coordinate='Y_centroid',
                 z_coordinate=None, phenotype='phenotype', subset=None, imageid='imageid', label='spatial_distance')

In [None]:
adata_lda.obs[adata_lda.uns['spatial_distance'].columns] = adata_lda.uns['spatial_distance']

In [None]:
s_col='CD8 T cell'
fig, ax = plot_results(adata_lda, s_col='CD8 T cell',  x='X_centroid',y='Y_centroid')
ax.set_title(f'distance to {s_col}')

In [None]:
#something now working with clustering?
adata_lda.raw = adata_lda
sm.tl.spatial_expression(adata_lda, x_coordinate='X_centroid', y_coordinate='Y_centroid', method='radius', 
                   radius=30, knn=10, imageid='imageid', use_raw=True, log=True, subset=s_slide, label='spatial_expression', output_dir=None)

In [None]:
sm.tl.spatial_cluster(adata_lda, df_name='spatial_expression', method='kmeans', k=10, n_pcs=None, resolution=1,
                phenograph_clustering_metric='euclidean', nearest_neighbors=30, random_state=0, label='spatial_kmeans_exp', output_dir=None)

In [None]:
for s_slide in ls_slide[0:10]:
    fig, ax =plot_results(adata_lda[adata_lda.obs.imageid==s_slide], s_col='spatial_kmeans_exp', x='X_centroid',y='Y_centroid')


In [None]:
s_col = 'spatial_kmeans_exp'
df = adata_lda.to_df()
df[adata_lda.obs.columns] = adata_lda.obs
df_plot = df.loc[:,[s_col]+adata_lda.var.index.tolist()].groupby(s_col).mean().fillna(0)
fig, ax = plt.subplots(figsize=(12,12))
plot_topics_heatmap(np.array(df_plot), df_plot.columns,ax,normalizer)
plt.tight_layout()

In [None]:
sm.pp.rescale(adata_lda)
sm.tl.spatial_expression(adata_lda, x_coordinate='X_centroid', y_coordinate='Y_centroid', method='radius', 
                   radius=30, knn=10, imageid='imageid', use_raw=False, log=True, subset=s_slide, label='spatial_expression', output_dir=None)
sm.tl.spatial_cluster(adata_lda, df_name='spatial_expression', method='kmeans', k=10, n_pcs=None, resolution=1,
                phenograph_clustering_metric='euclidean', nearest_neighbors=30, random_state=0, label='spatial_kmeans_rexp', output_dir=None)

In [None]:
s_col = 'spatial_kmeans_exp'
df = adata_lda.to_df()
df[adata_lda.obs.columns] = adata_lda.obs
df_plot = df.loc[:,[s_col]+adata_lda.var.index.tolist()].groupby(s_col).mean().fillna(0)
fig, ax = plt.subplots(figsize=(12,12))
plot_topics_heatmap(np.array(df_plot), df_plot.columns,ax,normalizer)
plt.tight_layout()

In [None]:
for s_slide in ls_slide[0:10]:
    fig, ax =plot_results(adata_lda[adata_lda.obs.imageid==s_slide], s_col='spatial_kmeans_rexp', x='X_centroid',y='Y_centroid')

In [None]:
#takes forever, but interesting
sm.tl.spatial_interaction(adata_lda, x_coordinate='X_centroid', y_coordinate='Y_centroid', z_coordinate=None, phenotype='phenotype',
                    method='radius', radius=30, knn=10, permutation=1000, imageid='imageid', subset=None, pval_method='zscore', label='spatial_interaction')

In [None]:
map_data = sm.pl.spatial_interaction(adata_lda, spatial_interaction='spatial_interaction', summarize_plot=True, p_val=0.05, row_cluster=True, col_cluster=True, cmap='vlag', 
                    nonsig_color='grey', subset_phenotype=None, subset_neighbour_phenotype=None, binary_view=False, return_data=True)

In [None]:
sm.tl.spatial_pscore(adata, proximity=['Prolif. t.','CD8 T cell'], score_by='imageid', x_coordinate='X_centroid', y_coordinate='Y_centroid',
               phenotype='phenotype', method='radius', radius=20, knn=3, imageid='imageid', subset=None, label='spatial_pscore')

In [None]:
adata.uns['spatial_pscore']

In [None]:
sm.pl.spatial_pscore(adata, label='spatial_pscore', plot_score='Proximity Density', order_xaxis=None, color='grey')

In [None]:
#save it!
results_file = f'scimap_test_{s_slide}.h5ad'
if not os.path.exists(results_file):
    adata_lda.write(results_file)

In [None]:
results_file = f'scimap_test_{s_slide}.h5ad'
if os.path.exists(results_file):
                print('loading data')
                adata = sc.read_h5ad(results_file)