In [1]:
jobids="""# Submitted batch job 14001082
# Submitted batch job 14001083
# Submitted batch job 14001084"""
jobids=[i.split(' ')[-1] for i in jobids.split("\n")]

In [3]:
import numpy as np
import seaborn as sns
import torch

import umap
import matplotlib.pyplot as plt
import pandas as pd
from community import community_louvain
from torch_geometric.utils import k_hop_subgraph,to_networkx,from_networkx
import matplotlib

import utils
import plots
from model_AE import reduction_AE
from model_GAT import Encoder,SenGAE,train_GAT
from model_Sencell import Sencell
from sampling import sub_sampling_by_random
from model_Sencell import cell_optim,update_cell_embeddings
from sampling import identify_sengene_then_sencell

import logging
import os
import argparse

is_jupyter=True

parser = argparse.ArgumentParser(description='Main program for sencells')

parser.add_argument('--output_dir', type=str, default='./outputs', help='')
parser.add_argument('--exp_name', type=str, default='', help='')
parser.add_argument('--sencell_num', type=int, default=100, help='')
parser.add_argument('--retrain', action='store_true', default=False, help='')


if is_jupyter:
    args = parser.parse_args(args=[])
    args.exp_name='disease1'
    args.retrain=False  
else:
    args = parser.parse_args()
args.is_jupyter=is_jupyter

args.exp_name='disease1_14001082'
args.output_dir='./outputs/14001082'
args.device_index=0

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

logging.basicConfig(format='%(asctime)s.%(msecs)03d [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
                    datefmt='# %Y-%m-%d %H:%M:%S')

logging.getLogger().setLevel(logging.INFO)
logger = logging.getLogger()

# Part 1: load and process data
# cell_cluster_arr在画umap的时候用
if 's5' in args.exp_name:
    adata,cluster_cell_ls,cell_cluster_arr,celltype_names=utils.load_data()
elif 'healthy' in args.exp_name:
    adata,cluster_cell_ls,cell_cluster_arr,celltype_names=utils.load_data_healthy()
elif 'disease1' in args.exp_name:
    adata, cluster_cell_ls, cell_cluster_arr, celltype_names = utils.load_data_disease1()
elif 'disease' in args.exp_name:
    adata,cluster_cell_ls,cell_cluster_arr,celltype_names=utils.load_data_disease()

# plots.umapPlot(adata.obsm['X_umap'],clusters=cell_cluster_arr,labels=celltype_names)

new_data, markers_index,\
    sen_gene_ls, nonsen_gene_ls, gene_names = utils.process_data(
        adata, cluster_cell_ls, cell_cluster_arr)

gene_cell = new_data.X.toarray().T
args.gene_num = gene_cell.shape[0]
args.cell_num = gene_cell.shape[1]

print(f'cell num: {new_data.shape[0]}, gene num: {new_data.shape[1]}')

graph_nx = utils.build_graph_nx(
    gene_cell, cell_cluster_arr, sen_gene_ls, nonsen_gene_ls, gene_names)
logger.info("Part 1, data loading and processing end!")

# Part 2: generate init embedding
device = torch.device(f"cuda:{args.device_index}" if torch.cuda.is_available() else "cpu")
print('device:', device)
args.device = device

if args.retrain:
    gene_embed, cell_embed = reduction_AE(gene_cell, device)
    print(gene_embed.shape, cell_embed.shape)
    torch.save(gene_embed, os.path.join(
        args.output_dir, f'{args.exp_name}_gene.emb'))
    torch.save(cell_embed, os.path.join(
        args.output_dir, f'{args.exp_name}_cell.emb'))
else:
    gene_embed = torch.load(os.path.join(
        args.output_dir, f'{args.exp_name}_gene.emb'))
    cell_embed = torch.load(os.path.join(
        args.output_dir, f'{args.exp_name}_cell.emb'))

graph_nx = utils.add_nx_embedding(graph_nx, gene_embed, cell_embed)
graph_pyg = utils.build_graph_pyg(gene_cell, gene_embed, cell_embed)
logger.info("Part 2, AE end!")

# Part 3: train GAT
# graph_pyg=graph_pyg.to('cpu')

GAT_model = train_GAT(graph_nx, graph_pyg, args,
                      retrain=args.retrain, resampling=args.retrain)
logger.info("Part 3, training GAT end!")


all_gene_ls = []

list_sencell_cover = []
list_sengene_cover = []

