In [98]:
import torch
import os
import scanpy as sp
import numpy as np
import pandas as pd
import torch_scatter
import utils
import argparse

parser = argparse.ArgumentParser(description='Main program for sencells')

parser.add_argument('--output_dir', type=str, default='./outputs', help='')
parser.add_argument('--exp_name', type=str, default='data1', help='')
parser.add_argument('--device_index', type=int, default=0, help='')
parser.add_argument('--retrain', action='store_true', default=False, help='')
parser.add_argument('--timestamp', type=str,  default="", help='use default')

parser.add_argument('--seed', type=int, default=40, help='different seed for different experiments')
parser.add_argument('--n_genes', type=str, default='full', help='set 3000, 8000 or full')
parser.add_argument('--ccc', type=str, default='type2', help='type1: cell-cell edge with weight in 0 and 1. type2: cell-cell edge with weight in 0 to 1. type3: no cell-cell edge')
parser.add_argument('--gene_set', type=str, default='full', help='senmayo or fridman or cellage or goterm or goterm+fridman or senmayo+cellage or senmayo+fridman or senmayo+fridman+cellage or full')

parser.add_argument('--gat_epoch', type=int, default=30, help='use default')


# --------------------------------------------------------------------------------------------------- #
# Write these code to fit our data input. This is for our @Yi, @Ahmed, and @Hu.
parser.add_argument('--input_data_count', type=str, default="/bmbl_data/huchen/deepSAS_data/fixed_data_0525.h5ad", help='it is a path to a adata object (.h5ad)')
parser.add_argument('--input_data_CCC_file', type=str, default="", help='it is a path to a CCC file (.csv or .npy)')
# --------------------------------------------------------------------------------------------------- #


# --------------------------------------------------------------------------------------------------- #
# Subsampling argument for our following version. Please check this @Yi, @Ahmed, and @Hu.
parser.add_argument('--subsampling', action='store_true', default=False, help='subsampling')
# --------------------------------------------------------------------------------------------------- #


# For @Hao
# Hao: Just delete these 3 parameters.
# --------------------------------------------------------------------------------------------------- #
# is this sencell_num parameter not used? Please check this @Hao.
parser.add_argument('--sencell_num', type=int, default=600, help='use default')
# is this sengene_num parameter not used? Please check this @Hao.
parser.add_argument('--sengene_num', type=int, default=200, help='use default')
# is this sencell_epoch parameter not used? Please check this @Hao.
parser.add_argument('--sencell_epoch', type=int, default=40, help='use default')
# --------------------------------------------------------------------------------------------------- #




parser.add_argument('--cell_optim_epoch', type=int, default=50, help='use default')


# For @Hao
# Hao: This is the emb size for GAT hidden embedding size
# --------------------------------------------------------------------------------------------------- #
# is this emb_size parameter for what, for GAT? is default 12? Please check this @Hao.
parser.add_argument('--emb_size', type=int, default=12, help='use default')
# --------------------------------------------------------------------------------------------------- #



parser.add_argument('--batch_id', type=int, default=0, help='use default')

args = parser.parse_args([])

adata, cluster_cell_ls, cell_cluster_arr, celltype_names = utils.load_data_newfix("/bmbl_data/huchen/deepSAS_data/fixed_data_0525.h5ad")

new_data, markers_index,\
    sen_gene_ls, nonsen_gene_ls, gene_names = utils.process_data(
        adata, cluster_cell_ls, cell_cluster_arr,args)
print(new_data)

gene_cell = new_data.X.toarray().T

load_data_newfix ...
Number of cells: 24125
Number of genes: 15844
The number of cell types 32
celltype names: ['Fibrotic fibroblast', 'AT2', 'EC Venous', 'EC Arterial', 'Dendritic cells', 'CD4+ T Cells', 'KRT5-/KRT17+ cells', 'EC General capillary', 'Monocyte-derived macrophage', 'CD8+ T Cells', 'Smooth muscle cells', 'Plasma cells', 'Ciliated', 'SPP1+ macrophages', 'MUC5B+ club', 'Alveolar fibroblasts', 'Alveolar macrophages', 'SCGB3A2+/SCGB1A1+ club', 'Monocyte', 'Mast cells', 'NK cells', 'Pericyte', 'Goblet', 'AT1', 'EC Lymphatic', 'B cells', 'Inflammatory fibroblasts', 'AT2 transitional', 'Peribronchial fibroblasts', 'Adventitial fibroblasts', 'Proliferating fibroblasts', 'Basal']
---------------------------  ----
Fibrotic fibroblast          2433
AT2                          2195
EC Venous                    2172
EC Arterial                  1847
Dendritic cells              1425
CD4+ T Cells                 1211
KRT5-/KRT17+ cells           1145
EC General capillary          950

