In [None]:
import sys
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    !pip3 install scanpy
    from google.colab import drive
    drive.mount('/content/drive')
    import sys
    sys.path.append('/content/drive/MyDrive/SpatialModelProject/model_test_colab/')

import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata
import torch

import matplotlib.pyplot as plt
import matplotlib.font_manager
from matplotlib import rcParams

font_list = []
fpaths = matplotlib.font_manager.findSystemFonts()
for i in fpaths:
    try:
        f = matplotlib.font_manager.get_font(i)
        font_list.append(f.family_name)
    except RuntimeError:
        pass

font_list = set(font_list)
plot_font = 'Arial'

rcParams['font.family'] = plot_font
rcParams.update({'font.size': 10})
rcParams.update({'figure.dpi': 300})
rcParams.update({'figure.figsize': (3,3)})
rcParams.update({'savefig.dpi':300})

import warnings
warnings.filterwarnings('ignore')

## Load Starfysh

from starfysh import (AA, utils, plot_utils, post_analysis, utils_integrate)
from starfysh import starfysh as sf_model

In [None]:
import sys
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    !pip3 install scanpy
    #!pip3 install histomicstk
    from google.colab import drive
    drive.mount('/content/drive')
    import sys
    sys.path.append('/content/drive/MyDrive/SpatialModelProject/model_test_colab/')
import os
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.font_manager
from matplotlib import rcParams
import anndata
#import plot_utils
import seaborn as sns
import scanpy.external as sce
#import plot_utils
sns.set_style('white')
from scipy.ndimage import gaussian_filter1d
font_list = []
fpaths = matplotlib.font_manager.findSystemFonts()
for i in fpaths:
    try:
        f = matplotlib.font_manager.get_font(i)
        font_list.append(f.family_name)
    except RuntimeError:
        pass
    
font_list = set(font_list)
plot_font = 'Helvetica' if 'Helvetica' in font_list else 'FreeSans'

rcParams['font.family'] = plot_font
rcParams.update({'font.size': 10})
rcParams.update({'figure.dpi': 300})
rcParams.update({'figure.figsize': (3,3)})
rcParams.update({'savefig.dpi': 500})

import warnings
warnings.filterwarnings('ignore')
from statannotations.Annotator import Annotator

In [None]:
sys.path.append('../')
sys.path.append('/Users/siyuhe/opt/miniconda3/lib/python3.9/site-packages')
from starfysh import (AA, utils, plot_utils, post_analysis)
from starfysh import starfysh as sf_model
#sys.path.append('../simulation/')
#import utils as benchmark_utils

In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import json
import matplotlib.pyplot as plt

min_genes_values = {
    'SLV11': 150,
    'SLV12': 50,
    'SLV13': 50,
    'SLV14': 50,
    'SLV15': 50,
    'SLV16': 150,
    'SLV17': 50,
    'SLV18': 150,# Change as needed for each sample
}



def load_visium_data(base_path, sample_id, min_genes=50, n_top_genes=2000, mt_thld=20):
    """
    Load and process Visium spatial data
    
    Parameters:
    -----------
    base_path : str
        Path to the base directory containing data
    sample_id : str
        ID of the sample to process
    min_genes : int, optional
        Minimum number of genes for filtering (default: 50)
    n_top_genes : int, optional
        Number of top variable genes to select (default: 2000)
    mt_thld : int, optional
        Threshold for mitochondrial gene percentage (default: 20)
    
    Returns:
    --------
    tuple
        (adata, adata_norm, img_metadata)
    """
    # Read tissue positions
    tissue_positions = pd.read_parquet(os.path.join(
        base_path,
        sample_id,
        'binned_outputs',
        'square_016um',
        'spatial',
        'tissue_positions.parquet'
    ))

    # Read gene expression data
    adata = sc.read_10x_h5(os.path.join(
        base_path,
        sample_id,
        'binned_outputs',
        'square_016um',
        'filtered_feature_bc_matrix.h5'
    ))

    # Match positions with expression data
    matched_positions = tissue_positions[tissue_positions['barcode'].isin(adata.obs.index)]
    matched_positions = matched_positions.set_index('barcode')
    matched_positions = matched_positions.loc[adata.obs.index]

    # Add spatial coordinates to adata
    adata.obs[['in_tissue','array_row','array_col','pxl_row_in_fullres', 'pxl_col_in_fullres']] = matched_positions[
        ['in_tissue','array_row','array_col','pxl_row_in_fullres', 'pxl_col_in_fullres']]

    # Make var names unique and add sample ID
    adata.var_names_make_unique()
    adata.obs['sample'] = sample_id

    # Clean up var names if needed
    if '_index' in adata.var.columns:
        adata.var_names = adata.var['_index']
        adata.var_names.name = 'Genes'
        adata.var.drop('_index', axis=1, inplace=True)

    # Calculate QC metrics
    adata.var["mt"] = adata.var_names.str.startswith("MT-")
    sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)

    # Basic filtering
    sc.pp.filter_cells(adata, min_counts=min_genes)
    #sc.pp.filter_genes(adata, min_cells=100)
    

    # Create copy for normalized version
    adata_raw = adata.copy()

    # Additional filtering and normalization
    adata.var['mt'] = np.logical_or(
        adata.var_names.str.startswith('MT-'),
        adata.var_names.str.startswith('mt-'))
    adata.var['rb'] = (
        adata.var_names.str.startswith('RP-') |
        adata.var_names.str.startswith('rp-'))

    sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True)
    mask_cell = adata.obs['pct_counts_mt'] < mt_thld
    mask_gene = np.logical_and(~adata.var['mt'], ~adata.var['rb'])

    adata = adata[mask_cell, mask_gene]

    # Normalize and log transform
    sc.pp.normalize_total(adata, inplace=True)
    sc.pp.log1p(adata)

    # Find variable genes
    sc.pp.highly_variable_genes(adata, flavor='seurat', n_top_genes=n_top_genes, inplace=True)

    # Filter raw data to match normalized data
    adata_raw = adata_raw[adata.obs_names, adata.var_names]
    adata_raw.var['highly_variable'] = adata.var['highly_variable']
    adata_raw.obs = adata.obs

    # Prepare image metadata
    # Load scale factors
    with open(os.path.join(base_path, sample_id, 'binned_outputs/square_016um/spatial/scalefactors_json.json'), 'r') as f:
        scalefactor = json.load(f)

    # Load high-res image
    img = plt.imread(os.path.join(base_path, sample_id, 'spatial/tissue_hires_image.png'))

    # Prepare map info
    map_info = matched_positions[['in_tissue','array_row','array_col','pxl_row_in_fullres', 'pxl_col_in_fullres']]
    map_info['imagerow'] = pd.to_numeric(map_info['pxl_row_in_fullres'], errors='coerce').fillna(0).astype(int)
    map_info['imagecol'] = pd.to_numeric(map_info['pxl_col_in_fullres'], errors='coerce').fillna(0).astype(int)
    map_info = map_info[['array_row','array_col','imagerow','imagecol']]

    img_metadata = {
        'img': img,
        'map_info': map_info,
        'scalefactor': scalefactor
    }

    return adata_raw, adata, img_metadata





In [None]:
meta_info = [
    ['SLV11', 'C159', 'Antrum', 'Severe'],
    ['SLV12', 'C162', 'Rectum', 'Mild'],
    ['SLV13', 'C98', 'Stomach_Body', 'Severe'],
    ['SLV14', 'C159', 'Rectum', 'Severe'],
    ['SLV15', 'C179', 'Antrum', 'Mild'],
    ['SLV16', 'C179', 'Ascending_Colon', 'Mild'],
    ['SLV17', 'ND001', 'Ascending_Colon', 'ND'],
    ['SLV18', 'C162', 'Stomach', 'Mild']
]
meta_info = pd.DataFrame(meta_info,columns=['sample','patient','tissue_type','grade'])
base_path = '/Users/lingting/Documents/GVHD_project/visiumHD/data/'
sig_file_name = '/Users/lingting/Documents/GVHD_project/Spatial data/data/GVHD_spatial_signature_v8_major_curated_unique_epi_tcells_subset.csv'

# Load data
adata_all = []
adata_normed_all = []
img_metadata_all = {}

for sample_id in meta_info['sample']:
    print(sample_id)
    adata, adata_normed, img_metadata = load_visium_data(base_path, sample_id, min_genes=min_genes_values.get(sample_id, 50), n_top_genes=2000, mt_thld=20 )  # root data directory


    adata_normed = adata_normed[adata_normed.obs.index.isin(adata.obs.index)]
    adata_all.append(adata)
    
    
    
    adata.obs['patient']=meta_info['patient'][list(meta_info['sample']).index(sample_id)]
    adata.obs['sample_type']=meta_info['tissue_type'][list(meta_info['sample']).index(sample_id)]
    adata.obs['grade']=meta_info['grade'][list(meta_info['sample']).index(sample_id)]
    adata.obs_names  = adata.obs_names+'-'+sample_id

    adata_normed.obs_names  = adata_normed.obs_names+'-'+sample_id

    adata_normed.obs['patient']=adata.obs['patient']
    adata_normed.obs['sample_type']=adata.obs['sample_type']
    adata_normed.obs['grade']=adata.obs['grade']
    adata_normed_all.append(adata_normed)

    img_metadata['map_info'].index = img_metadata['map_info'].index+'-'+sample_id
    map_info = img_metadata['map_info']
    map_info = map_info[map_info.index.isin(adata.obs.index)]
    img_metadata['map_info'] = map_info
    
    img_metadata_all[sample_id] = img_metadata 
    

# Save concat data
adata_all = anndata.concat(adata_all)
adata_normed_all = anndata.concat(adata_normed_all)

sc.pp.highly_variable_genes(adata_normed_all)
adata_all.uns = adata_normed_all.uns
adata_all.var = adata_normed_all.var

adata_all.write(os.path.join(base_path, 'adata_integrate.h5ad'))
adata_normed_all.write(os.path.join(base_path, 'adata_normed_integrate.h5ad'))

# Assign cell type to hubs

In [None]:
#adata_integrate = sc.read_h5ad('adata_fig4_0904_23.h5ad')
base_path = '/Users/lingting/Documents/GVHD_project/visiumHD/data/'
adata_integrate = anndata.read(os.path.join(base_path, 'adata_integrate_final.h5ad'))

In [None]:
adata_integrate

In [None]:
sc.pl.umap(adata_integrate,color="hub",
                frameon=False, s=15,
                title='UMAP after Starfysh sample integration'
                )
plt.show()

In [None]:
hub_stem_percent  =[]
for i in adata_integrate.obs['hub'].unique():
    adata_integrate_temp = adata_integrate[adata_integrate.obs['hub']==i]
    hub_stem_percent.append([i,adata_integrate_temp.obs[['Intestine_Epithelial Stem cell','Stomach_stem_cells']].sum(axis=1).mean()])
hub_stem_percent = pd.DataFrame(hub_stem_percent).sort_values(by =1,ascending=False)

In [None]:
# First, let's get the order right
hub_order = hub_stem_percent.iloc[:,0].tolist()  # Get current hub order
hub_colors_ordered = [adata_integrate.uns['hub_colors'][hub] for hub in hub_order]  # Reorder colors to match

# Creating bar plot with ordered colors
plt.figure(figsize=(12,3), dpi=300)
plt.bar(range(hub_stem_percent.shape[0]), 
        hub_stem_percent.iloc[:,1]*100,
        color=hub_colors_ordered)  # Using reordered colors

# Add horizontal line at 33%
plt.axhline(y=33, color='black', linestyle='--', alpha=0.7)

# Adding labels and title
plt.xlabel('Hubs')
plt.ylabel('Stem Cell \nPercentages (%)')

# Customizing x-ticks
plt.xticks(ticks=list(range(hub_stem_percent.shape[0])), 
          labels=hub_stem_percent.iloc[:,0],
          rotation=45)

# Adding grid for y-axis only
plt.grid(False, axis='y')
plt.tight_layout()
for ext in ['pdf', 'png', 'svg']:
    plt.savefig(
        f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/stem_percentage_hub_bar_plot.{ext}',
        bbox_inches='tight', 
        dpi=300)
plt.show()

In [None]:
adata_integrate.obs['hub']

# Figure 6 A

In [None]:
for sample in meta_info['sample']:
    print(sample)
    
    adata = adata_integrate[adata_integrate.obs['sample']==sample]

    
    dot_size = 0.5

    new_colormap = [
    'darkslateblue',    # 0
    'cornflowerblue',   # 1
    'red',              # 2
    'blueviolet',       # 3
    'skyblue',          # 4
    'orchid',           # 5
    'yellowgreen',      # 6
    'palevioletred',    # 7
    'orange',           # 8
    'cadetblue',        # 9
    'limegreen',        # 10
    'cyan',             # 11
    'gold',             # 12
    'slategray',        # 13
    'olive',            # 14
    'blue',             # 15
    'linen',            # 16
    'mistyrose',        # 17
    'peru',             # 18
    'darkturquoise',    # 19
    'teal',             # 20
    'salmon',           # 21 (new color)
    'violet',           # 22 (new color)
    'dodgerblue',       # 23 (new color)
    'darkgreen',       # 24 (new color)
    'mediumaquamarine', # 25 (new color)
    'tomato',           # 26 (new color)
    'sandybrown',       # 27 (new color)
    'darkkhaki',        # 28 (new color)
    'lightseagreen',    # 29 (new color)
    'mediumorchid',     # 30 (new color)
    'crimson',          # 31 (new color)
    'olivedrab',        # 32 (new color)
    'steelblue',        # 33 (new color)
    'plum',             # 34 (new color)
#    'chocolate'         # 35 (new color)
]
   
    all_loc = adata.obs[['array_col', 'array_row']]
    if sample in['SLV16']:
        fig,axs= plt.subplots(1,1,figsize=(4, 2.4),dpi=400)
    elif sample in ['SLV14']:
        fig,axs= plt.subplots(1,1,figsize=(3, 2.3),dpi=400)
    elif sample in ['SLV11']: 
        fig,axs= plt.subplots(1,1,figsize=(3, 2.8),dpi=400)
    elif sample in ['SLV17']: 
        fig,axs= plt.subplots(1,1,figsize=(2.5, 3),dpi=400)
    else:
        fig,axs= plt.subplots(1,1,figsize=(3, 3),dpi=400)
    color_list = np.array(adata.obs['hub'])
    for i in range(len((adata.obs['hub'].unique()))):
        g = axs.scatter(all_loc.iloc[color_list==i, 0], 
                -all_loc.iloc[color_list==i, 1],
                s=dot_size,
                c=new_colormap[i],
                marker='s',  # 'o' is the default and represents circles
                alpha=1.0,
                       linewidth = 0.5)     
    axs.set_xticks([])
    #plt.title(‘s’)
    axs.set_yticks([])
    plt.axis("off")
    #plt.legend(list(range(len((adata.obs['Cluster'].unique())))), bbox_to_anchor=(1,0.5))
    for ext in ['pdf', 'png', 'svg']:
        plt.savefig(
            f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}_hub_cluster.{ext}',
            bbox_inches='tight', 
            dpi=300)
    plt.show()

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['xtick.labelsize'] = 8
mpl.rcParams['ytick.labelsize'] = 8

# Create a figure for the legend only
fig, ax = plt.subplots(figsize=(0.5, 4))

# Create a list of Line2D objects with square markers for the legend
patches = []
for i in range(len(new_colormap)):
    # Use Line2D with marker='s' for squares instead of patches
    square = plt.Line2D([0], [0], marker='s', color='w', 
                      markerfacecolor=new_colormap[i], 
                      markersize=8,  # Adjust size as needed
                      label=f'{i}')
    patches.append(square)

# Add the legend to the figure
ax.legend(handles=patches, loc='center', frameon=False, 
          fontsize=8, ncol=2, title='Clusters/Hubs')

# Hide the axes
ax.set_axis_off()

# Save the legend as a separate file
for ext in ['pdf', 'png', 'svg']:
    plt.savefig(
        f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/clusters_legend.{ext}',
        bbox_inches='tight', transparent=True,
        dpi=300)

plt.show()

In [None]:
for sample in meta_info['sample']:
    print(sample)
    adata = anndata.read(f'/Users/lingting/Documents/GVHD_project/visiumHD/figures/{sample}_adata.h5ad')
    for resolution in np.arange(1,2,0.1):

        resolution = round(resolution, 1)  # Round to 1 decimal place
        key_added='leiden'+'_'+resolution.astype(str)
        
        dot_size = 0.1
        # new_colormap = [
        #     'darkslateblue',    # 0
        #     'cornflowerblue',   # 1
        #     'red',              # 2
        #     'blueviolet',       # 3
        #     'skyblue',          # 4
        #     'orchid',           # 5
        #     'yellowgreen',      # 6
        #     'palevioletred',    # 7
        #     'orange',           # 8
        #     'cadetblue',        # 9
        #     'limegreen',        # 10
        #     'cyan',             # 11
        #     'gold',             # 12
        #     'slategray',        # 13
        #     'olive',            # 14
        #     'blue',             # 15
        #     'linen',            # 16
        #     'mistyrose',        # 17
        #     'peru',             # 18
        #     'darkturquoise',    # 19
        #     'teal',             # 20
        #     'salmon',           # 21 (new color)
        #     'violet',           # 22 (new color)
        #     'dodgerblue',       # 23 (new color)
        #     'darkgreen',       # 24 (new color)
        #     'mediumaquamarine', # 25 (new color)
        #     'tomato',           # 26 (new color)
        #     'sandybrown',       # 27 (new color)
        #     'darkkhaki',        # 28 (new color)
        #     'lightseagreen',    # 29 (new color)
        #     'mediumorchid',     # 30 (new color)
        #     'crimson',          # 31 (new color)
        #     'olivedrab',        # 32 (new color)
        #     'steelblue',        # 33 (new color)
        #     'plum',             # 34 (new color)
        #     'chocolate'         # 35 (new color)
        # ]
       
        all_loc = adata.obs[['array_col', 'array_row']]
        if sample in['SLV16']:
            fig,axs= plt.subplots(1,1,figsize=(4, 2.4),dpi=400)
        elif sample in ['SLV14']:
            fig,axs= plt.subplots(1,1,figsize=(3, 2.3),dpi=400)
        elif sample in ['SLV11']: 
            fig,axs= plt.subplots(1,1,figsize=(3, 2.8),dpi=400)
        elif sample in ['SLV17']: 
            fig,axs= plt.subplots(1,1,figsize=(2.5, 3),dpi=400)
        else:
            fig,axs= plt.subplots(1,1,figsize=(3, 3),dpi=400)
        color_list = np.array(adata.obs[key_added+'_int'])
        for i in range(len((adata.obs[key_added+'_int'].unique()))):
            g = axs.scatter(all_loc.iloc[color_list==i, 0], 
                    -all_loc.iloc[color_list==i, 1],
                    s=dot_size,
                    c=new_colormap[i],
                    marker='s',  # 'o' is the default and represents circles
                    alpha=1.0,
                           linewidth = 0.5)    
        axs.set_xticks([])
        #plt.title(‘s’)
        axs.set_yticks([])
        plt.axis("off")
        #plt.legend(list(range(len((adata.obs['Cluster'].unique())))), bbox_to_anchor=(1,0.5))
        for ext in ['pdf', 'png', 'svg']:
            plt.savefig(
                f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}_{key_added}_cluster.{ext}',
                bbox_inches='tight', 
                dpi=300)
        plt.show()

