Initial Setup

In [None]:
# neighborhood analysis based on Katie's code since scimap seems to be weird
import anndata as ad
import pandas as pd
from scipy import spatial
from sklearn.cluster import MiniBatchKMeans
from matplotlib import pyplot as plt
import plotly.io as pio
import plotly.graph_objects as go
import scimap as sm
import os

# running on linux requires the file path to just be the file name, assuming that user has put the file in the same directory as this program

anndata_path = "Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_analysis/IM_full_data_rmDup.h5ad" # r"Z:\Multiplex_IHC_studies\Summer_Interns\2024\spatial\JT_TMA_spatial_analysis\DP20_full_data.h5ad" # r"Z:\Multiplex_IHC_studies\Summer_Interns\2024\spatial\scimapTest\TMA27ExNuclei.h5ad"
cmap_path = "Multiplex_IHC_studies/Eric_Berens/HuBrca_TMA_mIHC/DataAnalysis/ImmMimicry_ColorCodes.xlsx"

adata = ad.read_h5ad(anndata_path)
# update adata with the grade fix
adata.obs.loc[adata.obs['Subject_ID'] == 'ST-00021143', 'Grade'] = 3

Functions for Neighborhood Analysis

In [None]:
def relabel(data):
    '''
    re label immune-mimicked and non immune mimicked neoplastic cells
    image: slice of a dataframe for a given image
    '''
    # print(image)
    data['phenotype'] = data['phenotype'].cat.add_categories(["", "PanCK+ NCD"])
    data.loc[(data.phenotype == 'PanCK+') & (data.CC1_func == 0) & (data.GASDERMIN_func == 0) & (data.CC3_func == 0) & (data.CC8_func == 0) & (data.H2AX_func == 0) & (data.CD71_func == 0), 'phenotype'] = 'PanCK+ NCD'
    # all remaining PanCK+ cells should be PanCK+ CD since they have at least 1 cell death marker
    data.loc[(data.phenotype == 'PanCK+'), 'phenotype'] = 'PanCK+ CD'
    return data


def create_neighborhoods(tma, seeds, threshold, phenotypes):
    '''
    Purpose: count the neighbor cell types within the distance threshold, with the seed cell as the center
    '''
    # remove other cells from analysis
    # filtered_tma = tma[tma['phenotype'] != 'Other Immune cells']
    filtered_tma = tma
   
    all_locations = filtered_tma[['X_centroid', 'Y_centroid']].values
    kdtree = spatial.KDTree(all_locations)
    all_neighbors = []

    for i in range(len(all_locations)):

        cell_class = filtered_tma['phenotype'].values[i]
        neighbor_counts = {}
        # if cell_class == seed:
        if cell_class in seeds:
            # same phenotype as the seed cell
            neighbors = kdtree.query_ball_point(all_locations[i], threshold)
            neighbors.remove(i)
            neighbor_classes = [filtered_tma["phenotype"].values[j] for j in neighbors]
            prolif_tumor = 0
            non_prolif_tumor = 0
            for k in neighbors:
                # check if it is a tumor neighbor
                if filtered_tma['phenotype'].values[k] == 'nonIMNeoplastic cells':
                    # check if it is proliferating
                    if filtered_tma['KI67_func'].values[k] == 1:
                        prolif_tumor += 1
                    else:
                        non_prolif_tumor += 1
            total_counts = []
            for c in phenotypes:
                n_count = neighbor_classes.count(c)
                neighbor_counts[c] = n_count
                total_counts.append(n_count)
            count_num = sum(total_counts)
            n = 0
            for c in phenotypes:
                if count_num != 0:
                    neighbor_counts["percent_" + c] = total_counts[n] / count_num
                else:
                    neighbor_counts["percent_" + c] = 0
                n += 1
            # columns from anndata to add to the neighborhood df and csv file
            neighbor_counts['file'] = tma["imageid"].values[0]
            neighbor_counts['Subject_ID'] = tma['Subject_ID'].values[i]
            neighbor_counts['CellID'] = filtered_tma['CellID'].values[i]
            neighbor_counts['index'] = i
            neighbor_counts['timepoint'] = tma['timepoint'].values[0]
            neighbor_counts['Tx'] = tma['Tx'].values[0]
            neighbor_counts['Grade'] = tma['Grade'].values[0]
            neighbor_counts['Stage'] = tma['Stage'].values[0]
            neighbor_counts['ER'] = tma['ER'].values[0]
            neighbor_counts['KI67p_neighbors'] = prolif_tumor
            neighbor_counts['KI67n_neighbors'] = non_prolif_tumor
            neighbor_counts['seed'] = cell_class

            all_neighbors.append(neighbor_counts)

    if len(all_neighbors) != 0:
        keys = list(all_neighbors[0].keys())  
    else:
        # print("no neighborhood for seed", seed, tma_region["imageid"].values[0])
        keys = []
    neigh_df = pd.DataFrame(all_neighbors, columns=keys)
    neigh_df.fillna(0)
    return neigh_df

