In [1]:
jobids="""# Submitted batch job 14001082
# Submitted batch job 14001083
# Submitted batch job 14001084"""
jobids=[i.split(' ')[-1] for i in jobids.split("\n")]

In [2]:
import numpy as np
import seaborn as sns
import torch

import umap
import matplotlib.pyplot as plt
import pandas as pd
from community import community_louvain
from torch_geometric.utils import k_hop_subgraph,to_networkx,from_networkx
import matplotlib

import utils
import plots
from model_AE import reduction_AE
from model_GAT import Encoder,SenGAE,train_GAT
from model_Sencell import Sencell
from sampling import sub_sampling_by_random
from model_Sencell import cell_optim,update_cell_embeddings
from sampling import identify_sengene_then_sencell

import logging
import os
import argparse

is_jupyter=True

parser = argparse.ArgumentParser(description='Main program for sencells')

parser.add_argument('--output_dir', type=str, default='./outputs', help='')
parser.add_argument('--exp_name', type=str, default='', help='')
parser.add_argument('--sencell_num', type=int, default=100, help='')
parser.add_argument('--retrain', action='store_true', default=False, help='')


if is_jupyter:
    args = parser.parse_args(args=[])
    args.exp_name='disease1'
    args.retrain=False  
else:
    args = parser.parse_args()
args.is_jupyter=is_jupyter

args.exp_name='disease1_14001082'
args.output_dir='./outputs/14001082'
args.device_index=0

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

logging.basicConfig(format='%(asctime)s.%(msecs)03d [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
                    datefmt='# %Y-%m-%d %H:%M:%S')

logging.getLogger().setLevel(logging.INFO)
logger = logging.getLogger()

# Part 1: load and process data
# cell_cluster_arr在画umap的时候用
if 's5' in args.exp_name:
    adata,cluster_cell_ls,cell_cluster_arr,celltype_names=utils.load_data()
elif 'healthy' in args.exp_name:
    adata,cluster_cell_ls,cell_cluster_arr,celltype_names=utils.load_data_healthy()
elif 'disease1' in args.exp_name:
    adata, cluster_cell_ls, cell_cluster_arr, celltype_names = utils.load_data_disease1()
elif 'disease' in args.exp_name:
    adata,cluster_cell_ls,cell_cluster_arr,celltype_names=utils.load_data_disease()

# plots.umapPlot(adata.obsm['X_umap'],clusters=cell_cluster_arr,labels=celltype_names)

new_data, markers_index,\
    sen_gene_ls, nonsen_gene_ls, gene_names = utils.process_data(
        adata, cluster_cell_ls, cell_cluster_arr)

gene_cell = new_data.X.toarray().T
args.gene_num = gene_cell.shape[0]
args.cell_num = gene_cell.shape[1]

print(f'cell num: {new_data.shape[0]}, gene num: {new_data.shape[1]}')

graph_nx = utils.build_graph_nx(
    gene_cell, cell_cluster_arr, sen_gene_ls, nonsen_gene_ls, gene_names)
logger.info("Part 1, data loading and processing end!")

# Part 2: generate init embedding
device = torch.device(f"cuda:{args.device_index}" if torch.cuda.is_available() else "cpu")
print('device:', device)
args.device = device

if args.retrain:
    gene_embed, cell_embed = reduction_AE(gene_cell, device)
    print(gene_embed.shape, cell_embed.shape)
    torch.save(gene_embed, os.path.join(
        args.output_dir, f'{args.exp_name}_gene.emb'))
    torch.save(cell_embed, os.path.join(
        args.output_dir, f'{args.exp_name}_cell.emb'))
else:
    gene_embed = torch.load(os.path.join(
        args.output_dir, f'{args.exp_name}_gene.emb'))
    cell_embed = torch.load(os.path.join(
        args.output_dir, f'{args.exp_name}_cell.emb'))

graph_nx = utils.add_nx_embedding(graph_nx, gene_embed, cell_embed)
graph_pyg = utils.build_graph_pyg(gene_cell, gene_embed, cell_embed)
logger.info("Part 2, AE end!")

# Part 3: train GAT
# graph_pyg=graph_pyg.to('cpu')

GAT_model = train_GAT(graph_nx, graph_pyg, args,
                      retrain=args.retrain, resampling=args.retrain)
logger.info("Part 3, training GAT end!")


all_gene_ls = []

list_sencell_cover = []
list_sengene_cover = []