In [None]:
import matplotlib as mpl
mpl.rcParams['image.cmap'] = 'viridis'  # Or your preferred colormap
mpl.rcParams.update({'font.size': 12, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 12
mpl.rcParams['xtick.labelsize'] = 12  # Colorbar tick label size 
mpl.rcParams['ytick.labelsize'] = 12  # Colorbar tick label size

# First, let's get the order right
hub_order = hub_stem_percent.iloc[:,0].tolist()  # Get current hub order

# Map hub numbers to colormap indices
# We need to ensure each hub gets a consistent color, regardless of order
hub_color_mapping = {}
for i, hub in enumerate(hub_order):
    # Assuming hub is a number like 0, 1, 2, etc.
    # Use modulo to handle case where we have more hubs than colors
    hub_color_mapping[hub] = new_colormap[int(hub) % len(new_colormap)]

# Get colors for each hub in the order they appear in hub_stem_percent
hub_colors_ordered = [hub_color_mapping[hub] for hub in hub_order]

# Creating bar plot with ordered colors from new_colormap
plt.figure(figsize=(12,3), dpi=300)
plt.bar(range(hub_stem_percent.shape[0]), 
        hub_stem_percent.iloc[:,1]*100,
        color=hub_colors_ordered)  # Using reordered colors from new_colormap

# Add horizontal line at 33%
plt.axhline(y=33, color='black', linestyle='--', alpha=0.7)

# Adding labels and title
plt.xlabel('Hubs')
plt.ylabel('Stem Cell \nPercentages (%)')

# Customizing x-ticks
plt.xticks(ticks=list(range(hub_stem_percent.shape[0])), 
          labels=hub_stem_percent.iloc[:,0],
          rotation=45)

# Adding grid for y-axis only
plt.grid(False, axis='y')
plt.tight_layout()

# Save in multiple formats
for ext in ['pdf', 'png', 'svg']:
    plt.savefig(
        f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/stem_percentage_hub_bar_plot_new_colors.{ext}',
        bbox_inches='tight', 
        dpi=300)

plt.show()

In [None]:
# Version with minimum threshold
sig_file_name = '/Users/lingting/Documents/GVHD_project/Spatial data/data/GVHD_spatial_signature_v8_major_curated_unique_epi_tcells_subset_Transition_EPI_REFINE.csv'
gene_sig = pd.read_csv(sig_file_name, encoding='latin1')

min_proportion = 0.333  # adjust this threshold as needed
proportions_df = adata_integrate.obs[adata_integrate.obs.columns[adata_integrate.obs.columns.get_loc('Plasma cells'):
                                          adata_integrate.obs.columns.get_loc('Stomach_Body_Epithelial_cells')+1]]
proportions_df['hub'] = adata_integrate.obs['hub']
composition = proportions_df.groupby('hub').sum()
composition_normalized = composition.div(composition.sum(axis=1), axis=0)

# For each hub, get the dominant cell type
hub_dominant_celltype = composition_normalized.idxmax(axis=1)

# Map this back to cells based on their hub
adata_integrate.obs['hub_celltype'] = adata_integrate.obs['hub'].map(hub_dominant_celltype)

# If you want the combined annotation
#adata_integrate.obs['hub_celltype_combined'] = adata_integrate.obs['hub'] + '_' + adata_integrate.obs['hub_celltype']

# Print summary of the assignments
print("\nNumber of cells assigned to each hub-celltype combination:")
print(adata_integrate.obs.groupby(['hub', 'hub_celltype']).size())

# Print the proportion of the dominant cell type in each hub
dominant_proportions = composition_normalized.max(axis=1)
print("\nProportion of dominant cell type in each hub:")
print(dominant_proportions)
# Print the proportion of the dominant cell type in each hub
dominant_proportions = composition_normalized.idxmax(axis=1)
print("\nProportion of dominant cell type in each hub:")
print(dominant_proportions)

In [None]:
adata_integrate

In [None]:
import matplotlib as mpl
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['xtick.labelsize'] = 8  # Colorbar tick label size 
mpl.rcParams['ytick.labelsize'] = 8  # Colorbar tick label size
mpl.rcParams['axes.linewidth'] = 1

signature_markers = gene_sig
adata = adata_integrate
clusters = adata.obs['hub']

cluster_labels = clusters.unique()

signature_expression = {}
for column in signature_markers.columns:
    # Ensure genes exist in the dataset
    valid_genes = [g for g in signature_markers[column] if g in adata.var_names]
    if len(valid_genes) == 0:
        print(f"No valid genes found for signature: {column}")
        continue

    # Subset data for valid genes
    expr = adata[:, valid_genes].X

    # Convert sparse matrix to dense if necessary
    from scipy.sparse import issparse
    if issparse(expr):
        expr = expr.toarray()

    # Handle single gene case
    if len(valid_genes) == 1:
        expr = expr.flatten()  # Convert to 1D array for single gene

    # Calculate mean expression for each cluster
    signature_mean = (
        pd.DataFrame(expr, index=adata.obs.index, columns=valid_genes)
        .groupby(clusters)
        .mean()
    )
    # Store mean expression for the signature
    signature_expression[column] = signature_mean.mean(axis=1)

# Convert results to a DataFrame for plotting
signature_expression_df = pd.DataFrame(signature_expression)

from scipy.stats import zscore

# Calculate the z-score for each signature across clusters
z_score_df = signature_expression_df.apply(zscore, axis=0)

# Plot heatmap with z-score normalized data
#plt.figure(figsize=(4, 4), dpi=300)
#sns.heatmap(
    #z_score_df.T,  # Transpose to show clusters on the x-axis
    #annot=False, fmt=".2f", cmap="viridis",
    #xticklabels=True, yticklabels=True
#)

sns.clustermap(
z_score_df.T, 
cmap='viridis', 
method='average', 
metric='euclidean', 
figsize=(8, 6), 
linewidths=0,  # Set linewidths to 0 to remove grid lines
col_cluster = False,
)
plt.title(f"Z-Score Normalized Signature Expression by hubs")
plt.xlabel("Hubs")
plt.ylabel("Signature")
plt.tight_layout()

for ext in ['pdf', 'png', 'svg']:
    plt.savefig(
        f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/hub_gene_signature_heatmap.{ext}',
        bbox_inches='tight',
        dpi=300,
    )
plt.show()

In [None]:
adata_integrate.obs[gene_sig.columns] = adata_integrate.obsm['qc_m']
proportions_df = adata_integrate.obs[gene_sig.columns]
proportions_df['hub'] = adata_integrate.obs['hub']

composition = proportions_df.groupby('hub').sum()
composition_normalized = composition.div(composition.sum(axis=1), axis=0)

plt.figure(figsize=(60, 30))
composition_normalized.plot(kind='bar', stacked=True, color=new_colormap)
plt.xlabel('Spatial Spots (Hubs)')
plt.ylabel('Normalized Cell Type Proportions')
plt.title('Composition of Cell Types in Spatial Spots (Normalized)')
plt.legend(title='Cell Types', bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=5)
plt.xticks(rotation=45, ha='right',fontsize=4)  # Rotate labels
#plt.tight_layout()
for ext in ['pdf', 'png', 'svg']:
    plt.savefig(
        f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/integrated_hub_bar_plot.{ext}',
        bbox_inches='tight', 
        dpi=900)

In [None]:
for sample in meta_info['sample']:
    print(sample)
    adata = anndata.read(f'/Users/lingting/Documents/GVHD_project/visiumHD/figures/{sample}_adata.h5ad')

    resolution = 3.5

    key_added='leiden_4'

    # Step 3: Apply Leiden clustering
    # Here, we set the resolution. Higher values give more clusters.
    sc.tl.leiden(adata, resolution=resolution,key_added=key_added)

    clusters = adata.obs[key_added] # Replace with your cluster column
    
    cluster_labels = clusters.unique()
    
    signature_expression = {}
    for column in signature_markers.columns:
        # Ensure genes exist in the dataset
        valid_genes = [g for g in signature_markers[column] if g in adata.var_names]
        if len(valid_genes) == 0:
            print(f"No valid genes found for signature: {column}")
            continue
    
        # Subset data for valid genes
        expr = adata[:, valid_genes].X
    
        # Convert sparse matrix to dense if necessary
        from scipy.sparse import issparse
        if issparse(expr):
            expr = expr.toarray()
    
        # Handle single gene case
        if len(valid_genes) == 1:
            expr = expr.flatten()  # Convert to 1D array for single gene
    
        # Calculate mean expression for each cluster
        signature_mean = (
            pd.DataFrame(expr, index=adata.obs.index, columns=valid_genes)
            .groupby(clusters)
            .mean()
        )
        # Store mean expression for the signature
        signature_expression[column] = signature_mean.mean(axis=1)
    
    # Convert results to a DataFrame for plotting
    signature_expression_df = pd.DataFrame(signature_expression)
    
    from scipy.stats import zscore
    
    # Calculate the z-score for each signature across clusters
    z_score_df = signature_expression_df.apply(zscore, axis=0)
    
    # Plot heatmap with z-score normalized data
    plt.figure(figsize=(6, 4), dpi=300)
    #sns.heatmap(
        #z_score_df.T,  # Transpose to show clusters on the x-axis
        #annot=False, fmt=".2f", cmap="viridis",
        #xticklabels=True, yticklabels=True
    #)

    sns.clustermap(
    z_score_df.T, 
    cmap='viridis', 
    method='average', 
    metric='euclidean', 
    figsize=(8, 6), 
    linewidths=0,  # Set linewidths to 0 to remove grid lines
    col_cluster = False
)
    plt.title(f"Z-Score Normalized Signature Expression by Cluster {resolution}")
    plt.xlabel("Cluster")
    plt.ylabel("Signature")
    plt.tight_layout()
    for ext in ['pdf', 'png', 'svg']:
        plt.savefig(
            f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}hub_gene_signature_heatmap.{ext}',
            bbox_inches='tight',
            dpi=300,
        )
    plt.show()

# Figure 6 B

In [None]:
for sample in meta_info['sample']:
    # Filter data for current sample
    print(sample)
    adata = adata_integrate[adata_integrate.obs['sample'] == sample]
    print(adata.shape)

In [None]:
from scipy import ndimage
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['xtick.labelsize'] = 8  # Colorbar tick label size 
mpl.rcParams['ytick.labelsize'] = 8  # Colorbar tick label size
mpl.rcParams['axes.linewidth'] = 1


from scipy import ndimage
import matplotlib.pyplot as plt
import numpy as np

cell_type = 'CD8+ Effector T  cells'
for sample in meta_info['sample']:
    # Filter data for current sample
    print(sample)
    adata = adata_integrate[adata_integrate.obs['sample'] == sample]
    # Create figure with larger size to ensure full visibility
    fig, ax = plt.subplots(figsize=(3, 2.5), dpi=500)  # Increased figure size
    # Get location data
    adata_all_loc = img_metadata_all[sample]['map_info'][['array_col', 'array_row']]
    # Create coordinate grids - add padding to avoid cropping
    x_min, x_max = adata_all_loc.min()[0], adata_all_loc.max()[0]
    y_min, y_max = adata_all_loc.min()[1], adata_all_loc.max()[1]
    # Add some padding (5% on each side)
    x_padding = (x_max - x_min) * 0.05
    y_padding = (y_max - y_min) * 0.05
    x = np.arange(x_min - x_padding, x_max + x_padding, 1)
    y = np.arange(y_min - y_padding, y_max + y_padding, 1)
    X, Y = np.meshgrid(x, y)
    Z = np.zeros([len(x), len(y)])
    # Calculate values for each grid cell
    for i in range(len(x)-1):
        for j in range(len(y)-1):
            adata_temp = adata[
                (x[i] <= adata.obs['array_col']) &
                (adata.obs['array_col'] <= x[i+1]) &
                (y[j] <= adata.obs['array_row']) &
                (adata.obs['array_row'] <= y[j+1])
            ]
            if adata_temp.shape[0] > 0:
                Z[i, j] = adata_temp.obs[cell_type].mean().sum()
    # Apply Gaussian filter and normalize
    Z2 = ndimage.gaussian_filter(Z, sigma=3.0, order=0)
    Z2 = Z2 / Z2.max()
    # Create contour plot
    cset = plt.contourf(X, -Y, Z2.transpose(), cmap='Oranges',
                       vmin=0, vmax=1, linewidth=0, alpha=1.0)
    contour = plt.contour(X, -Y, Z2.transpose(), 2,
                         c='b', linewidth=0.01, vmin=0, vmax=1)
    plt.clabel(contour, fontsize=8, colors='k')
    # Add colorbar but make it smaller and more discreet
    cbar = plt.colorbar(cset, fraction=0.046, pad=0.04)
    cbar.outline.set_visible(False)  # Remove colorbar outline
    # Remove all spines (borders)
    for spine in ax.spines.values():
        spine.set_visible(False)
    # Turn off all ticks and labels
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    # Remove frame
    ax.set_frame_on(False)
    # Turn off axes completely
    plt.axis('off')
    # Adjust figure to eliminate any whitespace padding
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    # Save the plot in multiple formats
    for ext in ['pdf', 'png', 'svg']:
        plt.savefig(
            f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}_{cell_type}_contour_dis.{ext}',
            bbox_inches='tight',
            dpi=300,
            pad_inches=0  # Remove all padding
        )
    plt.show()

In [None]:
from scipy import ndimage
import matplotlib.pyplot as plt
import numpy as np

cell_type = 'CD4+ Effector T cells'
for sample in ['SLV14','SLV16']:
    # Filter data for current sample
    print(sample)
    adata = adata_integrate[adata_integrate.obs['sample'] == sample]
    # Create figure with larger size to ensure full visibility
    fig, ax = plt.subplots(figsize=(3, 2.5), dpi=500)  # Increased figure size
    # Get location data
    adata_all_loc = img_metadata_all[sample]['map_info'][['array_col', 'array_row']]
    # Create coordinate grids - add padding to avoid cropping
    x_min, x_max = adata_all_loc.min()[0], adata_all_loc.max()[0]
    y_min, y_max = adata_all_loc.min()[1], adata_all_loc.max()[1]
    # Add some padding (5% on each side)
    x_padding = (x_max - x_min) * 0.05
    y_padding = (y_max - y_min) * 0.05
    x = np.arange(x_min - x_padding, x_max + x_padding, 1)
    y = np.arange(y_min - y_padding, y_max + y_padding, 1)
    X, Y = np.meshgrid(x, y)
    Z = np.zeros([len(x), len(y)])
    # Calculate values for each grid cell
    for i in range(len(x)-1):
        for j in range(len(y)-1):
            adata_temp = adata[
                (x[i] <= adata.obs['array_col']) &
                (adata.obs['array_col'] <= x[i+1]) &
                (y[j] <= adata.obs['array_row']) &
                (adata.obs['array_row'] <= y[j+1])
            ]
            if adata_temp.shape[0] > 0:
                Z[i, j] = adata_temp.obs[cell_type].mean().sum()
    # Apply Gaussian filter and normalize
    Z2 = ndimage.gaussian_filter(Z, sigma=3.0, order=0)
    Z2 = Z2 / Z2.max()
    # Create contour plot
    cset = plt.contourf(X, -Y, Z2.transpose(), cmap='Oranges',
                       vmin=0, vmax=1, linewidth=0, alpha=1.0)
    contour = plt.contour(X, -Y, Z2.transpose(), 2,
                         c='b', linewidth=0.01, vmin=0, vmax=1)
    plt.clabel(contour, fontsize=8, colors='k')
    # Add colorbar but make it smaller and more discreet
    cbar = plt.colorbar(cset, fraction=0.046, pad=0.04)
    cbar.outline.set_visible(False)  # Remove colorbar outline
    # Remove all spines (borders)
    for spine in ax.spines.values():
        spine.set_visible(False)
    # Turn off all ticks and labels
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    # Remove frame
    ax.set_frame_on(False)
    # Turn off axes completely
    plt.axis('off')
    # Adjust figure to eliminate any whitespace padding
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    # Save the plot in multiple formats
    for ext in ['pdf', 'png', 'svg']:
        plt.savefig(
            f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}_{cell_type}_contour_dis.{ext}',
            bbox_inches='tight',
            dpi=300,
            pad_inches=0  # Remove all padding
        )
    plt.show()

In [None]:
from scipy import ndimage
import matplotlib.pyplot as plt
import numpy as np
cell_type =['Stomach_stem_cells', 'Intestine_Epithelial Stem cell']
for sample in ['SLV14','SLV16']:
    # Filter data for current sample
    print(sample)
    adata = adata_integrate[adata_integrate.obs['sample'] == sample]
    # Create figure with larger size to ensure full visibility
    fig, ax = plt.subplots(figsize=(3, 2.5), dpi=500)  # Increased figure size
    # Get location data
    adata_all_loc = img_metadata_all[sample]['map_info'][['array_col', 'array_row']]
    # Create coordinate grids - add padding to avoid cropping
    x_min, x_max = adata_all_loc.min()[0], adata_all_loc.max()[0]
    y_min, y_max = adata_all_loc.min()[1], adata_all_loc.max()[1]
    # Add some padding (5% on each side)
    x_padding = (x_max - x_min) * 0.05
    y_padding = (y_max - y_min) * 0.05
    x = np.arange(x_min - x_padding, x_max + x_padding, 1)
    y = np.arange(y_min - y_padding, y_max + y_padding, 1)
    X, Y = np.meshgrid(x, y)
    Z = np.zeros([len(x), len(y)])
    # Calculate values for each grid cell
    for i in range(len(x)-1):
        for j in range(len(y)-1):
            adata_temp = adata[
                (x[i] <= adata.obs['array_col']) &
                (adata.obs['array_col'] <= x[i+1]) &
                (y[j] <= adata.obs['array_row']) &
                (adata.obs['array_row'] <= y[j+1])
            ]
            if adata_temp.shape[0] > 0:
                Z[i, j] = adata_temp.obs[cell_type].mean().sum()
    # Apply Gaussian filter and normalize
    Z2 = ndimage.gaussian_filter(Z, sigma=3.0, order=0)
    Z2 = Z2 / Z2.max()
    # Create contour plot
    cset = plt.contourf(X, -Y, Z2.transpose(), cmap='Oranges',
                       vmin=0, vmax=1, linewidth=0, alpha=1.0)
    contour = plt.contour(X, -Y, Z2.transpose(), 2,
                         c='b', linewidth=0.01, vmin=0, vmax=1)
    plt.clabel(contour, fontsize=8, colors='k')
    # Add colorbar but make it smaller and more discreet
    cbar = plt.colorbar(cset, fraction=0.046, pad=0.04)
    cbar.outline.set_visible(False)  # Remove colorbar outline
    # Remove all spines (borders)
    for spine in ax.spines.values():
        spine.set_visible(False)
    # Turn off all ticks and labels
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    # Remove frame
    ax.set_frame_on(False)
    # Turn off axes completely
    plt.axis('off')
    # Adjust figure to eliminate any whitespace padding
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    # Save the plot in multiple formats
    for ext in ['pdf', 'png', 'svg']:
        plt.savefig(
            f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}_{cell_type}_contour_dis.{ext}',
            bbox_inches='tight',
            dpi=300,
            pad_inches=0  # Remove all padding
        )
    plt.show()

In [None]:


from scipy import ndimage
import matplotlib.pyplot as plt
import numpy as np
cell_type =['CD8+ Transitioning  Resident T cells']
for sample in ['SLV14','SLV16']:
    # Filter data for current sample
    print(sample)
    adata = adata_integrate[adata_integrate.obs['sample'] == sample]
    # Create figure with larger size to ensure full visibility
    fig, ax = plt.subplots(figsize=(3, 2.5), dpi=500)  # Increased figure size
    # Get location data
    adata_all_loc = img_metadata_all[sample]['map_info'][['array_col', 'array_row']]
    # Create coordinate grids - add padding to avoid cropping
    x_min, x_max = adata_all_loc.min()[0], adata_all_loc.max()[0]
    y_min, y_max = adata_all_loc.min()[1], adata_all_loc.max()[1]
    # Add some padding (5% on each side)
    x_padding = (x_max - x_min) * 0.05
    y_padding = (y_max - y_min) * 0.05
    x = np.arange(x_min - x_padding, x_max + x_padding, 1)
    y = np.arange(y_min - y_padding, y_max + y_padding, 1)
    X, Y = np.meshgrid(x, y)
    Z = np.zeros([len(x), len(y)])
    # Calculate values for each grid cell
    for i in range(len(x)-1):
        for j in range(len(y)-1):
            adata_temp = adata[
                (x[i] <= adata.obs['array_col']) &
                (adata.obs['array_col'] <= x[i+1]) &
                (y[j] <= adata.obs['array_row']) &
                (adata.obs['array_row'] <= y[j+1])
            ]
            if adata_temp.shape[0] > 0:
                Z[i, j] = adata_temp.obs[cell_type].mean().sum()
    # Apply Gaussian filter and normalize
    Z2 = ndimage.gaussian_filter(Z, sigma=3.0, order=0)
    Z2 = Z2 / Z2.max()
    # Create contour plot
    cset = plt.contourf(X, -Y, Z2.transpose(), cmap='Oranges',
                       vmin=0, vmax=1, linewidth=0, alpha=1.0)
    contour = plt.contour(X, -Y, Z2.transpose(), 2,
                         c='b', linewidth=0.01, vmin=0, vmax=1)
    plt.clabel(contour, fontsize=8, colors='k')
    # Add colorbar but make it smaller and more discreet
    cbar = plt.colorbar(cset, fraction=0.046, pad=0.04)
    cbar.outline.set_visible(False)  # Remove colorbar outline
    # Remove all spines (borders)
    for spine in ax.spines.values():
        spine.set_visible(False)
    # Turn off all ticks and labels
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    # Remove frame
    ax.set_frame_on(False)
    # Turn off axes completely
    plt.axis('off')
    # Adjust figure to eliminate any whitespace padding
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    # Save the plot in multiple formats
    for ext in ['pdf', 'png', 'svg']:
        plt.savefig(
            f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}_{cell_type}_contour_dis.{ext}',
            bbox_inches='tight',
            dpi=300,
            pad_inches=0  # Remove all padding
        )
    plt.show()

# Figure 6 D

In [None]:
hub_stem_list = [4,5,8,30,32]

In [None]:
# Assuming obs is your AnnData object's .obs DataFrame or a similar pandas DataFrame
stem_cell_types = ['Intestine_Epithelial Stem cell', 'Stomach_stem_cells']

# Create a new column that labels cells as "stem" if they belong to the specified cell types
adata_integrate.obs['hub_celltype'] = adata_integrate.obs['hub_celltype'].apply(lambda x: 'stem' if x in stem_cell_types else x)

In [None]:
import matplotlib as mpl
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['xtick.labelsize'] = 8  # Colorbar tick label size 
mpl.rcParams['ytick.labelsize'] = 8  # Colorbar tick label size

def plot_categories_on_histology(
    adata,
    column,
    his_loc,
    scalefactors,
    size,
    output_folder,
    sample_id,
    plot_cat,
    category_color_map  # Changed from color_map to category_color_map
):
    """
    Plot categories on a histology image using a consistent colormap.

    Parameters:
        adata: AnnData object.
        column: Column in adata.obs containing categorical data to plot.
        his_loc: Path to the histology image.
        scalefactors: Dictionary with scaling factors for histology coordinates.
        size: Size of scatter plot markers.
        output_folder: Directory to save the output image.
        sample_id: Sample ID for naming the output file.
        plot_cat: List of categories to plot.
        category_color_map: Dictionary mapping categories to colors.
    """
    # Load the histology image
    histology_image = plt.imread(his_loc)
    adata.obs['imagerow'] = pd.to_numeric(adata.obs['pxl_row_in_fullres'], errors='coerce').fillna(0).astype(int)
    adata.obs['imagecol'] = pd.to_numeric(adata.obs['pxl_col_in_fullres'], errors='coerce').fillna(0).astype(int)
    histology_x = adata.obs['imagerow'] * scalefactors['tissue_hires_scalef']
    histology_y = adata.obs['imagecol'] * scalefactors['tissue_hires_scalef']

    # Extract categories from adata.obs
    categories = adata.obs[column]
    
    # Plot histology image
    height, width = histology_image.shape[:2]
    scale_factor = 0.002
    fig, ax = plt.subplots(1, 1, figsize=(width*scale_factor, height*scale_factor), dpi=300)
    ax.imshow(histology_image)
    ax.axis('off')

    # Plot only requested categories with consistent colors
    for category in plot_cat:
        if category in category_color_map:
            mask = (categories == category)
            ax.scatter(
                x=histology_y[mask],
                y=histology_x[mask],
                c=[category_color_map[category]],
                label=category,
                marker='s',
                s=size,
                edgecolor='none'
            )
    
    # Add legend
    #ax.legend(title=column, bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Save the figure
    output_path = os.path.join(output_folder, f"{sample_id}_{column}_on_histology.png")
    plt.tight_layout()
    fig.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.show()
    plt.close(fig)
    
    print(f"Figure saved to {output_path}")

# Define your consistent color mapping
custom_colors = {
    'CD8+ Effector T  cells': 'Yellow',   # Cyan
    'CD4+ Effector T cells': 'Cyan',    # Orange
    'stem': '#023047'                      # Dark blue
}

# Setup output folder
output_folder = '/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/'

# Categories to plot
plot_categories = ['CD8+ Effector T  cells', 'stem', 'CD4+ Effector T cells']

# Process each sample
for sample in meta_info['sample']:
    # Load scalefactors
    file_path = f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/binned_outputs/square_016um/spatial/scalefactors_json.json'
    with open(file_path, 'r') as file:
        scalefactor = json.load(file)
    
    # Plot with consistent colors
    plot_categories_on_histology(
        adata=adata_integrate[adata_integrate.obs['sample']==sample],
        column='hub_celltype',
        his_loc=f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/spatial/tissue_hires_image.png',
        scalefactors=scalefactor,
        size=5,
        output_folder=output_folder,
        sample_id=sample,
        plot_cat=plot_categories,
        category_color_map=custom_colors  # Pass the color dictionary directly
    )



def create_standalone_legend(
    category_color_map,
    column_name,
    output_folder,
    filename="large_legend",
    figsize=(5, 3),
    fontsize=14,
    markersize=15
):
    """
    Create a standalone legend figure with large, clear elements.
    
    Parameters:
        category_color_map: Dictionary mapping categories to colors.
        column_name: Name of the column/category for the legend title.
        output_folder: Directory to save the output image.
        filename: Base filename for the output.
        figsize: Size of the figure (width, height) in inches.
        fontsize: Font size for legend text.
        markersize: Size of the markers in the legend.
    """
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    import os
    
    # Create empty figure with white background
    fig, ax = plt.subplots(figsize=figsize)
    fig.patch.set_facecolor('white')
    
    # Create legend handles
    handles = []
    for category, color in category_color_map.items():
        patch = mpatches.Patch(color=color, label=category)
        handles.append(patch)
    
    # Create the legend
    legend = ax.legend(
        handles=handles,
        title=column_name,
        loc='center',
        frameon=True,
        fontsize=fontsize,
        title_fontsize=fontsize+2,
        markerscale=markersize
    )
    
    # Make legend frame thicker
    legend.get_frame().set_linewidth(1.5)
    
    # Hide the axes
    ax.axis('off')
    
    # Save the figure
    output_path = os.path.join(output_folder, f"{filename}.png")
    fig.savefig(output_path, bbox_inches='tight', dpi=300)
    
    # Save SVG version too (for publications)
    svg_output_path = os.path.join(output_folder, f"{filename}.svg")
    fig.savefig(svg_output_path, bbox_inches='tight', format='svg')
    
    plt.close(fig)
    
    print(f"Legend saved to {output_path} and {svg_output_path}")
    
    return output_path

# Example usage:

# Define your consistent color mapping


# Setup output folder
output_folder = '/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/'

# Create standalone legend
create_standalone_legend(
    category_color_map=custom_colors,
    column_name='Cell Types',  # Legend title
    output_folder=output_folder,
    filename='cell_types_legend',
    figsize=(7, 4),  # Wide format for horizontal legend
    fontsize=16,
    markersize=20
)

In [None]:
adata_integrate.obs['hub_celltype'].unique()

In [None]:
# Define your consistent color mapping
custom_colors = {
    'CD4+ Regulatory T cells': 'Cyan'
}

# Setup output folder
output_folder = '/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/'

# Categories to plot
plot_categories = ['CD4+ Regulatory T cells']

# Process each sample
for sample in meta_info['sample']:
    # Load scalefactors
    file_path = f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/binned_outputs/square_016um/spatial/scalefactors_json.json'
    with open(file_path, 'r') as file:
        scalefactor = json.load(file)
    
    # Plot with consistent colors
    plot_categories_on_histology(
        adata=adata_integrate[adata_integrate.obs['sample']==sample],
        column='hub_celltype',
        his_loc=f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/spatial/tissue_hires_image.png',
        scalefactors=scalefactor,
        size=1,
        output_folder=output_folder,
        sample_id=sample,
        plot_cat=plot_categories,
        category_color_map=custom_colors  # Pass the color dictionary directly
    )

In [None]:
def plot_category_on_histology_crop(
    adata,
    column,
    his_loc,
    scalefactors,
    size,
    output_folder,
    sample_id,
    x_start,
    x_end,
    y_start,
    y_end,
    plot_cat,
    color_map=None
):
    """
    Plot categories on a histology image and crop the final plot.

    Parameters:
        adata: AnnData object.
        column: Column in adata.obs containing categorical data to plot.
        his_loc: Path to the histology image.
        scalefactors: Dictionary with scaling factors for histology coordinates.
        size: Size of scatter plot markers.
        output_folder: Directory to save the output image.
        sample_id: Sample ID for naming the output file.
        x_start, x_end, y_start, y_end: Coordinates to crop the final plot.
        plot_cat: Category to plot.
        color_map: List of colors to map to categories (optional).
    
    Returns:
        None (saves and shows the cropped figure).
    """
    import os
    import matplotlib.pyplot as plt

    # Load the histology image
    histology_image = plt.imread(his_loc)
    adata.obs['imagerow'] = pd.to_numeric(adata.obs['pxl_row_in_fullres'], errors='coerce').fillna(0).astype(int)
    adata.obs['imagecol'] = pd.to_numeric(adata.obs['pxl_col_in_fullres'], errors='coerce').fillna(0).astype(int)

    # Scale coordinates from adata.obs
    histology_x = adata.obs['imagerow'] * scalefactors['tissue_hires_scalef']
    histology_y = adata.obs['imagecol'] * scalefactors['tissue_hires_scalef']

    # Extract categories from adata.obs
    categories = adata.obs[column]
    
    # # Assign colors to categories
    # unique_categories = categories.unique()
    # if color_map is None:
    #     color_map = plt.cm.tab20.colors  # Default colormap if none is provided
    # category_color_map = {cat: color_map[i % len(color_map)] for i, cat in enumerate(unique_categories)}
    category_color_map = color_map
    # Plot the full histology image
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(histology_image)
    ax.axis('off')


    # Plot only requested categories with consistent colors
    for category in plot_cat:
        if category in category_color_map:
            mask = (categories == category)
            ax.scatter(
                x=histology_y[mask],
                y=histology_x[mask],
                c=[category_color_map[category]],
                label=category,
                marker='s',
                s=size,
                edgecolor='none'
            )
    

    # Add legend
    #ax.legend(title=column, bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Save the full figure
    #full_output_path = os.path.join(output_folder, f"{sample_id}_{column}_{plot_cat}_on_histology_full.png")
    #fig.savefig(full_output_path, bbox_inches='tight', dpi=300)

    # Crop and save the cropped figure
    ax.set_xlim([x_start, x_end])
    ax.set_ylim([y_end, y_start])  # Flip the y-axis for correct orientation
    cropped_output_path = os.path.join(output_folder, f"{sample_id}_{column}_{plot_cat}_on_histology_crop_{x_start}.png")
    fig.savefig(cropped_output_path, bbox_inches='tight', dpi=300)
    cropped_output_path = os.path.join(output_folder, f"{sample_id}_{column}_{plot_cat}_on_histology_crop_{x_start}.svg")
    fig.savefig(cropped_output_path, bbox_inches='tight', dpi=300)
    cropped_output_path = os.path.join(output_folder, f"{sample_id}_{column}_{plot_cat}_on_histology_crop_{x_start}.pdf")
    fig.savefig(cropped_output_path, bbox_inches='tight', dpi=300)

    # Show the cropped plot
    plt.show()
    plt.close(fig)

    #print(f"Full figure saved to {full_output_path}")
    print(f"Cropped figure saved to {cropped_output_path}")

output_folder = '/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/'
y_start, y_end = 2300, 2900  # Adjust these values for the vertical range
x_start, x_end = 1000, 1400  # Adjust these values for the horizontal range

sample = 'SLV14'
# Path to the JSON file
file_path = f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/binned_outputs/square_016um/spatial/scalefactors_json.json'

# Open and load the JSON file
with open(file_path, 'r') as file:
    scalefactor = json.load(file)

if sample == "SLV11":
    s = 0.1
elif sample == "SLV15":
    s = 0.6
else:
    s = 1

custom_colors = {
    'CD8+ Effector T  cells': 'Yellow',   # Cyan
    'CD4+ Effector T cells': 'Cyan',    # Orange
    'stem': '#023047'                      # Dark blue
}
# Categories to plot
plot_categories = ['CD8+ Effector T  cells', 'stem', 'CD4+ Effector T cells']
#custom_colors = ['Cyan', 'Yellow', '#023047']

plot_category_on_histology_crop(
    adata_integrate[adata_integrate.obs['sample']==sample],  # adata
    'hub_celltype',                                              # column
    f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/spatial/tissue_hires_image.png',  # his_loc
    scalefactor,                                           # scalefactors
    100,                                                   # size
    output_folder,                                         # output_folder
    sample,                                                # sample_id
    x_start,                                               # x_start
    x_end,                                                 # x_end
    y_start,                                               # y_start
    y_end,                                                 # y_end
    plot_categories,                                                  # plot_cat
    color_map=custom_colors                                # color_map (keyword argument)
)

In [None]:
y_start, y_end =1300,1900 # Adjust these values for the vertical range
x_start, x_end = 1500, 2000# Adjust these values for the horizontal range
plot_category_on_histology_crop(
    adata_integrate[adata_integrate.obs['sample']==sample],  # adata
    'hub_celltype',                                              # column
    f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/spatial/tissue_hires_image.png',  # his_loc
    scalefactor,                                           # scalefactors
    100,                                                   # size
    output_folder,                                         # output_folder
    sample,                                                # sample_id
    x_start,                                               # x_start
    x_end,                                                 # x_end
    y_start,                                               # y_start
    y_end,                                                 # y_end
    plot_categories,                                                  # plot_cat
    color_map=custom_colors                                # color_map (keyword argument)
)

In [None]:
sample = 'SLV16'
y_start, y_end =2200,3300  # Adjust these values for the vertical range
x_start, x_end = 4500, 5200  # Adjust these values for the horizontal range
# Path to the JSON file
file_path = f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/binned_outputs/square_016um/spatial/scalefactors_json.json'

# Open and load the JSON file
with open(file_path, 'r') as file:
    scalefactor = json.load(file)

if sample == "SLV11":
    s = 0.1
elif sample == "SLV15":
    s = 0.6
else:
    s = 1

plot_category_on_histology_crop(
    adata_integrate[adata_integrate.obs['sample']==sample],  # adata
    'hub_celltype',                                              # column
    f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/spatial/tissue_hires_image.png',  # his_loc
    scalefactor,                                           # scalefactors
    40,                                                   # size
    output_folder,                                         # output_folder
    sample,                                                # sample_id
    x_start,                                               # x_start
    x_end,                                                 # x_end
    y_start,                                               # y_start
    y_end,                                                 # y_end
    plot_categories,                                                  # plot_cat
    color_map=custom_colors                                # color_map (keyword argument)
)

In [None]:
sample = 'SLV14'
y_start, y_end =1500,2100  # Adjust these values for the vertical range
x_start, x_end =4200, 4800 
# Path to the JSON file
file_path = f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/binned_outputs/square_016um/spatial/scalefactors_json.json'

# Open and load the JSON file
with open(file_path, 'r') as file:
    scalefactor = json.load(file)

if sample == "SLV11":
    s = 0.1
elif sample == "SLV15":
    s = 0.6
else:
    s = 1

plot_category_on_histology_crop(
    adata_integrate[adata_integrate.obs['sample']==sample],  # adata
    'hub_celltype',                                              # column
    f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/spatial/tissue_hires_image.png',  # his_loc
    scalefactor,                                           # scalefactors
    140,                                                   # size
    output_folder,                                         # output_folder
    sample,                                                # sample_id
    x_start,                                               # x_start
    x_end,                                                 # x_end
    y_start,                                               # y_start
    y_end,                                                 # y_end
    plot_categories,                                                  # plot_cat
    color_map=custom_colors)

# Figure 7A

In [None]:
output_folder = '/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/'
def plot_proportion_on_histology_crop(
    adata,
    cat,
    his_loc,
    scalefactors,
    size,
    output_folder,
    sample_id,
    x_start,
    x_end,
    y_start,
    y_end,
    color_map=None  # Correct parameter name
):
    """
    Plot spots with gene expression values greater than 0 on a histology image.

    Parameters:
        adata: AnnData object.
        gene: Gene name to plot expression levels.
        his_loc: Path to the histology image.
        scalefactors: Dictionary with scaling factors for histology coordinates.
        size: Size of scatter plot markers.
        output_folder: Directory to save the output image.
        sample_id: Sample ID for naming the output file.
        x_start, x_end, y_start, y_end: Coordinates to crop the histology image.
        color_map: List of colors to map to categories (optional).
    """
    # Load the histology image
    histology_image = plt.imread(his_loc)

    # Scale coordinates
    histology_x = adata.obs['pxl_row_in_fullres'] * scalefactors['tissue_hires_scalef']
    histology_y = adata.obs['pxl_col_in_fullres'] * scalefactors['tissue_hires_scalef']

    # Get gene expression values

    #adata.obs[adata.uns['cell_types']] = adata.obsm['qc_m']
    gene_values = adata.obs[cat]
    #print( adata.obsm['qc_m'])

    # Filter spots with values > 0
    mask = gene_values > 0.01
    histology_x = histology_x[mask]
    histology_y = histology_y[mask]
    gene_values = gene_values[mask]

    # Plot histology image
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(histology_image)
    ax.axis('off')
        

    # Plot spots with filtered values
    scatter = ax.scatter(
        x=histology_y,
        y=histology_x,
        c=gene_values,
        cmap='plasma',
        marker='s',
        s=size,
        vmax = 0.2,
        #vmax = vmax,
        vmin=0,
        edgecolor='none'
    )

    # Add colorbar for reference
    #plt.colorbar(scatter, ax=ax, label=gene, fontsize=16)

    #cbar = plt.colorbar(scatter, ax=ax)  # Create colorbar
    #cbar.set_label(label=cat, fontsize=20)
    #cbar.ax.tick_params(labelsize=16)

    # Save the figure
    #output_path = os.path.join(output_folder, f"{sample_id}_{cat}_{y_start}_{x_start}_on_histology_gene_crop.png")
    #plt.show()
    #fig.savefig(output_path, bbox_inches='tight', dpi=300)

    # Crop and save the cropped figure
    ax.set_xlim([x_start, x_end])
    ax.set_ylim([y_end, y_start])  # Flip the y-axis for correct orientation
    cropped_output_path = os.path.join(output_folder, f"{sample_id}_{cat}_{y_start}_{x_start}_on_histology_crop.png")
    fig.savefig(cropped_output_path, bbox_inches='tight', dpi=300)
    cropped_output_path = os.path.join(output_folder, f"{sample_id}_{cat}_{y_start}_{x_start}_on_histology_crop.pdf")
    fig.savefig(cropped_output_path, bbox_inches='tight', dpi=300)
    cropped_output_path = os.path.join(output_folder, f"{sample_id}_{cat}_{y_start}_{x_start}_on_histology_crop.svg")
    fig.savefig(cropped_output_path, bbox_inches='tight', dpi=300)

    # Show the cropped plot
    #plt.show()
    plt.close(fig)