def remove_control_stains():
    # remove controls from the samples we are looking at
    filter_controls = adata.obs[adata.obs['Subject_ID'] != 'Tonsil']
    filter_controls = filter_controls[filter_controls['Subject_ID'] != "Spleen"]
    return filter_controls

def elbow_method(neigh_df, num_cell_types, clusters, image_path):
    # eliminate any neighborhoods that have no neighbor cells
    neigh_df['sum'] = neigh_df.iloc[:,:num_cell_types].sum(axis=1)
    neigh_df = neigh_df[neigh_df['sum'] != 0]
    percent_neigh_df = neigh_df.iloc[:, num_cell_types:num_cell_types+num_cell_types]
    percent_data = list(percent_neigh_df[percent_neigh_df.columns].values)
    kmeans_error = []

    for k in range(1, clusters):
        kmeans = MiniBatchKMeans(n_clusters=k, init="k-means++", max_iter=300, n_init=10, random_state=0)
        kmeans.fit(percent_data)
        kmeans_error.append(kmeans.inertia_)
    
    plt.plot(range(1, clusters), kmeans_error, 'gs-')
    plt.title("Elbow Method Graph for Determining # of Clusters")
    plt.xlabel("# of clusters")
    plt.ylabel("WCSS")
    plt.savefig(image_path)
    plt.close()

def cluster_neighborhoods(neigh_df, num_cell_types, add_cols, k, csv_path):
    # cluster based on neighborhoods created in create_neighborhoods()
    neigh_df['sum'] = neigh_df.iloc[:,:num_cell_types].sum(axis=1)
    neigh_df = neigh_df[neigh_df['sum'] != 0]
    percent_neigh_df = neigh_df.iloc[:, num_cell_types:num_cell_types+num_cell_types]
    percent_data = list(percent_neigh_df[percent_neigh_df.columns].values)

    kmeans = MiniBatchKMeans(n_clusters=k, init='k-means++', max_iter=300, n_init=10, random_state=0) # not sure if these parameters should be changed here
    predictions = kmeans.fit_predict(percent_data)

    percent_neigh_df['cluster'] = predictions
    percent_neigh_df['file'] = neigh_df['file']
    percent_neigh_df['CellID'] = neigh_df['CellID']
    percent_neigh_df['index'] = neigh_df['index']
    # percent_neigh_df['KI67_ratio'] = neigh_df['KI67_ratio']
    percent_neigh_df['seed'] = neigh_df['seed']
    # percent_neigh_df['Subtype'] = neigh_df['Subtype']
    for add_col in add_cols:
        percent_neigh_df[add_col] = neigh_df[add_col]
    percent_neigh_df.to_csv(csv_path, index=False)
    neigh_df['cluster'] = predictions
    return percent_neigh_df, neigh_df

def read_color_map(path):
    cmap_df = pd.read_excel(path)
    cmap_df.pop("Color")
    cmap_dict = pd.Series(cmap_df.Hex.values, index=cmap_df.Population).to_dict()
    return cmap_dict

