In [28]:
import torch
import os
import scanpy as sp
import numpy as np
import pandas as pd
import torch_scatter
import utils
import argparse

parser = argparse.ArgumentParser(description='Main program for sencells')

parser.add_argument('--output_dir', type=str, default='./outputs', help='')
parser.add_argument('--exp_name', type=str, default='data2', help='')
parser.add_argument('--device_index', type=int, default=0, help='')
parser.add_argument('--retrain', action='store_true', default=False, help='')
parser.add_argument('--timestamp', type=str,  default="", help='use default')

parser.add_argument('--seed', type=int, default=40, help='different seed for different experiments')
parser.add_argument('--n_genes', type=str, default='full', help='set 3000, 8000 or full')
parser.add_argument('--ccc', type=str, default='type2', help='type1: cell-cell edge with weight in 0 and 1. type2: cell-cell edge with weight in 0 to 1. type3: no cell-cell edge')
parser.add_argument('--gene_set', type=str, default='full', help='senmayo or fridman or cellage or goterm or goterm+fridman or senmayo+cellage or senmayo+fridman or senmayo+fridman+cellage or full')

parser.add_argument('--gat_epoch', type=int, default=30, help='use default')


# --------------------------------------------------------------------------------------------------- #
# Write these code to fit our data input. This is for our @Yi, @Ahmed, and @Hu.
parser.add_argument('--input_data_count', type=str, default="/bmbl_data/huchen/deepSAS_data/fixed_data_0525.h5ad", help='it is a path to a adata object (.h5ad)')
parser.add_argument('--input_data_CCC_file', type=str, default="", help='it is a path to a CCC file (.csv or .npy)')
# --------------------------------------------------------------------------------------------------- #


# --------------------------------------------------------------------------------------------------- #
# Subsampling argument for our following version. Please check this @Yi, @Ahmed, and @Hu.
parser.add_argument('--subsampling', action='store_true', default=False, help='subsampling')
# --------------------------------------------------------------------------------------------------- #


# For @Hao
# Hao: Just delete these 3 parameters.
# --------------------------------------------------------------------------------------------------- #
# is this sencell_num parameter not used? Please check this @Hao.
parser.add_argument('--sencell_num', type=int, default=600, help='use default')
# is this sengene_num parameter not used? Please check this @Hao.
parser.add_argument('--sengene_num', type=int, default=200, help='use default')
# is this sencell_epoch parameter not used? Please check this @Hao.
parser.add_argument('--sencell_epoch', type=int, default=40, help='use default')
# --------------------------------------------------------------------------------------------------- #




parser.add_argument('--cell_optim_epoch', type=int, default=50, help='use default')


# For @Hao
# Hao: This is the emb size for GAT hidden embedding size
# --------------------------------------------------------------------------------------------------- #
# is this emb_size parameter for what, for GAT? is default 12? Please check this @Hao.
parser.add_argument('--emb_size', type=int, default=12, help='use default')
# --------------------------------------------------------------------------------------------------- #



parser.add_argument('--batch_id', type=int, default=0, help='use default')

args = parser.parse_args([])

adata, cluster_cell_ls, cell_cluster_arr, celltype_names = utils.load_data_newfix("/bmbl_data/huchen/deepSAS_data/fixed_data_0525.h5ad")

new_data, markers_index,\
    sen_gene_ls, nonsen_gene_ls, gene_names = utils.process_data(
        adata, cluster_cell_ls, cell_cluster_arr,args)
print(new_data)

gene_cell = new_data.X.toarray().T