In [None]:
y_start, y_end = 2300, 2900  # Adjust these values for the vertical range
x_start, x_end = 1000, 1400  # Adjust these values for the horizontal range

sample = 'SLV14'
# Path to the JSON file
file_path = f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/binned_outputs/square_016um/spatial/scalefactors_json.json'

# Open and load the JSON file
with open(file_path, 'r') as file:
    scalefactor = json.load(file)



for cat in adata_integrate.uns['cell_types']:
    print(cat)
    plot_proportion_on_histology_crop(
        adata_integrate[adata_integrate.obs['sample']==sample],
        cat,
        f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/spatial/tissue_hires_image.png',
        scalefactor,
        100,
        output_folder,
        sample,
        x_start,
        x_end,
        y_start,
        y_end,
        color_map=None  # Correct parameter name
    )

In [None]:
y_start, y_end = 1800, 2000  # Adjust these values for the vertical range
x_start, x_end = 1600, 2100  # Adjust these values for the horizontal range

for cat in adata_integrate.uns['cell_types']:
    print(cat)
    plot_proportion_on_histology_crop(
        adata_integrate[adata_integrate.obs['sample']==sample],
        cat,
        f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/spatial/tissue_hires_image.png',
        scalefactor,
        100,
        output_folder,
        sample,
        x_start,
        x_end,
        y_start,
        y_end,
        color_map=None  # Correct parameter name
    )

In [None]:
y_start, y_end =1500,1800  # Adjust these values for the vertical range
x_start, x_end = 4300, 4700  # Adjust these values for the horizontal range
for cat in adata_integrate.uns['cell_types']:
    print(cat)
    plot_proportion_on_histology_crop(
        adata_integrate[adata_integrate.obs['sample']==sample],
        cat,
        f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/spatial/tissue_hires_image.png',
        scalefactor,
        100,
        output_folder,
        sample,
        x_start,
        x_end,
        y_start,
        y_end,
        color_map=None  # Correct parameter name
    )

In [None]:
y_start, y_end = 1400, 1800  # Adjust these values for the vertical range
x_start, x_end = 1700, 2300  # Adjust these values for the horizontal range

for cat in adata_integrate.uns['cell_types']:
    print(cat)
    plot_proportion_on_histology_crop(
        adata_integrate[adata_integrate.obs['sample']==sample],
        cat,
        f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/spatial/tissue_hires_image.png',
        scalefactor,
        100,
        output_folder,
        sample,
        x_start,
        x_end,
        y_start,
        y_end,
        color_map=None  # Correct parameter name
    )

In [None]:
def plot_separate_colorbar(
    vmin=0, 
    vmax=0.2, 
    cmap='plasma', 
    label=None, 
    orientation='vertical',
    output_folder=None,
    sample_id=None,
    cat=None,
    fontsize=20,
    ticksize=16,
    figsize=(1.2, 4)
):
    """
    Create and save a standalone colorbar.
    
    Parameters:
        vmin: Minimum value for colorbar.
        vmax: Maximum value for colorbar.
        cmap: Colormap to use (default: 'plasma').
        label: Label for the colorbar.
        orientation: 'vertical' or 'horizontal'.
        output_folder: Directory to save the output image.
        sample_id: Sample ID for naming the output file.
        cat: Category name for file naming.
        fontsize: Font size for colorbar label.
        ticksize: Font size for tick labels.
        figsize: Figure size (width, height).
    """
    import matplotlib.pyplot as plt
    import matplotlib as mpl
    import os
    
    # Adjust figsize based on orientation
    if orientation == 'horizontal':
        figsize = (figsize[1], figsize[0])
    
    # Create a figure with appropriate size
    fig, ax = plt.subplots(figsize=figsize)
    
    # Create a ScalarMappable with the colormap
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    
    # Create the colorbar
    cbar = plt.colorbar(sm, cax=ax, orientation=orientation)
    
    # Set the label if provided
    if label:
        cbar.set_label(label, fontsize=fontsize)
    
    # Set tick label size
    cbar.ax.tick_params(labelsize=ticksize)
    
    # Adjust layout
    plt.tight_layout()
    

    output_path = os.path.join(output_folder, f"{sample_id}_{cat}_colorbar.png")
    fig.savefig(output_path, bbox_inches='tight', dpi=300)
    output_path = os.path.join(output_folder, f"{sample_id}_{cat}_colorbar.pdf")
    fig.savefig(output_path, bbox_inches='tight', dpi=300)
    output_path = os.path.join(output_folder, f"{sample_id}_{cat}_colorbar.svg")
    fig.savefig(output_path, bbox_inches='tight', dpi=300)
    
    plt.show()
    plt.close(fig)



import matplotlib as mpl
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['xtick.labelsize'] = 8  # Colorbar tick label size 
mpl.rcParams['ytick.labelsize'] = 8  # Colorbar tick label size
mpl.rcParams['axes.linewidth'] = 1

# Then create a matching colorbar
plot_separate_colorbar(vmin=0, vmax=0.2, cmap='plasma', label='', 
                       output_folder=output_folder, sample_id=sample_id, cat='')

In [None]:
def plot_dual_violin(data1, data2, y1_label, y2_label,y_label, title=None, sample_id=None, figsize=(10,6)):
    """
    Create a split violin plot for comparing two distributions with custom colors
    
    Args:
        data1: Array-like, data for CD8 effector cells (blue)
        data2: Array-like, data for CD4 effector cells (orange)
        y1_label: String, label for first group
        y2_label: String, label for second group
        title: String, optional plot title
        sample_id: String, sample identifier
        figsize: Tuple, figure dimensions
    Returns:
        tuple: (figure, (ax1, ax2)) where ax1 and ax2 are the same axis object
    """
    # Create figure and axis
    fig, ax = plt.subplots(figsize=figsize)
    dd
    # Create DataFrame in the format needed for seaborn
    import pandas as pd
    import numpy as np
    
    df = pd.DataFrame({
        'value': np.concatenate([data1, data2]),
        'group': np.concatenate([
            np.full(len(data1), sample_id),
            np.full(len(data2), sample_id)
        ]),
        'category': np.concatenate([
            np.full(len(data1), y1_label),
            np.full(len(data2), y2_label)
        ])
    })
    
    # Create custom color palette
    custom_palette = {y1_label: '#1f77b4', y2_label: '#ff7f0e'}
    
    # Create the split violin plot
    sns.violinplot(
        data=df,
        x='group',
        y='value',
        hue='category',
        split=True,
        inner='quart',
        fill=False,
        ax=ax,
        palette=custom_palette
    )
    
    # Calculate statistics
    stat, pvalue = stats.mannwhitneyu(data1, data2, alternative='two-sided')
    n1 = len(data1)
    n2 = len(data2)
    total_samples = n1 + n2
    bonferroni_pvalue = pvalue * total_samples
    pvalue_corrrected = min(bonferroni_pvalue, 1.0)
    
    # Add title with p-value if specified
    if title:
        plt.title(f'{title}\np-value = {pvalue:.2e}')
    
    # Clean up the plot
    ax.set_xlabel('')
    ax.set_ylabel(y_label)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    # Adjust layout
    plt.tight_layout()
    
    # Return fig and tuple of (ax, ax) to match expected unpacking
    return fig, (ax, ax)

In [None]:
hub_stem_list

In [None]:
adata_integrate.obs['hub_celltype'].unique()

# Figure 6 E Cell type proportion

In [None]:
adata_integrate.obs['hub_celltype'].unique()

In [None]:
# Filter the AnnData object to keep only non-fibroblast cells
adata_filtered = adata_integrate[~adata_integrate.obs['hub_celltype'].str.contains('Fibroblast', case=False)].copy()

# Alternatively, if you want to be more specific with exact matches:
# adata_filtered = adata_integrate[~adata_integrate.obs['hub_celltype'].isin(['Fibroblast', 'fibroblasts', 'Hub Fibroblast'])].copy()

# Verify the removal
print(len(adata_filtered.obs['hub_celltype'].unique()))

In [None]:
sc.pl.umap(adata_filtered, color='SampleID_Grade')

In [None]:
sc.pl.umap(adata_filtered, color='hub_celltype')

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN
from scipy.spatial import distance

# Create a new column in adata_integrate to store the island information
#adata_filtered.obs['sample_island'] = adata_filtered.obs['sample'].astype(str)



In [None]:
import matplotlib as mpl
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['xtick.labelsize'] = 8  # Colorbar tick label size 
mpl.rcParams['ytick.labelsize'] = 8  # Colorbar tick label size
mpl.rcParams['axes.linewidth'] = 1

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rcParams
from scipy import stats
from copy import deepcopy
from statannotations.Annotator import Annotator
adata_filtered.obs['SampleID_Grade'] = adata_filtered.obs['SampleID_Grade'].astype('category')
adata_filtered.obs['hub_celltype'] = adata_filtered.obs['hub_celltype'].astype('category')
# Assuming adata is already defined and contains the necessary data
obs_1 = 'SampleID_Grade'
obs_2 = 'hub_celltype'
num = 1

for adata in [adata_filtered]:
    n_categories = {x: len(adata.obs[x].cat.categories) for x in [obs_1, obs_2]}
    df = adata.obs[[obs_2, obs_1]].values

    obs2_clusters = adata.obs[obs_2].cat.categories.tolist()
    obs1_clusters = adata.obs[obs_1].cat.categories.tolist()

    obs1_to_obs2 = {k: np.zeros(len(obs2_clusters), dtype="i")
                    for k in obs1_clusters}
    obs2_to_obs1 = {k: np.zeros(len(obs1_clusters), dtype="i")
                    for k in obs2_clusters}
    for b, l in df:
        obs2_to_obs1[b][obs1_clusters.index(str(l))] += 1
        obs1_to_obs2[l][obs2_clusters.index(str(b))] += 1
    obs2_to_obs1_array = np.zeros((len(obs2_clusters), len(obs1_clusters)))
    obs1_to_obs2_array = np.zeros((len(obs1_clusters), len(obs2_clusters)))
    for i, k in enumerate(obs2_clusters):
        obs2_to_obs1_array[i, :] = deepcopy(obs2_to_obs1[k])
    for i, k in enumerate(obs1_clusters):
        obs1_to_obs2_array[i, :] = deepcopy(obs1_to_obs2[k])

    sums = np.sum(obs2_to_obs1_array, 0)
    for col in range(np.size(obs2_to_obs1_array, 1)):
        for row in range(np.size(obs2_to_obs1_array, 0)):
            obs2_to_obs1_array[row, col] = obs2_to_obs1_array[row, col] / sums[col]

    sums2 = np.sum(obs1_to_obs2_array, 0)
    for col in range(np.size(obs1_to_obs2_array, 1)):
        for row in range(np.size(obs1_to_obs2_array, 0)):
            obs1_to_obs2_array[row, col] = obs1_to_obs2_array[row, col] / sums2[col]

    df_obs2 = pd.DataFrame(obs2_to_obs1_array)
    df_obs2.index = obs2_clusters
    df_obs2.columns = obs1_clusters

    # Sort columns if necessary
    sorter = sorted(obs1_clusters)
    df_obs2 = df_obs2.reindex(columns=sorter)

    df_obs2.T.plot(kind="bar", stacked=True, color=adata.uns[obs_2 + '_colors'])
    plt.legend(bbox_to_anchor=(1.01, 0.5), loc='center left')
    plt.show()

    df_obs1 = pd.DataFrame(obs1_to_obs2_array)
    df_obs1.index = obs1_clusters
    df_obs1.columns = obs2_clusters
    df_obs1.T.plot(kind="bar", stacked=True, color=adata.uns[obs_1 + '_colors'])
    plt.legend(bbox_to_anchor=(1.01, 0.5), loc='center left')
    plt.show()

    print(num)
    num += 1
    p_list = []

    for i in adata.obs[obs_2].unique():
        substring = 'Severe'
        column_mask = df_obs2.columns.str.contains(substring)
        selected_columns_severe = df_obs2.loc[:, column_mask]
        severe = selected_columns_severe.loc[i].values[~np.isnan(selected_columns_severe.loc[i].values)]

        substring = 'Mild'
        column_mask = df_obs2.columns.str.contains(substring)
        selected_columns_mild = df_obs2.loc[:, column_mask]
        mild = selected_columns_mild.loc[i].values[~np.isnan(selected_columns_mild.loc[i].values)]


        substring = 'ND'
        column_mask = df_obs2.columns.str.contains(substring)
        selected_columns_nd = df_obs2.loc[:, column_mask]
        nd = selected_columns_nd.loc[i].values[~np.isnan(selected_columns_nd.loc[i].values)]

        
        df_severe = pd.DataFrame()
        df_severe['Value'] = severe
        df_severe['Grade'] = 'severe'

        df_mild = pd.DataFrame()
        df_mild['Value'] = mild
        df_mild['Grade'] = 'mild/no'


        df_ND = pd.DataFrame()
        df_ND['Value'] = nd
        df_ND['Grade'] = 'mild/no'
        

        palette = { 'mild/no': '#005A8F', 'severe': '#B85000'}
        #df_ND,

        result = pd.concat([df_ND,df_mild, df_severe])
        #rcParams.update({'font.size': 10})
        fig, ax = plt.subplots(figsize=(2, 1.5), dpi=300)
        sns.boxplot(x='Grade', y='Value', data=result, ax=ax, palette=palette, order=['mild/no', 'severe'],showfliers=False)
        sns.stripplot(x='Grade', y='Value', data=result, ax=ax, color='black', alpha=0.5, jitter=True, order=['mild/no', 'severe'])

        ax.spines[['right', 'top']].set_visible(False)
        ax.set_xlabel('Grade')
        ax.set_ylabel('Cell type proportion')
        ax.set_title(i)
        ax.get_xaxis().tick_bottom()
        ax.get_yaxis().tick_left()

        pairs = [('mild/no', 'severe')]
        annotator = Annotator(ax, pairs, data=result, x='Grade', y='Value')
        annotator.configure(test='t-test_ind', text_format='simple')
        annotator.apply_and_annotate()

        plt.tight_layout()

        # Save figure with 300 DPI
        plt.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/fig_{i}_boxplot_300dpi.png', dpi=300, transparent=True)
        plt.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/fig_{i}_boxplot_300dpi.pdf', dpi=300, transparent=True)
        plt.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/fig_{i}_boxplot_300dpi.svg', dpi=300, transparent=True)

        plt.show()


# Figure 6 G DEGs on cd8 effector

In [None]:
cd8_effector = adata_integrate[adata_integrate.obs['hub_celltype']== 'CD8+ Effector T  cells']

In [None]:
sc.pp.normalize_per_cell(cd8_effector)
sc.pp.log1p(cd8_effector)


In [None]:
sc.tl.rank_genes_groups(cd8_effector, groupby='grade', use_raw=False, method='wilcoxon')
degs_df = sc.get.rank_genes_groups_df(cd8_effector, group = 'ND')

In [None]:
degs_df = degs_df.loc[np.abs(degs_df.logfoldchanges) < 10]

crit_pval = 0.05
crit_l2fc = 0.5
nonsig_degs_df = degs_df.loc[((np.abs(degs_df['logfoldchanges']) <= crit_l2fc) & 
                           (degs_df['pvals_adj'] >= crit_pval))]
sig_degs_df = degs_df.loc[((np.abs(degs_df['logfoldchanges']) > crit_l2fc) & 
                           (degs_df['pvals_adj'] < crit_pval))]

marker_intersect = sig_degs_df['names']

# crit_pval = 0.05
# crit_l2fc = 0.5
# nonsig_degs_df = degs_df.loc[((np.abs(degs_df['logfoldchanges']) <= crit_l2fc) & 
#                            (degs_df['pvals_adj'] >= crit_pval))]
# sig_degs_df = degs_df.loc[((np.abs(degs_df['logfoldchanges']) > crit_l2fc) & 
#                            (degs_df['pvals_adj'] < crit_pval))]

fig, ax = plt.subplots(figsize=(3,3), dpi=300)

ax.scatter(degs_df['logfoldchanges'], 
            degs_df['pvals_adj'].apply(lambda x: -np.log10(x)), c='gray',
            s=3)

ax.scatter(nonsig_degs_df['logfoldchanges'], 
            nonsig_degs_df['pvals_adj'].apply(lambda x: -np.log10(x)), 
            s=3)

ax.scatter(sig_degs_df['logfoldchanges'], 
            sig_degs_df['pvals_adj'].apply(lambda x: -np.log10(x)), 
            c='g',
            s=3)

ax.scatter(sig_degs_df.loc[degs_df['names'].isin(marker_intersect)]['logfoldchanges'],
            sig_degs_df.loc[degs_df['names'].isin(marker_intersect)]['pvals_adj'].apply(lambda x: -np.log10(x)),
            c='r',
            s=3)

ax.axvline(x=crit_l2fc, c='gray', linestyle='--', linewidth=0.5)
ax.axvline(x=-crit_l2fc, c='gray', linestyle='--', linewidth=0.5)
ax.axhline(y=-np.log10(crit_pval), c='gray', linestyle='--', linewidth=0.5)

oxphos_lfcs = sig_degs_df.loc[degs_df['names'].isin(marker_intersect)]['logfoldchanges']
oxphos_pvals = sig_degs_df.loc[degs_df['names'].isin(marker_intersect)]['pvals_adj']
oxphos_genes = sig_degs_df.loc[degs_df['names'].isin(marker_intersect)]['names']


#ax.set_xlim(-2, 2)  # Adjust these limits as needed
#ax.set_ylim(-1, 30)  # Adjust these limits as needed


for lfc, pval, gene in zip(oxphos_lfcs, oxphos_pvals, oxphos_genes):
    ax.text(lfc+0.1, -np.log10(pval)+0.5, gene, fontsize=3)
         
ax.set_xlabel('Log2FC', fontsize=12)  # Adjust fontsize as needed
ax.set_ylabel('-log10(adj.pvalue)', fontsize=12)  # Adjust fontsize as needed
ax.set_title('ND vs Rest', fontsize=12)
ax.spines[['right', 'top']].set_visible(False)
plt.show()

print(marker_intersect)

In [None]:
degs_df = sc.get.rank_genes_groups_df(cd8_effector, group = 'Severe')
degs_df = degs_df.loc[np.abs(degs_df.logfoldchanges) < 10]

crit_pval = 0.05
crit_l2fc = 0.5
nonsig_degs_df = degs_df.loc[((np.abs(degs_df['logfoldchanges']) <= crit_l2fc) & 
                           (degs_df['pvals_adj'] >= crit_pval))]
sig_degs_df = degs_df.loc[((np.abs(degs_df['logfoldchanges']) > crit_l2fc) & 
                           (degs_df['pvals_adj'] < crit_pval))]

marker_intersect = sig_degs_df['names']

# crit_pval = 0.05
# crit_l2fc = 0.5
# nonsig_degs_df = degs_df.loc[((np.abs(degs_df['logfoldchanges']) <= crit_l2fc) & 
#                            (degs_df['pvals_adj'] >= crit_pval))]
# sig_degs_df = degs_df.loc[((np.abs(degs_df['logfoldchanges']) > crit_l2fc) & 
#                            (degs_df['pvals_adj'] < crit_pval))]

fig, ax = plt.subplots(figsize=(3,3), dpi=300)

ax.scatter(degs_df['logfoldchanges'], 
            degs_df['pvals_adj'].apply(lambda x: -np.log10(x)), c='gray',
            s=3)

ax.scatter(nonsig_degs_df['logfoldchanges'], 
            nonsig_degs_df['pvals_adj'].apply(lambda x: -np.log10(x)), 
            s=3)

ax.scatter(sig_degs_df['logfoldchanges'], 
            sig_degs_df['pvals_adj'].apply(lambda x: -np.log10(x)), 
            c='g',
            s=3)