def stacked_bar_plot(df, x_axis, image_name, x_axis_title, y_axis_title, cmap, title):
    traces = []
    layout = {'xaxis':{'title':x_axis_title}, 'yaxis':{'title':y_axis_title}, 'barmode':'stack', 'height': 800, 'title':title}
    for i in range(len(df)):
        # get cell class name
        cell_class = df.index[i].split("_")[-1]
        # if cell_class == "CD" or cell_class == "NCD":
        #     cell_class = "PanCK+" + "_" + cell_class
        trace = {'type':'bar', 'x': x_axis, 'y': list(df.iloc[i, 0:]), 'name': cell_class, 'marker':{'color':cmap[cell_class]}}
        traces.append(trace)
    fig = {'data': traces, 'layout': layout}
    pio.show(fig, renderer='notebook')
    pio.write_image(fig, image_name)

def plot_cluster(cluster_df, num_cell_types, image_name, k, cmap, title):
    '''
    to plot the composition of the clustered cellular neighborhoods
    '''
    cluster_df.drop(['file'], axis=1)
    cluster_df.drop(['index'], axis=1)
    cluster_df.drop(['CellID'], axis=1)
    cluster_df_filtered = cluster_df.iloc[:,:num_cell_types]

    cluster_df_filtered['cluster'] = cluster_df['cluster'].values
    cluster_df_filtered = cluster_df_filtered.groupby(['cluster']).mean()

    cluster_df_filtered = cluster_df_filtered.T
    col_dict = {}
    for j in range(k):
        col_dict[j] = str(j+1)
    cluster_df_filtered = cluster_df_filtered.rename(columns=col_dict)
    # print(cluster_df_filtered)
    stacked_bar_plot(cluster_df_filtered, list(cluster_df_filtered.columns), image_name, "Cluster", "Fraction Present", cmap, title)

# find out if clusters are distrbuted across TMAs
# look for whether there is bias where 1 neighborhood cluster comes from just 1/a few TMAs, or distributed

def get_cluster_vals(cluster_df, k):
    images = cluster_df['file']
    cluster_vals = [[] for _ in range(k)]
    for i in range(len(images)):
        for j in range(k):
            if cluster_df.iloc[i]['cluster'] == j:
                cluster_vals[j].append(1) # count
            else:
                cluster_vals[j].append(0)
    return cluster_vals

def get_sort_indices(indices, order):
    lst = []
    order_idx = 0
    while order_idx < len(order):
        for i in indices:
            if order[order_idx] in i:
                # for NT_ST make sure not to take LMNT_ST
                if order[order_idx] == "NT_ST":
                    if "LMNT_ST" not in i:
                        lst.append(i)
                # for LumB make sure not to take LumBH
                elif order[order_idx] == "LumB":
                    if "LumBH" not in i:
                        lst.append(i)
                else:
                    lst.append(i)
        order_idx += 1
    return lst

