In [1]:
import commot as ct
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import h5py

path = 'C:/Users/kmccr/Desktop/MRSA Mouse Kidney for Takashi/'
# adata2 = sc.datasets.visium_sge(sample_id='V1_Mouse_Brain_Sagittal_Posterior')

In [None]:
# Create sample list to loop over
samp_list = ['A2','A4','S2','S4']

# Loop over samples
for samp_name in samp_list:
    
    ### 1. DATA
    # Import expression matrix and metadata
    spatial = pd.read_csv(f'exported from Seurat/spatial_{samp_name}.csv', index_col=0)
    spatial = spatial.iloc[:,3:]

    expr = pd.read_csv(f'exported from Seurat/expr_{samp_name}.csv', index_col=0)
    expr.index = expr.index.str.replace('^mm10-', '', regex=True)
    meta = pd.read_csv(f'exported from Seurat/meta_{samp_name}.csv', index_col=0)

    # Create the AnnData object
    adata = sc.AnnData(X=expr.T)  # Transpose so cells are rows and genes are columns
    adata.obs = meta  # Assign metadata to `.obs` (cell-level metadata)

    # Add spatial coordinates to the AnnData object
    adata.obsm["spatial"] = spatial.values  # Store spatial coordinates as a NumPy array
    adata.obsm['spatial'][:, [0, 1]] = adata.obsm['spatial'][:, [1, 0]] # switch columns

    # Preprocessing
    adata.var_names_make_unique()
    adata.raw = adata
    sc.pp.normalize_total(adata, inplace=True)
    sc.pp.log1p(adata)

    ### 2. SCALE FACTORS
    # Load the CSV file as a DataFrame
    sf = pd.read_csv(f'exported from Seurat/scalefactors_{samp_name}.csv')
    # Set up as dict
    sfdict = sf.to_dict(orient='list')
    for key in sfdict:
        sfdict[key] = sfdict[key][0]
    # Set hires scale factor to the same as lowres
    sfdict['hires'] = sfdict['lowres']
    # Rename keys
    key_map = {
        'spot': 'spot_diameter_fullres',
        'hires': 'tissue_hires_scalef',
        'fiducial': 'fiducial_diameter_fullres',
        'lowres': 'tissue_lowres_scalef'
    }
    sfdict = {key_map.get(k, k): float(v) for k, v in sfdict.items()}

    ### 3. IMAGE
    with h5py.File(f'exported from Seurat/image_{samp_name}.h5', "r") as f:
        image_array = np.array(f["dataset"])
    image_array = image_array.transpose(2, 1, 0)

    # Store the image in the uns attribute of AnnData
    adata.uns["spatial"] = {}
    adata.uns["spatial"]["slice"] = {
        "images": {'hires':image_array},
        "scalefactors": sfdict
    }

    ### SPATIAL PLOT
    sc.pl.spatial(adata, color='transfer_subset', show=False)
    plt.title(f'{samp_name} Spatial Plot', fontsize=16)
    plt.savefig(f'{samp_name} Spatial Plot.png', dpi=300, bbox_inches='tight')

    ### Spatial communication inference using CellChatDB ligand-receptor database
    df_cellchat = ct.pp.ligand_receptor_database(database='CellChat', species='mouse')
    # Filter the LR pairs to keep only the pairs with both ligand and receptor expressed in at least 5% of the spots.
    df_cellchat_filtered = ct.pp.filter_lr_database(df_cellchat, adata, min_cell_pct=0.05)

    ### RUN
    ct.tl.spatial_communication(adata, database_name='cellchat', df_ligrec=df_cellchat_filtered, dis_thr=500, heteromeric=True, pathway_sum=True)
    # Save the AnnData object to an HDF5 file
    adata.write(f'{samp_name} adata.h5ad')
    # Read the saved AnnData object from the file
    # adata = sc.read(f'{samp_name} adata.h5ad')

    
    ### Find top pathways and sort
    sums_df = pd.DataFrame()
    for key in adata.obsp.keys():
        new_row_df = pd.DataFrame({'Name': [key], 'Element Sum': [np.sum(np.abs(adata.obsp[key].toarray()))] })
        sums_df = pd.concat([sums_df, new_row_df], ignore_index=True)
    sums_df = sums_df.sort_values(by='Element Sum', ascending=False)

    # Get pathway list 
    pathways = sums_df["Name"].str.replace("commot-cellchat-", "", regex=False).tolist()

    # Iterate over pathways and save CCC plots
    count = 0
    for pathway in pathways:
        count += 1
        ct.tl.communication_direction(adata, database_name='cellchat', pathway_name=pathway, k=5)
        ct.pl.plot_cell_communication(adata, database_name='cellchat', pathway_name=pathway, plot_method='stream', background_legend=True,
            scale=0.00002, ndsize=8, grid_density=0.5, summary='sender', background='image', clustering='transfer_subset', cmap='viridis', arrow_color = "royalblue",
            normalize_v = True, normalize_v_quantile=0.995)
        # Resize the figure after plotting
        fig = plt.gcf()  # Get the current figure
        fig.set_size_inches(12, 8)  # Resize to 12x8 inches
        # Add title
        plt.title(f"{samp_name}, Cell Communication Plot for Pathway: {pathway}", fontsize=16)
        # Save the plot
        plt.savefig(f'CCC Plots from COMMOT/{samp_name}/{samp_name} CCC Plot {count} - {pathway}.png', dpi=300)