In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('/home/chenjn/rna2adt_push/')

import triplet_utils
import graph_utils

import warnings
warnings.filterwarnings("ignore")

import numpy as np
import anndata as ad
import scanpy as sc
import pandas as pd
import scipy.sparse as sp
import scipy.linalg
from scipy.sparse import csr_matrix
import random

import torch
seed = 3407
torch.manual_seed(seed) # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU，为所有GPU设置随机种子
np.random.seed(seed)  # Numpy module.
random.seed(seed)  # Python random module.	
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
used_device = torch.device('cuda:6' if torch.cuda.is_available() else 'cpu')

from sklearn.metrics import adjusted_rand_score as ARI
from sklearn.metrics import normalized_mutual_info_score as NMI
from sklearn.metrics import fowlkes_mallows_score as FMI
from sklearn.metrics import silhouette_score as SC

from utils import find_resolution, find_res_label

from sklearn.cluster import KMeans

In [None]:
data_name = "cross_species"
label_key = 'cluster_o'
batch_key = 'batch_id'
section_keys  = ['Human_with_h1n1', 'Rheus_macaques_ifgn']
# section_ids = ['Human_with_h1n1', 'Human_ifgn', 'Cynomolgus_monkeys_ifgn', 'Rheus_macaques_ifgn']

refdata = sc.read_h5ad(f'/home/chenjn/rna2adt_fork/A_run_test/data/{data_name}/ADT.h5ad')
# /home/chenjn/rna2adt/data/SCoPE2/ADT.h5ad

adata = sc.read_h5ad(f'/home/chenjn/rna2adt_fork/A_run_test/new_emb_sc/{data_name}.h5ad')
# adata = sc.read_h5ad(f'/home/chenjn/rna2adt_fork/A_run_test/data/{data_name}/ADT.h5ad')
adata.obs = refdata.obs
# adata = ad.AnnData(adata, obs=refdata.obs)

In [None]:
index = [i in section_keys for i in refdata.obs[batch_key]]
refdata = refdata[index]
print(refdata)

index = [i in section_keys for i in adata.obs[batch_key]]
adata = adata[index]
print(adata)

In [None]:
sc.pp.neighbors(refdata)

sc.tl.louvain(refdata)
refdata.obs['louvain_res'] = find_res_label(refdata, len(np.unique(adata.obs[label_key])))
sc.tl.umap(refdata)

sc.pl.umap(refdata, color=[batch_key, label_key], ncols=2, wspace=0.4, show=True)
sc.pl.umap(refdata, color=['louvain', 'louvain_res'], ncols=2, wspace=0.4, show=True)

print(ARI(refdata.obs['louvain'], refdata.obs[label_key]))
print(NMI(refdata.obs['louvain'], refdata.obs[label_key])) 
print(FMI(refdata.obs['louvain'], refdata.obs[label_key]))
print()
print(ARI(refdata.obs['louvain_res'], refdata.obs[label_key]))
print(NMI(refdata.obs['louvain_res'], refdata.obs[label_key])) 
print(FMI(refdata.obs['louvain_res'], refdata.obs[label_key]))

In [None]:
# sc.pp.neighbors(adata)

# sc.tl.louvain(adata)
# adata.obs['louvain_res'] = find_res_label(adata, len(np.unique(adata.obs[label_key])))
# sc.tl.umap(adata)

# sc.pl.umap(adata, color=[batch_key, label_key], ncols=2, wspace=0.4, show=True)
# sc.pl.umap(adata, color=['louvain', 'louvain_res'], ncols=2, wspace=0.4, show=True)

# print(ARI(adata.obs['louvain'], adata.obs[label_key]))
# print(NMI(adata.obs['louvain'], adata.obs[label_key])) 
# print(FMI(adata.obs['louvain'], adata.obs[label_key]))
# print()
# print(ARI(adata.obs['louvain_res'], adata.obs[label_key]))
# print(NMI(adata.obs['louvain_res'], adata.obs[label_key])) 
# print(FMI(adata.obs['louvain_res'], adata.obs[label_key]))

In [None]:
Batch_list = []
adj_list = []
section_ids = section_keys

adata.obs_names = [x + '_' + y + '_' + z for x, y, z in zip(adata.obs_names, adata.obs[batch_key], adata.obs[label_key])]  

for section_id in section_ids:
    tmpdata = adata[adata.obs[batch_key] == section_id]
    # tmpdata.X = csr_matrix(tmpdata.X)
    # adata.var_names_make_unique(join="++")
    
    # Constructing the spatial network
    graph_utils.Cal_Spatial_Net(tmpdata, k_cutoff=20, model='KNN') # the spatial network are saved in adata.uns[‘adj’]

    # adata = adata[:, adata.var['highly_variable']]
    adj_list.append(tmpdata.uns['adj'])
    Batch_list.append(tmpdata)

In [None]:
# iter_comb is used to specify the order of integration. For example, (0, 1) means slice 0 will be algined with slice 1 as reference.
iter_comb = [(i, i + 1) for i in range(len(section_keys) - 1)]

# Here, to reduce GPU memory usage, each slice is considered as a subgraph for training.
adata_af = triplet_utils.train_triplet(adata, verbose=True, knn_neigh = 10,
                                       n_epochs = 3000, iter_comb = iter_comb, 
                                       Batch_list=Batch_list, device=used_device,
                                       margin=0.01, batch_key=batch_key, label_key=label_key)



In [None]:
adata1 = ad.AnnData(adata_af.obsm['triplet_emb'], obs=adata_af.obs)
adata2 = ad.AnnData(adata_af.obsm['triplet_out'], obs=adata_af.obs)

In [None]:
section_ids = np.array(adata.obs[batch_key].unique())
section_ids

In [None]:
sc.pp.neighbors(adata1)
sc.tl.louvain(adata1)
adata1.obs['louvain_res'] = find_res_label(adata1, len(np.unique(adata1.obs[label_key])))

In [None]:
sc.tl.umap(adata1)
sc.pl.umap(adata1, color=[batch_key, label_key], ncols=2, wspace=0.4, show=True)
sc.pl.umap(adata1, color=['louvain', 'louvain_res'], ncols=2, wspace=0.4, show=True)

print(ARI(adata1.obs['louvain'], adata1.obs[label_key]))
print(NMI(adata1.obs['louvain'], adata1.obs[label_key]))
print(FMI(adata1.obs['louvain'], adata1.obs[label_key]))
print()
print(ARI(adata1.obs['louvain_res'], adata1.obs[label_key]))
print(NMI(adata1.obs['louvain_res'], adata1.obs[label_key])) 
print(FMI(adata1.obs['louvain_res'], adata1.obs[label_key]))