cluster 数量： 24
celltype names: ['Macrophages', 'T cell lineage', 'Monocytes', 'Unknown', 'B cell lineage', 'Innate lymphoid cell NK', 'Dendritic cells', 'Multiciliated lineage', 'AT2', 'Secretory', 'Mast cells', 'Fibroblasts', 'EC capillary', 'EC venous', 'Lymphatic EC mature', 'AT1', 'Basal', 'EC arterial', 'Myofibroblasts', 'None', 'Submucosal Secretory', 'Lymphatic EC differentiating', 'SM activated stress response', 'Rare']
----------------------------  -----
Macrophages                   18151
T cell lineage                 5401
Monocytes                      4842
Unknown                        2830
B cell lineage                 1540
Innate lymphoid cell NK        1324
Dendritic cells                 772
Multiciliated lineage           738
AT2                             724
Secretory                       389
Mast cells                      366
Fibroblasts                     337
EC capillary                    277
EC venous                       186
Lymphatic EC mature         

  self.data[key] = value


highly_genes num:  2000
After highly genes dropped duplicate:  1924
Total gene num: 2286
cell num: 38407, gene num: 2286
The number of edges: 3688551


# 2022-12-07 08:14:59.162 [INFO] [2682742472.py:79] Part 1, data loading and processing end!


device: cuda:0
the number of edges: 3688551
edge index:  torch.Size([2, 3688551])
node feature:  torch.Size([12724, 128])
Pyg graph: Data(x=[12724, 128], edge_index=[2, 7377102], y=[40693])


RuntimeError: The size of tensor a (7061225) must match the size of tensor b (6745357) at non-singleton dimension 0

In [None]:
device = torch.device(f"cuda:{args.device_index}" if torch.cuda.is_available() else "cpu")
print('device:', device)
args.device = device

In [4]:

cellmodel = Sencell().to(device)
optimizer = torch.optim.Adam(cellmodel.parameters(), lr=0.001,
                             weight_decay=1e-3)

all_marker_index = sen_gene_ls

iteration_results = []
for iteration in range(5):
    logger.info(f"iteration: {iteration}")
    sampled_graph, sencell_dict, nonsencell_dict, cell_clusters, big_graph_index_dict = sub_sampling_by_random(graph_nx,
                                                                                                               sen_gene_ls,
                                                                                                               nonsen_gene_ls,
                                                                                                               GAT_model,
                                                                                                               args,
                                                                                                               all_marker_index,
                                                                                                               n_gene=len(
                                                                                                                   all_marker_index),
                                                                                                               gene_rate=0.3, cell_rate=0.5,
                                                                                                               debug=False)
    old_sengene_indexs = all_marker_index
    for epoch in range(10):
        logger.info(f"epoch: {epoch}")
        old_sencell_dict = sencell_dict
        cellmodel, sencell_dict, nonsencell_dict = cell_optim(cellmodel, optimizer,
                                                              sencell_dict, nonsencell_dict, args,
                                                              train=False)
        # sampled_graph=update_cell_embeddings(sampled_graph,sencell_dict,nonsencell_dict)
        sencell_dict, nonsencell_dict, \
            sen_gene_indexs, nonsen_gene_indexs = identify_sengene_then_sencell(sampled_graph, GAT_model,
                                                                                sencell_dict, nonsencell_dict,
                                                                                cell_clusters,
                                                                                big_graph_index_dict,
                                                                                len(all_marker_index), args)

        ratio_cell = utils.get_sencell_cover(old_sencell_dict, sencell_dict)
        ratio_gene = utils.get_sengene_cover(
            old_sengene_indexs, sen_gene_indexs)
        old_sengene_indexs = sen_gene_indexs
        break
    break

    iteration_results.append([sen_gene_indexs, sencell_dict])


# 2022-12-06 13:23:44.652 [INFO] [2833460121.py:9] iteration: 0


Start sampling subgraph randomly ...
    Sengene num: 362, Nonsengen num: 1924
subgraph total node num: (40693,)
After sampling, gene num:  tensor(2286)


# 2022-12-06 13:55:16.389 [INFO] [2833460121.py:22] epoch: 0


obj saved ./outputs/13995010/disease_13995010_cell_score_dict_test
    Sencell num: 100, Nonsencell num: 1000
rechoice sengene num: 362 rechoice nonsengene num: 1924
obj saved ./outputs/13995010/disease_13995010_cell_score_dict_test
    Sencell num: 100, Nonsencell num: 1000
sencell cover: 0.82
sengene cover: 0.4171270718232044
