In [1]:
import torch
import os
import scanpy as sp
import numpy as np
import pandas as pd
import torch_scatter
import utils
import argparse

parser = argparse.ArgumentParser(description='Main program for sencells')

parser.add_argument('--output_dir', type=str, default='./outputs', help='')
parser.add_argument('--exp_name', type=str, default='data2', help='')
parser.add_argument('--device_index', type=int, default=0, help='')
parser.add_argument('--retrain', action='store_true', default=False, help='')
parser.add_argument('--timestamp', type=str,  default="", help='use default')

parser.add_argument('--seed', type=int, default=40, help='different seed for different experiments')
parser.add_argument('--n_genes', type=str, default='full', help='set 3000, 8000 or full')
parser.add_argument('--ccc', type=str, default='type2', help='type1: cell-cell edge with weight in 0 and 1. type2: cell-cell edge with weight in 0 to 1. type3: no cell-cell edge')
parser.add_argument('--gene_set', type=str, default='full', help='senmayo or fridman or cellage or goterm or goterm+fridman or senmayo+cellage or senmayo+fridman or senmayo+fridman+cellage or full')

parser.add_argument('--gat_epoch', type=int, default=30, help='use default')


# --------------------------------------------------------------------------------------------------- #
# Write these code to fit our data input. This is for our @Yi, @Ahmed, and @Hu.
parser.add_argument('--input_data_count', type=str, default="/bmbl_data/huchen/deepSAS_data/deepSAS_data2.h5ad", help='it is a path to a adata object (.h5ad)')
parser.add_argument('--input_data_CCC_file', type=str, default="", help='it is a path to a CCC file (.csv or .npy)')
# --------------------------------------------------------------------------------------------------- #


# --------------------------------------------------------------------------------------------------- #
# Subsampling argument for our following version. Please check this @Yi, @Ahmed, and @Hu.
parser.add_argument('--subsampling', action='store_true', default=False, help='subsampling')
# --------------------------------------------------------------------------------------------------- #


# For @Hao
# Hao: Just delete these 3 parameters.
# --------------------------------------------------------------------------------------------------- #
# is this sencell_num parameter not used? Please check this @Hao.
parser.add_argument('--sencell_num', type=int, default=600, help='use default')
# is this sengene_num parameter not used? Please check this @Hao.
parser.add_argument('--sengene_num', type=int, default=200, help='use default')
# is this sencell_epoch parameter not used? Please check this @Hao.
parser.add_argument('--sencell_epoch', type=int, default=40, help='use default')
# --------------------------------------------------------------------------------------------------- #




parser.add_argument('--cell_optim_epoch', type=int, default=50, help='use default')


# For @Hao
# Hao: This is the emb size for GAT hidden embedding size
# --------------------------------------------------------------------------------------------------- #
# is this emb_size parameter for what, for GAT? is default 12? Please check this @Hao.
parser.add_argument('--emb_size', type=int, default=12, help='use default')
# --------------------------------------------------------------------------------------------------- #



parser.add_argument('--batch_id', type=int, default=0, help='use default')

args = parser.parse_args([])

adata, cluster_cell_ls, cell_cluster_arr, celltype_names = utils.load_data2(args.input_data_count)

new_data, markers_index,\
    sen_gene_ls, nonsen_gene_ls, gene_names = utils.process_data(
        adata, cluster_cell_ls, cell_cluster_arr,args)
print(new_data)

gene_cell = new_data.X.toarray().T

  from .autonotebook import tqdm as notebook_tqdm