def cluster_dist_stacked_bar_plot(cluster_df, x_axis_val, k, image_name, image_width=None, image_height=None, sort_by=None, percent=False, groupby=None):
    cluster_vals = get_cluster_vals(cluster_df, k)
    data_dict = {}
    if groupby != None:
        data_dict[groupby] = cluster_df[groupby]
    else:
        data_dict[x_axis_val] = cluster_df[x_axis_val]
    for cluster in range(k):
        data_dict['cluster ' + str(cluster + 1)] = cluster_vals[cluster]
    
    cluster_images = pd.DataFrame(data=data_dict)
    if groupby == None:
        cluster_df_filtered = cluster_images.groupby([x_axis_val]).sum()
    else:
        cluster_df_filtered = cluster_images.groupby([groupby]).sum()
        files = list(cluster_df['file'].unique())
        subtypes = []
        for i in range(len(cluster_df)):
            if cluster_df.iloc[i]['file'] in files:
                # add this image's subtype to the list
                roi = cluster_df.iloc[i]['file'].split("_")[-1]
                subtypes.append(cluster_df.iloc[i]['Subtype'] + "_" + roi)
                files.remove(cluster_df.iloc[i]['file'])
        cluster_df_filtered['Subtype'] = subtypes
        cluster_df_filtered = cluster_df_filtered.set_index('Subtype')

    if percent == True:
        # want the y axis to show percentage rather than just count
        row_sums = cluster_df_filtered.sum(axis=1)
        for i in range(len(cluster_df_filtered)):
            cluster_df_filtered.iloc[i] = cluster_df_filtered.iloc[i].div(row_sums.iloc[i])

    if sort_by != None:
        if sort_by == "Subtype":
            indices = cluster_df_filtered.index
            order = ["NB", "ILC", "LumA", "LumB", "LumBH", "Her2", "TN"]
            sorted_indices = get_sort_indices(indices, order)
            cluster_df_filtered = cluster_df_filtered.reindex(sorted_indices)
        elif sort_by == "Tx_Timepoint":
            indices = cluster_df_filtered.index
            # NT - LMNT - PTX - Ent - 2x - 3xNR - 3xR - 4x - 4xAST 3xNR - 3xR - 4x - bpost - bpre
            order = ["NT_ST", "LMNT_ST", "PTX_ST", "Ent_ST", "2x_ST", "3xNR_ST", "3xR_ST", "4x_ST", "4xAST_ST", "3xNR_LT", "3xR_LT", "4x_LT", "Bpost_LT", "Bpre_LT"]
            sorted_indices = get_sort_indices(indices, order)
            cluster_df_filtered = cluster_df_filtered.reindex(sorted_indices)
        else:   
            cluster_df_filtered = cluster_df_filtered.sort_values(by=[sort_by], ascending=False)


    cluster_df_filtered = cluster_df_filtered.T
    # print(cluster_df_filtered)
    # before transpose format of df
    #           cluster 1  2  3  4  5
    # image x
    # image y
    # image z
    x_axis = list(cluster_df_filtered.columns)

    data = [] # make list of go.Bars for the graph
    for i in range(len(cluster_df_filtered)):
        bar = go.Bar(name='cluster ' + str(i + 1), x=x_axis, y=cluster_df_filtered.iloc[i, 0:])
        data.append(bar)
    if percent == False:
        y_axis_title = "count"
    else:
        y_axis_title = "percent"
    x_axis_title = x_axis_val
    layout = {'xaxis':{'title':x_axis_title}, 'yaxis':{'title':y_axis_title}}
    fig = go.Figure(data=data, layout=layout)
    fig.update_layout(barmode='stack')
    fig.show(renderer='notebook')
    fig.write_image(image_name, width=image_width, height=image_height)

def investigate_cluster_distribution(cluster_df, k):
    '''
    plots cluster distributions as desired
    '''

    # stacked bars for each ROI, group by subtype
    # do x axis as image id, then instead of cell type, stack by cluster classification
    # cluster_dist_stacked_bar_plot(cluster_df, "file", k, "Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/IM_graphs/IM2/IM_cluster_distribution.png", 2000)

    # # plot all clusters but sorted by lowest to highest cluster 2, with y axis as percentage of each cluster
    cluster_dist_stacked_bar_plot(cluster_df=cluster_df, x_axis_val="Subject_ID", k=k, image_name="Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/IM_graphs/FINAL_IM/sorted_cluster_distribution.png", image_width=1000, sort_by="cluster 4", percent=True)
    # # x axis as subtype, and then cluster as y
    # grade3 = cluster_df.loc[cluster_df['Grade'] == 3]
    # cluster_dist_stacked_bar_plot(cluster_df=cluster_df, x_axis_val="timepoint", k=k, image_name="Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/IM_graphs/IM2/IM_grouped_tx_cluster_distribution.png", image_height=600)
    # # ER status
    # cluster_dist_stacked_bar_plot(cluster_df=cluster_df, x_axis_val="ER", k=k, image_name="Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/IM_graphs/IM2/IM_grouped_er_cluster_distribution.png", image_height=600, percent=True)
    # cluster_dist_stacked_bar_plot(cluster_df=cluster_df, x_axis_val="Subject_ID", k=k, image_name="Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/IM_graphs/IM2/IM_subject_cluster_distribution.png", image_width=2000)

    # cluster_dist_stacked_bar_plot(cluster_df=cluster_df, x_axis_val="Grade", k=k, image_name="Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/IM_graphs/IM2/IM_grade_cluster_distribution.png", image_height=600)

    # cluster_dist_stacked_bar_plot(cluster_df=cluster_df, x_axis_val="Tx_Timepoint", k=k, image_name="Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/DP20_graphs/DP_sorted_tx_cluster_distribution.png", image_width=1000, image_height=600, sort_by="Tx_Timepoint", percent=False)