ax.scatter(sig_degs_df.loc[degs_df['names'].isin(marker_intersect)]['logfoldchanges'],
            sig_degs_df.loc[degs_df['names'].isin(marker_intersect)]['pvals_adj'].apply(lambda x: -np.log10(x)),
            c='r',
            s=3)

ax.axvline(x=crit_l2fc, c='gray', linestyle='--', linewidth=0.5)
ax.axvline(x=-crit_l2fc, c='gray', linestyle='--', linewidth=0.5)
ax.axhline(y=-np.log10(crit_pval), c='gray', linestyle='--', linewidth=0.5)

oxphos_lfcs = sig_degs_df.loc[degs_df['names'].isin(marker_intersect)]['logfoldchanges']
oxphos_pvals = sig_degs_df.loc[degs_df['names'].isin(marker_intersect)]['pvals_adj']
oxphos_genes = sig_degs_df.loc[degs_df['names'].isin(marker_intersect)]['names']


#ax.set_xlim(-2, 2)  # Adjust these limits as needed
#ax.set_ylim(-1, 30)  # Adjust these limits as needed


for lfc, pval, gene in zip(oxphos_lfcs, oxphos_pvals, oxphos_genes):
    ax.text(lfc+0.1, -np.log10(pval)+0.5, gene, fontsize=5)
         
ax.set_xlabel('Log2FC', fontsize=12)  # Adjust fontsize as needed
ax.set_ylabel('-log10(adj.pvalue)', fontsize=12)  # Adjust fontsize as needed
ax.set_title('Severe vs Rest', fontsize=12)
ax.spines[['right', 'top']].set_visible(False)
plt.show()

print(marker_intersect)

In [None]:
def plot_marker_heatmap(adata, gene_list, groupby='Severe', figsize=(2, 1.5), group_colors=None):
    """
    Plot a heatmap of z-scored expression values for marker genes.
    
    Parameters:
    -----------
    adata : AnnData
        The annotated data matrix.
    gene_list : list
        List of gene names to include in the heatmap.
    groupby : str
        Column in adata.obs to group cells by.
    figsize : tuple
        Figure size (width, height).
    group_colors : dict, optional
        Dictionary mapping group names to colors. If None, default colors will be used.
        
    Returns:
    --------
    fig : matplotlib.figure.Figure
        The figure object containing the heatmap.
    """
    # Check if all genes are in the dataset
    missing_genes = [gene for gene in gene_list if gene not in adata.var_names]
    if missing_genes:
        print(f"Warning: The following genes are not in the dataset: {missing_genes}")
        gene_list = [gene for gene in gene_list if gene in adata.var_names]
    
    if len(gene_list) == 0:
        print("No valid genes to plot!")
        return None
    
    # Create a copy of the AnnData object with only the genes of interest
    adata_subset = adata[:, gene_list].copy()
    
    # Z-score the data (scale each gene)
    sc.pp.scale(adata_subset, max_value=4)
    
    # Get average expression by group
    if groupby in adata_subset.obs.columns:
        # If group_colors is provided, set them in .uns
        if group_colors is not None and isinstance(group_colors, dict):
            # Get unique categories in the groupby column
            categories = adata_subset.obs[groupby].cat.categories if hasattr(adata_subset.obs[groupby], 'cat') else adata_subset.obs[groupby].unique()
            
            # Create color dictionary with the specified colors for each category
            color_dict = {}
            for cat in categories:
                if cat in group_colors:
                    color_dict[cat] = group_colors[cat]
            
            # Store the colors in adata_subset.uns
            if color_dict:
                adata_subset.uns[f'{groupby}_colors'] = [color_dict[cat] for cat in categories if cat in color_dict]
        
        # Clear current figure if any
        plt.clf()
        
        # First, compute dendrogram
        sc.tl.dendrogram(adata_subset, groupby=groupby)
        
        # Now create the heatmap
        sc.pl.heatmap(adata_subset, 
                      var_names=gene_list, 
                      groupby=['sample','patient','sample_type','grade'],
                      use_raw=False,
                      dendrogram=True,
                      swap_axes=True,
                      show_gene_labels=True,
                      cmap='RdBu_r',
                      vmin=-2,
                      vmax=2,
                      figsize=figsize,
                      show=False)
        
        # Get the current figure
        fig = plt.gcf()
        plt.tight_layout()
        
        return fig
    else:
        print(f"Error: '{groupby}' column not found in adata.obs")
        return None

In [None]:
fig = plot_marker_heatmap(cd8_effector, marker_intersect.to_list(), groupby='sample', figsize = (4,5))

# Save figure with 300 DPI
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/cd8effector_gene_sig.png', dpi=300, bbox_inches='tight')
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/cd8effector_gene_sig.pdf', dpi=300, bbox_inches='tight')

In [None]:
cd8_effector.uns['grade_colors']  =  [ '#005A8F','#09BB8C', '#B85000']

In [None]:
# Extract the gene names from sig_degs_df
genes_from_degs = set(sig_degs_df['names'])

# Extract the CD8+ Effector T cells gene list
cd8_effector_genes = set(gene_sig['CD8+ Effector T cells'])

# Find the intersection (common genes)
common_genes = genes_from_degs.intersection(cd8_effector_genes)

# Convert to a list if needed
common_genes_list = list(common_genes)
print(common_genes_list)

In [None]:
import matplotlib as mpl
mpl.rcParams['image.cmap'] = 'RdBu_r'  # Or your preferred colormap
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['xtick.labelsize'] = 8  # Colorbar tick label size 
mpl.rcParams['ytick.labelsize'] = 8  # Colorbar tick label size

gene_list = [
    'NKG7', 
    'CST7', 
    'GZMK', 
    'GZMH', 
    'GZMB', 
    'GZMM', 
    'GZMA', 
    'PRF1', 
    'LYST', 
    'APOBEC3G'
]

groupby ='sample'

fig = plot_marker_heatmap(cd8_effector, gene_list, groupby=groupby, figsize = (5,2.5))

# Save figure with 300 DPI
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{groupby}cd8effector_gene_deg.png', dpi=300, bbox_inches='tight', transparent = True)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{groupby}cd8effector_gene_deg.pdf', dpi=300, bbox_inches='tight', transparent = True)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{groupby}cd8effector_gene_deg.svg', dpi=300, bbox_inches='tight', transparent = True)

In [None]:
import scipy.sparse as sparse
import scipy.cluster.hierarchy as sch
for sample in meta_info['sample']:
    print(sample)
    # Get cells from SLV11 sample
    cells_subset = cd8_effector[cd8_effector.obs['sample']==sample].copy()
    sc.pp.scale(cells_subset, max_value=4)
    
    # Find valid genes
    valid_genes = [gene for gene in gene_list if gene in cd8_effector.var_names]
    print(f"Found {len(valid_genes)} genes from your list in the dataset")
    
    # Extract just the expression data for these genes
    gene_expr = cells_subset[:, valid_genes].X
    
    # Check the data matrix
    print(f"Expression matrix shape: {gene_expr.shape}")
    print(f"Data type: {type(gene_expr)}")
    
    # Convert sparse matrix to dense if needed
    if sparse.issparse(gene_expr):
        gene_expr_dense = gene_expr.toarray()
    else:
        gene_expr_dense = gene_expr
    
    # Check for NaN or infinite values
    if np.isnan(gene_expr_dense).any() or np.isinf(gene_expr_dense).any():
        print("Warning: Matrix contains NaN or infinite values")
        # Replace NaN/inf with zeros or other values
        gene_expr_dense = np.nan_to_num(gene_expr_dense)
    
    print(f"Dense matrix shape: {gene_expr_dense.shape}")
    
    # Perform hierarchical clustering directly
    cell_linkage = sch.linkage(gene_expr_dense, method='average', metric='correlation')
    cell_clusters = sch.fcluster(cell_linkage, 5, criterion='maxclust')
    
    # Add cluster labels back to the AnnData object
    cells_subset.obs['gene_clusters'] = cell_clusters.astype(str)
    
    # Assuming your clusters are labeled as strings like '1', '2', '3', etc.
    # Define your desired cluster order
    if sample is 'SLV13':
        desired_cluster_order = ['5','4', '3','1', '2'] 
    elif sample in ['SLV12']:
        desired_cluster_order = ['3','2','5','1','4']
    elif sample in ['SLV18']:
        desired_cluster_order = ['3','4','5','1','2'] 
    elif sample is 'SLV14':
        desired_cluster_order = ['1','5','3','2', '4'] 
    elif sample in ['SLV15','SLV16']:
        desired_cluster_order = ['2','5','3','1', '4'] 
    elif sample in ['SLV17']:
        desired_cluster_order = ['3','5', '2','1','4'] 

    else:
        desired_cluster_order = ['3', '5', '4', '1','2']  # Match the order from your image
    
    # Create a temporary categorical variable with the custom ordering
    cells_subset.obs['gene_clusters_ordered'] = cells_subset.obs['gene_clusters'].astype(pd.CategoricalDtype(categories=desired_cluster_order,ordered=True))
    
        # Set linewidth to 0 before plotting
    sc.settings.set_figure_params(dpi=300, frameon=False)
    
    # Plot the heatmap but don't display it yet
    fig_dict = sc.pl.heatmap(cells_subset, 
                  var_names=valid_genes,
                  groupby='gene_clusters_ordered',
                  use_raw=False,
                  swap_axes=True,
                  #show_gene_labels=True,
                  cmap='RdBu_r',
                  vmin=-2,
                  vmax=2,
                  figsize=((cells_subset.shape[0]/cd8_effector.shape[0])*4.5,1.3),
                  show=False)  # Important: don't show yet
    
    # Get the figure from the return dictionary
    fig = plt.gcf()  # Get current figure
    # Remove the standard margins
    fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
    # Now modify all axes in the figure to remove lines
    for ax in fig.get_axes():
        # Remove all spines (borders)
        for spine in ax.spines.values():
            spine.set_visible(False)
        
        # Remove tick marks but keep labels
        ax.tick_params(size=0, width=0, which='both')
        
        # Remove all lines in the plot
        for line in ax.get_lines():
            line.set_visible(False)
        
        # Remove grid
        ax.grid(False)
        
        # For any patch collections (often used for borders)
        for collection in ax.collections:
            collection.set_linewidth(0)

        # Remove y-axis tick labels (gene names)
        ax.set_yticklabels([])  # This removes the gene names
        
        # If you want to completely hide y-axis ticks too
        ax.tick_params(axis='y', which='both', length=0, labelleft=False)


        # Find all axes in the figure
    axes = fig.get_axes()
    
    # The main heatmap is usually the largest axes, keep only that one
    if len(axes) > 1:
        main_ax = None
        max_area = 0
        
        # Find the largest axes (likely the main plot)
        for ax in axes:
            pos = ax.get_position()
            area = pos.width * pos.height
            if area > max_area:
                max_area = area
                main_ax = ax
        
        # Remove all axes except the main one
        if main_ax:
            for ax in axes:
                if ax != main_ax:
                    ax.remove()

    
    
    # Now display the modified figure
    plt.tight_layout()
    # Save figure with 300 DPI
    fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}cd8effector_gene_deg.png', dpi=300, bbox_inches='tight', transparent = True, pad_inches = 0)
    fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}cd8effector_gene_deg.pdf', dpi=300, bbox_inches='tight', transparent = True, pad_inches = 0)
    fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}cd8effector_gene_deg.svg', dpi=300, bbox_inches='tight', transparent = True, pad_inches = 0)
    plt.show()

In [None]:
def plot_marker_heatmap_color_bar(adata, gene_list, groupby='Severe', figsize=(12, 10), 
                    group_colors=None, category_colors=None, transpose=True, 
                    show_cell_labels=False):
    """
    Plot a heatmap of z-scored expression values for marker genes with multiple color bars.
    
    Parameters:
    -----------
    adata : AnnData
        The annotated data matrix.
    gene_list : list
        List of gene names to include in the heatmap.
    groupby : str
        Primary column in adata.obs to group cells by for clustering.
    figsize : tuple
        Figure size (width, height) in inches. Default is (12, 10).
    group_colors : dict, optional
        Dictionary mapping group names to colors for the primary groupby variable.
    category_colors : dict, optional
        Dictionary of dictionaries mapping category values to colors for annotation columns.
        Example: {'grade': {'Grade 1': 'green', 'Grade 2': 'orange'}, 
                  'sample_type': {'Tumor': 'red', 'Normal': 'blue'}}
    transpose : bool, optional
        If True, transpose the heatmap to show genes as rows and samples as columns.
    show_cell_labels : bool, optional
        If True, show cell barcode labels. Default is False to hide them.
        
    Returns:
    --------
    fig : matplotlib.figure.Figure
        The figure object containing the heatmap with multiple color bars.
    """
    import scanpy as sc
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import seaborn as sns
    from matplotlib import gridspec
    from matplotlib.colors import ListedColormap
    from matplotlib.lines import Line2D
    
    # Initialize category_colors if None
    if category_colors is None:
        category_colors = {}
    
    # Check if all genes are in the dataset
    missing_genes = [gene for gene in gene_list if gene not in adata.var_names]
    if missing_genes:
        print(f"Warning: The following genes are not in the dataset: {missing_genes}")
        gene_list = [gene for gene in gene_list if gene in adata.var_names]
    
    if len(gene_list) == 0:
        print("No valid genes to plot!")
        return None
    
    # Create a copy of the AnnData object with only the genes of interest
    adata_subset = adata[:, gene_list].copy()
    
    # Z-score the data (scale each gene)
    sc.pp.scale(adata_subset, max_value=4)
    
    # Define annotation columns to use as color bars
    annotation_columns = ['sample', 'patient', 'sample_type', 'grade']
    valid_annotations = [col for col in annotation_columns if col in adata_subset.obs.columns]
    
    if not valid_annotations:
        print(f"Warning: None of the requested annotation columns {annotation_columns} found in adata.obs")
        valid_annotations = [groupby] if groupby in adata_subset.obs.columns else []
        
    if len(valid_annotations) == 0:
        print("Error: No valid grouping columns found in adata.obs")
        return None
    
    # Ensure primary groupby is valid
    if groupby not in adata_subset.obs.columns:
        if len(valid_annotations) > 0:
            groupby = valid_annotations[0]
            print(f"Warning: Specified groupby not found, using '{groupby}' instead")
        else:
            print("Error: No valid groupby column found")
            return None
    
    # Compute dendrogram for the primary groupby variable
    sc.tl.dendrogram(adata_subset, groupby=groupby)
    
    # Get the ordering from the dendrogram
    if f'dendrogram_{groupby}' in adata_subset.uns:
        dendrogram_info = adata_subset.uns[f'dendrogram_{groupby}']
        if 'categories_ordered' in dendrogram_info:
            # Get ordered categories
            ordered_categories = dendrogram_info['categories_ordered']
            
            # Get indices for each category to create the ordered index
            indices = []
            for cat in ordered_categories:
                cat_indices = np.where(adata_subset.obs[groupby] == cat)[0]
                indices.extend(cat_indices)
            
            # Subset the adata object with the ordered indices
            adata_subset = adata_subset[indices].copy()
    
    # Create a figure
    fig = plt.figure(figsize=figsize)
    
    # Define the layout based on transposition or not
    if transpose:
        # For transposed layout: dendrogram on top, annotations below, heatmap at the bottom
        n_annotations = len(valid_annotations)
        
        # Create gridspec with appropriate sizing and minimal spacing
        height_ratios = [0.1]  # Top dendrogram
        height_ratios.extend([0.05] * n_annotations)  # Annotation bars
        height_ratios.append(0.8)  # Main heatmap
        
        gs = gridspec.GridSpec(
            n_annotations + 2, 
            1, 
            height_ratios=height_ratios,
            hspace=0.02  # Minimal spacing between rows
        )
        
        # Create dendrogram axis at the top
        ax_dendrogram = fig.add_subplot(gs[0])
        
        # Set up main heatmap at the bottom
        ax_heatmap = fig.add_subplot(gs[-1])
        
        # Create axes for annotation bars (between dendrogram and heatmap)
        ax_annots = []
        for i, annot in enumerate(valid_annotations):
            # When creating each annotation subplot, explicitly share the x-axis with the heatmap
            ax = fig.add_subplot(gs[i+1], sharex=ax_heatmap)
            ax_annots.append((ax, annot))
    else:
        # For non-transposed layout: dendrogram on left, heatmap in middle, annotations on right
        n_annotations = len(valid_annotations)
        
        # Define grid with minimal spacing
        width_ratios = [0.1] + [0.8] + [0.05] * n_annotations  # dendrogram, heatmap, annotations
        
        gs = gridspec.GridSpec(
            1, 
            n_annotations + 2, 
            width_ratios=width_ratios,
            wspace=0.02  # Minimal spacing between columns
        )
        
        # Create dendrogram axis on the left
        ax_dendrogram = fig.add_subplot(gs[0, 0])
        
        # Set up main heatmap in the middle
        ax_heatmap = fig.add_subplot(gs[0, 1])
        
        # Set up annotation bars on the right
        ax_annots = []
        for i, annot in enumerate(valid_annotations):
            # When creating each annotation subplot, explicitly share the y-axis with the heatmap
            ax = fig.add_subplot(gs[0, i+2], sharey=ax_heatmap)
            ax_annots.append((ax, annot))
    
    # Plot dendrogram if available
    if f'dendrogram_{groupby}' in adata_subset.uns:
        dendrogram_info = adata_subset.uns[f'dendrogram_{groupby}']
        if 'linkage' in dendrogram_info:
            from scipy.cluster.hierarchy import dendrogram as scipy_dendrogram
            orientation = 'top' if transpose else 'left'
            scipy_dendrogram(
                dendrogram_info['linkage'], 
                ax=ax_dendrogram,
                labels=dendrogram_info['categories_ordered'] if 'categories_ordered' in dendrogram_info else None,
                orientation=orientation
            )
    ax_dendrogram.set_axis_off()
    
    # Prepare data for the main heatmap
    data_matrix = pd.DataFrame(
        adata_subset[:, gene_list].X, 
        index=adata_subset.obs_names, 
        columns=gene_list
    )
    
    # Transpose if requested
    if transpose:
        data_matrix = data_matrix.T
    
    # # Plot main heatmap
    # sns.heatmap(
    #     data_matrix, 
    #     ax=ax_heatmap, 
    #     cmap='RdBu_r',
    #     center=0, 
    #     vmin=-2, 
    #     vmax=2, 
    #     cbar_kws={'label': 'Z-score'},
    #     xticklabels=True if show_cell_labels else False,
    #     yticklabels=True
    # )
    
    # # Adjust heatmap labels based on transposition
    # if transpose:
    #     ax_heatmap.set_ylabel('Genes')
    #     ax_heatmap.set_xlabel('')
    # else:
    #     ax_heatmap.set_ylabel('')
    #     ax_heatmap.set_xlabel('Genes')
    #     if not show_cell_labels:
    #         ax_heatmap.set_yticks([])
    
    # Plot annotation color bars
    for ax, annot in ax_annots:
        # Get the annotation data
        annot_data = adata_subset.obs[annot].values
        categories = sorted(np.unique(annot_data))
        
        # Create color mapping
        if annot == groupby and group_colors is not None:
            # Use provided colors for the primary groupby
            colors = [group_colors.get(cat, f"C{i % 10}") for i, cat in enumerate(categories)]
        elif annot in category_colors:
            # Use provided category-specific colors
            colors = [category_colors[annot].get(cat, f"C{i % 10}") for i, cat in enumerate(categories)]
        else:
            # Generate colors for other annotations
            palette = sns.color_palette("tab10", len(categories))
            colors = [palette[i % 10] for i, cat in enumerate(categories)]
        
        # Create a categorical color map
        cat_cmap = ListedColormap(colors)
        
        # Create a numerical mapping for categories
        cat_to_num = {cat: i for i, cat in enumerate(categories)}
        numeric_data = np.array([cat_to_num[x] for x in annot_data])
        
        # Reshape for heatmap format based on transposition
        if transpose:
            # For transposed version, color bars go horizontally
            color_data = numeric_data.reshape(1, -1)
        else:
            # For original version, color bars go vertically
            color_data = numeric_data.reshape(-1, 1)
        
        # Plot the color bar
        sns.heatmap(
            color_data, 
            ax=ax, 
            cmap=cat_cmap, 
            cbar=False,
            vmin=0, 
            vmax=len(categories)-1,
            xticklabels=False
        )
        
        # Adjust tick labels based on transposition
        if transpose:
            ax.set_xticks([])
            ax.set_yticks([0.5])
            ax.set_yticklabels([annot], rotation=0)
        else:
            ax.set_yticks([])
            ax.set_xticks([0.5])
            ax.set_xticklabels([annot], rotation=90)
    
    # Create a unified legend for all annotations
    all_handles = []
    all_labels = []
    
    # Collect all legend elements and organize them by annotation
    for annot in valid_annotations:
        categories = sorted(adata_subset.obs[annot].unique())
        
        # Create color mapping for this annotation
        if annot == groupby and group_colors is not None:
            colors = [group_colors.get(cat, f"C{i % 10}") for i, cat in enumerate(categories)]
        elif annot in category_colors:
            colors = [category_colors[annot].get(cat, f"C{i % 10}") for i, cat in enumerate(categories)]
        else:
            palette = sns.color_palette("tab10", len(categories))
            colors = [palette[i % 10] for i, cat in enumerate(categories)]
        
        # Add a subtitle-like entry for this annotation
        all_handles.append(Line2D([0], [0], color='none'))
        all_labels.append(f"{annot}")
        
        # Add entries for each category under this annotation
        for i, cat in enumerate(categories):
            all_handles.append(plt.Rectangle((0,0), 1, 1, color=colors[i]))
            all_labels.append(f"  {cat}")  # Indent to show hierarchy
    
    # Place the unified legend in a separate axis
    legend_ax = fig.add_axes([0.05, 0.02, 0.9, 0.1])  # [left, bottom, width, height]
    legend_ax.axis('off')
    
    # Calculate number of columns based on total number of legend items
    total_items = len(all_labels)
    ncol = 1
    
    # Create the legend
    legend = legend_ax.legend(
        all_handles, 
        all_labels, 
        loc='center', 
        ncol=ncol,
        frameon=True,
        fontsize='small'
    )
    
    # Style the annotation titles (make them bold)
    for i, text in enumerate(legend.get_texts()):
        if all_labels[i] in valid_annotations:
            text.set_weight('bold')
            text.set_size('medium')
    
    # Add a title
    plt.suptitle(f"Gene Expression Heatmap (grouped by {groupby})", y=0.98)
    
    # Adjust spacing to make room for the legend at the bottom
    plt.subplots_adjust(bottom=0.15)
    
    # Use different rect values based on transposition
    rect = [0, 0.15, 0.95, 0.92]
    plt.tight_layout(rect=rect)
    
    return fig

