In [12]:
results=[]

In [14]:
import numpy as np
import seaborn as sns
import torch

import umap
import matplotlib.pyplot as plt
import pandas as pd
from torch_geometric.utils import k_hop_subgraph, to_networkx, from_networkx
import matplotlib

import utils
import plots
from model_AE import reduction_AE
from model_GAT import Encoder, SenGAE, train_GAT, train_GAT_new
from model_Sencell import Sencell
from sampling import sub_sampling_by_random
from model_Sencell import cell_optim, update_cell_embeddings
from sampling import identify_sengene_then_sencell

import logging
import os
import argparse
import random


is_jupyter = True

parser = argparse.ArgumentParser(description='Main program for sencells')

parser.add_argument('--output_dir', type=str, default='./outputs', help='')
parser.add_argument('--exp_name', type=str, default='', help='')
parser.add_argument('--device_index', type=int, default=0, help='')
parser.add_argument('--retrain', action='store_true', default=False, help='')

parser.add_argument('--gat_epoch', type=int, default=30, help='')
parser.add_argument('--sencell_num', type=int, default=100, help='')
parser.add_argument('--cell_optim_epoch', type=int, default=15, help='')

for exp_name,output_dir in zip(['h1_7','h2_8','healthy_9'],['./outputs/7/','./outputs/8/','./outputs/9/']):

    if is_jupyter:
        args = parser.parse_args(args=[])
        args.exp_name = exp_name
        args.output_dir=output_dir
        args.retrain = False
    else:
        args = parser.parse_args()
    args.is_jupyter = is_jupyter


    # set random seed
    seed=42
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    logging.basicConfig(format='%(asctime)s.%(msecs)03d [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
                        datefmt='# %Y-%m-%d %H:%M:%S')

    logging.getLogger().setLevel(logging.INFO)
    logger = logging.getLogger()

    # Part 1: load and process data
    # cell_cluster_arr在画umap的时候用
    print("\n====== Part 1: load and process data ======")
    if 's5' in args.exp_name:
        adata, cluster_cell_ls, cell_cluster_arr, celltype_names = utils.load_data()
    elif 'h1' in args.exp_name:
        adata, cluster_cell_ls, cell_cluster_arr, celltype_names = utils.load_data_healthy1()
    elif 'h2' in args.exp_name:
        adata, cluster_cell_ls, cell_cluster_arr, celltype_names = utils.load_data_healthy2()
    elif 'd1' in args.exp_name:
        adata, cluster_cell_ls, cell_cluster_arr, celltype_names = utils.load_data_d1()
    elif 'd2' in args.exp_name:
        adata, cluster_cell_ls, cell_cluster_arr, celltype_names = utils.load_data_d2()
    elif 'healthy' in args.exp_name:
        adata, cluster_cell_ls, cell_cluster_arr, celltype_names = utils.load_data_healthy()
    elif 'disease1' in args.exp_name:
        adata, cluster_cell_ls, cell_cluster_arr, celltype_names = utils.load_data_disease1()
    elif 'disease' in args.exp_name:
        adata, cluster_cell_ls, cell_cluster_arr, celltype_names = utils.load_data_disease()
    elif 'hd' in args.exp_name:
        adata, cluster_cell_ls, cell_cluster_arr, celltype_names = utils.load_data_healthy_disease()

    # plots.umapPlot(adata.obsm['X_umap'],clusters=cell_cluster_arr,labels=celltype_names)

    new_data, markers_index,\
        sen_gene_ls, nonsen_gene_ls, gene_names = utils.process_data(
            adata, cluster_cell_ls, cell_cluster_arr)

    gene_cell = new_data.X.toarray().T
    args.gene_num = gene_cell.shape[0]
    args.cell_num = gene_cell.shape[1]

    print(f'cell num: {new_data.shape[0]}, gene num: {new_data.shape[1]}')

    graph_nx = utils.build_graph_nx(
        new_data,gene_cell, cell_cluster_arr, sen_gene_ls, nonsen_gene_ls, gene_names)
    logger.info("Part 1, data loading and processing end!")


    from utils import load_objs
    cell_score_dict, big_graph_index_dict=load_objs(f"{args.output_dir}{args.exp_name}_cell_score_dict")

    file_path=f'{args.output_dir}{args.exp_name}_outputs.data'
    sencell_dict,sen_gene_indexs=torch.load(file_path)

    cell_names=[]
    for i in sencell_dict.keys():
        cell_names.append(graph_nx.nodes[big_graph_index_dict[i]]['name'])
        
    results.append(cell_names)


                     orig.ident  nCount_RNA  nFeature_RNA  percent.ribo  \