def find_cell_cluster(source_row, cluster_df):
    # gets cluster for seed cells
    name = source_row.name
    name_parts = name.split("_")
    file = name_parts[0] + "_" + name_parts[1]
    cellid = name_parts[-1]
    source_location = cluster_df[cluster_df["file"] == file]
    found_row = source_location[source_location['CellID'] == int(cellid)]
    try:
        cluster = int(found_row['cluster'])
    except TypeError:
        cluster = None
    return cluster

def spatial_scatterPlot (adata, 
                         colorBy, 
                         topLayer=None,
                         x_coordinate='X_centroid',
                         y_coordinate='Y_centroid',
                         imageid='imageid',
                         layer=None,
                         subset=None,
                         s=None,
                         ncols=None,
                         alpha=1,
                         dpi=200,
                         fontsize=None,
                         plotLegend=True,
                         cmap='RdBu_r',
                         catCmap='tab20',
                         vmin=None,
                         vmax=None,
                         customColors=None,
                         figsize=(5, 5),
                         invert_yaxis=True,
                         saveDir=None,
                         fileName='scimapScatterPlot.png',
                         title=None,
                         **kwargs):

    import anndata as ad
    import pathlib
    import matplotlib.pyplot as plt
    import pandas as pd
    import math
    import numpy as np
    import matplotlib.patches as mpatches
    import matplotlib as mpl
    import os
    # Load the andata object
    if isinstance(adata, str):
        adata = ad.read(adata)
    else:
        adata = adata.copy()

    # subset data if neede
    if subset is not None:
        if isinstance (subset, str):
            subset = [subset]
        if layer == 'raw':
            bdata=adata.copy()
            bdata.X = adata.raw.X
            bdata = bdata[bdata.obs[imageid].isin(subset)]
        else:
            bdata=adata.copy()
            bdata = bdata[bdata.obs[imageid].isin(subset)]
    else:
        bdata=adata.copy()

    # isolate the data
    if layer is None:
        data = pd.DataFrame(bdata.X, index=bdata.obs.index, columns=bdata.var.index)
    elif layer == 'raw':
        data = pd.DataFrame(bdata.raw.X, index=bdata.obs.index, columns=bdata.var.index)
    else:
        data = pd.DataFrame(bdata.layers[layer], index=bdata.obs.index, columns=bdata.var.index)

    # isolate the meta data
    meta = bdata.obs

    # toplayer logic
    if isinstance (topLayer, str):
        topLayer = [topLayer]  

    # identify the things to color
    if isinstance (colorBy, str):
        colorBy = [colorBy]   
    # extract columns from data and meta
    data_cols = [col for col in data.columns if col in colorBy]
    meta_cols = [col for col in meta.columns if col in colorBy]
    # combine extracted columns from data and meta
    colorColumns = pd.concat([data[data_cols], meta[meta_cols]], axis=1)

    # identify the x and y coordinates
    x = meta[x_coordinate]
    y = meta[y_coordinate]


    # auto identify rows and columns in the grid plot
    def calculate_grid_dimensions(num_items, num_columns=None):
        """
        Calculates the number of rows and columns for a square grid
        based on the number of items.
        """
        if num_columns is None:
            num_rows_columns = int(math.ceil(math.sqrt(num_items)))
            return num_rows_columns, num_rows_columns
        else:
            num_rows = int(math.ceil(num_items / num_columns))
            return num_rows, num_columns

    # calculate the number of rows and columns
    nrows, ncols = calculate_grid_dimensions(len(colorColumns.columns), num_columns = ncols)


    # resolve figsize
    #figsize = (figsize[0]*ncols, figsize[1]*nrows)

    # Estimate point size
    if s is None:
        s = (10000 / bdata.shape[0]) / len(colorColumns.columns)

    # Define the categorical colormap (optional)
    cmap_cat = plt.get_cmap(catCmap)

    # FIIGURE
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, dpi=dpi)

    # Flatten the axs array for easier indexing
    if nrows == 1 and ncols == 1:
        axs = [axs]  # wrap single subplot in a list
    else:
        axs = axs.flatten()

    # Loop over the columns of the DataFrame
    for i, col in enumerate(colorColumns):
        # Select the current axis
        ax = axs[i]

        # invert y-axis
        if invert_yaxis is True:
            ax.invert_yaxis()

        # Scatter plot for continuous data
        # print(colorColumns[col])
        if colorColumns[col].dtype.kind in 'iufc':
            # print("continuous")
            scatter = ax.scatter(x=x, y=y, 
                                 c=colorColumns[col], 
                                 cmap=cmap, 
                                 s=s,
                                 vmin=vmin,
                                 vmax=vmax,
                                 linewidths=0,
                                 alpha=alpha, **kwargs)
            if plotLegend is True:
                cbar = plt.colorbar(scatter, ax=ax, pad=0)
                cbar.ax.tick_params(labelsize=fontsize)

        # Scatter plot for categorical data
        else:
            # Get the unique categories in the column
            categories = colorColumns[col].unique()
            # print("Categorical", categories)

            # Map the categories to colors using either the custom colors or the categorical colormap
            if customColors:
                colors = {cat: customColors[cat] for cat in categories if cat in customColors}
            else:
                colors = {cat: cmap_cat(i) for i, cat in enumerate(categories)}

            # Ensure topLayer categories are plotted last
            categories_to_plot_last = [cat for cat in topLayer if cat in categories] if topLayer else []
            categories_to_plot_first = [cat for cat in categories if cat not in categories_to_plot_last]

            # Plot non-topLayer categories first
            for cat in categories_to_plot_first:
                cat_mask = colorColumns[col] == cat
                ax.scatter(x=x[cat_mask], y=y[cat_mask], 
                           c=[colors.get(cat, cmap_cat(np.where(categories == cat)[0][0]))],
                           s=s, linewidths=0, alpha=alpha, **kwargs)

            # Then plot topLayer categories
            for cat in categories_to_plot_last:
                cat_mask = colorColumns[col] == cat
                ax.scatter(x=x[cat_mask], y=y[cat_mask], 
                           c=[colors.get(cat, cmap_cat(np.where(categories == cat)[0][0]))],
                           s=s, linewidths=0, alpha=alpha, **kwargs)

            if plotLegend is True:
                # Adjust legend to include all categories
                sorted_categories = sorted(categories)
                handles = [mpatches.Patch(color=colors.get(cat, cmap_cat(np.where(categories == cat)[0][0])), label=cat) for cat in sorted_categories]
                ax.legend(handles=handles, bbox_to_anchor=(1.0, 1.0), loc='upper left', bbox_transform=ax.transAxes, fontsize=fontsize)

        if title == None:
            title = col
        ax.set_title(title)  # fontsize=fontsize
        ax.set_yticklabels([])
        ax.set_xticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])


    # Remove any empty subplots
    num_plots = len(colorColumns.columns)
    for i in range(num_plots, nrows * ncols):
        ax = axs[i]
        fig.delaxes(ax)

    # Adjust the layout of the subplots grid
    plt.tick_params(axis='both', labelsize=fontsize)
    plt.tight_layout()

    # save figure    
    if saveDir:
        if not os.path.exists(saveDir):
            os.makedirs(saveDir)
        full_path = os.path.join(saveDir, fileName)
        plt.savefig(full_path, dpi=dpi)
        # plt.show()
        plt.close()
        print(f"Saved plot to {full_path}")
    else:
        plt.show()