load_data_newfix ...
Number of cells: 24125
Number of genes: 15844
The number of cell types 32
celltype names: ['Fibrotic fibroblast', 'AT2', 'EC Venous', 'EC Arterial', 'Dendritic cells', 'CD4+ T Cells', 'KRT5-/KRT17+ cells', 'EC General capillary', 'Monocyte-derived macrophage', 'CD8+ T Cells', 'Smooth muscle cells', 'Plasma cells', 'Ciliated', 'SPP1+ macrophages', 'MUC5B+ club', 'Alveolar fibroblasts', 'Alveolar macrophages', 'SCGB3A2+/SCGB1A1+ club', 'Monocyte', 'Mast cells', 'NK cells', 'Pericyte', 'Goblet', 'AT1', 'EC Lymphatic', 'B cells', 'Inflammatory fibroblasts', 'AT2 transitional', 'Peribronchial fibroblasts', 'Adventitial fibroblasts', 'Proliferating fibroblasts', 'Basal']
---------------------------  ----
Fibrotic fibroblast          2433
AT2                          2195
EC Venous                    2172
EC Arterial                  1847
Dendritic cells              1425
CD4+ T Cells                 1211
KRT5-/KRT17+ cells           1145
EC General capillary          950

In [29]:
file_path = "/bmbl_data/huchen/sencell_data1/outputs/data1/data1_sencellgene-epoch4.data"

# Load the saved object
loaded_data = torch.load(file_path)

# Unpack the loaded data
sencell_dict, sen_gene_ls, attention_scores, edge_index_selfloop = loaded_data

print(f"Number of SnCs: {len(sencell_dict.keys())}")
print(f"Number of SnGs: {len(sen_gene_ls)}")

Number of SnCs: 777
Number of SnGs: 299


In [30]:
file_path_repro = "/bmbl_data/huchen/sencell_data1/outputs/data1/data1_sencellgene-epoch4_repro.data"

# Load the saved object
loaded_data = torch.load(file_path_repro)

# Unpack the loaded data
sencell_dict_repro, sen_gene_ls_repro, attention_scores_repro, edge_index_selfloop_repro = loaded_data

print(f"Number of SnCs: {len(sencell_dict_repro.keys())}")
print(f"Number of SnGs: {len(sen_gene_ls_repro)}")

Number of SnCs: 746
Number of SnGs: 299


In [31]:
repeat_cells = len(set(sencell_dict.keys()) & set(sencell_dict_repro.keys()))
repeat_genes = len(set(sen_gene_ls_repro.tolist()) & set(sen_gene_ls.tolist()))
print(f"Overlap cells: {repeat_cells}, Overlap genes: {repeat_genes}")

Overlap cells: 702, Overlap genes: 267


In [32]:
sencell_indexs=list(sencell_dict.keys())
sencell_cluster = []
for i in sencell_indexs:
    ct=new_data.obs.iloc[i-new_data.shape[1]].clusters
    sencell_cluster.append(ct)
sencell_df = pd.DataFrame({'sencell_index': sencell_indexs, 'sencell_cluster': sencell_cluster})
sencell_df.head()

Unnamed: 0,sencell_index,sencell_cluster
0,15851,SPP1+ macrophages
1,15853,SPP1+ macrophages
2,15870,AT2
3,15898,SPP1+ macrophages
4,15900,AT2


In [33]:
def AttentionEachCell(gene_cell, sen_gene_ls, edge_index_selfloop):
    """
    Calculate SnCs score based on senescent gene list
    Same as in the main4.py
    """
    # Convert the list of senescent genes to a tensor index
    # gene_index = torch.tensor(sen_gene_ls, dtype=torch.long)
    gene_index = sen_gene_ls
    
    # Create a boolean mask for senescent genes
    total_nodes = gene_cell.shape[0] + gene_cell.shape[1]
    gene_mask = torch.zeros(total_nodes, dtype=torch.bool)
    gene_mask[gene_index] = True

    # Identify edges where the target is a cell node
    cell_offset = gene_cell.shape[0]
    edge_mask_cell = edge_index_selfloop[1] >= cell_offset
    
    # Identify edges where the source is a senescent gene
    edge_mask_gene = gene_mask[edge_index_selfloop[0]]
    
    # Combined mask for edges from senescent genes to cells
    edge_mask_selected = edge_mask_cell & edge_mask_gene
    
    # Indices of selected edges
    edges_selected_indices = edge_mask_selected.nonzero().squeeze()

    # Target cell indices and attention scores for selected edges
    selected_edges_targets = edge_index_selfloop[1][edges_selected_indices]
    selected_attention_scores = attention_scores[edges_selected_indices].squeeze()
    
    # Adjust cell indices to start from 0
    cell_indices_in_range = selected_edges_targets - cell_offset
    
    # Number of cells
    num_cells = gene_cell.shape[1]
    
    # Sum and count attention scores per cell
    attention_sums = torch_scatter.scatter(selected_attention_scores, cell_indices_in_range,
                                           dim_size=num_cells, reduce='sum')
    attention_counts = torch_scatter.scatter(torch.ones_like(selected_attention_scores),
                                             cell_indices_in_range, dim_size=num_cells, reduce='sum')
    
    # Compute mean attention score per cell, handle division by zero
    attention_s_per_cell = torch.zeros(num_cells)
    valid_cells = attention_counts > 0
    # NOTE: Mean operation for each attention edge
    attention_s_per_cell[valid_cells] = attention_sums[valid_cells] / attention_counts[valid_cells]

    return attention_s_per_cell.detach().tolist()

