# Starfyshhd tutorial for sample integration

**Azizi Lab**

Lingting Shi, Siyu He, Yinuo Jin

2025

This tutorial provides an example of using StarfyshHD to characterize common *spatial hubs* through integrative spatial deconvolution from multiple samples. 

In [None]:
import sys
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    !pip3 install scanpy
    from google.colab import drive
    drive.mount('/content/drive')
    import sys
    sys.path.append('/content/drive/MyDrive/SpatialModelProject/model_test_colab/')

import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata
import torch

import matplotlib.pyplot as plt
import matplotlib.font_manager
from matplotlib import rcParams

font_list = []
fpaths = matplotlib.font_manager.findSystemFonts()
for i in fpaths:
    try:
        f = matplotlib.font_manager.get_font(i)
        font_list.append(f.family_name)
    except RuntimeError:
        pass

font_list = set(font_list)
plot_font = 'Helvetica' if 'Helvetica' in font_list else 'FreeSans'

rcParams['font.family'] = plot_font
rcParams.update({'font.size': 10})
rcParams.update({'figure.dpi': 300})
rcParams.update({'figure.figsize': (3,3)})
rcParams.update({'savefig.dpi': 500})

import warnings
warnings.filterwarnings('ignore')

## Load Starfysh

from starfysh import (AA, utils, plot_utils, post_analysis, utils_integrate)
from starfysh import starfysh as sf_model
#!pip install -U PhenoGraph

### (1), load datasets & marker genes

File input:

- Spatial transcriptomics
    - Count matrices: list of `adata`s
    - (Optional): list of paired histology & spot coordinates corresponding to count matrices
- Annotated signatures (marker genes) for potential cell types: `gene_sig`

Starfysh is built upon scanpy and Anndata. The common ST/Visium data sample folder consists a expression count file (usually `filtered_feature_bc_matrix.h5`), and a subdirectory with corresponding H&E image and spatial information, as provided by Visium platform.

For ST data that doesn't follow the consensus data struction (e.g. missing `filtered_feature_bc_matrix.h5`), please ensure that a `.h5ad` count matrix exists in the data directory.

Example directory structure:

```
├── ../data
    signature.csv

    ├── CID44971_TNBC:
        \__ CID44971_TNBC.h5ad

        ├── (Optional) spatial:
            \__ scalefactors_json.json
                tissue_hires_image.png
                tissue_positions_list.csv
                ...
    .
    .
    .
    
    ├── P4A_MBC:
        \__ filtered_feature_bc_matrix.h5

        ├── (Optional) spatial:
            \__ scalefactors_json.json
                tissue_hires_image.png
                tissue_positions_list.csv
                ...        
```

Create meta info. for samples

In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import json
import matplotlib.pyplot as plt

min_genes_values = {
    'SLV11': 150,
    'SLV12': 50,
    'SLV13': 50,
    'SLV14': 50,
    'SLV15': 50,
    'SLV16': 150,
    'SLV17': 50,
    'SLV18': 150,# Change as needed for each sample
}



def load_visium_data(base_path, sample_id, min_genes=50, n_top_genes=2000, mt_thld=20):
    """
    Load and process Visium spatial data
    
    Parameters:
    -----------
    base_path : str
        Path to the base directory containing data
    sample_id : str
        ID of the sample to process
    min_genes : int, optional
        Minimum number of genes for filtering (default: 50)
    n_top_genes : int, optional
        Number of top variable genes to select (default: 2000)
    mt_thld : int, optional
        Threshold for mitochondrial gene percentage (default: 20)
    
    Returns:
    --------
    tuple
        (adata, adata_norm, img_metadata)
    """
    # Read tissue positions
    tissue_positions = pd.read_parquet(os.path.join(
        base_path,
        sample_id,
        'binned_outputs',
        'square_016um',
        'spatial',
        'tissue_positions.parquet'
    ))

    # Read gene expression data
    adata = sc.read_10x_h5(os.path.join(
        base_path,
        sample_id,
        'binned_outputs',
        'square_016um',
        'filtered_feature_bc_matrix.h5'
    ))

    # Match positions with expression data
    matched_positions = tissue_positions[tissue_positions['barcode'].isin(adata.obs.index)]
    matched_positions = matched_positions.set_index('barcode')
    matched_positions = matched_positions.loc[adata.obs.index]

    # Add spatial coordinates to adata
    adata.obs[['in_tissue','array_row','array_col','pxl_row_in_fullres', 'pxl_col_in_fullres']] = matched_positions[
        ['in_tissue','array_row','array_col','pxl_row_in_fullres', 'pxl_col_in_fullres']]

    # Make var names unique and add sample ID
    adata.var_names_make_unique()
    adata.obs['sample'] = sample_id

    # Clean up var names if needed
    if '_index' in adata.var.columns:
        adata.var_names = adata.var['_index']
        adata.var_names.name = 'Genes'
        adata.var.drop('_index', axis=1, inplace=True)

    # Calculate QC metrics
    adata.var["mt"] = adata.var_names.str.startswith("MT-")
    sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)

    # Basic filtering
    sc.pp.filter_cells(adata, min_counts=min_genes)
    #sc.pp.filter_genes(adata, min_cells=100)
    

    # Create copy for normalized version
    adata_raw = adata.copy()

    # Additional filtering and normalization
    adata.var['mt'] = np.logical_or(
        adata.var_names.str.startswith('MT-'),
        adata.var_names.str.startswith('mt-'))
    adata.var['rb'] = (
        adata.var_names.str.startswith('RP-') |
        adata.var_names.str.startswith('rp-'))

    sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True)
    mask_cell = adata.obs['pct_counts_mt'] < mt_thld
    mask_gene = np.logical_and(~adata.var['mt'], ~adata.var['rb'])

    adata = adata[mask_cell, mask_gene]

    # Normalize and log transform
    sc.pp.normalize_total(adata, inplace=True)
    sc.pp.log1p(adata)

    # Find variable genes
    sc.pp.highly_variable_genes(adata, flavor='seurat', n_top_genes=n_top_genes, inplace=True)

    # Filter raw data to match normalized data
    adata_raw = adata_raw[adata.obs_names, adata.var_names]
    adata_raw.var['highly_variable'] = adata.var['highly_variable']
    adata_raw.obs = adata.obs

    # Prepare image metadata
    # Load scale factors
    with open(os.path.join(base_path, sample_id, 'binned_outputs/square_016um/spatial/scalefactors_json.json'), 'r') as f:
        scalefactor = json.load(f)

    # Load high-res image
    img = plt.imread(os.path.join(base_path, sample_id, 'spatial/tissue_hires_image.png'))

    # Prepare map info
    map_info = matched_positions[['in_tissue','array_row','array_col','pxl_row_in_fullres', 'pxl_col_in_fullres']]
    map_info['imagerow'] = pd.to_numeric(map_info['pxl_row_in_fullres'], errors='coerce').fillna(0).astype(int)
    map_info['imagecol'] = pd.to_numeric(map_info['pxl_col_in_fullres'], errors='coerce').fillna(0).astype(int)
    map_info = map_info[['array_row','array_col','imagerow','imagecol']]

    img_metadata = {
        'img': img,
        'map_info': map_info,
        'scalefactor': scalefactor
    }

    return adata_raw, adata, img_metadata