load data disease1!
cluster 数量： 22
celltype names: ['Macrophages', 'T cell lineage', 'Unknown', 'Monocytes', 'Innate lymphoid cell NK', 'AT2', 'Dendritic cells', 'Multiciliated lineage', 'Fibroblasts', 'B cell lineage', 'AT1', 'EC capillary', 'Lymphatic EC mature', 'Secretory', 'Mast cells', 'EC venous', 'Basal', 'EC arterial', 'Myofibroblasts', 'None', 'SM activated stress response', 'Lymphatic EC differentiating']
----------------------------  ----
Macrophages                   6381
T cell lineage                1010
Unknown                        812
Monocytes                      340
Innate lymphoid cell NK        251
AT2                            241
Dendritic cells                240
Multiciliated lineage          214
Fibroblasts                    195
B cell lineage                 124
AT1                            114
EC capillary                   100
Lymphatic EC mature             93
Secretory                       92
Mast cells                      81
EC venous           

  self.data[key] = value


highly_genes num:  2000
After highly genes dropped duplicate:  1926
Total gene num: 2279
cell num: 10445, gene num: 2279
The number of edges: 892449


# 2022-12-07 08:18:39.267 [INFO] [3258086314.py:81] Part 1, data loading and processing end!


device: cuda:0
the number of edges: 892449
edge index:  torch.Size([2, 892449])
node feature:  torch.Size([12724, 128])
Pyg graph: Data(x=[12724, 128], edge_index=[2, 1784898], y=[12724])


# 2022-12-07 08:18:40.166 [INFO] [3258086314.py:103] Part 2, AE end!
# 2022-12-07 08:18:40.212 [INFO] [3258086314.py:110] Part 3, training GAT end!


graph.is_directed(): False


In [4]:
device = torch.device(f"cuda:{args.device_index}" if torch.cuda.is_available() else "cpu")
print('device:', device)
args.device = device

device: cuda:0


In [5]:

cellmodel = Sencell().to(device)
optimizer = torch.optim.Adam(cellmodel.parameters(), lr=0.001,
                             weight_decay=1e-3)

all_marker_index = sen_gene_ls

iteration_results = []
for iteration in range(5):
    logger.info(f"iteration: {iteration}")
    sampled_graph, sencell_dict, nonsencell_dict, cell_clusters, big_graph_index_dict = sub_sampling_by_random(graph_nx,
                                                                                                               sen_gene_ls,
                                                                                                               nonsen_gene_ls,
                                                                                                               GAT_model,
                                                                                                               args,
                                                                                                               all_marker_index,
                                                                                                               n_gene=len(
                                                                                                                   all_marker_index),
                                                                                                               gene_rate=0.3, cell_rate=0.5,
                                                                                                               debug=False)
    old_sengene_indexs = all_marker_index
    for epoch in range(10):
        logger.info(f"epoch: {epoch}")
        old_sencell_dict = sencell_dict
        cellmodel, sencell_dict, nonsencell_dict = cell_optim(cellmodel, optimizer,
                                                              sencell_dict, nonsencell_dict, args,
                                                              train=False)
        # sampled_graph=update_cell_embeddings(sampled_graph,sencell_dict,nonsencell_dict)
        sencell_dict, nonsencell_dict, \
            sen_gene_indexs, nonsen_gene_indexs = identify_sengene_then_sencell(sampled_graph, GAT_model,
                                                                                sencell_dict, nonsencell_dict,
                                                                                cell_clusters,
                                                                                big_graph_index_dict,
                                                                                len(all_marker_index), args)

        ratio_cell = utils.get_sencell_cover(old_sencell_dict, sencell_dict)
        ratio_gene = utils.get_sengene_cover(
            old_sengene_indexs, sen_gene_indexs)
        old_sengene_indexs = sen_gene_indexs
        break
    break

    iteration_results.append([sen_gene_indexs, sencell_dict])


# 2022-12-07 08:19:21.472 [INFO] [2833460121.py:9] iteration: 0


Start sampling subgraph randomly ...
    Sengene num: 353, Nonsengen num: 1926
subgraph total node num: (12724,)
After sampling, gene num:  tensor(2279)


# 2022-12-07 08:26:48.914 [INFO] [2833460121.py:22] epoch: 0


obj saved ./outputs/14001082/disease1_14001082_cell_score_dict_test
    Sencell num: 100, Nonsencell num: 1000
rechoice sengene num: 353 rechoice nonsengene num: 1926
obj saved ./outputs/14001082/disease1_14001082_cell_score_dict_test
    Sencell num: 100, Nonsencell num: 1000
sencell cover: 0.47
sengene cover: 0.3909348441926346