In [102]:
file_path = "/bmbl_data/huchen/sencell_data1_base/outputs/data1/data1_sencellgene-epoch4base_decimal.data"

# Load the saved object
loaded_data = torch.load(file_path)

# Unpack the loaded data
sencell_dict, sen_gene_ls, attention_scores, edge_index_selfloop = loaded_data

print(f"Number of SnCs: {len(sencell_dict.keys())}")
print(f"Number of SnGs: {len(sen_gene_ls)}")

Number of SnCs: 784
Number of SnGs: 299


In [103]:
file_path_repro = "/bmbl_data/huchen/sencell_data1_repro/outputs/data1/data1_sencellgene-epoch4repro_decimal.data"

# Load the saved object
loaded_data = torch.load(file_path_repro)

# Unpack the loaded data
sencell_dict_repro, sen_gene_ls_repro, attention_scores_repro, edge_index_selfloop_repro = loaded_data

print(f"Number of SnCs: {len(sencell_dict_repro.keys())}")
print(f"Number of SnGs: {len(sen_gene_ls_repro)}")

Number of SnCs: 771
Number of SnGs: 299


In [104]:
repeat_cells = len(set(sencell_dict.keys()) & set(sencell_dict_repro.keys()))
repeat_genes = len(set(sen_gene_ls_repro.tolist()) & set(sen_gene_ls.tolist()))
print(f"Overlap cells: {repeat_cells}, Overlap genes: {repeat_genes}")

Overlap cells: 763, Overlap genes: 270


In [105]:
inital_masker = pd.read_csv("initial_sene_markers.csv")

In [106]:
base_filtered_gene_names = []
for gene_idx in sen_gene_ls:
    gene_idx = gene_idx.item()
    base_filtered_gene_names.append(new_data.var.index[gene_idx])
base_gene_idx2names = {'gene_idx': sen_gene_ls, 'gene_name': base_filtered_gene_names}
base_gene_idx2names_df = pd.DataFrame(base_gene_idx2names)
base_predict_sens_list = list(set(base_gene_idx2names_df['gene_name']).difference(set(inital_masker['gene_name'])))
print(base_predict_sens_list)
print(len(base_predict_sens_list))

['CERS1', 'FUT5', 'HEPN1', 'LAMB4', 'SCYGR8', 'ZNF80', 'ALX4', 'ACSM6', 'BRSK2', 'JAKMIP1', 'HOXA11', 'C12orf40', 'CD8B2', 'CALCA', 'CELA1', 'ASCL1', 'DEFB4B', 'PRSS51', 'SLC4A10', 'PBOV1', 'LDHAL6A', 'TMPRSS11E', 'L1CAM', 'C4orf50']
24


In [107]:
repro_filtered_gene_names = []
for gene_idx in sen_gene_ls_repro:
    gene_idx = gene_idx.item()
    repro_filtered_gene_names.append(new_data.var.index[gene_idx])
repro_gene_idx2names = {'gene_idx': sen_gene_ls, 'gene_name': repro_filtered_gene_names}
repro_gene_idx2names_df = pd.DataFrame(repro_gene_idx2names)
repro_predict_sens_list = list(set(repro_gene_idx2names_df['gene_name']).difference(set(inital_masker['gene_name'])))
print(repro_predict_sens_list)
print(len(base_predict_sens_list))

['CERS1', 'FUT5', 'LAMB4', 'SCYGR8', 'ZNF80', 'ALX4', 'ACSM6', 'BRSK2', 'JAKMIP1', 'HOXA11', 'C12orf40', 'CALCA', 'CILP2', 'CELA1', 'ASCL1', 'DEFB4B', 'SKA1', 'ITPRID1', 'SLC4A10', 'PBOV1', 'LDHAL6A', 'TMPRSS11E', 'L1CAM', 'C4orf50']
24


In [108]:
print(set(base_predict_sens_list) & set(repro_predict_sens_list))
print(len(set(base_predict_sens_list) & set(repro_predict_sens_list)))

{'CERS1', 'FUT5', 'LAMB4', 'SCYGR8', 'ZNF80', 'ALX4', 'ACSM6', 'BRSK2', 'JAKMIP1', 'HOXA11', 'C12orf40', 'CALCA', 'CELA1', 'ASCL1', 'DEFB4B', 'SLC4A10', 'PBOV1', 'LDHAL6A', 'TMPRSS11E', 'L1CAM', 'C4orf50'}
21