In [None]:
meta_info = [
    ['SLV11', 'C159', 'Antrum', 'Severe'],
    ['SLV12', 'C162', 'Rectum', 'Mild'],
    ['SLV13', 'C98', 'Stomach_Body', 'Severe'],
    ['SLV14', 'C159', 'Rectum', 'Severe'],
    ['SLV15', 'C179', 'Antrum', 'Mild'],
    ['SLV16', 'C179', 'Ascending_Colon', 'Mild'],
    ['SLV17', 'ND001', 'Ascending_Colon', 'ND'],
    ['SLV18', 'C162', 'Stomach', 'Mild']
]
meta_info = pd.DataFrame(meta_info,columns=['sample','patient','tissue_type','grade'])
base_path = '/Users/lingting/Documents/GVHD_project/visiumHD/data/'
sig_file_name = '/Users/lingting/Documents/GVHD_project/Spatial data/data/GVHD_spatial_signature_v8_major_curated_unique_epi_tcells_subset.csv'

# Load data
adata_all = []
adata_normed_all = []
img_metadata_all = {}

for sample_id in meta_info['sample']:
    print(sample_id)
    adata, adata_normed, img_metadata = load_visium_data(base_path, sample_id, min_genes=min_genes_values.get(sample_id, 50), n_top_genes=2000, mt_thld=20 )  # root data directory


    adata_normed = adata_normed[adata_normed.obs.index.isin(adata.obs.index)]
    adata_all.append(adata)
    
    
    
    adata.obs['patient']=meta_info['patient'][list(meta_info['sample']).index(sample_id)]
    adata.obs['sample_type']=meta_info['tissue_type'][list(meta_info['sample']).index(sample_id)]
    adata.obs['grade']=meta_info['grade'][list(meta_info['sample']).index(sample_id)]
    adata.obs_names  = adata.obs_names+'-'+sample_id

    adata_normed.obs_names  = adata_normed.obs_names+'-'+sample_id

    adata_normed.obs['patient']=adata.obs['patient']
    adata_normed.obs['sample_type']=adata.obs['sample_type']
    adata_normed.obs['grade']=adata.obs['grade']
    adata_normed_all.append(adata_normed)

    img_metadata['map_info'].index = img_metadata['map_info'].index+'-'+sample_id
    map_info = img_metadata['map_info']
    map_info = map_info[map_info.index.isin(adata.obs.index)]
    img_metadata['map_info'] = map_info
    
    img_metadata_all[sample_id] = img_metadata 
    

# Save concat data
adata_all = anndata.concat(adata_all)
adata_normed_all = anndata.concat(adata_normed_all)

sc.pp.highly_variable_genes(adata_normed_all)
adata_all.uns = adata_normed_all.uns
adata_all.var = adata_normed_all.var

adata_all.write(os.path.join(base_path, 'adata_integrate.h5ad'))
adata_normed_all.write(os.path.join(base_path, 'adata_normed_integrate.h5ad'))

In [None]:

base_path = '/Users/lingting/Documents/GVHD_project/visiumHD/data/'
#adata_all= anndata.read(os.path.join(base_path, 'adata_integrate.h5ad'))
#adata_normed_all = anndata.read(os.path.join(base_path, 'adata_normed_integrate.h5ad'))