attention_s_per_cell = AttentionEachCell(gene_cell, sen_gene_ls, edge_index_selfloop)

In [34]:
if_senCs = np.zeros(new_data.shape[0],dtype=np.int64)
if_senCs[[i - new_data.shape[1] for i in sencell_indexs]] = 1
list(if_senCs)

[0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,


In [35]:
SnC_scores = pd.DataFrame({'cell_id':list(range(new_data.shape[0])), 'cell_name': new_data.obs_names, 'cell_type':new_data.obs['clusters'], 'ifSnCs': list(if_senCs), 'SnC scores': attention_s_per_cell})

In [36]:
SnC_scores

Unnamed: 0,cell_id,cell_name,cell_type,ifSnCs,SnC scores
WL_20146_AAACAAGCAGTAATACATACGTCA-1,0,WL_20146_AAACAAGCAGTAATACATACGTCA-1,Monocyte-derived macrophage,0,0.001014
WL_20146_AAACCAGGTTAGGCGGATACGTCA-1,1,WL_20146_AAACCAGGTTAGGCGGATACGTCA-1,AT2,0,0.001170
WL_20146_AAACGGGCAAGCCACGATACGTCA-1,2,WL_20146_AAACGGGCAAGCCACGATACGTCA-1,SPP1+ macrophages,0,0.000972
WL_20146_AAACGTTCAACAAGTTATACGTCA-1,3,WL_20146_AAACGTTCAACAAGTTATACGTCA-1,AT2,0,0.001495
WL_20146_AAACTGTCAGAAACTTATACGTCA-1,4,WL_20146_AAACTGTCAGAAACTTATACGTCA-1,SPP1+ macrophages,0,0.000355
...,...,...,...,...,...
OSU10161_LL_TTTGAGAAGCAAGTTAATGTTGAC-1,24120,OSU10161_LL_TTTGAGAAGCAAGTTAATGTTGAC-1,Goblet,0,0.001106
OSU10161_LL_TTTGCGGGTCCGCTAAATGTTGAC-1,24121,OSU10161_LL_TTTGCGGGTCCGCTAAATGTTGAC-1,Smooth muscle cells,0,0.001995
OSU10161_LL_TTTGCGGGTGGTTCTGATGTTGAC-1,24122,OSU10161_LL_TTTGCGGGTGGTTCTGATGTTGAC-1,SCGB3A2+/SCGB1A1+ club,0,0.001412
OSU10161_LL_TTTGCTGAGTCTTGCAATGTTGAC-1,24123,OSU10161_LL_TTTGCTGAGTCTTGCAATGTTGAC-1,SCGB3A2+/SCGB1A1+ club,0,0.001913


In [37]:
SnC_scores.to_csv("new_output/data1_Cell_Table1_SnC_scores.csv", index=False)

In [38]:
cluster_count = pd.DataFrame(new_data.obs['clusters'].value_counts())
sencell_cluster_count = pd.DataFrame(sencell_df['sencell_cluster'].value_counts())
sencell_cluster_count = sencell_cluster_count['sencell_cluster']
merged_df = cluster_count.join(sencell_cluster_count).fillna(0)
merged_df = merged_df.rename(columns={'clusters': 'number_of_cells', 'sencell_cluster': "number_of_SnCs"})
merged_df = merged_df.astype(int)
merged_df.index.name = "cell_type"

In [39]:
merged_df.to_csv("new_output/data1_Cell_Table2_SnCs_per_ct.csv")

In [12]:
ct_sencell_indexs={}
row_numbers=np.array(sencell_indexs)-new_data.shape[1]

# Dict {ct: cell_index}
for i in row_numbers:
    ct_=new_data.obs.iloc[i]['clusters']
    if ct_ in ct_sencell_indexs:
        ct_sencell_indexs[ct_].append(i+new_data.shape[1])
    else:
        ct_sencell_indexs[ct_]=[i+new_data.shape[1]]

In [13]:
ct_sencell_indexs.keys()

dict_keys(['SPP1+ macrophages', 'AT2', 'MUC5B+ club', 'Monocyte-derived macrophage', 'CD4+ T Cells', 'EC Arterial', 'Dendritic cells', 'Fibrotic fibroblast', 'Alveolar macrophages', 'Smooth muscle cells', 'Mast cells', 'KRT5-/KRT17+ cells', 'EC Venous', 'SCGB3A2+/SCGB1A1+ club', 'Plasma cells', 'Goblet', 'Ciliated', 'Pericyte', 'EC General capillary'])

In [14]:
def AttentionEachGene(gene_cell, cell_indices, edge_index_selfloop):
    """
    gene attention score from edges connected with cell_indices,
    the same as main4.py, but separate for each cell type.
    """
    num_genes = gene_cell.shape[0]
    num_cells = gene_cell.shape[1]
    total_nodes = num_genes + num_cells

    # Create a mask for the selected cells
    cell_mask = torch.zeros(total_nodes, dtype=torch.bool)
    cell_mask[cell_indices] = True

    # Create a mask for edges where the target node is a gene
    edge_mask_gene = edge_index_selfloop[1] < num_genes
    
    edge_mask_cell = cell_mask[edge_index_selfloop[0]]

    edge_mask_selected = edge_mask_gene & edge_mask_cell

    edges_selected_indices = edge_mask_selected.nonzero().squeeze()

    # Target cell indices and attention scores for selected edges
    selected_edges_targets = edge_index_selfloop[1][edges_selected_indices]
    selected_attention_scores = attention_scores[edges_selected_indices].squeeze()

    # Compute per-gene sums and counts using torch_scatter
    # Counts: Number of times each gene appears in masked_genes
    counts = torch_scatter.scatter(torch.ones_like(selected_attention_scores), selected_edges_targets,
                     dim=0, dim_size=num_genes, reduce='sum')
    # Sums: Sum of attention scores per gene
    sums = torch_scatter.scatter(selected_attention_scores, selected_edges_targets,
                   dim=0, dim_size=num_genes, reduce='sum')
      
    # Avoid division by zero
    res = torch.zeros(num_genes, dtype=torch.float32)
    nonzero_mask = counts > 0
    res[nonzero_mask] = sums[nonzero_mask] / counts[nonzero_mask] # Mean

    return res.detach()

ct2gene_score = {}
# SnGs for each cell type
for ct, cell_indices in ct_sencell_indexs.items():
    score_per_gene = AttentionEachGene(gene_cell, cell_indices, edge_index_selfloop)
    sene_gene_score = score_per_gene[sen_gene_ls].tolist()
    ct2gene_score[ct] = sene_gene_score

In [15]:
ct2gene_score_df = pd.DataFrame(ct2gene_score)
ct2gene_score_df

Unnamed: 0,SPP1+ macrophages,AT2,MUC5B+ club,Monocyte-derived macrophage,CD4+ T Cells,EC Arterial,Dendritic cells,Fibrotic fibroblast,Alveolar macrophages,Smooth muscle cells,Mast cells,KRT5-/KRT17+ cells,EC Venous,SCGB3A2+/SCGB1A1+ club,Plasma cells,Goblet,Ciliated,Pericyte,EC General capillary
0,0.000228,0.000231,0.000226,0.000233,0.000220,0.000231,0.000233,0.000225,0.000000,0.000239,0.000228,0.000226,0.000232,0.000238,0.000227,0.000234,0.000242,0.000228,0.000230
1,0.000000,0.000233,0.000225,0.000237,0.000220,0.000238,0.000231,0.000227,0.000225,0.000227,0.000236,0.000227,0.000235,0.000236,0.000000,0.000000,0.000228,0.000226,0.000230
2,0.000239,0.000000,0.000000,0.000238,0.000230,0.000238,0.000238,0.000235,0.000234,0.000000,0.000000,0.000238,0.000240,0.000246,0.000000,0.000000,0.000000,0.000238,0.000000
3,0.000000,0.000242,0.000240,0.000249,0.000234,0.000248,0.000243,0.000240,0.000239,0.000243,0.000241,0.000240,0.000245,0.000000,0.000000,0.000253,0.000251,0.000241,0.000246
4,0.000244,0.000245,0.000243,0.000242,0.000229,0.000244,0.000241,0.000241,0.000239,0.000247,0.000245,0.000240,0.000245,0.000000,0.000000,0.000253,0.000248,0.000247,0.000248
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
294,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.101894,0.000000,0.000000,0.097682,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
295,0.000000,0.000000,0.100395,0.000000,0.000000,0.000000,0.000000,0.099176,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.100602,0.000000,0.000000
296,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.099513,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
297,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.099162,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000


In [33]:
# align gene index with gene names
var_df = new_data.var.reset_index()

sene_gene_names = []

for gene_idx in sen_gene_ls:
    gene_idx = gene_idx.item()
    gene_name = var_df.iloc[gene_idx]['index']
    sene_gene_names.append(gene_name)
ct2gene_score_df.index = sene_gene_names

In [34]:
ct2gene_score_df

Unnamed: 0,SPP1+ macrophages,AT2,MUC5B+ club,Monocyte-derived macrophage,CD4+ T Cells,EC Arterial,Dendritic cells,Fibrotic fibroblast,Alveolar macrophages,Smooth muscle cells,Mast cells,KRT5-/KRT17+ cells,EC Venous,SCGB3A2+/SCGB1A1+ club,Plasma cells,Goblet,Ciliated,Pericyte,EC General capillary
OPTN,0.000228,0.000231,0.000226,0.000233,0.000220,0.000231,0.000233,0.000225,0.000000,0.000239,0.000228,0.000226,0.000232,0.000238,0.000227,0.000234,0.000242,0.000228,0.000230
NEK6,0.000000,0.000233,0.000225,0.000237,0.000220,0.000238,0.000231,0.000227,0.000225,0.000227,0.000236,0.000227,0.000235,0.000236,0.000000,0.000000,0.000228,0.000226,0.000230
TNFRSF1B,0.000239,0.000000,0.000000,0.000238,0.000230,0.000238,0.000238,0.000235,0.000234,0.000000,0.000000,0.000238,0.000240,0.000246,0.000000,0.000000,0.000000,0.000238,0.000000
RGL2,0.000000,0.000242,0.000240,0.000249,0.000234,0.000248,0.000243,0.000240,0.000239,0.000243,0.000241,0.000240,0.000245,0.000000,0.000000,0.000253,0.000251,0.000241,0.000246
SMURF2,0.000244,0.000245,0.000243,0.000242,0.000229,0.000244,0.000241,0.000241,0.000239,0.000247,0.000245,0.000240,0.000245,0.000000,0.000000,0.000253,0.000248,0.000247,0.000248
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CERS1,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.101894,0.000000,0.000000,0.097682,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
ALX4,0.000000,0.000000,0.100395,0.000000,0.000000,0.000000,0.000000,0.099176,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.100602,0.000000,0.000000
BRSK2,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.099513,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
SCYGR8,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.099162,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000


In [35]:
ct2gene_score_df.to_csv("new_output/data1_Gene_Table1_SnG_scores_per_ct.csv")