In [109]:
sencell_indexs=list(sencell_dict.keys())
sencell_cluster = []
for i in sencell_indexs:
    ct=new_data.obs.iloc[i-new_data.shape[1]].clusters
    sencell_cluster.append(ct)
sencell_df = pd.DataFrame({'sencell_index': sencell_indexs, 'sencell_cluster': sencell_cluster})
sencell_df.head()

Unnamed: 0,sencell_index,sencell_cluster
0,15851,SPP1+ macrophages
1,15853,SPP1+ macrophages
2,15870,AT2
3,15898,SPP1+ macrophages
4,15900,AT2


In [110]:
def AttentionEachCell(gene_cell, sen_gene_ls, edge_index_selfloop):
    """
    Calculate SnCs score based on senescent gene list
    Same as in the main4.py
    """
    # Convert the list of senescent genes to a tensor index
    # gene_index = torch.tensor(sen_gene_ls, dtype=torch.long)
    gene_index = sen_gene_ls
    
    # Create a boolean mask for senescent genes
    total_nodes = gene_cell.shape[0] + gene_cell.shape[1]
    gene_mask = torch.zeros(total_nodes, dtype=torch.bool)
    gene_mask[gene_index] = True

    # Identify edges where the target is a cell node
    cell_offset = gene_cell.shape[0]
    edge_mask_cell = edge_index_selfloop[1] >= cell_offset
    
    # Identify edges where the source is a senescent gene
    edge_mask_gene = gene_mask[edge_index_selfloop[0]]
    
    # Combined mask for edges from senescent genes to cells
    edge_mask_selected = edge_mask_cell & edge_mask_gene
    
    # Indices of selected edges
    edges_selected_indices = edge_mask_selected.nonzero().squeeze()

    # Target cell indices and attention scores for selected edges
    selected_edges_targets = edge_index_selfloop[1][edges_selected_indices]
    selected_attention_scores = attention_scores[edges_selected_indices].squeeze()
    
    # Adjust cell indices to start from 0
    cell_indices_in_range = selected_edges_targets - cell_offset
    
    # Number of cells
    num_cells = gene_cell.shape[1]
    
    # Sum and count attention scores per cell
    attention_sums = torch_scatter.scatter(selected_attention_scores, cell_indices_in_range,
                                           dim_size=num_cells, reduce='sum')
    attention_counts = torch_scatter.scatter(torch.ones_like(selected_attention_scores),
                                             cell_indices_in_range, dim_size=num_cells, reduce='sum')
    
    # Compute mean attention score per cell, handle division by zero
    attention_s_per_cell = torch.zeros(num_cells)
    valid_cells = attention_counts > 0
    # NOTE: Mean operation for each attention edge
    attention_s_per_cell[valid_cells] = attention_sums[valid_cells] / attention_counts[valid_cells]

    return attention_s_per_cell.detach().tolist()

attention_s_per_cell = AttentionEachCell(gene_cell, sen_gene_ls, edge_index_selfloop)

In [111]:
if_senCs = np.zeros(new_data.shape[0],dtype=np.int64)
if_senCs[[i - new_data.shape[1] for i in sencell_indexs]] = 1
list(if_senCs)