In [None]:
adata_all

In [None]:
sig_file_name = '/Users/lingting/Documents/GVHD_project/Spatial data/data/GVHD_spatial_signature_v8_major_curated_unique_epi_tcells_subset_Transition_EPI_REFINE.csv'
gene_sig = pd.read_csv(sig_file_name, encoding='latin1')
gene_sig = utils.filter_gene_sig(gene_sig, adata_all.to_df())
gene_sig.head()

In [None]:
img_metadata_all

### (2). Preprocessing

Starfysh calcualates cell-type proportion scores for each spot per sample given the annotated signatures (per cell type) as the model's prior

Each spot will be ranked (`anchors_df`) based on the prior, which represents the rough estimation of the given cell-type's enrichment

In [None]:
# Load saved concat data
#adata_all = sc.read_h5ad(os.path.join(data_path, 'adata_integrate.h5ad'))
#adata_normed_all = sc.read_h5ad(os.path.join(data_path, 'adata_normed_integrate.h5ad'))

In [None]:
adata

In [None]:
adata_normed

In [None]:
gene_sig

In [None]:
from starfysh.AA import ArchetypalAnalysis
import scanpy as sc

In [None]:
def assign_archetypes(anchor_df, r=30):
    """
    Assign best 1-1 mapping of archetype community to its closest anchor community (cell-type specific anchor spots)
    Criteria: choose the top cell type in which its anchors belongs to the top r neighbors to the given archetype

    Parameters
    ----------
    anchor_df : pd.DataFrame
        Dataframe of anchor spot indices

`       r : int
        Resolution parameter to threshold archetype - anchor mapping

    Returns
    -------
    map_df : pd.DataFrame
        DataFrame of overlapping spot ratio of each anchor `i` to archetype `j`

    map_dict : dict
        Dictionary of cell type -> mapped archetype
    """
    assert aa_model.arche_df is not None, "Please compute archetypes & assign nearest-neighbors first!"

    n_nbrs, n_archetypes = aa_model.arche_df.shape # number of archetypal spots for each archtype, number of archetypes
    x_concat = np.vstack([aa_model.count, aa_model.archetype])
    anchor_nbrs = anchor_df.values
    archetypal_nbrs = aa_model._get_knns(x_concat, n_nbrs=r, indices=aa_model.n_spots+aa_model.major_idx).T  # r-nearest nbrs to each archetype

    
    
    print(archetypal_nbrs.shape)
    overlaps = np.array(
        [
            [
                len(np.intersect1d(anchor_nbrs[:, i], archetypal_nbrs[:, j]))
                for j in range(archetypal_nbrs.shape[1])
            ]
            for i in range(anchor_nbrs.shape[1])
        ]
    )
    overlaps_df = pd.DataFrame(overlaps, index=anchor_df.columns, columns=aa_model.arche_df.columns)
    arche_argmaxs = overlaps.argmax(0)
    
    distance_df = pd.DataFrame(np.zeros([anchor_df.shape[1],
                                         aa_model.arche_df.shape[1]]), 
                               index=anchor_df.columns, 
                               columns=aa_model.arche_df.columns
                              )
    for i in range(distance_df.shape[0]):
        for j in range(distance_df.shape[1]):
            dist_tmp = 0
            for kk in anchor_df.iloc[:,i]:
                dist_tmp += np.linalg.norm(aa_model.count[kk,:] - 
                                           aa_model.archetype[j,:]
                                          )
            
            distance_df.iloc[i,j] = dist_tmp
    
    map_dict = {}
    for k in range(overlaps.shape[0]):
        list_ = np.argsort(overlaps[k])[::-1]
        for i in list_:
            if (np.argsort(overlaps[:,i])[::-1][0]==k):
                if ((overlaps[k,:]==overlaps[k,i]).sum()==1):
                    map_dict[anchor_df.columns[k]]=aa_model.arche_df.columns[i]
                break 
                
                
    map_dict2 = {}
    for k in range(distance_df.shape[0]):
        list_ = np.argsort(np.array(distance_df)[k])
        #print(list_)
        for i in list_:
            if (np.argsort(np.array(distance_df)[:,i])[0]==k):
                if ((np.array(distance_df)[k,:]==np.array(distance_df)[k,i]).sum()==1):
                    map_dict2[anchor_df.columns[k]]=aa_model.arche_df.columns[i]
                break 
                
    return overlaps_df, distance_df,map_dict,map_dict2

In [None]:
def append_sigs(gene_sig, factor, sigs, n_genes):
    sigs = sigs[:n_genes]
    temp = set([i for i in gene_sig[factor] if str(i) != 'nan'] + [i for i in sigs if str(i) != 'nan'])
    temp = list(temp)
    
    # Adjust temp length to match gene_sig length
    if len(temp) > gene_sig.shape[0]:
        temp = temp[:gene_sig.shape[0]]  # Truncate temp if it's longer
    elif len(temp) < gene_sig.shape[0]:
        temp += [np.nan] * (gene_sig.shape[0] - len(temp))  # Extend temp with NaN if it's shorter

    gene_sig[factor] = temp
    return gene_sig

In [None]:
meta_info['sample']

In [None]:
import pickle
import json
import numpy as np