AAACCCAGTAGCTTAC-1_7    SCB01S9      7963.0          2384     23.320357   
AAACGAAAGTGCAAAT-1_7    SCB01S9     10125.0          3448     17.106173   
AAACGAACACCTCAGG-1_7    SCB01S9      7893.0          2779     23.083745   
AAACGCTCAAGACTGG-1_7    SCB01S9      1361.0           873      5.069802   
AAACGCTCACCCTATC-1_7    SCB01S9     14995.0          4031     19.093031   

                      percent.mito  count.mad.higher integrated_snn_res.0.2  \
AAACCCAGTAGCTTAC-1_7     10.573904             False                      4   
AAACGAAAGTGCAAAT-1_7     14.775309             False                      9   
AAACGAACACCTCAGG-1_7      6.398074             False                      4   
AAACGCTCAAGACTGG-1_7      8.670096             False                      1   
AAACGCTCACCCTATC-1_7      3.867956             False                      6   

                     seurat_clusters       cell_type_seurat  disease loca

  self.data[key] = value


highly_genes num:  2000
After highly genes dropped duplicate:  1897
Total gene num: 2240
cell num: 3801, gene num: 2240
The number of edges: 1234461


# 2023-01-04 09:51:55.435 [INFO] [3529215334.py:108] Part 1, data loading and processing end!



                     orig.ident  nCount_RNA  nFeature_RNA  percent.ribo  \
AAACCCAAGACCTCAT-1_7    SCB01S9      9608.0          2997     23.043297   
AAACCCAGTCTCGGAC-1_7    SCB01S9     15611.0          4070     20.491961   
AAACGAAAGGCAGGGA-1_7    SCB01S9     19211.0          4306     25.547863   
AAACGAACACAGCATT-1_7    SCB01S9      9237.0          3090     19.313630   
AAACGAACACGGCTAC-1_7    SCB01S9      9373.0          2766     25.658807   

                      percent.mito  count.mad.higher integrated_snn_res.0.2  \
AAACCCAAGACCTCAT-1_7      2.445878             False                      1   
AAACCCAGTCTCGGAC-1_7      3.330985             False                      1   
AAACGAAAGGCAGGGA-1_7      5.189735             False                      6   
AAACGAACACAGCATT-1_7     13.034535             False                      4   
AAACGAACACGGCTAC-1_7      3.606103             False                      0   

                     seurat_clusters       cell_type_seurat  disease loca

  self.data[key] = value


highly_genes num:  2000
After highly genes dropped duplicate:  1911
Total gene num: 2254
cell num: 3801, gene num: 2254
The number of edges: 1180068


# 2023-01-04 09:52:05.124 [INFO] [3529215334.py:108] Part 1, data loading and processing end!



cluster 数量： 13
celltype names: ['Fibroblasts', 'Basal', 'Unknown', 'Secretory', 'None', 'Myofibroblasts', 'Multiciliated lineage', 'Rare', 'Submucosal Secretory', 'Mast cells', 'EC capillary', 'EC arterial', 'T cell lineage']
---------------------  ----
Fibroblasts            2822
Basal                  2819
Unknown                1534
Secretory               154
None                    114
Myofibroblasts          109
Multiciliated lineage    42
Rare                      2
Submucosal Secretory      2
Mast cells                1
EC capillary              1
EC arterial               1
T cell lineage            1
---------------------  ----
各marker list所包含的gene数：
  Markers1    Markers2    Markers3    Markers4
----------  ----------  ----------  ----------
       126          78         145          84
total marker genes:  380
highly_genes num:  2000
After highly genes dropped duplicate:  1901
Total gene num: 2244
cell num: 7602, gene num: 2244
The number of edges: 2276908


# 2023-01-04 09:52:17.305 [INFO] [3529215334.py:108] Part 1, data loading and processing end!


In [22]:
set1=set(results[0])
set2=set(results[1])
set3=set(results[2])

print(len(set3.intersection(set1))/len(set3))
print(len(set3.intersection(set2))/len(set3))

0.24675324675324675
0.03896103896103896


In [21]:
set3.intersection(set2)

{'CTCAGTCAGACAACAT-1_8', 'GTCCCATTCCTGGTCT-1_7', 'TGGTAGTTCTCCCATG-1_8'}