# Example usage
sample_colors = {'Severe': 'red', 'ND': 'yellow', 'Mild': 'blue'}

# Create a dictionary for grade colors
grade_colors = {'ND': '#09BB8C', 'Mild': '#005A8F', 'Severe': '#B85000'}

# Create a category_colors dictionary to hold colors for multiple annotations
category_colors = {
    'grade': grade_colors,
    
    # You can add more category color mappings here
    # 'sample_type': {'Tumor': 'red', 'Normal': 'blue'},
    # 'patient': {'Patient1': 'cyan', 'Patient2': 'magenta'}
}


# Define the new color panel based on the image
new_colors = [
    '#3674B0',  # medium blue
    '#B3C6E5',  # light blue
    '#F18C31',  # orange
    '#F9CB9C',  # light orange/peach
    '#4CA743',  # green
    '#B6E39C',  # light green
    '#D73E3E',  # red
    '#F1BBBA'   # light pink/salmon
]

# Create category color dictionaries using the new colors
# For each category, we'll assign colors from the palette

# Sample colors
sample_colors = {
    'SLV11': new_colors[0],
    'SLV13': new_colors[1],
    'SLV14': new_colors[2],
    'SLV12': new_colors[3],
    'SLV15': new_colors[4],
    'SLV18': new_colors[5],
    'SLV16': new_colors[6],
    'SLV17': new_colors[7]
    # Add more samples if needed
}

# Patient colors
patient_colors = {
    'C159': new_colors[1],
    'C162': new_colors[3],
    'C179': new_colors[5],
    'C98': new_colors[7],
    'ND001': '#E5C2FF'
    # Add more patients if needed
}

# Sample type colors
sample_type_colors = {
    'Antrum': new_colors[1],
    'Ascending_Colon': new_colors[3],
    'Rectum': new_colors[5],
    'Stomach': new_colors[7],
    'Stomach_Body': '#E5C2FF'
    # Add more sample types if needed
}



# Combine all category colors into a single dictionary
category_colors = {
    'sample': sample_colors,
    'patient': patient_colors,
    'sample_type': sample_type_colors,
    'grade': grade_colors
}


# Print the category_colors dictionary to verify
print(category_colors)

import matplotlib as mpl
mpl.rcParams['image.cmap'] = 'RdBu_r'  # Or your preferred colormap
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['xtick.labelsize'] = 8  # Colorbar tick label size 
mpl.rcParams['ytick.labelsize'] = 8  # Colorbar tick label size
# Call the function with both sample and grade colors
fig = plot_marker_heatmap_color_bar(
    cd8_effector, 
    gene_list=gene_list, 
    groupby='sample',  # Keep sample as primary groupby
    figsize=(5, 2.5),
    group_colors=sample_colors,  # For the primary groupby (sample)
    category_colors=category_colors,  # For other categories including grade
    transpose=True,
    show_cell_labels=False
)



# Save figure with 300 DPI
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{groupby}cd8effector_gene_deg_color_bar.png', dpi=300, bbox_inches='tight', transparent = True)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{groupby}cd8effector_gene_deg_color_bar.pdf', dpi=300, bbox_inches='tight', transparent = True)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{groupby}cd8effector_gene_deg_color_bar.svg', dpi=300, bbox_inches='tight', transparent = True)

# Figure 6 F CD8 Distance to stem across grade

In [None]:
import numpy as np
import pandas as pd
import scipy.spatial
import scipy.stats as stats
import anndata
import matplotlib.pyplot as plt
import seaborn as sns

def compute_nndistance(adata, cell_type1, cell_type2, sample_id):
    """
    Computes nearest neighbor distance between cell_type1 and the closest cell_type2 in a given sample.
    """
    # Extract spatial coordinates
    spatial_coords = adata.obsm['spatial']
    spatial_coords = np.array(adata.obs[['array_row', 'array_col']])

    # Filter by sample and cell type (get the row labels)
    cell1_idx = adata.obs.loc[(adata.obs['hub_celltype'] == cell_type1) & (adata.obs['sample'] == sample_id)].index
    cell2_idx = adata.obs.loc[(adata.obs['hub_celltype'] == cell_type2) & (adata.obs['sample'] == sample_id)].index

    # Convert row labels to integer positions for spatial_coords indexing
    cell1_pos = adata.obs.index.get_indexer_for(cell1_idx)
    cell2_pos = adata.obs.index.get_indexer_for(cell2_idx)

    # Extract spatial coordinates using integer indices
    cell1_coords = spatial_coords[cell1_pos]
    cell2_coords = spatial_coords[cell2_pos]

    # Ensure we have valid coordinates
    if len(cell1_coords) == 0 or len(cell2_coords) == 0:
        raise ValueError(f"Not enough {cell_type1} or {cell_type2} cells in sample {sample_id}.")

    # Compute nearest neighbor distances using KDTree
    cell2_tree = scipy.spatial.KDTree(cell2_coords)
    nnd, _ = cell2_tree.query(cell1_coords)

    return nnd

# Function to compare NND between two samples
def compare_nnd_between_samples(nnd_sample1, nnd_sample2, label1, label2):
    """
    Performs statistical comparison of nearest neighbor distances between two samples.

    Parameters:
        nnd_sample1 (np.array): Nearest neighbor distances for sample 1.
        nnd_sample2 (np.array): Nearest neighbor distances for sample 2.
        label1 (str): Label for sample 1 (e.g., 'SLV14').
        label2 (str): Label for sample 2 (e.g., 'SLV16').

    Returns:
        dict: Statistical test results (Wilcoxon p-value, Cohen’s d, means).
    """

    # Perform Wilcoxon rank-sum test
    wilcoxon_test = stats.mannwhitneyu(nnd_sample1, nnd_sample2, alternative='two-sided')

    # Compute Cohen's d
    mean1, mean2 = np.mean(nnd_sample1), np.mean(nnd_sample2)
    std_pooled = np.sqrt((np.var(nnd_sample1, ddof=1) + np.var(nnd_sample2, ddof=1)) / 2)
    cohens_d = (mean1 - mean2) / std_pooled

    # Visualization
    results_df = pd.DataFrame({
        "Sample": [label1] * len(nnd_sample1) + [label2] * len(nnd_sample2),
        "NND": np.concatenate([nnd_sample1, nnd_sample2])
    })

    plt.figure(figsize=(8, 6))
    sns.boxplot(x="Sample", y="NND", data=results_df, palette=["red", "blue"])
    sns.swarmplot(x="Sample", y="NND", data=results_df, color="black", alpha=0.6)
    plt.xlabel("GVHD Sample")
    plt.ylabel("Nearest Neighbor Distance (NND)")
    plt.title(f"NND Comparison: {label1} vs. {label2}")
    plt.axhline(y=mean1, color="red", linestyle="dashed", label=f"{label1} Mean: {mean1:.2f}")
    plt.axhline(y=mean2, color="blue", linestyle="dashed", label=f"{label2} Mean: {mean2:.2f}")
    plt.legend()
    plt.savefig(f"nnd_comparison_{label1}_vs_{label2}.png", dpi=300)
    plt.show()

    # Output results
    return {
        "Wilcoxon Test Statistic": wilcoxon_test.statistic,
        "Wilcoxon p-value": wilcoxon_test.pvalue,
        "Cohen's d": cohens_d,
        f"Mean NND {label1}": mean1,
        f"Mean NND {label2}": mean2
    }

In [None]:
adata_integrate.uns['spatial'][sample_id]['scalefactors']

In [None]:
import numpy as np
import pandas as pd
import scipy.spatial
import scipy.stats as stats
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from statsmodels.stats.multitest import multipletests

def compute_nndistance_multiple(adata, cell_type1, cell_type2, sample_ids):
    """
    Computes nearest neighbor distances for multiple samples.
    
    Parameters:
        adata: AnnData object
        cell_type1: str, source cell type
        cell_type2: str, target cell type
        sample_ids: list of str, sample IDs to analyze
    
    Returns:
        dict: Dictionary mapping sample IDs to their NND arrays
    """
    nnd_dict = {}
    
    for sample_id in sample_ids:
        try:
            nnd = compute_nndistance(adata, cell_type1, cell_type2, sample_id)
            #nnd = nnd* adata.uns['spatial'][sample_id]['scalefactors'][sample_id]['microns_per_pixel']
            nnd_dict[sample_id] = nnd
        except ValueError as e:
            print(f"Warning: {e}")
            continue
            
    return nnd_dict

def compare_multiple_samples_nnd(nnd_dict):
    """
    Performs statistical comparison across multiple samples.
    
    Parameters:
        nnd_dict: dict, mapping sample IDs to NND arrays
    
    Returns:
        tuple: (DataFrame with pairwise comparisons, DataFrame with sample statistics)
    """
    # Calculate basic statistics for each sample
    stats_dict = {
        sample: {
            'Mean': np.mean(nnd),
            'Median': np.median(nnd),
            'Std': np.std(nnd),
            'N': len(nnd)
        }
        for sample, nnd in nnd_dict.items()
    }
    
    # Perform pairwise comparisons
    comparisons = []
    for (sample1, nnd1), (sample2, nnd2) in combinations(nnd_dict.items(), 2):
        # Mann-Whitney U test
        stat, pval = stats.mannwhitneyu(nnd1, nnd2, alternative='two-sided')
        
        # Cohen's d
        mean1, mean2 = np.mean(nnd1), np.mean(nnd2)
        std_pooled = np.sqrt((np.var(nnd1, ddof=1) + np.var(nnd2, ddof=1)) / 2)
        cohens_d = (mean1 - mean2) / std_pooled
        
        comparisons.append({
            'Sample1': sample1,
            'Sample2': sample2,
            'Mann-Whitney-U': stat,
            'p-value': pval,
            'Cohen\'s d': cohens_d
        })
    
    # Create DataFrames
    comparison_df = pd.DataFrame(comparisons)
    
    # Apply multiple testing correction
    comparison_df['p-value-adjusted'] = multipletests(
        comparison_df['p-value'], method='fdr_bh'
    )[1]
    
    stats_df = pd.DataFrame(stats_dict).T
    
    return comparison_df, stats_df

def visualize_multiple_nnd(nnd_dict, title="NND Comparison Across Samples"):
    """
    Creates visualization for multiple sample comparison.
    
    Parameters:
        nnd_dict: dict, mapping sample IDs to NND arrays
        title: str, plot title
    """
    # Prepare data for plotting
    plot_data = []
    for sample, nnd in nnd_dict.items():
        plot_data.extend([(sample, d) for d in nnd])
    
    df = pd.DataFrame(plot_data, columns=['Sample', 'NND'])
    
    # Create figure
    plt.figure(figsize=(10, 6))
    
    # Create violin plot with boxplot inside
    #sns.violinplot(data=df, x='Sample', y='NND', inner='box')
    
    # Add individual points with jitter
    #sns.stripplot(data=df, x='Sample', y='NND', color='black', alpha=0.3, size=3, jitter=0.2)
    
    # Add mean lines
    #for sample in nnd_dict.keys():
    #    mean_val = df[df['Sample'] == sample]['NND'].mean()
    #    plt.axhline(y=mean_val, color='red', linestyle='--', alpha=0.5,
    #               xmin=df['Sample'].unique().tolist().index(sample)/len(df['Sample'].unique()),
    #               xmax=(df['Sample'].unique().tolist().index(sample)+1)/len(df['Sample'].unique()))
    
    #plt.title(title)
    #plt.xlabel("Sample")
    #plt.ylabel("Nearest Neighbor Distance (NND)")\
    my_color_map=['purple',
                   'orange',
                   'teal','green'
                     ]


    plt.figure(figsize=(2.8,1.8),dpi=900)

    sns.set_theme(style="ticks", font_scale=1)
    boxprops = {'edgecolor': 'k', 'linewidth': 2, 'facecolor': 'w'}
    lineprops = {'color': 'k', 'linewidth': 2}
    g1=sns.boxplot(data=df, x='Sample', y='NND',
                   fliersize=0,
                
                   palette=my_color_map
                  )
    #plt.xticks(rotation=90)
    #plt.xticks([])
    plt.xlabel('')
    #box_medians = [np.median(patch.get_path().vertices[:, 1]) for patch in g1.patches]
    
    plt.ylim([0,30])
    
    #plt.savefig('/Users/siyuhe/Starfysh_revision/manuscript/fig5/5d_new_tumor.pdf',transparent=False,bbox_inches='tight',dpi=900,)
    #plt.savefig('/Users/siyuhe/Starfysh_revision/manuscript/fig5/5d_new_tumor.png',transparent=False,bbox_inches='tight',dpi=900,)


    plt.xticks(rotation=45)
    plt.tight_layout()
    
    return plt.gcf()

def visualize_multiple_nnd(nnd_dict, title="NND Comparison Across Samples"):
    """
    Creates a clean visualization for multiple sample comparison using violin plots.
    
    Parameters:
        nnd_dict: dict, mapping sample IDs to NND arrays
        title: str, plot title
    """
    # Prepare data for plotting
    plot_data = []
    for sample, nnd in nnd_dict.items():
        plot_data.extend([(sample, d) for d in nnd])
    
    df = pd.DataFrame(plot_data, columns=['Sample', 'NND'])
    
    # Define color palette
    my_color_map = ['purple', 'orange', 'teal', 'green']
    
    # Create figure
    plt.figure(figsize=(2.8, 1.8), dpi=900)
    sns.set_theme(style="ticks", font_scale=1)
    
    # Create minimalist violin plot
    g1 = sns.violinplot(
        data=df, 
        x='Sample', 
        y='NND',
        palette=my_color_map,
        inner=None,         # No inner elements
        linewidth=1,
        scale="width",      # Scale violins to have same width
        saturation=1
    )
    
    # Add thin black outlines to violins
    for violin in g1.collections:
        violin.set_edgecolor('k')
    
    # Add discrete median lines
    medians = df.groupby('Sample')['NND'].mean()
    for i, sample in enumerate(medians.index):
        plt.hlines(
            y=medians[sample], 
            xmin=i-0.3, 
            xmax=i+0.3, 
            colors='white', 
            linestyles='-', 
            linewidth=2
        )
    
    # Customize the plot
    plt.xlabel('')
    plt.ylim([0, 18])
    plt.xticks(rotation=45)
    plt.tight_layout()
    
    return plt.gcf()       

# Example usage:

# Define your samples
sample_ids = ['SLV11', 'SLV13', 'SLV15', 'SLV18']


# Compute NND for all samples
nnd_dict = compute_nndistance_multiple(
    adata_integrate, 
    cell_type1="stem", 
    cell_type2="CD8+ Effector T  cells", 
    sample_ids=sample_ids
)

# Perform statistical comparison
comparison_df, stats_df = compare_multiple_samples_nnd(nnd_dict)

# Print results
print("\n--- Sample Statistics ---")
print(stats_df)
print("\n--- Pairwise Comparisons ---")
print(comparison_df)

# Create visualization
fig = visualize_multiple_nnd(nnd_dict)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/nnd_cd8_effector_stem_stomach.pdf', 
       bbox_inches='tight', dpi=300)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/nnd_cd8_effector_stem_stomatch.png', 
       bbox_inches='tight', dpi=300)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/nnd_cd8_effector_stem_stomatch.svg', 
       bbox_inches='tight', dpi=300)
plt.show()

In [None]:
# Define your samples
sample_ids = ['SLV12', 'SLV14', 'SLV16', 'SLV17']

# Compute NND for all samples
nnd_dict = compute_nndistance_multiple(
    adata_integrate, 
    cell_type1="stem", 
    cell_type2="CD8+ Effector T  cells", 
    sample_ids=sample_ids
)

# Perform statistical comparison
comparison_df, stats_df = compare_multiple_samples_nnd(nnd_dict)

# Print results
print("\n--- Sample Statistics ---")
print(stats_df)
print("\n--- Pairwise Comparisons ---")
print(comparison_df)

# Create visualization
fig = visualize_multiple_nnd(nnd_dict)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/nnd_cd8_effector_stem_colon.pdf', 
       bbox_inches='tight', dpi=300)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/nnd_cd8_effector_stem_colon.png', 
       bbox_inches='tight', dpi=300)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/nnd_cd8_effector_stem_colon.svg', 
       bbox_inches='tight', dpi=300)
plt.show()

In [None]:
# Define your samples
sample_ids = ['SLV11', 'SLV13', 'SLV15', 'SLV18','SLV12', 'SLV14', 'SLV16', 'SLV17']

# Compute NND for all samples
nnd_dict = compute_nndistance_multiple(
    adata_integrate, 
    cell_type1="stem", 
    cell_type2="CD8+ Effector T  cells", 
    sample_ids=sample_ids
)

# Perform statistical comparison
comparison_df, stats_df = compare_multiple_samples_nnd(nnd_dict)

# Print results
print("\n--- Sample Statistics ---")
print(stats_df)
print("\n--- Pairwise Comparisons ---")
print(comparison_df)