def save_data(gene_sig_all, indiv_args_dict, base_path, use_pickle=True):
    """
    Save gene signatures and arguments dictionary to files
    
    Parameters:
    -----------
    gene_sig_all : list
        List of gene signatures
    indiv_args_dict : dict
        Dictionary of individual arguments
    base_path : str
        Base path for saving files
    use_pickle : bool, optional
        Whether to use pickle format (True) or JSON (False)
    """
    if use_pickle:
        # Save using pickle (better for complex Python objects)
        with open(f'{base_path}/gene_sig_all.pkl', 'wb') as f:
            pickle.dump(gene_sig_all, f)
            
        with open(f'{base_path}/indiv_args_dict.pkl', 'wb') as f:
            pickle.dump(indiv_args_dict, f)
    else:
        # Save using JSON (more portable, human-readable)
        with open(f'{base_path}/gene_sig_all.json', 'w') as f:
            json.dump(gene_sig_all, f)
            
        with open(f'{base_path}/indiv_args_dict.json', 'w') as f:
            json.dump(indiv_args_dict, f)

def load_data(base_path, use_pickle=True):
    """
    Load gene signatures and arguments dictionary from files
    
    Parameters:
    -----------
    base_path : str
        Base path where files are stored
    use_pickle : bool, optional
        Whether to load pickle format (True) or JSON (False)
        
    Returns:
    --------
    tuple
        (gene_sig_all, indiv_args_dict)
    """
    if use_pickle:
        # Load using pickle
        with open(f'{base_path}/gene_sig_all.pkl', 'rb') as f:
            gene_sig_all = pickle.load(f)
            
        with open(f'{base_path}/indiv_args_dict.pkl', 'rb') as f:
            indiv_args_dict = pickle.load(f)
    else:
        # Load using JSON
        with open(f'{base_path}/gene_sig_all.json', 'r') as f:
            gene_sig_all = json.load(f)
            
        with open(f'{base_path}/indiv_args_dict.json', 'r') as f:
            indiv_args_dict = json.load(f)
            
    return gene_sig_all, indiv_args_dict

# Example usage:
"""
# To save:
base_path = '/path/to/save/directory'
save_data(gene_sig_all, indiv_args_dict, base_path, use_pickle=True)

# To load:
gene_sig_all, indiv_args_dict = load_data(base_path, use_pickle=True)
"""

In [None]:
#gene_sig_all, indiv_args_dict = load_data(base_path, use_pickle=True)

In [None]:
import importlib
importlib.reload(utils)
gene_sig_all = []
indiv_args_dict = {}

for sample_id in meta_info['sample']:
    print(sample_id)
    adata = adata_all[adata_all.obs['sample']==sample_id]
    adata_normed = adata_normed_all[adata_normed_all.obs['sample']==sample_id]
    
    sc.pp.highly_variable_genes(adata_normed)
    adata.uns = adata_normed.uns
    adata.var = adata_normed.var
    img_metadata = img_metadata_all[sample_id]
    
    args = utils.VisiumArguments(adata,
                                 adata_normed,
                                 gene_sig,
                                 img_metadata,
                                 window_size=1,
                                 sample_id=sample_id,
                                 n_img_chan=1
                                )

    # adata, adata_normed = args.get_adata()
    # anchors_df = args.get_anchors()
    # aa_model = AA.ArchetypalAnalysis(adata_orig=adata
    #                                 )
    # archetype, arche_dict, major_idx, evs = aa_model.compute_archetypes(cn=100,
    #                                                                 converge=1e-3,
    #                                                                 display=False)

    # arche_df = aa_model.find_archetypal_spots()
    
    # markers_df = aa_model.find_markers(n_markers=30, display=False)
    
    # percent_anchor = 0.01
    # map_df, distance_df,map_dict,map_dict2 = assign_archetypes(anchors_df[:int(percent_anchor*adata.shape[0])], 
    #                                  r=int(percent_anchor*adata.shape[0])
    #                                 )  # here we subset the `r` top archetypes
    # gene_sig_arch = gene_sig.copy()


    # for cell_type in map_dict.keys():
    #     if (~pd.isna(gene_sig[cell_type])).sum() < 50:
    #         arch = map_dict[cell_type]
    #         gene_sig_arch = append_sigs(gene_sig_arch, cell_type, markers_df[arch], n_genes=30)
    # gene_sig_all.append(gene_sig_arch)
    indiv_args_dict[sample_id] = args

In [None]:
save_data(gene_sig, indiv_args_dict, base_path)

In [None]:
#gene_sig_all

gene_temp = gene_sig_all[0]
for i in range(1, meta_info['sample'].shape[0]):
    gene_temp = pd.merge(gene_temp, gene_sig_all[i],on=list(gene_temp.columns),how='outer')

gene_sig_new = []
for i in gene_sig.columns:
    gene_sig_new.append(gene_temp[i].unique())
gene_sig_new = pd.DataFrame(list(gene_sig_new)).transpose()
gene_sig_new.columns= gene_temp.columns  

In [None]:
sc.pp.highly_variable_genes(adata_normed_all)
adata_all.uns = adata_normed_all.uns
adata_all.var = adata_normed_all.var

