In [1]:
import numpy as np
import seaborn as sns
import torch

import umap
import matplotlib.pyplot as plt
import pandas as pd
from community import community_louvain
from torch_geometric.utils import k_hop_subgraph,to_networkx,from_networkx
import matplotlib

import utils
import plots
from model_AE import reduction_AE
from model_GAT import Encoder,SenGAE,train_GAT
from model_Sencell import Sencell

import logging
import os
import argparse


parser = argparse.ArgumentParser(description='Main program for sencells')

parser.add_argument('--output_dir', type=str, default='./outputs', help='')
parser.add_argument('--exp_name', type=str, default='', help='')

args = parser.parse_args(args=[])

args.exp_name='disease'

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

logging.basicConfig(format='%(asctime)s.%(msecs)03d [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
                    datefmt='# %Y-%m-%d %H:%M:%S')

logging.getLogger().setLevel(logging.DEBUG)
logger = logging.getLogger()

# Part 1: load and process data
# cell_cluster_arr在画umap的时候用
adata,cluster_cell_ls,cell_cluster_arr,celltype_names=utils.load_data_disease()
# plots.umapPlot(adata.obsm['X_umap'],clusters=cell_cluster_arr,labels=celltype_names)

new_data,markers_index,\
sen_gene_ls,nonsen_gene_ls,gene_names=utils.process_data(adata,cluster_cell_ls,cell_cluster_arr)

print(f'cell num: {new_data.shape[0]}, gene num: {new_data.shape[1]}')

gene_cell=new_data.X.toarray().T
graph_nx=utils.build_graph_nx(gene_cell,cell_cluster_arr,sen_gene_ls,nonsen_gene_ls,gene_names)
logger.info("Part 1, data loading and processing end!")

# Part 2: generate init embedding
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device:',device)
args.device=device

retrain=False
if retrain:
    gene_embed,cell_embed=reduction_AE(gene_cell,device)
    print(gene_embed.shape,cell_embed.shape)
    torch.save(gene_embed,os.path.join(args.output_dir,f'{args.exp_name}_gene.emb'))
    torch.save(cell_embed,os.path.join(args.output_dir,f'{args.exp_name}_cell.emb'))
else:
    gene_embed=torch.load(os.path.join(args.output_dir,f'{args.exp_name}_gene.emb'))
    cell_embed=torch.load(os.path.join(args.output_dir,f'{args.exp_name}_cell.emb'))

graph_nx=utils.add_nx_embedding(graph_nx,gene_embed,cell_embed)
graph_pyg=utils.build_graph_pyg(gene_cell,gene_embed,cell_embed)
logger.info("Part 2, AE end!")

# Part 3: train GAT
# graph_pyg=graph_pyg.to('cpu')
args.gene_num=gene_cell.shape[0]
args.cell_num=gene_cell.shape[1]



GAT_model=train_GAT(graph_nx,graph_pyg,args,retrain=False,resampling=False)
logger.info("Part 3, training GAT end!")



from sampling import identify_sencell_marker_graph
from sampling import sub_sampling_by_random
from model_Sencell import cell_optim,update_cell_embeddings,old_cell_optim

from sampling import identify_sengene_then_sencell


all_gene_ls=[]

list_sencell_cover=[]
list_sengene_cover=[]

def get_sencell_cover(old_sencell_dict,sencell_dict):
    set1=set(list(old_sencell_dict.keys()))
    set2=set(list(sencell_dict.keys()))
    set3=set1.intersection(set2)
    print('sencell cover:',len(set3)/len(set2))
    
    return len(set3)/len(set2)

def get_sengene_cover(old_sengene_ls,sengene_ls):
    set1=set(old_sengene_ls)
    set2=set(sengene_ls)
    set3=set1.intersection(set2)
    print('sengene cover:',len(set3)/len(set2))
    
    return len(set3)/len(set2)


# 2022-11-30 08:18:32.160 [DEBUG] [attrs.py:77] Creating converter from 3 to 5


cluster 数量： 14
celltype names: ['0', '1', '2', '3', '5', 'Epithelial cells', '7', '8', '10', 'Pericytes', 'Fibroblasts-Pericytes', '14', '9', 'Endothelial cells']
---------------------  -----
0                      10712
1                       7273
2                       5435
3                       3923
5                       2695
Epithelial cells        2385
7                       1630
8                       1607
10                       900
Pericytes                571
Fibroblasts-Pericytes    466
14                       361
9                        263
Endothelial cells        186
---------------------  -----
各marker list所包含的gene数：
  Markers1    Markers2    Markers3    Markers4
----------  ----------  ----------  ----------
       126          78         145          84
total marker genes:  380
highly_genes num:  2000
After highly genes dropped duplicate:  1881
highly genes里面有全0！！
Total gene num: 2186
cell num: 38407, gene num: 2186
The number of edges: 5240183


# 2022-11-30 08:18:56.067 [INFO] [3306991342.py:53] Part 1, data loading and processing end!


device: cuda:0
the number of edges: 5240183
edge index:  torch.Size([2, 5240183])
node feature:  torch.Size([40593, 128])
Pyg graph: Data(x=[40593, 128], edge_index=[2, 10480366], y=[40593])


# 2022-11-30 08:19:05.125 [INFO] [3306991342.py:72] Part 2, AE end!
# 2022-11-30 08:19:05.190 [INFO] [3306991342.py:82] Part 3, training GAT end!


graph.is_directed(): False


In [2]:
%%time
import sampling
all_marker_index=sen_gene_ls
    
iteration_results=[]
for iteration in range(5):
    logger.info(f"iteration: {iteration}")
    sampled_graph,sencell_dict,nonsencell_dict,cell_clusters,big_graph_index_dict=sampling.sub_sampling_by_random(graph_nx,
                                                            sen_gene_ls,
                                                            nonsen_gene_ls,
                                                            GAT_model,
                                                            args,
                                                            all_marker_index,
                                                            n_gene=len(all_marker_index),                                                        
                                                            gene_rate=0.3,cell_rate=0.5,
                                                            debug=False)
    break

# 2022-11-30 08:19:05.199 [INFO] [<timed exec>:6] iteration: 0


Start sampling subgraph randomly ...
    Sengene num: 362, Nonsengen num: 362
subgraph total node num: (39131,)
After sampling, gene num:  tensor(724)
obj saved ./outputs/disease_cell_score_dict
    Sencell num: 200, Nonsencell num: 2000
CPU times: user 19min 35s, sys: 3.91 s, total: 19min 39s
Wall time: 19min 41s