load_data2 ...
Number of cells: 15246
Number of genes: 15822
The number of cell types 49
celltype names: ['Alveolar fibroblasts', 'AT2', 'CD4 T cells', 'Monocyte-derived Mph', 'EC general capillary', 'CD8 T cells', 'Classical monocytes', 'Alveolar macrophages', 'Pericytes', 'Plasma cells', 'Peribronchial fibroblasts', 'Interstitial Mph perivascular', 'EC aerocyte capillary', 'AT1', 'NK cells', 'Adventitial fibroblasts', 'Smooth muscle', 'EC arterial', 'Multiciliated (non-nasal)', 'Myofibroblasts', 'Mast cells', 'EC venous pulmonary', 'Non-classical monocytes', 'B cells', 'Basal resting', 'pre-TB secretory', 'Goblet (nasal)', 'Lymphatic EC differentiating', 'Plasmacytoid DCs', 'EC venous systemic', 'Suprabasal', 'AT2 proliferating', 'Lymphatic EC mature', 'AT0', 'Club (non-nasal)', 'Goblet (subsegmental)', 'T cells proliferating', 'Migratory DCs', 'DC2', 'SM activated stress response', 'Club (nasal)', 'Smooth muscle FAM83D+', 'Alveolar Mph proliferating', 'Hillock-like', 'Deuterosomal',

In [3]:
file_path = "/bmbl_data/huchen/sencell_data1/outputs/data2/data2_sencellgene-epoch4_repro.data"

# Load the saved object
loaded_data = torch.load(file_path)

# Unpack the loaded data
sencell_dict, sen_gene_ls, attention_scores, edge_index_selfloop = loaded_data

print(f"Number of SnCs: {len(sencell_dict.keys())}")
print(f"Number of SnGs: {len(sen_gene_ls)}")

Number of SnCs: 632
Number of SnGs: 297


In [8]:
file_path_repro = "/bmbl_data/huchen/sencell_data1/outputs/data2/data2_sencellgene-epoch4_repro.data"

# Load the saved object
loaded_data = torch.load(file_path_repro)

# Unpack the loaded data
sencell_dict_repro, sen_gene_ls_repro, attention_scores_repro, edge_index_selfloop_repro = loaded_data

print(f"Number of SnCs: {len(sencell_dict_repro.keys())}")
print(f"Number of SnGs: {len(sen_gene_ls_repro)}")

Number of SnCs: 746
Number of SnGs: 299


In [17]:
repeat_cells = len(set(sencell_dict.keys()) & set(sencell_dict_repro.keys()))
repeat_genes = len(set(sen_gene_ls_repro.tolist()) & set(sen_gene_ls.tolist()))
print(f"Overlap cells: {repeat_cells}, Overlap genes: {repeat_genes}")

Overlap cells: 702, Overlap genes: 267


In [6]:
sencell_indexs=list(sencell_dict.keys())
sencell_cluster = []
for i in sencell_indexs:
    ct=new_data.obs.iloc[i-new_data.shape[1]].ct
    sencell_cluster.append(ct)
sencell_df = pd.DataFrame({'sencell_index': sencell_indexs, 'sencell_cluster': sencell_cluster})
sencell_df.head()

Unnamed: 0,sencell_index,sencell_cluster
0,15835,Plasma cells
1,15858,Plasma cells
2,15901,Alveolar fibroblasts
3,15917,AT1
4,15920,B cells


In [7]:
def AttentionEachCell(gene_cell, sen_gene_ls, edge_index_selfloop):
    """
    Calculate SnCs score based on senescent gene list
    Same as in the main4.py
    """
    # Convert the list of senescent genes to a tensor index
    # gene_index = torch.tensor(sen_gene_ls, dtype=torch.long)
    gene_index = sen_gene_ls
    
    # Create a boolean mask for senescent genes``
    total_nodes = gene_cell.shape[0] + gene_cell.shape[1]
    gene_mask = torch.zeros(total_nodes, dtype=torch.bool)
    gene_mask[gene_index] = True

    # Identify edges where the target is a cell node
    cell_offset = gene_cell.shape[0]
    edge_mask_cell = edge_index_selfloop[1] >= cell_offset
    
    # Identify edges where the source is a senescent gene
    edge_mask_gene = gene_mask[edge_index_selfloop[0]]
    
    # Combined mask for edges from senescent genes to cells
    edge_mask_selected = edge_mask_cell & edge_mask_gene
    
    # Indices of selected edges
    edges_selected_indices = edge_mask_selected.nonzero().squeeze()

    # Target cell indices and attention scores for selected edges
    selected_edges_targets = edge_index_selfloop[1][edges_selected_indices]
    selected_attention_scores = attention_scores[edges_selected_indices].squeeze()
    
    # Adjust cell indices to start from 0
    cell_indices_in_range = selected_edges_targets - cell_offset
    
    # Number of cells
    num_cells = gene_cell.shape[1]
    
    # Sum and count attention scores per cell
    attention_sums = torch_scatter.scatter(selected_attention_scores, cell_indices_in_range,
                                           dim_size=num_cells, reduce='sum')
    attention_counts = torch_scatter.scatter(torch.ones_like(selected_attention_scores),
                                             cell_indices_in_range, dim_size=num_cells, reduce='sum')
    
    # Compute mean attention score per cell, handle division by zero
    attention_s_per_cell = torch.zeros(num_cells)
    valid_cells = attention_counts > 0
    # NOTE: Mean operation for each attention edge
    attention_s_per_cell[valid_cells] = attention_sums[valid_cells] / attention_counts[valid_cells]

    return attention_s_per_cell.detach().tolist()

attention_s_per_cell = AttentionEachCell(gene_cell, sen_gene_ls, edge_index_selfloop)

In [8]:
if_senCs = np.zeros(new_data.shape[0],dtype=np.int64)
if_senCs[[i - new_data.shape[1] for i in sencell_indexs]] = 1
list(if_senCs)

[0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 1,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,


In [9]:
SnC_scores = pd.DataFrame({'cell_id':list(range(new_data.shape[0])), 'cell_name': new_data.obs_names, 'ifSnCs': list(if_senCs), 'SnC scores': attention_s_per_cell})

In [10]:
SnC_scores

Unnamed: 0,cell_id,cell_name,ifSnCs,SnC scores
0,0,AAACAAGCAGCAATTGAAGTAGAG-1-OSU20126_Healthy_Old,0,0.002451
1,1,AAACCAATCATGCGTCAAGTAGAG-1-OSU20126_Healthy_Old,0,0.001460
2,2,AAACCAGGTAAAGCATAAGTAGAG-1-OSU20126_Healthy_Old,0,0.001370
3,3,AAACCAGGTATTACCAAAGTAGAG-1-OSU20126_Healthy_Old,0,0.001873
4,4,AAACCAGGTCAATTCAAAGTAGAG-1-OSU20126_Healthy_Old,0,0.001073
...,...,...,...,...
15241,15241,TTTGGACGTTTGACTAATCATGTG-1-OSU10172_IPF_LL,0,0.000257
15242,15242,TTTGGCGGTGTTTGCGATCATGTG-1-OSU10172_IPF_LL,0,0.000144
15243,15243,TTTGGCGGTTTGCTCCATCATGTG-1-OSU10172_IPF_LL,0,0.002481
15244,15244,TTTGTGAGTACAAAGTATCATGTG-1-OSU10172_IPF_LL,0,0.000442


In [11]:
SnC_scores.to_csv("new_output/data2_Cell_Table1_SnC_scores.csv", index=False)

In [12]:
sencell_df.head()

Unnamed: 0,sencell_index,sencell_cluster
0,15835,Plasma cells
1,15858,Plasma cells
2,15901,Alveolar fibroblasts
3,15917,AT1
4,15920,B cells


In [13]:
cluster_count = pd.DataFrame(new_data.obs['ct'].value_counts())
sencell_cluster_count = pd.DataFrame(sencell_df['sencell_cluster'].value_counts())
sencell_cluster_count = sencell_cluster_count['sencell_cluster']
merged_df = cluster_count.join(sencell_cluster_count).fillna(0)
merged_df = merged_df.rename(columns={'ct': 'Number of cells', 'sencell_cluster': "Number of SnCs"})
merged_df = merged_df.astype(int)
merged_df.index.name = "Cell type"

In [14]:
merged_df.to_csv("new_output/data2_Cell_Table2_SnCs_per_ct.csv")

In [15]:
ct_sencell_indexs={}
row_numbers=np.array(sencell_indexs)-new_data.shape[1]

# Dict {ct: cell_index}
for i in row_numbers:
    ct_=new_data.obs.iloc[i]['ct']
    if ct_ in ct_sencell_indexs:
        ct_sencell_indexs[ct_].append(i+new_data.shape[1])
    else:
        ct_sencell_indexs[ct_]=[i+new_data.shape[1]]

In [16]:
ct_sencell_indexs.keys()

dict_keys(['Plasma cells', 'Alveolar fibroblasts', 'AT1', 'B cells', 'Monocyte-derived Mph', 'CD8 T cells', 'Pericytes', 'CD4 T cells', 'EC general capillary', 'Classical monocytes', 'AT2', 'Alveolar macrophages', 'Peribronchial fibroblasts', 'EC aerocyte capillary', 'Mast cells', 'Myofibroblasts'])

In [17]:
def AttentionEachGene(gene_cell, cell_indices, edge_index_selfloop):
    """
    gene attention score from edges connected with cell_indices,
    the same as main4.py, but separate for each cell type.
    """
    num_genes = gene_cell.shape[0]
    num_cells = gene_cell.shape[1]
    total_nodes = num_genes + num_cells

    # Create a mask for the selected cells
    cell_mask = torch.zeros(total_nodes, dtype=torch.bool)
    cell_mask[cell_indices] = True

    # Create a mask for edges where the target node is a gene
    edge_mask_gene = edge_index_selfloop[1] < num_genes
    
    edge_mask_cell = cell_mask[edge_index_selfloop[0]]

    edge_mask_selected = edge_mask_gene & edge_mask_cell

    edges_selected_indices = edge_mask_selected.nonzero().squeeze()

    # Target cell indices and attention scores for selected edges
    selected_edges_targets = edge_index_selfloop[1][edges_selected_indices]
    selected_attention_scores = attention_scores[edges_selected_indices].squeeze()

    # Compute per-gene sums and counts using torch_scatter
    # Counts: Number of times each gene appears in masked_genes
    counts = torch_scatter.scatter(torch.ones_like(selected_attention_scores), selected_edges_targets,
                     dim=0, dim_size=num_genes, reduce='sum')
    # Sums: Sum of attention scores per gene
    sums = torch_scatter.scatter(selected_attention_scores, selected_edges_targets,
                   dim=0, dim_size=num_genes, reduce='sum')
      
    # Avoid division by zero
    res = torch.zeros(num_genes, dtype=torch.float32)
    nonzero_mask = counts > 0
    res[nonzero_mask] = sums[nonzero_mask] / counts[nonzero_mask] # Mean

    return res.detach()

ct2gene_score = {}
# SnGs for each cell type
for ct, cell_indices in ct_sencell_indexs.items():
    score_per_gene = AttentionEachGene(gene_cell, cell_indices, edge_index_selfloop)
    sene_gene_score = score_per_gene[sen_gene_ls].tolist()
    ct2gene_score[ct] = sene_gene_score

In [18]:
ct2gene_score_df = pd.DataFrame(ct2gene_score)
ct2gene_score_df

Unnamed: 0,Plasma cells,Alveolar fibroblasts,AT1,B cells,Monocyte-derived Mph,CD8 T cells,Pericytes,CD4 T cells,EC general capillary,Classical monocytes,AT2,Alveolar macrophages,Peribronchial fibroblasts,EC aerocyte capillary,Mast cells,Myofibroblasts
0,0.000014,0.000014,0.000014,0.000000,0.000014,0.000014,0.000014,0.000014,0.000000,0.000014,0.000014,0.000014,0.000014,0.000014,0.000000,0.000014
1,0.000000,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014
2,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000000,0.000000
3,0.000014,0.000014,0.000014,0.000000,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000000,0.000014
4,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000000,0.000014,0.000014,0.000014,0.000014,0.000000,0.000014,0.000014
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
292,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000016,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
293,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000016,0.000000,0.000000,0.000000,0.000000
294,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000016,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
295,0.000000,0.000000,0.000000,0.000000,0.000016,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000


In [19]:
# align gene index with gene names
var_df = new_data.var.reset_index()

sene_gene_names = []

for gene_idx in sen_gene_ls:
    gene_idx = gene_idx.item()
    gene_name = var_df.iloc[gene_idx]['index']
    sene_gene_names.append(gene_name)
ct2gene_score_df.index = sene_gene_names

In [20]:
ct2gene_score_df

Unnamed: 0,Plasma cells,Alveolar fibroblasts,AT1,B cells,Monocyte-derived Mph,CD8 T cells,Pericytes,CD4 T cells,EC general capillary,Classical monocytes,AT2,Alveolar macrophages,Peribronchial fibroblasts,EC aerocyte capillary,Mast cells,Myofibroblasts
TIMP2,0.000014,0.000014,0.000014,0.000000,0.000014,0.000014,0.000014,0.000014,0.000000,0.000014,0.000014,0.000014,0.000014,0.000014,0.000000,0.000014
IFI16,0.000000,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014
PECAM1,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000000,0.000000
MMP2,0.000014,0.000014,0.000014,0.000000,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000000,0.000014
THBS1,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000014,0.000000,0.000014,0.000014,0.000014,0.000014,0.000000,0.000014,0.000014
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
MGAM,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000016,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
CDC20B,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000016,0.000000,0.000000,0.000000,0.000000
CYP4F3,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000016,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
KIAA1257,0.000000,0.000000,0.000000,0.000000,0.000016,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000


In [21]:
ct2gene_score_df.to_csv("new_output/data2_Gene_Table1_SnG_scores_per_ct.csv")