[0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,


In [112]:
SnC_scores = pd.DataFrame({'cell_id':list(range(new_data.shape[0])), 'cell_name': new_data.obs_names, 'cell_type':new_data.obs['clusters'], 'ifSnCs': list(if_senCs), 'SnC scores': attention_s_per_cell})

In [113]:
SnC_scores

Unnamed: 0,cell_id,cell_name,cell_type,ifSnCs,SnC scores
WL_20146_AAACAAGCAGTAATACATACGTCA-1,0,WL_20146_AAACAAGCAGTAATACATACGTCA-1,Monocyte-derived macrophage,0,0.000069
WL_20146_AAACCAGGTTAGGCGGATACGTCA-1,1,WL_20146_AAACCAGGTTAGGCGGATACGTCA-1,AT2,0,0.000062
WL_20146_AAACGGGCAAGCCACGATACGTCA-1,2,WL_20146_AAACGGGCAAGCCACGATACGTCA-1,SPP1+ macrophages,0,0.000070
WL_20146_AAACGTTCAACAAGTTATACGTCA-1,3,WL_20146_AAACGTTCAACAAGTTATACGTCA-1,AT2,0,0.000048
WL_20146_AAACTGTCAGAAACTTATACGTCA-1,4,WL_20146_AAACTGTCAGAAACTTATACGTCA-1,SPP1+ macrophages,0,0.000000
...,...,...,...,...,...
OSU10161_LL_TTTGAGAAGCAAGTTAATGTTGAC-1,24120,OSU10161_LL_TTTGAGAAGCAAGTTAATGTTGAC-1,Goblet,0,0.000061
OSU10161_LL_TTTGCGGGTCCGCTAAATGTTGAC-1,24121,OSU10161_LL_TTTGCGGGTCCGCTAAATGTTGAC-1,Smooth muscle cells,0,0.000059
OSU10161_LL_TTTGCGGGTGGTTCTGATGTTGAC-1,24122,OSU10161_LL_TTTGCGGGTGGTTCTGATGTTGAC-1,SCGB3A2+/SCGB1A1+ club,0,0.000055
OSU10161_LL_TTTGCTGAGTCTTGCAATGTTGAC-1,24123,OSU10161_LL_TTTGCTGAGTCTTGCAATGTTGAC-1,SCGB3A2+/SCGB1A1+ club,0,0.000056


In [114]:
SnC_scores.to_csv("new_output/data1_Cell_Table1_SnC_scores.csv", index=False)

In [115]:
new_data.obs = pd.merge(new_data.obs, SnC_scores, left_index=True, right_index=True)

In [116]:
new_data.obs.columns

Index(['orig.ident', 'nCount_RNA', 'nFeature_RNA', 'percent.mt', 'Sample',
       'Treatment', 'Injury', 'Status', 'Area', 'Type', 'Age', 'Age_Status',
       'Sex', 'Treatment_Age_Status', 'Treatment_Age_Status_Sex',
       'Treatment_Age_Status_Sex_Area', 'Names_Sample', 'Names_Age_Status',
       'Names_Sample_Age_Status', 'nCount_SCT', 'nFeature_SCT',
       'integrated_snn_res.0.2', 'seurat_clusters', 'integrated_snn_res.0.5',
       'predicted.id', 'prediction.score.SCGB3A2..SCGB1A1..club',
       'prediction.score.Plasma.cells', 'prediction.score.CD4..T.Cells',
       'prediction.score.Goblet', 'prediction.score.B.cells',
       'prediction.score.EC.Arterial', 'prediction.score.Smooth.muscle.cells',
       'prediction.score.Ciliated',
       'prediction.score.Proliferating.fibroblasts', 'prediction.score.AT2',
       'prediction.score.SPP1..macrophages',
       'prediction.score.Monocyte.derived.macrophage',
       'prediction.score.Alveolar.fibroblasts', 'prediction.score.Mast.

In [117]:
new_data.obs['ifSnCs'] = new_data.obs['ifSnCs'].astype(str)

In [118]:
def DEGTable(new_data):
    # prepare adata for deg
    adata_deg=new_data.copy()
    sp.pp.normalize_total(adata_deg, target_sum=1e4)
    sp.pp.log1p(adata_deg)
    sp.pp.scale(adata_deg)

    cell_types = adata_deg.obs['clusters'].unique()
    for cell_type in cell_types:
        adata_deg_sub=adata_deg[adata_deg.obs['clusters']==cell_type].copy()
        value_counts = (adata_deg_sub.obs['ifSnCs'] == '1').sum()
        if value_counts <= 5:
            continue

        print(f"{cell_type}, Number of SnCs: {value_counts}")    
        sp.tl.rank_genes_groups(adata_deg_sub, groupby='ifSnCs', groups=["1"], 
                                reference="0", method='wilcoxon')
        # Extract the results into a DataFrame
        degs = pd.DataFrame({
            'gene': adata_deg_sub.uns['rank_genes_groups']['names']['1'],
            'p_val': adata_deg_sub.uns['rank_genes_groups']['pvals']['1'],
            'logFC': adata_deg_sub.uns['rank_genes_groups']['logfoldchanges']['1'],
            'p_val_adj': adata_deg_sub.uns['rank_genes_groups']['pvals_adj']['1']
        })
        
        degs=degs.sort_values(by='logFC')

        save_ct_name = cell_type.replace(' ', '_')
        save_ct_name = cell_type.replace('/', '_')
        print(save_ct_name)
        degs.to_csv(f"new_output/DEG_res/{save_ct_name}_DEG_results.csv", index=False)

DEGTable(new_data)

Monocyte-derived macrophage, Number of SnCs: 41


  self.stats[group_name, 'logfoldchanges'] = np.log2(


Monocyte-derived macrophage
AT2, Number of SnCs: 120


  self.stats[group_name, 'logfoldchanges'] = np.log2(


AT2
SPP1+ macrophages, Number of SnCs: 17


  self.stats[group_name, 'logfoldchanges'] = np.log2(


SPP1+ macrophages
CD8+ T Cells, Number of SnCs: 10


  self.stats[group_name, 'logfoldchanges'] = np.log2(


CD8+ T Cells
EC Arterial, Number of SnCs: 33


  self.stats[group_name, 'logfoldchanges'] = np.log2(


EC Arterial
Plasma cells, Number of SnCs: 20


  self.stats[group_name, 'logfoldchanges'] = np.log2(


Plasma cells
MUC5B+ club, Number of SnCs: 21


  self.stats[group_name, 'logfoldchanges'] = np.log2(


MUC5B+ club
Mast cells, Number of SnCs: 21


  self.stats[group_name, 'logfoldchanges'] = np.log2(


Mast cells
CD4+ T Cells, Number of SnCs: 76


  self.stats[group_name, 'logfoldchanges'] = np.log2(


CD4+ T Cells
SCGB3A2+/SCGB1A1+ club, Number of SnCs: 21


  self.stats[group_name, 'logfoldchanges'] = np.log2(


SCGB3A2+_SCGB1A1+ club
Smooth muscle cells, Number of SnCs: 21


  self.stats[group_name, 'logfoldchanges'] = np.log2(


Smooth muscle cells
Dendritic cells, Number of SnCs: 34


  self.stats[group_name, 'logfoldchanges'] = np.log2(


Dendritic cells
Goblet, Number of SnCs: 11


  self.stats[group_name, 'logfoldchanges'] = np.log2(


Goblet
Ciliated, Number of SnCs: 43


  self.stats[group_name, 'logfoldchanges'] = np.log2(


Ciliated
Fibrotic fibroblast, Number of SnCs: 122


  self.stats[group_name, 'logfoldchanges'] = np.log2(


Fibrotic fibroblast
Alveolar macrophages, Number of SnCs: 33


  self.stats[group_name, 'logfoldchanges'] = np.log2(


Alveolar macrophages
EC Venous, Number of SnCs: 48


  self.stats[group_name, 'logfoldchanges'] = np.log2(


EC Venous
KRT5-/KRT17+ cells, Number of SnCs: 55


  self.stats[group_name, 'logfoldchanges'] = np.log2(


KRT5-_KRT17+ cells
EC General capillary, Number of SnCs: 23


  self.stats[group_name, 'logfoldchanges'] = np.log2(


EC General capillary
Pericyte, Number of SnCs: 14
Pericyte


  self.stats[group_name, 'logfoldchanges'] = np.log2(


In [119]:
cluster_count = pd.DataFrame(new_data.obs['clusters'].value_counts())
sencell_cluster_count = pd.DataFrame(sencell_df['sencell_cluster'].value_counts())
sencell_cluster_count = sencell_cluster_count['sencell_cluster']
merged_df = cluster_count.join(sencell_cluster_count).fillna(0)
merged_df = merged_df.rename(columns={'clusters': 'number_of_cells', 'sencell_cluster': "number_of_SnCs"})
merged_df = merged_df.astype(int)
merged_df.index.name = "cell_type"

In [120]:
merged_df.to_csv("new_output/data1_Cell_Table2_SnCs_per_ct.csv")

In [121]:
ct_sencell_indexs={}
row_numbers=np.array(sencell_indexs)-new_data.shape[1]

# Dict {ct: cell_index}
for i in row_numbers:
    ct_=new_data.obs.iloc[i]['clusters']
    if ct_ in ct_sencell_indexs:
        ct_sencell_indexs[ct_].append(i+new_data.shape[1])
    else:
        ct_sencell_indexs[ct_]=[i+new_data.shape[1]]

In [122]:
ct_sencell_indexs.keys()

dict_keys(['SPP1+ macrophages', 'AT2', 'MUC5B+ club', 'Monocyte-derived macrophage', 'CD4+ T Cells', 'EC Arterial', 'CD8+ T Cells', 'Dendritic cells', 'Alveolar macrophages', 'Smooth muscle cells', 'Mast cells', 'KRT5-/KRT17+ cells', 'Fibrotic fibroblast', 'EC Venous', 'SCGB3A2+/SCGB1A1+ club', 'Plasma cells', 'Goblet', 'Ciliated', 'Pericyte', 'EC General capillary'])

Based on SnCs in each cell types, calculate SnG score for each cell type.

In [123]:
def AttentionEachGene(gene_cell, cell_indices, edge_index_selfloop):
    """
    gene attention score from edges connected with cell_indices,
    the same as main4.py, but separate for each cell type.
    """
    num_genes = gene_cell.shape[0]
    num_cells = gene_cell.shape[1]
    total_nodes = num_genes + num_cells

    # Create a mask for the selected cells
    cell_mask = torch.zeros(total_nodes, dtype=torch.bool)
    cell_mask[cell_indices] = True

    # Create a mask for edges where the target node is a gene
    edge_mask_gene = edge_index_selfloop[1] < num_genes
    
    edge_mask_cell = cell_mask[edge_index_selfloop[0]]

    edge_mask_selected = edge_mask_gene & edge_mask_cell

    edges_selected_indices = edge_mask_selected.nonzero().squeeze()

    # Target cell indices and attention scores for selected edges
    selected_edges_targets = edge_index_selfloop[1][edges_selected_indices]
    selected_attention_scores = attention_scores[edges_selected_indices].squeeze()

    # Compute per-gene sums and counts using torch_scatter
    # Counts: Number of times each gene appears in masked_genes
    counts = torch_scatter.scatter(torch.ones_like(selected_attention_scores), selected_edges_targets,
                     dim=0, dim_size=num_genes, reduce='sum')
    # Sums: Sum of attention scores per gene
    sums = torch_scatter.scatter(selected_attention_scores, selected_edges_targets,
                   dim=0, dim_size=num_genes, reduce='sum')
      
    # Avoid division by zero
    res = torch.zeros(num_genes, dtype=torch.float32)
    nonzero_mask = counts > 0
    res[nonzero_mask] = sums[nonzero_mask] / counts[nonzero_mask] # Mean

    return res.detach()

ct2gene_score = {}
# SnGs for each cell type
for ct, cell_indices in ct_sencell_indexs.items():
    score_per_gene = AttentionEachGene(gene_cell, cell_indices, edge_index_selfloop)
    sene_gene_score = score_per_gene[sen_gene_ls].tolist()
    ct2gene_score[ct] = sene_gene_score

In [124]:
ct2gene_score_df = pd.DataFrame(ct2gene_score)
ct2gene_score_df

Unnamed: 0,SPP1+ macrophages,AT2,MUC5B+ club,Monocyte-derived macrophage,CD4+ T Cells,EC Arterial,CD8+ T Cells,Dendritic cells,Alveolar macrophages,Smooth muscle cells,Mast cells,KRT5-/KRT17+ cells,Fibrotic fibroblast,EC Venous,SCGB3A2+/SCGB1A1+ club,Plasma cells,Goblet,Ciliated,Pericyte,EC General capillary
0,0.00000,0.000275,0.000275,0.000273,0.000267,0.000320,0.000270,0.000280,0.000265,0.000270,0.00028,0.000270,0.000273,0.000280,0.0000,0.00028,0.000280,0.000270,0.00028,0.000275
1,0.00028,0.000281,0.000280,0.000279,0.000263,0.000280,0.000290,0.000273,0.000268,0.000285,0.00029,0.000277,0.000276,0.000285,0.0000,0.00000,0.000285,0.000273,0.00028,0.000288
2,0.00029,0.000288,0.000275,0.000293,0.000269,0.000293,0.000000,0.000295,0.000270,0.000000,0.00028,0.000272,0.000282,0.000290,0.0000,0.00000,0.000000,0.000283,0.00029,0.000290
3,0.00031,0.000283,0.000270,0.000270,0.000278,0.000340,0.000285,0.000285,0.000267,0.000280,0.00000,0.000284,0.000279,0.000285,0.0000,0.00030,0.000290,0.000290,0.00028,0.000290
4,0.00000,0.000291,0.000278,0.000305,0.000277,0.000293,0.000000,0.000290,0.000278,0.000000,0.00030,0.000284,0.000285,0.000300,0.0003,0.00000,0.000000,0.000282,0.00000,0.000295
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
294,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.00000,0.000000,0.103037,0.106740,0.0000,0.00000,0.000000,0.000000,0.00000,0.000000
295,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.10181,0.000000,0.104755,0.000000,0.0000,0.00000,0.000000,0.000000,0.00000,0.000000
296,0.00000,0.106400,0.000000,0.000000,0.100573,0.000000,0.000000,0.000000,0.000000,0.000000,0.00000,0.000000,0.000000,0.106205,0.0000,0.00000,0.000000,0.000000,0.00000,0.000000
297,0.00000,0.000000,0.000000,0.000000,0.101590,0.000000,0.000000,0.000000,0.000000,0.000000,0.00000,0.000000,0.102445,0.000000,0.0000,0.00000,0.000000,0.000000,0.10743,0.000000


In [125]:
# align gene index with gene names
var_df = new_data.var.reset_index()

sene_gene_names = []

for gene_idx in sen_gene_ls:
    gene_idx = gene_idx.item()
    gene_name = var_df.iloc[gene_idx]['index']
    sene_gene_names.append(gene_name)
ct2gene_score_df.index = sene_gene_names

In [126]:
ct2gene_score_df

Unnamed: 0,SPP1+ macrophages,AT2,MUC5B+ club,Monocyte-derived macrophage,CD4+ T Cells,EC Arterial,CD8+ T Cells,Dendritic cells,Alveolar macrophages,Smooth muscle cells,Mast cells,KRT5-/KRT17+ cells,Fibrotic fibroblast,EC Venous,SCGB3A2+/SCGB1A1+ club,Plasma cells,Goblet,Ciliated,Pericyte,EC General capillary
MAP2K3,0.00000,0.000275,0.000275,0.000273,0.000267,0.000320,0.000270,0.000280,0.000265,0.000270,0.00028,0.000270,0.000273,0.000280,0.0000,0.00028,0.000280,0.000270,0.00028,0.000275
SMURF2,0.00028,0.000281,0.000280,0.000279,0.000263,0.000280,0.000290,0.000273,0.000268,0.000285,0.00029,0.000277,0.000276,0.000285,0.0000,0.00000,0.000285,0.000273,0.00028,0.000288
MXD4,0.00029,0.000288,0.000275,0.000293,0.000269,0.000293,0.000000,0.000295,0.000270,0.000000,0.00028,0.000272,0.000282,0.000290,0.0000,0.00000,0.000000,0.000283,0.00029,0.000290
RAB5B,0.00031,0.000283,0.000270,0.000270,0.000278,0.000340,0.000285,0.000285,0.000267,0.000280,0.00000,0.000284,0.000279,0.000285,0.0000,0.00030,0.000290,0.000290,0.00028,0.000290
NEK6,0.00000,0.000291,0.000278,0.000305,0.000277,0.000293,0.000000,0.000290,0.000278,0.000000,0.00030,0.000284,0.000285,0.000300,0.0003,0.00000,0.000000,0.000282,0.00000,0.000295
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
PRSS51,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.00000,0.000000,0.103037,0.106740,0.0000,0.00000,0.000000,0.000000,0.00000,0.000000
CERS1,0.00000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.10181,0.000000,0.104755,0.000000,0.0000,0.00000,0.000000,0.000000,0.00000,0.000000
CELA1,0.00000,0.106400,0.000000,0.000000,0.100573,0.000000,0.000000,0.000000,0.000000,0.000000,0.00000,0.000000,0.000000,0.106205,0.0000,0.00000,0.000000,0.000000,0.00000,0.000000
LAMB4,0.00000,0.000000,0.000000,0.000000,0.101590,0.000000,0.000000,0.000000,0.000000,0.000000,0.00000,0.000000,0.102445,0.000000,0.0000,0.00000,0.000000,0.000000,0.10743,0.000000


In [127]:
ct2gene_score_df.to_csv("new_output/data1_Gene_Table1_SnG_scores_per_ct.csv")

## gene table2

In [128]:
def GeneTable2(new_data, ct2gene_score_df):
    cell_types = new_data.obs['clusters'].unique()

    df_list = []
    for ct in cell_types:
        save_ct_name = ct.replace(' ', '_')
        save_ct_name = ct.replace('/', '_')
        degs_ct_path = f"new_output/DEG_res/{save_ct_name}_DEG_results.csv"
        if os.path.exists(degs_ct_path):
            degs_ct_df = pd.read_csv(degs_ct_path)
            sub_degs_ct_df = degs_ct_df[degs_ct_df['gene'].isin(ct2gene_score_df.index)]
            sub_degs_ct_df = sub_degs_ct_df[sub_degs_ct_df['logFC'] >= 0.25] # fliter DEG genes based on logFC
            sub_degs_ct_df['cell_type'] = ct
            sub_degs_ct_df[f'SnG_score'] = sub_degs_ct_df['gene'].map(ct2gene_score_df[ct].to_dict()) # SnG scores fo this cell_type
            df_list.append(sub_degs_ct_df)

    all_df = pd.concat(df_list)
    all_df = all_df[all_df["SnG_score"] != 0.]
    return all_df
total_df = GeneTable2(new_data, ct2gene_score_df)
total_df

Unnamed: 0,gene,p_val,logFC,p_val_adj,cell_type,SnG_score
6872,TES,0.828351,0.373577,1.0,Monocyte-derived macrophage,0.000463
7342,SPIN1,0.676307,0.469145,1.0,Monocyte-derived macrophage,0.000350
7935,MORC3,0.996902,0.614606,1.0,Monocyte-derived macrophage,0.000400
8205,SPP1,0.589848,0.678412,1.0,Monocyte-derived macrophage,0.003493
8302,CDKN2AIP,0.905149,0.707207,1.0,Monocyte-derived macrophage,0.001090
...,...,...,...,...,...,...
10966,ING1,0.439142,6.949714,1.0,Pericyte,0.002080
11000,ITPKA,0.662609,7.403682,1.0,Pericyte,0.028630
11028,EWSR1,0.791477,8.464504,1.0,Pericyte,0.000340
11035,LAMB4,0.662609,8.950857,1.0,Pericyte,0.107430


In [129]:
total_df.to_csv("new_output/data1_Gene_Table2_DEG_ct_SnG_score.csv", index=False)

## Gene Table 3

In [130]:
# Group by 'gene' and aggregate cell_type into a list or join them
grouped_df = total_df.groupby('gene').agg({
    'cell_type': lambda x: ', '.join(x.unique()),  # Combine unique cell_types
    'p_val': 'mean',   # Example of how you can aggregate other columns
    'logFC': 'mean',
    'p_val_adj': 'mean',
    'SnG_score': 'mean'
}).reset_index()

In [131]:
grouped_df

Unnamed: 0,gene,cell_type,p_val,logFC,p_val_adj,SnG_score
0,AAK1,"AT2, MUC5B+ club, Mast cells, Fibrotic fibrobl...",0.641095,0.977616,1.0,0.000396
1,ABI3,"EC Venous, EC General capillary",0.861669,0.667252,1.0,0.000798
2,ACSM6,Fibrotic fibroblast,0.765960,10.113629,1.0,0.106690
3,ACVR1B,"Alveolar macrophages, EC General capillary",0.905044,2.319039,1.0,0.000571
4,ALDH1A3,KRT5-/KRT17+ cells,0.619541,2.400855,1.0,0.002215
...,...,...,...,...,...,...
237,TUBGCP2,Goblet,0.646926,2.004862,1.0,0.000595
238,TYK2,"AT2, Plasma cells, EC Venous",0.434308,1.846329,1.0,0.000309
239,VEGFC,Fibrotic fibroblast,0.865018,0.552924,1.0,0.000781
240,ZNF148,"MUC5B+ club, Ciliated, EC General capillary",0.846931,1.382244,1.0,0.000735


In [132]:
grouped_df['hallmarker'] = grouped_df['gene'].isin(inital_masker['gene_name'])

In [133]:
grouped_df

Unnamed: 0,gene,cell_type,p_val,logFC,p_val_adj,SnG_score,hallmarker
0,AAK1,"AT2, MUC5B+ club, Mast cells, Fibrotic fibrobl...",0.641095,0.977616,1.0,0.000396,True
1,ABI3,"EC Venous, EC General capillary",0.861669,0.667252,1.0,0.000798,True
2,ACSM6,Fibrotic fibroblast,0.765960,10.113629,1.0,0.106690,False
3,ACVR1B,"Alveolar macrophages, EC General capillary",0.905044,2.319039,1.0,0.000571,True
4,ALDH1A3,KRT5-/KRT17+ cells,0.619541,2.400855,1.0,0.002215,True
...,...,...,...,...,...,...,...
237,TUBGCP2,Goblet,0.646926,2.004862,1.0,0.000595,True
238,TYK2,"AT2, Plasma cells, EC Venous",0.434308,1.846329,1.0,0.000309,True
239,VEGFC,Fibrotic fibroblast,0.865018,0.552924,1.0,0.000781,True
240,ZNF148,"MUC5B+ club, Ciliated, EC General capillary",0.846931,1.382244,1.0,0.000735,True


In [134]:
grouped_df.to_csv("new_output/data1_Gene_Table3_gene_ct_count.csv", index=False)