def prolif_tumor_ratio(cluster_4):
    # profileration ratio for tumor cells near and not near immune-mimicked cells

    slice = cluster_4.loc[(cluster_4['percent_KI67- IMNeoplastic cells CD69hi'] > 0) | (cluster_4['percent_KI67+ IMNeoplastic cells CD69hi'] > 0) | (cluster_4['seed'] == 'KI67- IMNeoplastic cells CD69hi') | (cluster_4['seed'] == 'KI67+ IMNeoplastic cells CD69hi')]
    tumor_slice = cluster_4.loc[(cluster_4['percent_KI67- IMNeoplastic cells CD69hi'] == 0) & (cluster_4['percent_KI67+ IMNeoplastic cells CD69hi'] == 0) & (cluster_4['seed'] != 'KI67- IMNeoplastic cells CD69hi') & (cluster_4['seed'] != 'KI67+ IMNeoplastic cells CD69hi')]

    cluster_ratio = slice[['cluster', 'KI67_ratio']]
    # cluster_ratio.set_index('cluster', inplace=True)
    # cluster_ratio = cluster_ratio.groupby('cluster').mean()
    cluster_ratio['cluster'] = cluster_ratio['cluster'].map({0:1, 1:2, 2:3, 3:4})
    fig = go.Figure()
    # ['Tumor cells near Immune-mimicked', 'Tumor cells not near Immune-mimicked']
    fig.add_trace(
        go.Box(y=slice['KI67_ratio'],
        name='Near Immune-mimicked',
        # boxpoints='all',
        boxmean=True)
    )
    fig.add_trace(
        go.Box(y=tumor_slice['KI67_ratio'],
        name='Not near Immune-mimicked',
        # boxpoints='all',
        boxmean=True)
    )
    fig.update_layout(showlegend=False, title="Proliferating Tumor Ratio for Tumor Cells Near to and Far from Immune-Mimicked Cells", width=700, xaxis_title="Tumor Neighbor Location", yaxis_title="KI67 ratio")
    fig.show(renderer='notebook')
    
    # to generate cluster KI67 ratio
    # df = pd.read_csv("Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/spreadsheets/IM2/IM_create_all_neighborhoods_ki67.csv")
    # df['KI67_ratio'] = df['KI67p_neighbors'] / (df['KI67p_neighbors'] + df['KI67n_neighbors'])
    # # cluster_ratio = df[['cluster', 'KI67_ratio']]
    # # cluster_ratio

    # cluster_ratio
    # box plots
    # import plotly.express as px
    # box_fig = px.box(cluster_ratio, x='cluster', y='KI67_ratio', points=False, title="KI67+ Neoplastic / Neoplastic Ratios for Clusters")
    # box_fig.write_image("Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/IM_graphs/IM2/IM_ki67_ratio_box_no_outliers.png")