integrated_args = utils_integrate.VisiumArguments_integrate(adata_all,
                                                            adata_normed_all,
                                                            #gene_sig_new,
                                                            gene_sig,
                                                            img_metadata_all,
                                                            indiv_args_dict,                                        
                                                            window_size=1,
                                                            sample_id=meta_info['sample']
                                                           )

In [None]:
integrated_args

In [None]:
import pickle
import os

# Combine directory path with filename
file_path = os.path.join(base_path, 'integrated_args.pkl')

In [None]:
# Save
with open(file_path, 'wb') as f:
    pickle.dump(integrated_args, f)

In [None]:
#if os.path.exists(file_path):
    #with open(file_path, 'rb') as f:
        #integrated_args = pickle.load(f)
#else:
    #print(f"File not found: {file_path}")

## Run Starfysh for joint spatial deconvolution & sample integration

### (1). Model parameters

In [None]:
n_repeats = 3
epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
poe_on = False

In [None]:
integrated_args

In [None]:
adata_normed_all

### (2). Model training

In [None]:
model, loss = utils_integrate.run_starfysh(integrated_args,
                                           n_repeats=n_repeats,
                                           epochs=epochs,
                                           poe=poe_on,
                                           device=device
                                          )

In [None]:
# Save model
torch.save(model.state_dict(), os.path.join(base_path, 'integ_model_arch.pt'))

### (3). Downstream analysis

#### Parse Starfysh inference output

In [None]:
# Load saved model
# model = sf_model.AVAE_PoE(
#     adata=integrated_args.adata,
#     gene_sig=integrated_args.sig_mean_norm,
#     patch_r=integrated_args.params['patch_r'],
#     win_loglib=1,
#     alpha_mul=50
# )
poe_on = False

# Load saved model

if poe_on:
    model = sf_model.AVAE_PoE(
        adata=integrated_args.adata,
        gene_sig=integrated_args.sig_mean_norm,
        patch_r=integrated_args.params['patch_r'],
        win_loglib=1,
        alpha_mul=50,
        n_img_chan=integrated_args.params['n_img_chan']
    )
else:
    model = sf_model.AVAE(
        adata=integrated_args.adata,
        gene_sig=integrated_args.sig_mean_norm,
        win_loglib=1,
        alpha_mul=50,
    )
    
model = model.to(device)
model.load_state_dict(torch.load(os.path.join(base_path, 'integ_model_arch.pt')))

# If running this snippet to load saved model parameters, 
# need to update image patches to compute inference outputs

from starfysh.dataloader import IntegrativePoEDataset
trainset = IntegrativePoEDataset(adata=integrated_args.adata, args=integrated_args)
integrated_args._update_img_patches(trainset)

In [None]:
torch.cuda.empty_cache()

In [None]:
# inference_outputs, generative_outputs = sf_model.model_eval_integrate(model,
#                                                                       integrated_args.adata,
#                                                                       integrated_args,
#                                                                       poe=False,
#                                                                       device=device
#                                                                      )

inference_outputs, generative_outputs, adata_integrate_starfysh = sf_model.model_eval_integrate(model,
                                                                                      integrated_args.adata,
                                                                                      integrated_args,
                                                                                      poe=poe_on,
                                                                                      device=device
                                                                                     )

#### Visualize learnt joint embeddings

Compare UMAPs before & after sample integration

In [None]:
import pickle
import os

# Create a dictionary to store all outputs
outputs_dict = {
    'inference_outputs': inference_outputs,
    'generative_outputs': generative_outputs,
    'adata_integrate': adata_integrate_starfysh
}

# Save all outputs in one pickle file
with open(os.path.join(base_path, 'model_outputs.pkl'), 'wb') as f:
    pickle.dump(outputs_dict, f)

# Additionally, save AnnData object separately if needed
adata_integrate_starfysh.write(os.path.join(base_path, 'adata_integrate.h5ad'))

# To load later:
# with open(os.path.join(base_path, 'model_outputs.pkl'), 'rb') as f:
#     outputs_dict = pickle.load(f)
# inference_outputs = outputs_dict['inference_outputs']
# generative_outputs = outputs_dict['generative_outputs']
# adata_integrate = sc.read(os.path.join(base_path, 'adata_integrate.h5ad'))

In [None]:
adata_integrate = integrated_args.adata.copy()

In [None]:
adata_integrate.shape

In [None]:
adata_integrate.write(os.path.join(base_path, 'adata_integrate_after_starfysh.h5ad'))

In [None]:
#adata_integrate = anndata.read(os.path.join(base_path, 'adata_integrate_after_starfysh.h5ad'))
# This one does not work for some reason

In [None]:
base_path = '/Users/lingting/Documents/GVHD_project/visiumHD/data/'
adata_integrate = anndata.read(os.path.join(base_path, 'adata_integrate_final.h5ad'))

In [None]:
# Before integration
sc.pp.neighbors(adata_all)
sc.tl.umap(adata_all)

In [None]:
import matplotlib as mpl
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['xtick.labelsize'] = 8
mpl.rcParams['ytick.labelsize'] = 8
meta = 'sample'
fig, ax = plt.subplots(1, 1, figsize=(3,3), dpi=300)
ax = sc.pl.umap(adata_all[np.random.permutation(len(adata_all))],color=meta,
                frameon=False, s=2, ax=ax,
                title='UMAP before Starfysh sample integration'
                )
