In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('/home/liwb/project/rna2adt')

import triplet_utils
import graph_utils

import warnings
warnings.filterwarnings("ignore")

import numpy as np
import anndata as ad
import scanpy as sc
import pandas as pd
import scipy.sparse as sp
import scipy.linalg
from scipy.sparse import csr_matrix
import random

import torch
seed = 3407
torch.manual_seed(seed) # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU，为所有GPU设置随机种子
np.random.seed(seed)  # Numpy module.
random.seed(seed)  # Python random module.	
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
used_device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')

from sklearn.metrics import adjusted_rand_score as ARI
from sklearn.metrics import normalized_mutual_info_score as NMI
from sklearn.metrics import fowlkes_mallows_score as FMI
from sklearn.metrics import silhouette_score as SC

from utils import find_resolution, find_res_label

from sklearn.cluster import KMeans

In [None]:
data_name = "pbmc"

refdata = sc.read_h5ad(f'/data/user/liwb/project/rna2adt/data/{data_name}/ADT.h5ad')
# /home/chenjn/rna2adt/data/SCoPE2/ADT.h5ad

adata = pd.read_csv(f'/data/user/liwb/project/rna2adt/output/{data_name}/embeddings.csv', index_col=0)
adata = ad.AnnData(adata)
adata.obs = refdata.obs
# adata = ad.AnnData(adata, obs=refdata.obs)

index = [i in ['P1', 'P5'] for i in adata.obs['donor']]
# index = [i in ['P2', 'P6'] for i in adata.obs['donor']]
# index = [i in ['P3', 'P4', 'P7', 'P8'] for i in adata.obs['donor']]
# index = [i in ['P1', 'P5', 'P2', 'P6', 'P3', 'P4', 'P7', 'P8'] for i in adata.obs['donor']]
adata = adata[index]
refdata = refdata[index]

In [None]:
# refdata.obs[['donor','Batch']]
print(refdata.obs['Batch'].value_counts())
t1 = refdata[refdata.obs['Batch']=='Batch1']
t2 = refdata[refdata.obs['Batch']=='Batch2']
t1.obs['donor'].value_counts(),t2.obs['donor'].value_counts()

In [None]:
# sc.pp.neighbors(refdata)

# sc.tl.louvain(refdata)
# refdata.obs['louvain_res'] = find_res_label(refdata, len(np.unique(adata.obs['celltype.l2'])))
# sc.tl.umap(refdata)

# sc.pl.umap(refdata, color=['Batch', 'celltype.l2'], ncols=2, wspace=0.4, show=True)
# sc.pl.umap(refdata, color=['louvain', 'louvain_res'], ncols=2, wspace=0.4, show=True)

# print(ARI(refdata.obs['louvain'], refdata.obs['celltype.l2']))
# print(NMI(refdata.obs['louvain'], refdata.obs['celltype.l2'])) 
# print(FMI(refdata.obs['louvain'], refdata.obs['celltype.l2']))
# print()
# print(ARI(refdata.obs['louvain_res'], refdata.obs['celltype.l2']))
# print(NMI(refdata.obs['louvain_res'], refdata.obs['celltype.l2'])) 
# print(FMI(refdata.obs['louvain_res'], refdata.obs['celltype.l2']))

In [None]:
# sc.pp.neighbors(adata)

# sc.tl.louvain(adata)
# adata.obs['louvain_res'] = find_res_label(adata, len(np.unique(adata.obs['celltype.l2'])))
# sc.tl.umap(adata)

# sc.pl.umap(adata, color=['Batch', 'celltype.l2'], ncols=2, wspace=0.4, show=True)
# sc.pl.umap(adata, color=['louvain', 'louvain_res'], ncols=2, wspace=0.4, show=True)

# print(ARI(adata.obs['louvain'], adata.obs['celltype.l2']))
# print(NMI(adata.obs['louvain'], adata.obs['celltype.l2'])) 
# print(FMI(adata.obs['louvain'], adata.obs['celltype.l2']))
# print()
# print(ARI(adata.obs['louvain_res'], adata.obs['celltype.l2']))
# print(NMI(adata.obs['louvain_res'], adata.obs['celltype.l2'])) 
# print(FMI(adata.obs['louvain_res'], adata.obs['celltype.l2']))

In [None]:
Batch_list = []
adj_list = []
section_ids = ['Batch1', 'Batch2']

adata.obs_names = [x + '_' + y + '_' + z for x, y, z in zip(adata.obs_names, adata.obs['Batch'], adata.obs['celltype.l2'])]  

for section_id in section_ids:
    tmpdata = adata[adata.obs['Batch'] == section_id]
    # tmpdata.X = csr_matrix(tmpdata.X)
    # adata.var_names_make_unique(join="++") 
    
    # Constructing the spatial network
    graph_utils.Cal_Spatial_Net(tmpdata, k_cutoff=20, model='KNN') # the spatial network are saved in adata.uns[‘adj’]
       
    sc.pp.normalize_total(tmpdata)

    adj_list.append(tmpdata.uns['adj'])
    Batch_list.append(tmpdata)

In [None]:
# iter_comb is used to specify the order of integration. For example, (0, 1) means slice 0 will be algined with slice 1 as reference.
iter_comb = [(0, 1)]

# Here, to reduce GPU memory usage, each slice is considered as a subgraph for training.
adata_af = triplet_utils.train_triplet(adata, verbose=True, knn_neigh = 10,
                                       n_epochs = 3000, iter_comb = iter_comb, 
                                       Batch_list=Batch_list, device=used_device,
                                       margin=0.01, batch_key='Batch', label_key='celltype.l2')



In [None]:
section_ids = np.array(adata.obs['Batch'].unique())
section_ids

In [None]:
adata1 = ad.AnnData(adata_af.obsm['triplet_emb'], obs=adata_af.obs)
adata2 = ad.AnnData(adata_af.obsm['triplet_out'], obs=adata_af.obs)

In [None]:
sc.pp.neighbors(adata1)
sc.tl.louvain(adata1)
adata1.obs['louvain_res'] = find_res_label(adata1, len(np.unique(adata1.obs['celltype.l2'])))

In [None]:
sc.tl.umap(adata1)
sc.pl.umap(adata1, color=['Batch', 'celltype.l2'], ncols=2, wspace=0.4, show=True)
sc.pl.umap(adata1, color=['louvain', 'louvain_res'], ncols=2, wspace=0.4, show=True)

print(ARI(adata1.obs['louvain'], adata1.obs['celltype.l2']))
print(NMI(adata1.obs['louvain'], adata1.obs['celltype.l2']))
print(FMI(adata1.obs['louvain'], adata1.obs['celltype.l2']))
print()
print(ARI(adata1.obs['louvain_res'], adata1.obs['celltype.l2']))
print(NMI(adata1.obs['louvain_res'], adata1.obs['celltype.l2'])) 
print(FMI(adata1.obs['louvain_res'], adata1.obs['celltype.l2']))