# Create visualization
fig = visualize_multiple_nnd(nnd_dict)
plt.show()

In [None]:
stats_df['sample'] =stats_df.index

mapping = {}
for _, row in meta_info.iterrows():
    mapping[row['sample']] = f"{row['patient']}_{row['tissue_type']}"

# Copy the stats_df to avoid modifying the original
new_stats_df = stats_df.copy()

# Reset the index to make 'sample' a regular column
new_stats_df = new_stats_df.reset_index()

# Replace the sample values with patient + tissue type
new_stats_df['sample'] = new_stats_df['sample'].map(mapping)

# Display the new dataframe
new_stats_df

In [None]:
import matplotlib as mpl
mpl.rcParams.update({'font.size': 10, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 10
mpl.rcParams['xtick.labelsize'] = 10  # Colorbar tick label size 
mpl.rcParams['ytick.labelsize'] = 10  # Colorbar tick label size
mpl.rcParams['axes.linewidth'] = 1

def plot_nnd_vs_cell_proportion(adata, nnd_stats_df, cell_type, sample_ids, palette=None):
    """
    Plot the mean NND against the proportion of a specific cell type in each sample,
    with points colored by grade.
    
    Parameters:
    -----------
    adata : AnnData
        AnnData object with sample, cell type and grade annotations
    nnd_stats_df : DataFrame
        DataFrame with NND statistics from compare_multiple_samples_nnd
    cell_type : str
        Cell type to calculate proportion for
    sample_ids : list
        List of sample IDs to include in the plot
    palette : dict, optional
        Color palette mapping grades to colors
        
    Returns:
    --------
    fig : matplotlib figure
        Figure with scatter plot and regression line
    """
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    import numpy as np
    from scipy import stats
    
    # Set default palette if none provided
    if palette is None:
        palette = {'ND': '#09BB8C', 'Mild': '#005A8F', 'Severe': '#B85000'}
    
    # Calculate cell type proportions and get grades for each sample
    proportions = []
    
    for sample_id in sample_ids:
        # Subset to current sample
        sample_mask = adata.obs["sample"] == sample_id
        if sum(sample_mask) == 0:
            print(f"Warning: Sample {sample_id} not found in data")
            continue
            
        adata_sample = adata[sample_mask]
        
        # Count cells of the specified type
        target_cells = sum(adata_sample.obs["hub_celltype"] == cell_type)
        total_cells = adata_sample.shape[0]
        
        # Get the grade for this sample (assuming it's consistent within a sample)
        # If grades vary within a sample, you might want to use the most common grade
        if "grade" in adata_sample.obs.columns:
            grade = adata_sample.obs["grade"].iloc[0]
        else:
            grade = "unknown"
        
        proportion = target_cells / total_cells if total_cells > 0 else 0
        proportions.append({
            'sample': sample_id,
            'proportion': proportion,
            'count': target_cells,
            'total': total_cells,
            'grade': grade
        })
    
    prop_df = pd.DataFrame(proportions)

    prop_df['sample'] = prop_df['sample'].map(mapping)



    
    # Merge with NND statistics
    plot_df = pd.merge(nnd_stats_df, prop_df, on='sample')
    plot_df['Mean'] = plot_df['Mean']*16
    
    # Create figure
    fig, ax = plt.subplots(figsize=(6, 3))
    
    # Scatter plot colored by grade
    sns.scatterplot(
        x='proportion', 
        y='Mean', 
        data=plot_df, 
        hue='grade',
        palette=palette,
        s=100, 
        ax=ax
    )
    
    # Add labels to points
    for idx, row in plot_df.iterrows():
        ax.annotate(row['sample'], (row['proportion'], row['Mean']), 
                    xytext=(5, 5), textcoords='offset points')
    
    # Add regression line
    if len(plot_df) > 2:  # Need at least 3 points for meaningful regression
        x = plot_df['proportion']
        y = plot_df['Mean']
        
        # Add regression line
        m, b = np.polyfit(x, y, 1)
        ax.plot(x, m*x + b, color='gray', linestyle='--')
        
        # Calculate correlation
        correlation, p_value = stats.pearsonr(x, y)
        ax.text(0.05, 0.95, f"r = {correlation:.3f}, p = {p_value:.3f}", 
                transform=ax.transAxes, fontsize=10,
                verticalalignment='top')
    
    #ax.set_title(f'Mean NND vs Proportion of {cell_type}')
    ax.set_xlabel(f'Proportion of {cell_type}')
    ax.set_ylabel('NND to Intestinal Stem cells (uM)')
    
    # Add legend
    ax.legend(title='Grade')
    
    plt.tight_layout()
    return fig

# Example of how to call the function with the specified palette
fig = plot_nnd_vs_cell_proportion(
    adata_integrate,
    new_stats_df,
    cell_type="CD8+ Effector T  cells",
    sample_ids=sample_ids,
    palette={'ND': '#09BB8C', 'Mild': '#005A8F', 'Severe': '#B85000'}
)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/Mean NND vs Proportion.pdf', 
       bbox_inches='tight', dpi=300, transparent=True)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/Mean NND vs Proportion.png', 
       bbox_inches='tight', dpi=300, transparent=True)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/Mean NND vs Proportion.svg', 
       bbox_inches='tight', dpi=300, transparent=True)
plt.show()

In [None]:
# Assuming obs is your AnnData object's .obs DataFrame or a similar pandas DataFrame
stem_cell_types = ['Instestinal_Epithelial cells', 'Stomach_epithelial_cells','Stomach_epithelial_cells']

# Create a new column that labels cells as "stem" if they belong to the specified cell types
adata_integrate.obs['hub_celltype'] = adata_integrate.obs['hub_celltype'].apply(lambda x: 'epithelial' if x in stem_cell_types else x)

In [None]:
# Define your samples
sample_ids = ['SLV11', 'SLV13', 'SLV15', 'SLV18','SLV12', 'SLV14', 'SLV16', 'SLV17']

# Compute NND for all samples
nnd_dict = compute_nndistance_multiple(
    adata_integrate, 
    cell_type1="epithelial", 
    cell_type2="CD8+ Effector T  cells", 
    sample_ids=sample_ids
)

# Perform statistical comparison
comparison_df, stats_df = compare_multiple_samples_nnd(nnd_dict)

# Print results
print("\n--- Sample Statistics ---")
print(stats_df)
print("\n--- Pairwise Comparisons ---")
print(comparison_df)

# Create visualization
fig = visualize_multiple_nnd(nnd_dict)
plt.show()

In [None]:
stats_df['sample'] =stats_df.index

mapping = {}
for _, row in meta_info.iterrows():
    mapping[row['sample']] = f"{row['patient']}_{row['tissue_type']}"

# Copy the stats_df to avoid modifying the original
new_stats_df = stats_df.copy()

# Reset the index to make 'sample' a regular column
new_stats_df = new_stats_df.reset_index()

# Replace the sample values with patient + tissue type
new_stats_df['sample'] = new_stats_df['sample'].map(mapping)

# Display the new dataframe
new_stats_df

# Example of how to call the function with the specified palette
fig = plot_nnd_vs_cell_proportion(
    adata_integrate,
    new_stats_df,
    cell_type="CD8+ Effector T  cells",
    sample_ids=sample_ids,
    palette={'ND': '#09BB8C', 'Mild': '#005A8F', 'Severe': '#B85000'}
)

In [None]:
for cell_type in list(adata_integrate.obs['hub_celltype'].unique()):
    # Define your samples
    sample_ids = ['SLV11', 'SLV13', 'SLV15', 'SLV18','SLV12', 'SLV14', 'SLV16', 'SLV17']
    
    # Compute NND for all samples
    nnd_dict = compute_nndistance_multiple(
        adata_integrate, 
        cell_type1="stem", 
        cell_type2=cell_type, 
        sample_ids=sample_ids
    )
    
    # Perform statistical comparison
    comparison_df, stats_df = compare_multiple_samples_nnd(nnd_dict)
    
    # Print results
    print("\n--- Sample Statistics ---")
    print(stats_df)
    print("\n--- Pairwise Comparisons ---")
    print(comparison_df)
    
    # Create visualization
    fig = visualize_multiple_nnd(nnd_dict)
    plt.show()
    
    stats_df['sample'] =stats_df.index
    
    mapping = {}
    for _, row in meta_info.iterrows():
        mapping[row['sample']] = f"{row['patient']}_{row['tissue_type']}"
    
    # Copy the stats_df to avoid modifying the original
    new_stats_df = stats_df.copy()
    
    # Reset the index to make 'sample' a regular column
    new_stats_df = new_stats_df.reset_index()
    
    # Replace the sample values with patient + tissue type
    new_stats_df['sample'] = new_stats_df['sample'].map(mapping)
    
    # Display the new dataframe
    new_stats_df
    print(cell_type)
    # Example of how to call the function with the specified palette
    fig = plot_nnd_vs_cell_proportion(
        adata_integrate,
        new_stats_df,
        cell_type=cell_type,
        sample_ids=sample_ids,
        palette={'ND': '#09BB8C', 'Mild': '#005A8F', 'Severe': '#B85000'}
    )

In [None]:
# Define your samples
sample_ids = ['SLV12', 'SLV14', 'SLV16', 'SLV17']

# Compute NND for all samples
nnd_dict = compute_nndistance_multiple(
    adata_integrate, 
    cell_type1="CD8+ Effector T  cells", 
    cell_type2="CD4+ Regulatory T cells", 
    sample_ids=sample_ids
)

# Perform statistical comparison
comparison_df, stats_df = compare_multiple_samples_nnd(nnd_dict)

# Print results
print("\n--- Sample Statistics ---")
print(stats_df)
print("\n--- Pairwise Comparisons ---")
print(comparison_df)

# Create visualization
fig = visualize_multiple_nnd(nnd_dict)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/nnd_cd8_effector_Treg_colon.pdf', 
       bbox_inches='tight', dpi=300)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/nnd_cd8_effector_Treg_colon.png', 
       bbox_inches='tight', dpi=300)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/nnd_cd8_effector_Treg_colon.svg', 
       bbox_inches='tight', dpi=300)
plt.show()

In [None]:
# Define your samples
sample_ids = ['SLV11', 'SLV13', 'SLV15', 'SLV18']

# Compute NND for all samples
nnd_dict = compute_nndistance_multiple(
    adata_integrate, 
    cell_type1="CD8+ Effector T  cells", 
    cell_type2="CD4+ Effector T cells", 
    sample_ids=sample_ids
)

# Perform statistical comparison
comparison_df, stats_df = compare_multiple_samples_nnd(nnd_dict)

# Print results
print("\n--- Sample Statistics ---")
print(stats_df)
print("\n--- Pairwise Comparisons ---")
print(comparison_df)

# Create visualization
fig = visualize_multiple_nnd(nnd_dict)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/nnd_cd8_effector_cd4_stomach.pdf', 
       bbox_inches='tight', dpi=300)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/nnd_cd8_effector_cd4_stomach.png', 
       bbox_inches='tight', dpi=300)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/nnd_cd8_effector_cd4_stomach.svg', 
       bbox_inches='tight', dpi=300)
plt.show()

In [None]:
# Define your samples
sample_ids = ['SLV11', 'SLV13', 'SLV15', 'SLV18']

# Compute NND for all samples
nnd_dict = compute_nndistance_multiple(
    adata_integrate, 
    cell_type1="CD8+ Effector T  cells", 
    cell_type2="CD4+ Regulatory T cells", 
    sample_ids=sample_ids
)

# Perform statistical comparison
comparison_df, stats_df = compare_multiple_samples_nnd(nnd_dict)

# Print results
print("\n--- Sample Statistics ---")
print(stats_df)
print("\n--- Pairwise Comparisons ---")
print(comparison_df)

# Create visualization
fig = visualize_multiple_nnd(nnd_dict)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/nnd_cd8_effector_cd4_stomach.pdf', 
       bbox_inches='tight', dpi=300)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/nnd_cd8_effector_cd4_stomach.png', 
       bbox_inches='tight', dpi=300)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/nnd_cd8_effector_cd4_stomach.svg', 
       bbox_inches='tight', dpi=300)
plt.show()

In [None]:
adata

In [None]:
adata

In [None]:
meta_info = [
    ['SLV11', 'C159', 'Antrum', 'Severe'],
    ['SLV12', 'C162', 'Rectum', 'Mild'],
    ['SLV13', 'C98', 'Stomach_Body', 'Severe'],
    ['SLV14', 'C159', 'Rectum', 'Severe'],
    ['SLV15', 'C179', 'Antrum', 'Mild'],
    ['SLV16', 'C179', 'Ascending_Colon', 'Mild'],
    ['SLV17', 'ND001', 'Ascending_Colon', 'ND'],
    ['SLV18', 'C162', 'Stomach', 'Mild']
]
meta_info = pd.DataFrame(meta_info,columns=['sample','patient','tissue_type','grade'])

In [None]:
import matplotlib as mpl
mpl.rcParams['image.cmap'] = 'RdBu_r'  # Or your preferred colormap
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['xtick.labelsize'] = 8  # Colorbar tick label size 
mpl.rcParams['ytick.labelsize'] = 8  # Colorbar tick label size
new_colormap = [
    'darkslateblue',    # 0
    'cornflowerblue',   # 1
    'red',              # 2
    'blueviolet',       # 3
    'skyblue',          # 4
    'orchid',           # 5
    'yellowgreen',      # 6
    'palevioletred',    # 7
    'orange',           # 8
    'cadetblue',        # 9
    'limegreen',        # 10
    'cyan',             # 11
    'gold',             # 12
    'slategray',        # 13
    'olive',            # 14
    'blue',             # 15
    'linen',            # 16
    'mistyrose',        # 17
    'peru',             # 18
    'darkturquoise',    # 19
    'teal',             # 20
    'salmon',           # 21 (new color)
    'violet',           # 22 (new color)
    'dodgerblue',       # 23 (new color)
    'darkgreen',       # 24 (new color)
    'mediumaquamarine', # 25 (new color)
    'tomato',           # 26 (new color)
    'sandybrown',       # 27 (new color)
    'darkkhaki',        # 28 (new color)
    'lightseagreen',    # 29 (new color)
    'mediumorchid',     # 30 (new color)
    'crimson',          # 31 (new color)
    'olivedrab',        # 32 (new color)
    'steelblue',        # 33 (new color)
    'plum',             # 34 (new color)
    'chocolate'         # 35 (new color)
]

# For the exact mapping shown in the example
new_order = list(range(36))
exact_mapping = {}
for new_index, original_index in enumerate(new_order):
    if original_index < len(new_colormap):
        exact_mapping[original_index] = new_colormap[original_index]

print(exact_mapping)


output_folder = '/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/'
y_start, y_end = 2300, 2900  # Adjust these values for the vertical range
x_start, x_end = 1000, 1400  # Adjust these values for the horizontal range

sample = 'SLV14'
# Path to the JSON file
file_path = f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/binned_outputs/square_016um/spatial/scalefactors_json.json'

    
plot_category_on_histology_crop(
    adata_integrate[adata_integrate.obs['sample']==sample],  # adata
    'hub',                                              # column
    f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/spatial/tissue_hires_image.png',  # his_loc
    scalefactor,                                           # scalefactors
    100,                                                   # size
    output_folder,                                         # output_folder
    sample,                                                # sample_id
    x_start,                                               # x_start
    x_end,                                                 # x_end
    y_start,                                               # y_start
    y_end,                                                 # y_end                                              
    [0,18,14,13,21,20,29,28,4,30,1,11,24,23,9,8],# plot_cat
    color_map=exact_mapping                             # color_map (keyword argument)
)

In [None]:
y_start, y_end = 1800, 2000  # Adjust these values for the vertical range
x_start, x_end = 1600, 2100  # Adjust these values for the horizontal range
sample = 'SLV14'
# Path to the JSON file
file_path = f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/binned_outputs/square_016um/spatial/scalefactors_json.json'

    
plot_category_on_histology_crop(
    adata_integrate[adata_integrate.obs['sample']==sample],  # adata
    'hub',                                              # column
    f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/spatial/tissue_hires_image.png',  # his_loc
    scalefactor,                                           # scalefactors
    100,                                                   # size
    output_folder,                                         # output_folder
    sample,                                                # sample_id
    x_start,                                               # x_start
    x_end,                                                 # x_end
    y_start,                                               # y_start
    y_end,                                                 # y_end                                              
    [0,18,14,13,21,20,29,28,4,30,1,11,24,23,9,8],# plot_cat
    color_map=exact_mapping                             # color_map (keyword argument)
)

In [None]:
y_start, y_end = 1400, 1800  # Adjust these values for the vertical range
x_start, x_end = 1700, 2300  # Adjust these values for the horizontal range
sample = 'SLV14'
# Path to the JSON file
file_path = f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/binned_outputs/square_016um/spatial/scalefactors_json.json'

    
plot_category_on_histology_crop(
    adata_integrate[adata_integrate.obs['sample']==sample],  # adata
    'hub',                                              # column
    f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/spatial/tissue_hires_image.png',  # his_loc
    scalefactor,                                           # scalefactors
    100,                                                   # size
    output_folder,                                         # output_folder
    sample,                                                # sample_id
    x_start,                                               # x_start
    x_end,                                                 # x_end
    y_start,                                               # y_start
    y_end,                                                 # y_end                                              
    [0,18,14,13,21,20,29,28,4,30,1,11,24,23,9,8],# plot_cat
    color_map=exact_mapping                             # color_map (keyword argument)
)

In [None]:
def plot_legend_only(
    adata,
    column,
    output_folder,
    sample_id,
    color_map=None,
    title_fontsize='medium',
    legend_fontsize='small',
    markerscale=2,
    figsize=(8, 6)
):
    """
    Create and save just the legend for the categories in the specified column.
    
    Parameters:
        adata: AnnData object.
        column: Column in adata.obs containing categorical data to plot.
        output_folder: Directory to save the output image.
        sample_id: Sample ID for naming the output file.
        color_map: List of colors to map to categories (optional).
        title_fontsize: Font size for the legend title.
        legend_fontsize: Font size for the legend items.
        markerscale: Scale factor for the legend markers.
        figsize: Size of the figure.
    
    Returns:
        None (saves and shows the legend).
    """
    import os
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    
    # Extract categories from adata.obs
    categories = adata.obs[column]
    
    # Assign colors to categories
    unique_categories = sorted(categories.unique())  # Sort categories for consistent colors
    if color_map is None:
        color_map = plt.cm.tab20.colors  # Default colormap if none is provided
    category_color_map = {cat: color_map[i % len(color_map)] for i, cat in enumerate(unique_categories)}
    
    # Create figure with a white background
    fig, ax = plt.subplots(figsize=figsize)
    ax.set_facecolor('white')
    
    # Hide axes
    ax.set_axis_off()
    
    # Create patches for the legend
    patches = [mpatches.Patch(color=color, label=category) 
               for category, color in category_color_map.items()]
    
    # Add only the legend with 3 columns
    legend = ax.legend(
        handles=patches,
        title=column,
        loc='center',
        fontsize=legend_fontsize,
        markerscale=markerscale,
        title_fontsize=title_fontsize,
        ncol=1  # Display legend in 3 columns
    )
    
    # Tight layout
    plt.tight_layout()
    
    # Save the legend
    legend_output_path = os.path.join(output_folder, f"{sample_id}_{column}_legend.png")
    fig.savefig(legend_output_path, bbox_inches='tight', dpi=300, facecolor='white')
    
    plt.show()
    plt.close(fig)
    print(f"Legend saved to {legend_output_path}")


plot_legend_only(
    adata_integrate[adata_integrate.obs['sample']==sample],
    column="hub",
    output_folder=output_folder,
    sample_id="SLV14",
    color_map=new_colormap
    # Optional parameters:
    # color_map=custom_colors,  # If you want specific colors
    # title_fontsize='large',   # Adjust title size
    # legend_fontsize='medium', # Adjust legend text size
    # markerscale=3,           # Adjust marker size
    # figsize=(10, 8)          # Change figure dimensions
)