for ext in ['pdf', 'png', 'svg']:
    fig.savefig(
        f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/tissue_{meta}_integrated_all_umap_before_intergate.{ext}',
        bbox_inches='tight', 
        dpi=300)
plt.show()

In [None]:
# After integration
adata_integrate.obs[gene_sig.columns]=pd.DataFrame(adata_integrate.obsm['qc_m'],
                                                   columns=gene_sig.columns,
                                                   index=adata_integrate.obs_names)

In [None]:
#adata_sub = adata_integrate[adata_integrate.obs.index.isin(adata_sub.obs.index)]


In [None]:
adata_integrate

In [None]:
adata_integrate.obs[gene_sig.columns].sum(axis = 1)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10,10), dpi=200)
ax = sc.pl.umap(adata_integrate,color='sample',
                frameon=False, s=15, ax=ax,
                title='UMAP after Starfysh sample integration'
                )

plt.show()

##### It's observed that Starfysh's integration through learnt `qc` embedding space corrected for batch-effects introduced by different samples / tumor subtypes via joint deconvolution

In [None]:
adata_integrate

#### Display deconvolution profiles


In [None]:
for color in gene_sig.columns:
    
    fig, ax = plt.subplots(1, 1,figsize=(3, 3), dpi=500)
    sc.pl.umap(adata_integrate, color=color,  
               frameon=False, vmax=0.2,
               s=10, cmap='Spectral_r', ax=ax,
               title=color
              )
    for ext in ['pdf', 'png', 'svg']:
        fig.savefig(f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{color}_tissue_integrated_all_umap.{ext}', bbox_inches='tight', dpi=300)
        plt.show()

#### Spatial Hub calculation

In [None]:
import scanpy.external as sce

adata_integrate.obsm['qz_m'] = np.array(inference_outputs['qz_m'].detach().cpu().numpy())
adata_integrate.obsm['qc_m'] = np.array(inference_outputs['qc_m'].detach().cpu().numpy())
adata_integrate.obsm['X_pca'] = adata_integrate.obsm['qc_m']

del inference_outputs

import gc
gc.collect()

sc.pp.neighbors(adata_integrate)
adata_integrate.write(os.path.join(base_path, 'adata_integrate_final1.h5ad'))
sc.tl.leiden(adata_integrate)
adata_integrate.write(os.path.join(base_path, 'adata_integrate_final2.h5ad'))

In [None]:
# Calculate hubs via Phenograph clustering
sce.tl.phenograph(adata_integrate, clustering_algo="louvain", k=500)
adata_integrate.obs['hub'] = adata_integrate.obs['pheno_louvain'].astype('category')
adata_integrate.write(os.path.join(base_path, 'adata_integrate_final.h5ad'))

In [None]:
sc.tl.umap(adata_integrate)
adata_integrate.write(os.path.join(base_path, 'adata_integrate_final.h5ad'))

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import scanpy as sc
import numpy as np

# Update font settings
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['xtick.labelsize'] = 8
mpl.rcParams['ytick.labelsize'] = 8

# Your custom color map
new_colormap = [
    'darkslateblue',    # 0
    'cornflowerblue',   # 1
    'red',              # 2
    'blueviolet',       # 3
    'skyblue',          # 4
    'orchid',           # 5
    'yellowgreen',      # 6
    'palevioletred',    # 7
    'orange',           # 8
    'cadetblue',        # 9
    'limegreen',        # 10
    'cyan',             # 11
    'gold',             # 12
    'slategray',        # 13
    'olive',            # 14
    'blue',             # 15
    'linen',            # 16
    'mistyrose',        # 17
    'peru',             # 18
    'darkturquoise',    # 19
    'teal',             # 20
    'salmon',           # 21
    'violet',           # 22
    'dodgerblue',       # 23
    'darkgreen',        # 24
    'mediumaquamarine', # 25
    'tomato',           # 26
    'sandybrown',       # 27
    'darkkhaki',        # 28
    'lightseagreen',    # 29
    'mediumorchid',     # 30
    'crimson',          # 31
    'olivedrab',        # 32
    'steelblue',        # 33
    'plum',             # 34
]

# For the GVHD grade visualization
meta = 'hub'

# Create figure
fig, ax = plt.subplots(1, 1, figsize=(3,3), dpi=300)

# Plot UMAP with custom colors
# If your 'grade' categories are named like 'Grade 0', 'Grade 1', etc.,
# you'll need to map the colors appropriately
ax = sc.pl.umap(adata_integrate[np.random.permutation(len(adata_integrate))],
                color=meta,
                frameon=False, 
                s=2, 
                ax=ax,
                title='UMAP after Starfysh sample integration',
                palette=new_colormap  # Use your custom color map
                )

# Save figures
for ext in ['pdf', 'png', 'svg']:
    fig.savefig(
        f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/tissue_{meta}_integrated_all_umap.{ext}',
        bbox_inches='tight', 
        dpi=300)

plt.show()

# For cluster visualization, you would do similarly:
# meta = 'cluster' or whatever your cluster column is named
# Then repeat the plotting code with this meta variable

In [None]:
import matplotlib as mpl
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['xtick.labelsize'] = 8
mpl.rcParams['ytick.labelsize'] = 8
meta = 'sample'
fig, ax = plt.subplots(1, 1, figsize=(3,3), dpi=300)
ax = sc.pl.umap(adata_integrate[np.random.permutation(len(adata_integrate))],color=meta,
                frameon=False, s=2, ax=ax,
                title='UMAP after Starfysh sample integration'
                )
for ext in ['pdf', 'png', 'svg']:
    fig.savefig(
        f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/tissue_{meta}_integrated_all_umap.{ext}',
        bbox_inches='tight', 
        dpi=300)
plt.show()

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import scanpy as sc
import numpy as np

# Update font settings
mpl.rcParams.update({'font.size': 8, 'svg.fonttype': 'none'})
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['xtick.labelsize'] = 8
mpl.rcParams['ytick.labelsize'] = 8

meta = 'grade'

# Create a custom color map for GVHD grades
# Replace these colors with your preferred ones
grade_colors = {'ND': '#09BB8C', 'Mild': '#005A8F', 'Severe': '#B85000'
}
# Create figure
fig, ax = plt.subplots(1, 1, figsize=(3,3), dpi=300)

# Plot UMAP with custom colors
ax = sc.pl.umap(adata_integrate[np.random.permutation(len(adata_integrate))],
                color=meta,
                frameon=False, 
                s=2, 
                ax=ax,
                title='UMAP after Starfysh sample integration',
                palette=grade_colors  # Add this line to specify custom colors
                )

# Save figures
for ext in ['pdf', 'png', 'svg']:
    fig.savefig(
        f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/tissue_{meta}_integrated_all_umap.{ext}',
        bbox_inches='tight', 
        dpi=300)

plt.show()

In [None]:
meta = 'grade'

# Get unique grades
unique_grades = adata_integrate.obs[meta].unique()

# Create a separate plot for each grade
for grade in unique_grades:
    # Create subset of data for this grade
    mask = adata_integrate.obs[meta] == grade
    adata_subset = adata_integrate[mask]
    
    # Create plot
    fig, ax = plt.subplots(1, 1, figsize=(10,10), dpi=200)
    ax = sc.pl.umap(adata_subset[np.random.permutation(len(adata_subset))],
                    color=meta,
                    frameon=False, 
                    s=15, 
                    ax=ax,
                    title=f'UMAP after Starfysh integration - Grade {grade}'
                    )
    
    # Save in multiple formats
    for ext in ['pdf', 'png', 'svg']:
        fig.savefig(
            f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/tissue_{meta}_grade_{grade}_integrated_umap.{ext}',
            bbox_inches='tight', 
            dpi=900
        )
    plt.show()

In [None]:
adata_integrate.obs['array_col'] = pd.to_numeric(adata_integrate.obs['array_col'], errors='coerce').fillna(0).astype(int)

adata_integrate.obs['array_row'] = pd.to_numeric(adata_integrate.obs['array_row'], errors='coerce').fillna(0).astype(int)

In [None]:
new_colormap = [
   'red',
   'navy',
   'forestgreen', 
   'magenta',
   'gold',
   'cyan',
   'darkorange',
   'darkmagenta',
   'darkgreen',
   'indigo',
   'deeppink',
   'darkturquoise',
   'orangered',
   'seagreen', 
   'darkviolet',
   'darkcyan',
   'peru',
   'steelblue',
   'saddlebrown',
   'darkslateblue',
   'hotpink',
   'lightseagreen',
   'goldenrod',
   'purple',
   'tomato',
   'mediumturquoise',
   'mediumvioletred',
   'midnightblue',
   'sandybrown',
   'darkolivegreen',
   'darkgoldenrod',
   'sienna',
   'darkkhaki',
   'darkseagreen',
   'rosybrown',
   'mediumaquamarine'
]

In [None]:
adata_integrate

In [None]:
s=0.5
vmax=0.08
vmin=None
for sample in meta_info['sample']:
    adata = adata_integrate[adata_integrate.obs['sample'] == sample]
    all_loc = adata.obs[['array_col', 'array_row']]
    fig,axs= plt.subplots(1,1,figsize=(5,5),dpi=400)
    color_list = np.array(adata.obs['hub'])
    for i in range(len((adata.obs['hub'].unique()))):
        g=axs.scatter(all_loc.iloc[color_list==i,0],-all_loc.iloc[color_list==i,1],
                  s=s,
                  vmin=vmin,
                    c=new_colormap[i],
                  vmax=vmax, edgecolors ="none")

    axs.set_xticks([])
    #plt.title(‘s’)
    axs.set_yticks([])
    plt.axis("off")
    plt.legend(list(range(len((adata.obs['hub'].unique())))), bbox_to_anchor=(1,0.5))
    for ext in ['pdf', 'png', 'svg']:
        plt.savefig(
            f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/{sample}_tissue_hub_integrated.{ext}',
            bbox_inches='tight', 
            dpi=900)

In [None]:
markers_df = gene_sig
select_sigs = markers_df.columns

for c in select_sigs:
    marker_genes = []
    cc = np.array(markers_df[c].values.tolist())
    marker_genes = np.append(marker_genes,[x for x in cc if str(x) != 'nan'])
    real_genes = np.intersect1d(adata_integrate.var.index, marker_genes)
    print (c)
    print(marker_genes)
    print("Genes that are in the adata")
    print(real_genes)
    val = np.sum(adata_integrate[:,real_genes].X,axis=1) # sum for each row across columns
    val = np.array(val).flatten()

    label = str(c)
    adata_integrate.obs[label+'gene_expression'] = val/len(real_genes)
    s=1
    #vmax=0.08
    vmin=None
    all_loc = adata_integrate.obsm['X_umap']
    fig,axs= plt.subplots(1,1,figsize=(1.8,2),dpi=400)
    g = axs.scatter(all_loc[:,0],all_loc[:,1],
                  s=s,
                  vmin=vmin,
                  c=adata_integrate.obs[label+'gene_expression'])
    axs.set_xticks([])
    
    axs.set_yticks([])
    plt.axis("off")
    fig.colorbar(g,label=c)

In [None]:
adata_integrate.var

In [None]:
adata_integrate.obs['hub'].unique()

In [None]:
gene_sig

In [None]:
adata_integrate.obs['sample']

In [None]:
new_colormap

In [None]:
hub_type_list = [adata_integrate.obs['hub'].unique()]
proportions_df = adata_integrate.obs[gene_sig.columns]
sc.settings.set_figure_params(dpi=300, facecolor='white')
plt.rcParams["axes.prop_cycle"] = plt.cycler("color",list(pd.DataFrame(new_colormap).transpose()[hub_type_list[0]].iloc[0,:]))
qc_m_df = proportions_df.copy()
qc_m_df['pheno_louvain']=adata_integrate.obs['hub']
qc_m_df['sample']=adata_integrate.obs['sample']
qc_m_df['sample_type']=adata_integrate.obs['sample_type']
qc_m_df_reorder = qc_m_df.sort_values("pheno_louvain")
qc_m_df_reorder['count']=1

agg_tips = qc_m_df_reorder.groupby(['sample', 'pheno_louvain'])['count'].sum().unstack().fillna(0)
#agg_tips = agg_tips.reindex(['MBC_1A','MBC_1B','CID4290','CID4535',
                             #'MBC_2A','MBC_2B','CID4465','CID44971','1142243F','1160920F',
                             #'MBC_3A','MBC_3B','MBC_4A','MBC_4B'])

med_lib = 1#(agg_tips.sum(axis=1).median())
indx = np.array(med_lib/agg_tips.sum(axis=1))
for i in range(indx.shape[0]):
    agg_tips.iloc[i,:] = agg_tips.iloc[i,:] * indx[i]

agg_tips = agg_tips[hub_type_list[0]]

fig, ax = plt.subplots(figsize=(4, 4))
agg_tips.plot(kind='bar', stacked=True, ax=ax, linewidth=0)

# Set plot labels and limits
plt.xlabel('Samples')
plt.ylabel('Percent')
plt.ylim([0, 1])  # Set y-axis limits from 0 to 1 for percentages
plt.grid(False)

# Add legend with specified labels
legend_labels = agg_tips.columns
plt.legend(legend_labels, loc='upper right', frameon=False, bbox_to_anchor=(1.5, 1.0), fontsize='small')  # Adjust the legend location as needed

# Adjust layout and display the plot
plt.tight_layout()
plt.show()

In [None]:
adata_integrate.obs[gene_sig.columns] = adata_integrate.obsm['qc_m']
proportions_df = adata_integrate.obs[gene_sig.columns]
proportions_df['hub'] = adata_integrate.obs['hub']

composition = proportions_df.groupby('hub').sum()
composition_normalized = composition.div(composition.sum(axis=1), axis=0)

plt.figure(figsize=(60, 30))
composition_normalized.plot(kind='bar', stacked=True, color=new_colormap)
plt.xlabel('Spatial Spots (Hubs)')
plt.ylabel('Normalized Cell Type Proportions')
plt.title('Composition of Cell Types in Spatial Spots (Normalized)')
plt.legend(title='Cell Types', bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=26)
plt.xticks(rotation=45, ha='right',fontsize=4)  # Rotate labels
plt.tight_layout()
for ext in ['pdf', 'png', 'svg']:
    plt.savefig(
        f'/Users/lingting/Documents/GVHD_project/Paper_ready_pipeline/Figures/spatial_figures/integrated_hub_bar_plot.{ext}',
        bbox_inches='tight', 
        dpi=900)

In [None]:
adata_integrate.obs['hub_str'] = adata_integrate.obs['hub'].astype(str)

In [None]:
ax = sc.pl.umap(adata_integrate,color='hub_str',
                frameon=False, s=15, ax=ax,
                title='UMAP after Starfysh sample integration'
                )

In [None]:
ax = sc.pl.umap(adata_integrate,color='sample',
                frameon=False, s=0.2, ax=ax,
                title='UMAP after Starfysh sample integration'
                )

In [None]:
adata_integrate.obs.columns

In [None]:
adata_integrate.obs['SampleID_Grade'] = adata_integrate.obs['sample'].astype(str)+'_'+adata_integrate.obs['grade'].astype(str)

In [None]:
sc.pl.umap(adata_integrate, color='SampleID_Grade')

In [None]:
adata_integrate.write(os.path.join(base_path, 'adata_integrate_final.h5ad'))

---