In [1]:
import numpy as np
import seaborn as sns
import torch

import umap
import matplotlib.pyplot as plt
import pandas as pd
from community import community_louvain
from torch_geometric.utils import k_hop_subgraph,to_networkx,from_networkx
import matplotlib

import utils
import plots
from model_AE import reduction_AE
from model_GAT import Encoder,SenGAE,train_GAT
from model_Sencell import Sencell

import logging

logging.basicConfig(format='%(asctime)s.%(msecs)03d [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
                    datefmt='# %Y-%m-%d %H:%M:%S')

logging.getLogger().setLevel(logging.DEBUG)
logger = logging.getLogger()

# Part 1: load and process data
data_path="/users/PCON0022/haocheng/Basu_lab/rmarkdown/SCB01S5.h5ad"
# cell_cluster_arr在画umap的时候用
adata,cluster_cell_ls,cell_cluster_arr,celltype_names=utils.load_data(data_path)
# plots.umapPlot(adata.obsm['X_umap'],clusters=cell_cluster_arr,labels=celltype_names)

new_data,markers_index,\
sen_gene_ls,nonsen_gene_ls,gene_names=utils.process_data(adata,cluster_cell_ls,cell_cluster_arr)

print(f'cell num: {new_data.shape[0]}, gene num: {new_data.shape[1]}')

gene_cell=new_data.X.toarray().T
graph_nx=utils.build_graph_nx(gene_cell,cell_cluster_arr,sen_gene_ls,nonsen_gene_ls,gene_names)
logger.info("Part 1, data loading and processing end!")

# Part 2: generate init embedding
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device: ',device)

retrain=False
if retrain:
    gene_embed,cell_embed=reduction_AE(gene_cell,device)
    logger.info(gene_embed.shape,cell_embed.shape)
    torch.save(gene_embed,'./gene.emb')
    torch.save(cell_embed,'./cell.emb')
else:
    gene_embed=torch.load('./gene.emb')
    cell_embed=torch.load('./cell.emb')

graph_nx=utils.add_nx_embedding(graph_nx,gene_embed,cell_embed)
graph_pyg=utils.build_graph_pyg(gene_cell,gene_embed,cell_embed)
logger.info("Part 2, AE end!")

# Part 3: train GAT
# graph_pyg=graph_pyg.to('cpu')
GAT_model=train_GAT(graph_nx,graph_pyg,device,retrain=False,resampling=False)
logger.info("Part 3, training GAT end!")



from sampling import identify_sencell_marker_graph
from sampling import sub_sampling_by_random
from model_Sencell import cell_optim,update_cell_embeddings,old_cell_optim

from sampling import identify_sengene_then_sencell


all_gene_ls=[]

list_sencell_cover=[]
list_sengene_cover=[]

def get_sencell_cover(old_sencell_dict,sencell_dict):
    set1=set(list(old_sencell_dict.keys()))
    set2=set(list(sencell_dict.keys()))
    set3=set1.intersection(set2)
    print('sencell cover:',len(set3)/len(set2))
    
    return len(set3)/len(set2)

def get_sengene_cover(old_sengene_ls,sengene_ls):
    set1=set(old_sengene_ls)
    set2=set(sengene_ls)
    set3=set1.intersection(set2)
    print('sengene cover:',len(set3)/len(set2))
    
    return len(set3)/len(set2)
    
    
cellmodel=Sencell().to(device)
optimizer = torch.optim.Adam(cellmodel.parameters(), lr=0.001, 
                                weight_decay=1e-3)



# 2022-11-22 12:29:29.842 [DEBUG] [attrs.py:77] Creating converter from 3 to 5


cluster 数量： 21
celltype names: ['Macrophages', 'T cell lineage', 'Unknown', 'B cell lineage', 'Innate lymphoid cell NK', 'AT2', 'Monocytes', 'Multiciliated lineage', 'Dendritic cells', 'EC capillary', 'Mast cells', 'Fibroblasts', 'Secretory', 'EC venous', 'Lymphatic EC mature', 'AT1', 'Basal', 'EC arterial', 'Myofibroblasts', 'None', 'Submucosal Secretory']
-----------------------  ----
Macrophages              6941
T cell lineage            749
Unknown                   618
B cell lineage            374
Innate lymphoid cell NK   327
AT2                       294
Monocytes                 228
Multiciliated lineage     194
Dendritic cells           177
EC capillary              138
Mast cells                100
Fibroblasts                93
Secretory                  86
EC venous                  74
Lymphatic EC mature        68
AT1                        27
Basal                      26
EC arterial                20
Myofibroblasts             17
None                        6
Submucosal

# 2022-11-22 12:29:36.011 [INFO] [3655344800.py:39] Part 1, data loading and processing end!


device:  cuda:0
the number of edges: 1745053
edge index:  torch.Size([2, 1745053])
node feature:  torch.Size([12803, 128])
Pyg graph: Data(x=[12803, 128], edge_index=[2, 3490106], y=[12803])


# 2022-11-22 12:29:39.848 [INFO] [3655344800.py:57] Part 2, AE end!
# 2022-11-22 12:29:39.856 [INFO] [3655344800.py:62] Part 3, training GAT end!


graph.is_directed(): False


In [None]:
all_marker_index=sen_gene_ls
    
iteration_results=[]
for iteration in range(5):
    logger.info(f"iteration: {iteration}")
    sampled_graph,sencell_dict,nonsencell_dict,cell_clusters,big_graph_index_dict=sub_sampling_by_random(graph_nx,
                                                            sen_gene_ls,
                                                            nonsen_gene_ls,
                                                            GAT_model,
                                                            all_marker_index,
                                                            n_gene=len(all_marker_index),                                                        
                                                            gene_num=2245,cell_num=10558,
                                                            gene_rate=0.3,cell_rate=0.5,
                                                            debug=False)
    old_sengene_indexs=all_marker_index
    for epoch in range(10):
        logger.info(f"epoch: {epoch}")
        old_sencell_dict=sencell_dict
        cellmodel,sencell_dict,nonsencell_dict=cell_optim(cellmodel,optimizer,
                                                          sencell_dict,nonsencell_dict,device,
                                                         train=True)
        sampled_graph=update_cell_embeddings(sampled_graph,sencell_dict,nonsencell_dict)
        sencell_dict,nonsencell_dict, \
        sen_gene_indexs,nonsen_gene_indexs=identify_sengene_then_sencell(sampled_graph,GAT_model,
                                                                      sencell_dict,nonsencell_dict,
                                                                      cell_clusters,
                                                                      big_graph_index_dict,
                                                                      len(all_marker_index))

        get_sencell_cover(old_sencell_dict,sencell_dict)
        get_sengene_cover(old_sengene_indexs,sen_gene_indexs)
        old_sengene_indexs=sen_gene_indexs
    iteration_results.append([sen_gene_indexs,sencell_dict])