Main Script for Analysis, can be changed

In [None]:
# generate neighborhoods
filtered_adata = remove_control_stains() 
tma_images = list(filtered_adata["imageid"].unique())
phenotypes = list(filtered_adata['phenotype'].unique())

# should seed all tumor cells 
# KI67+ IMNeoplastic cells CD69hi, KI67- IMNeoplastic cells, nonIMNeoplastic cells
seeds = ["KI67+ IMNeoplastic cells CD69hi", "KI67- IMNeoplastic cells CD69hi", "nonIMNeoplastic cells"]
first_df = True
df = None
neighborhood_radius = 120

for tma in tma_images:
    tma_region = filtered_adata[filtered_adata["imageid"] == tma]
    # have every cell type be treated as seed so each will get assigned a cluster later on
    # for seed in seeds:
    if first_df == True:
        df = create_neighborhoods(tma_region, seeds, neighborhood_radius, phenotypes)
        first_df = False
    else:
        df = pd.concat([df, create_neighborhoods(tma_region, seeds, neighborhood_radius, phenotypes)])
    print("Calculated neighborhood for", tma)

# changed path for running on linux
neighborhood_path = "Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/spreadsheets/IM2/IM_create_all_neighborhoods_SUBJ.csv"
df.to_csv(neighborhood_path, index=False)
# save neighbor counts to adata.uns
adata.uns['custom_neighborhoods'] = df

# generate elbow plot
types = len(adata.obs['phenotype'].unique())
clusters = 15
image_path = "Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/IM_graphs/FINAL_IM/Elbow_Plot.png"
elbow_method(df, types, clusters, image_path)