In [None]:
sample = 'SLV14'
adata_processed = adata_integrate[adata_integrate.obs['sample']==sample]
img = plt.imread(f'/Users/lingting/Documents/GVHD_project/visiumHD/data/{sample}/spatial/tissue_hires_image.png')
y_start, y_end = 2300, 2900  # Adjust these values for the vertical range
x_start, x_end = 1000, 1400  # Adjust these values for the horizontal range

# Slice the image array to extract the region of interest
cropped_img = img[y_start:y_end, x_start:x_end]

# Plot the cropped region
plt.figure(figsize=(6, 6), dpi=200)
plt.imshow(cropped_img)
plt.axis('off')  # Optional: Remove axes for a cleaner look
plt.show()

subset_adata_loss1 = adata_processed[
    (adata_processed.obs['pxl_row_in_fullres'] *scalefactor['tissue_hires_scalef'] >= y_start) & (adata_processed.obs['pxl_row_in_fullres'] *scalefactor['tissue_hires_scalef'] <= y_end) &
    (adata_processed.obs['pxl_col_in_fullres'] *scalefactor['tissue_hires_scalef'] >= x_start) & (adata_processed.obs['pxl_col_in_fullres'] *scalefactor['tissue_hires_scalef'] <= x_end)]


y_start, y_end = 1400, 1800  # Adjust these values for the vertical range
x_start, x_end = 1700, 2300  # Adjust these values for the horizontal range

# Slice the image array to extract the region of interest
cropped_img = img[y_start:y_end, x_start:x_end]

# Plot the cropped region
plt.figure(figsize=(6, 6), dpi=200)
plt.imshow(cropped_img)
plt.axis('off')
plt.show()

subset_adata = adata_processed[
    (adata_processed.obs['pxl_row_in_fullres'] *scalefactor['tissue_hires_scalef'] >= y_start) & (adata_processed.obs['pxl_row_in_fullres'] *scalefactor['tissue_hires_scalef'] <= y_end) &
    (adata_processed.obs['pxl_col_in_fullres'] *scalefactor['tissue_hires_scalef'] >= x_start) & (adata_processed.obs['pxl_col_in_fullres'] *scalefactor['tissue_hires_scalef'] <= x_end)
]

y_start, y_end = 1800, 2000  # Adjust these values for the vertical range
x_start, x_end = 1600, 2100  # Adjust these values for the horizontal range

# Slice the image array to extract the region of interest
cropped_img = img[y_start:y_end, x_start:x_end]

# Plot the cropped region
plt.figure(figsize=(6, 6), dpi=200)
plt.imshow(cropped_img)
plt.axis('off')  # Optional: Remove axes for a cleaner look
plt.show()

#y_start, y_end = 1500, 1800  # Adjust these values for the vertical range
#x_start, x_end = 1800, 2300  # Adjust these values for the horizontal range
subset_adata_loss2 =adata_processed[
    (adata_processed.obs['pxl_row_in_fullres'] *scalefactor['tissue_hires_scalef'] >= y_start) & (adata_processed.obs['pxl_row_in_fullres'] *scalefactor['tissue_hires_scalef'] <= y_end) &
    (adata_processed.obs['pxl_col_in_fullres'] *scalefactor['tissue_hires_scalef'] >= x_start) & (adata_processed.obs['pxl_col_in_fullres'] *scalefactor['tissue_hires_scalef'] <= x_end)
]

In [None]:
from scipy.stats import fisher_exact
from statsmodels.stats.multitest import multipletests
hub_list = []
loss_counts = []
loss_props = []
subset_counts = []
subset_props = []
odds_ratios = []
pvalues = []
combined_loss = pd.concat([subset_adata_loss1.obs['hub'], subset_adata_loss2.obs['hub']])
all_hubs = np.union1d(combined_loss.unique(), subset_adata.obs['hub'].unique())
# Combine observations from loss1 and loss2
# Calculate total counts
n_loss = len(combined_loss)
n_subset = len(subset_adata.obs['hub'])
for hub in combined_loss.unique():
    loss_count = (combined_loss == hub).sum()
    subset_count = (subset_adata.obs['hub'] == hub).sum()
    
    # Create contingency table
    table = np.array([
        [loss_count, n_loss - loss_count],
        [subset_count, n_subset - subset_count]
    ])
    
    # Perform test
    odds_ratio, pvalue = fisher_exact(table)
    
    # Append all results
    hub_list.append(hub)
    loss_counts.append(int(loss_count))
    loss_props.append(float(loss_count) / n_loss)
    subset_counts.append(int(subset_count))
    subset_props.append(float(subset_count) / n_subset)
    odds_ratios.append(float(odds_ratio))
    pvalues.append(float(pvalue))

# Create DataFrame
results_df = pd.DataFrame({
    'hub': hub_list,
    'loss_count': loss_counts,
    'loss_proportion': loss_props,
    'subset_count': subset_counts,
    'subset_proportion': subset_props,
    'odds_ratio': odds_ratios,
    'pvalue': pvalues
})
    
# Apply multiple testing correction
results_df['adjusted_pvalue'] = multipletests(results_df['pvalue'], 
                                            alpha=0.05, 
                                            method='fdr_bh')[1]
results_df['significant'] = results_df['adjusted_pvalue'] < 0.05

# Sort by adjusted p-value
results_df = results_df.sort_values('adjusted_pvalue')

In [None]:
results_df['proportion_diff'] = results_df['loss_proportion']- results_df['subset_proportion']

In [None]:
sc.pl.umap(adata_integrate,color="hub_celltype",
                frameon=False, s=15,
                title='UMAP after Starfysh sample integration'
                )
plt.show()

In [None]:
import matplotlib as mpl
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['xtick.labelsize'] = 8  # Colorbar tick label size
mpl.rcParams['ytick.labelsize'] = 8  # Colorbar tick label size
mpl.rcParams['axes.labelsize'] = 10   # Axis label size

def plot_significant_differences(results_df):
    
    # Ensure we're working with a DataFrame
    if isinstance(results_df, dict):
        # Convert dictionary to DataFrame, ensuring all arrays are the same length
        df = pd.DataFrame({k: pd.Series(v) for k, v in results_df.items()})
    else:
        df = results_df.copy()
    
    # Calculate proportion difference
    df['proportion_diff'] = df['loss_proportion'] - df['subset_proportion']
    
    # Filter for significant results
    sig_df = df[df['adjusted_pvalue'] < 0.05].copy()
    
    if len(sig_df) == 0:
        print("No significant differences found!")
        return None
    
    # Sort by difference magnitude
    sig_df = sig_df.sort_values('proportion_diff')
    
    # Create figure
    fig, ax = plt.subplots(figsize=(3, max(3, len(sig_df) * 0.2)))
    
    # Create bars
    bars = ax.barh(range(len(sig_df)), sig_df['proportion_diff'])
    
    # Color bars based on direction
    for i, bar in enumerate(bars):
        if bar.get_width() < 0:
            bar.set_color('lightcoral')  # Decreased in Loss
        else:
            bar.set_color('lightblue')   # Increased in Loss
    
    # Add significance markers
    for i, p in enumerate(sig_df['adjusted_pvalue']):
        if p < 0.001:
            star = '***'
        elif p < 0.01:
            star = '**'
        else:
            star = '*'
        x = sig_df['proportion_diff'].iloc[i]
        if x < 0:
            ax.text(x - 0.01, i, star, ha='right', va='center')
        else:
            ax.text(x + 0.01, i, star, ha='left', va='center')
    
    # Customize plot
    ax.set_yticks(range(len(sig_df)))
    ax.set_yticklabels(sig_df['hub'])
    ax.axvline(0, color='black', linestyle='-', linewidth=0.5)
    
    # Labels and title
    ax.set_xlabel('Proportion Difference \n(Loss - Control)')
    #ax.set_title('Differences in Hub Proportions')
    
    # Add percentage labels on bars
    for i, v in enumerate(sig_df['proportion_diff']):
        ax.text(v/2 if v >= 0 else v/2, i, 
                f'{v:.1%}', 
                ha='center', va='center',
                color='black' if abs(v) > 0.1 else 'white')
    
    # Add grid
    ax.grid(True, axis='x', linestyle='--', alpha=0.3)
    
    # Add legend
    # OR place it in the middle-right side
    # ax.text(1, 0.5, '* p<0.05\n** p<0.01\n*** p<0.001',
    #     transform=ax.transAxes, va='center', ha='right',
    #     bbox=dict(facecolor='white', alpha=0, edgecolor='none'))
    
    plt.tight_layout()
    
    # Print numerical summary
    print("\nSignificant Differences Summary:")
    summary_df = sig_df[['hub', 'loss_proportion', 'subset_proportion', 
                        'proportion_diff', 'adjusted_pvalue']]
    summary_df = summary_df.round(4)
    print(summary_df.to_string(float_format=lambda x: '{:.1%}'.format(x) 
                              if x < 1 else '{:.2e}'.format(x)))
    
    return fig, sig_df


# Create and show the plot
fig, sig_df = plot_significant_differences(results_df)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/SLV14_fisher_test.pdf', 
       bbox_inches='tight', dpi=300, transparent = True)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/SLV14_fisher_test.png', 
       bbox_inches='tight', dpi=300, transparent = True)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/SLV14_fisher_test.svg', 
       bbox_inches='tight', dpi=300, transparent = True)
plt.show()


In [None]:
import matplotlib as mpl
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 10
mpl.rcParams['xtick.labelsize'] = 8  # Colorbar tick label size
mpl.rcParams['ytick.labelsize'] = 8  # Colorbar tick label size
mpl.rcParams['axes.labelsize'] = 10   # Axis label size
new_colormap = [
    'darkslateblue',    # 0
    'cornflowerblue',   # 1
    'red',              # 2
    'blueviolet',       # 3
    'skyblue',          # 4
    'orchid',           # 5
    'yellowgreen',      # 6
    'palevioletred',    # 7
    'orange',           # 8
    'cadetblue',        # 9
    'limegreen',        # 10
    'cyan',             # 11
    'gold',             # 12
    'slategray',        # 13
    'olive',            # 14
    'blue',             # 15
    'linen',            # 16
    'mistyrose',        # 17
    'peru',             # 18
    'darkturquoise',    # 19
    'teal',             # 20
    'salmon',           # 21
    'violet',           # 22
    'dodgerblue',       # 23
    'lightcoral',       # 24
    'mediumaquamarine', # 25
    'crimson',          # 26
    'darkkhaki',        # 27
    'mediumpurple',     # 28
    'indianred',        # 29
    'darkseagreen',     # 30
    'sienna'            # 31 (rich brown tone)
]
sample = "SLV14"
gene_sig = ['Instestinal_Epithelial cells', 'Enterocyte', 'Stomach_epithelial_cells', 'Stomach_stem_cells',  'CD8+ Effector T  cells', 'CD4+ Effector T cells', 'CD4+ Central Memory T cells', 'CD8+ Cytotoxic Unconventional T cells', 'CD8+ Proliferating T cells', 'CD4+ Regulatory T cells', 'CD8+ Homeostatic Unconventional T cells', 'CD8+ Tissue Resident Memory T cells', 'CD8+ Transitioning  Resident T cells', 'Intestine_Epithelial Stem cell', 'Stomach_Body_Epithelial_cells','hub']
proportions_df = adata_integrate.obs[gene_sig]
# Get the ordered list of hubs from sig_df and reverse it
ordered_hubs = sig_df['hub'].tolist()
# Set font sizes
proportions_df['hub'] = adata_integrate.obs['hub']
composition = proportions_df.groupby('hub').sum()
composition_normalized = composition.div(composition.sum(axis=1), axis=0)

#ordered_hubs.reverse()  # This reverses the list in-place
# Filter and reorder the composition_normalized DataFrame with reversed order
composition_normalized = composition_normalized[composition_normalized.index.isin(ordered_hubs)]
composition_normalized = composition_normalized.reindex(ordered_hubs)  # Reorders according to reversed list
# Create the plot with reordered data as a horizontal bar chart
fig, ax = plt.subplots(figsize=(3,3))  # You might want to adjust the figsize for horizontal orientation
composition_normalized.plot(kind='barh', stacked=True, color=new_colormap, ax=ax)

# Remove the legend
ax.get_legend().remove()

# Switch the axis labels since we've rotated the chart
plt.ylabel('Spatial Spots (Hubs)')
plt.xlabel('Cell Type Proportions')
#plt.legend(title='Cell Types', bbox_to_anchor=(1.05, 1), loc='upper left')
#plt.tight_layout()
# Save the figure
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}_cryptloss_hub.pdf', 
       bbox_inches='tight', dpi=300, transparent = True)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}_cryptloss_hub.png', 
       bbox_inches='tight', dpi=300, transparent = True)
fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}_cryptloss_hub.svg', 
       bbox_inches='tight', dpi=300, transparent = True)
plt.show()

In [None]:
import matplotlib as mpl
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['xtick.labelsize'] = 8  # Colorbar tick label size
mpl.rcParams['ytick.labelsize'] = 8  # Colorbar tick label size

# Create a separate figure for the legend
fig_legend = plt.figure(figsize=(5, 2))  # Adjust the size as needed

# Create a list of patches for the legend
from matplotlib.patches import Patch
legend_elements = []

# Get the cell types from your composition_normalized columns
cell_types = composition_normalized.columns.tolist()

# Create patches for each cell type with corresponding color
for i, cell_type in enumerate(cell_types):
    # Make sure we don't go beyond the colormap length
    color_idx = i % len(new_colormap)
    legend_elements.append(Patch(facecolor=new_colormap[color_idx], 
                                 label=cell_type))

# Create the legend on the new figure
fig_legend.legend(handles=legend_elements, 
                  loc='center', 
                  ncol=1,  # Adjust the number of columns as needed
                  frameon=False)

# Save the legend figure
fig_legend.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}_legend.pdf', 
                   bbox_inches='tight', dpi=300, transparent = True)
fig_legend.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}_legend.png', 
                   bbox_inches='tight', dpi=300, transparent = True)
fig_legend.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}_legend.svg', 
                   bbox_inches='tight', dpi=300, transparent = True)

plt.close(fig_legend)  # Close the legend figure

In [None]:
def plot_dual_violin(data1, data2, y1_label, y2_label,y_label, title=None, sample_id=None, figsize=(10,6)):
    """
    Create a split violin plot for comparing two distributions with custom colors
    
    Args:
        data1: Array-like, data for CD8 effector cells (blue)
        data2: Array-like, data for CD4 effector cells (orange)
        y1_label: String, label for first group
        y2_label: String, label for second group
        title: String, optional plot title
        sample_id: String, sample identifier
        figsize: Tuple, figure dimensions
    Returns:
        tuple: (figure, (ax1, ax2)) where ax1 and ax2 are the same axis object
    """
    # Create figure and axis
    fig, ax = plt.subplots(figsize=figsize)
    
    # Create DataFrame in the format needed for seaborn
    
    df = pd.DataFrame({
        'value': np.concatenate([data1, data2]),
        'group': np.concatenate([
            np.full(len(data1), sample_id),
            np.full(len(data2), sample_id)
        ]),
        'category': np.concatenate([
            np.full(len(data1), y1_label),
            np.full(len(data2), y2_label)
        ])
    })
    
    # Create custom color palette
    custom_palette = {y1_label: '#003049', y2_label: '#669bbc'}
    
    # Create the split violin plot
    sns.violinplot(
        data=df,
        x='group',
        y='value',
        hue='category',
        split=True,
        inner='quart',
        fill=False,
        ax=ax,
        palette=custom_palette
    )

    # sample1 = np.random.choice(data1, size=100, replace=False)
    # sample2 = np.random.choice(data2, size=100, replace=False)
    
    # # Perform Mann-Whitney U test on the samples
    # stat, pvalue = stats.mannwhitneyu(sample1, sample2, alternative='two-sided')
    # Calculate statistics
    stat, pvalue = stats.mannwhitneyu(data1, data2, alternative='two-sided')

    # Add title with p-value if specified
    if title:
        plt.title(f'{title}\np-value = {pvalue:.2e}')
    
    # Clean up the plot
    #ax.set_xlabel('')
    ax.set_ylabel(y_label)
    ax.get_legend().remove()
    ax.set_xlabel('')  # Remove x-axis label
    ax.set_xticklabels([])  # Remove the sample_id tick labels
    #plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    # Adjust layout
    plt.tight_layout()
    
    # Return fig and tuple of (ax, ax) to match expected unpacking
    return fig, (ax, ax)

In [None]:
import matplotlib as mpl
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 10
mpl.rcParams['xtick.labelsize'] = 8  # Colorbar tick label size
mpl.rcParams['ytick.labelsize'] = 8  # Colorbar tick label size
mpl.rcParams['axes.labelsize'] = 10   # Axis label size
from scipy import stats
gene_sig = ['Instestinal_Epithelial cells', 'Enterocyte', 'Stomach_epithelial_cells', 'Stomach_stem_cells',  'CD8+ Effector T  cells', 'CD4+ Effector T cells', 'CD4+ Central Memory T cells', 'CD8+ Cytotoxic Unconventional T cells', 'CD8+ Proliferating T cells', 'CD4+ Regulatory T cells', 'CD8+ Homeostatic Unconventional T cells', 'CD8+ Tissue Resident Memory T cells', 'CD8+ Transitioning  Resident T cells', 'Intestine_Epithelial Stem cell', 'Stomach_Body_Epithelial_cells']
# Create plots comparing intact vs combined loss regions
for cell_type in gene_sig:
    # Combine observations from loss1 and loss2
    print(cell_type)
    combined_loss = pd.concat([subset_adata_loss1.obs[cell_type], subset_adata_loss2.obs[cell_type]])

    fig, (ax1, ax2) = plot_dual_violin(
        np.log(subset_adata.obs[cell_type]), 
        np.log(combined_loss),  # Using combined loss data
        y1_label='Crypt Intact',
        y2_label='Crypt Loss',
        sample_id=sample_id,
        title=" ",
        y_label='log(Proportion)',
        figsize=(1.5,1.5)
    )


    # Save figures
    for ext in ['pdf', 'png', 'svg']:
        fig.savefig(
            f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}_Severe_{cell_type}_crypt_loss_proportion_diff.{ext}',
            bbox_inches='tight', transparent = True,
            dpi=900)
    
    plt.show()

In [None]:
def create_legend_only(y1_label, y2_label, custom_palette=None, figsize=(3, 1), save_path="legend.png"):
    """
    Creates and saves just the legend for the dual violin plot
    
    Args:
        y1_label: String, label for first group
        y2_label: String, label for second group
        custom_palette: Dict, mapping labels to colors (optional)
        figsize: Tuple, figure dimensions
        save_path: String, path to save the legend image
        
    Returns:
        matplotlib Figure object
    """
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    
    # Default custom palette if not provided
    if custom_palette is None:
        custom_palette = {y1_label: '#003049', y2_label: '#669bbc'}
    
    # Create a figure just for the legend
    fig = plt.figure(figsize=figsize)
    
    # Create patches for the legend
    patches = []
    for label in [y1_label, y2_label]:
        patch = mpatches.Patch(color=custom_palette[label], label=label)
        patches.append(patch)
    
    # Create the legend
    legend = plt.legend(handles=patches, loc='center', frameon=False)
    
    # Remove axes
    plt.axis('off')
    
    # Save the legend
    plt.savefig(save_path, bbox_inches='tight', transparent=True)
    plt.close()
    
    return fig


# Example usage:
create_legend_only('Crypt Intact', 'Crypt Loss', save_path="/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/violin_legend.png")
create_legend_only('Crypt Intact', 'Crypt Loss', save_path="/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/violin_legend.pdf")
create_legend_only('Crypt Intact', 'Crypt Loss', save_path="/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/violin_legend.svg")