# determine number of clusters based on elbow plot and generate cluster csv
add_cols = ['timepoint', 'Tx', 'Grade', 'Stage', 'ER', 'Subject_ID']
k = 4
csv_path = "Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/spreadsheets/IM2/IM_neighborhood_cluster4.csv"
cluster_4, neigh_cluster_4 = cluster_neighborhoods(df, types, add_cols, k, csv_path)

# generate image of cluster cell type composition
cmap_path = "Multiplex_IHC_studies/Eric_Berens/HuBrca_TMA_mIHC/DataAnalysis/ImmMimicry_ColorCodes.xlsx"
cmap = read_color_map(cmap_path)
image_name = "Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/IM_graphs/FINAL_IM/IM_cluster_composition4.png"
plot_cluster(cluster_4, types, image_name, k, cmap, "Neighborhood Cluster Cell Type Composition")

# generate cluster graphs, edit this function to change what you want to see
investigate_cluster_distribution(cluster_4, k)

# plot clusters in a particular order
high = ['ST-00011028', 'ST-00006692', 'ST-00018308', 'ST-00021050', 'ST-00006399', 'ST-00018307', 'ST-00021143', 'ST-00018343', 'ST-00016958']
mid = ['ST-00006415', 'ST-00014645', 'ST-00018134 - RT', 'ST-00015949', 'ST-00018134 - LT','ST-00006630', 'ST-00017876','ST-00019469', 'ST-00017405']
low = [ 'ST-00014447', 'ST-00013076', 'ST-00006152', 'ST-00017865', 'ST-00022280',   'ST-00006624', 'ST-00006509', 'ST-00023046']
subj_order = high + mid + low
# by grade and ER status
# cluster_dist_stacked_bar_plot(cluster_df=cluster_4, x_axis_val="Subject_ID", k=k, image_name="Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/IM_graphs/FINAL_IM/count_cluster_distribution.png", image_width=1000)
cluster_dist_stacked_bar_plot(cluster_df=cluster_4, x_axis_val="Grade", k=k, image_name="Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/IM_graphs/FINAL_IM/grade_cluster_distribution.png", image_height=600, percent=True)
cluster_dist_stacked_bar_plot(cluster_df=cluster_4, x_axis_val="ER", k=k, image_name="Multiplex_IHC_studies/Summer_Interns/2024/JT_spatial_output/IM_graphs/FINAL_IM/er_cluster_distribution.png", image_height=600, percent=True)


Visualization of Clusters

In [None]:
# visualize clusters
clusters = []
count = 0
for i in range(len(adata.obs)):
    if adata.obs.iloc[i]['phenotype'] == "Other Cells" or adata.obs.iloc[i]['Subtype'] == 'Control':
        # -1 for no cluster
        clusters.append("no cluster")
    else:
        # find the cluster classification for this cell
        found_cluster = find_cell_cluster(adata.obs.iloc[i], cluster_4)
        if found_cluster == None: # not a seed cell, no cluster
            clusters.append("no cluster")
        else:
            clusters.append("cluster " + str(found_cluster + 1))    
    if i % 100000 == 0:
        # approx 500000 in this dataset
        print("About " + str(count / 5 * 100) + "%% done.")
        count += 1
    
adata.obs['cluster'] = clusters
cluster_cmap={'cluster 1': '#45ACEA', 'cluster 2': '#D127A1', 'cluster 3': '#00CC9B', 'cluster 4': '#FFC134', 'no cluster': '#B4C0D2'}
# for tma in tma_images:

image = "formatted_S15ROI41"
title = image.split("_")[-1]
fn = title + ".png"
save_dir = "Multiplex_IHC_studies/Summer_Interns/2024/spatial/JT_spatial_output/scatterplots"
spatial_scatterPlot(adata, colorBy=['cluster'], subset=[image], figsize=(5,5), s=0.7, fontsize=5, title=title, customColors=cluster_cmap, fileName=fn, saveDir=save_dir)
# note: it is only plotting seed cells as cluster classified since those were the only ones that we were keeping track of 