## <center>Run all</center>

### Step1:加载库与函数

In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ['CUBLAS_WORKSPACE_CONFIG']=':4096:8'
import sys
sys.path.append('/home/chenjn/rna2adt')

%config InlineBackend.figure_format = 'retina'
import warnings
warnings.filterwarnings('ignore')

In [None]:
import numpy as np
from tqdm import tqdm
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics import average_precision_score, roc_auc_score
from sklearn.metrics import adjusted_rand_score as ARI
from sklearn.metrics import normalized_mutual_info_score as NMI
from sklearn.metrics import fowlkes_mallows_score as FMI
from sklearn.metrics import silhouette_score as SC
from scipy.stats import pearsonr, spearmanr
import dataset
import dataloaders
import scanpy as sc
import scbasset_ori as scbasset
import sklearn
from biock import make_directory, make_logger, get_run_info
from biock.pytorch import model_summary, set_seed
from biock import HG19_FASTA_H5, HG38_FASTA_H5

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs/All/epoch_20")

from scipy.optimize import linear_sum_assignment
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import NearestNeighbors
from scipy.sparse import issparse

from utils import find_res_label
from scipy.spatial.distance import cosine

In [None]:
def cluster_acc(y_true, y_pred):
    """
    Calculate clustering accuracy. Require scikit-learn installed

    # Arguments
        y: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`

    # Return
        accuracy, in [0,1]
    """
    assert y_pred.size == y_true.size
    
    encoder = LabelEncoder()
    encoder = encoder.fit(np.unique(y_true))
    y_true = encoder.transform(y_true).astype(np.int64)
    y_pred = y_pred.astype(np.int64)

    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)

    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1

    # ind = linear_assignment(w.max() - w)
    ind = linear_sum_assignment(w.max() - w)
    ind = np.array((ind[0], ind[1])).T

    return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size


def label_scores(embeddings, labels):
    n_neigh = min(20, len(embeddings) // 3)
    nn_ = NearestNeighbors(n_neighbors=n_neigh)
    nn_.fit(embeddings)
    knns = nn_.kneighbors(embeddings, return_distance=False)

    res = 0
    for i in range(len(embeddings)):
        num = 0
        for j in range(len(knns[i])):
            if labels[i] == labels[knns[i][j]]:
                num += 1
        res += num / len(knns[i])

    return res / len(embeddings)


def get_CSS(data1, data2, dim=1,  func=cosine):
    r1, p1 = [], []
    # print(data1.shape, data2.shape)
    for g in range(data1.shape[dim]):
        if dim == 1:
            # print(np.sum(data1[:, g]), np.sum(data2[:, g]))
            r = func(data1[:, g], data2[:, g])
        elif dim == 0:
            # print(np.sum(data1[g, :]), np.sum(data2[g, :]))
            r = func(data1[g, :], data2[g, :])
        # print(r)
        r1.append(r)
    r1 = np.array(r1)
    return np.mean(r1)


def get_R(data1, data2, dim=1, func=pearsonr):
    r1, p1 = [], []
    # print(data1.shape, data2.shape)
    for g in range(data1.shape[dim]):
        if dim == 1:
            # print(np.isnan(data1[:, g]).any())
            # print(np.isnan(data2[:, g]).any())
            # print(np.isinf(data1[:, g]).any())
            # print(np.isinf(data2[:, g]).any())
            # print(np.sum(data1[:, g]), np.sum(data2[:, g]))
            r, pv = func(data1[:, g], data2[:, g])
        elif dim == 0:
            # print(np.isnan(data1[g, :]).any())
            # print(np.isnan(data2[g, :]).any())
            # print(np.isinf(data1[g, :]).any())
            # print(np.isinf(data2[g, :]).any())
            # print(np.sum(data1[g, :]), np.sum(data2[g, :]))
            r, pv = func(data1[g, :], data2[g, :])
        # print(r)
        r1.append(r)
        p1.append(pv)
    r1 = np.array(r1)
    p1 = np.array(p1)

    return r1, p1


def test_model(model, loader, device, epoch):
    model.eval()
    all_label = list()
    all_pred = list()

    for it, (seq, adt) in enumerate(tqdm(loader)):
        seq = seq.to(device)
        output = model(seq)[0].detach()
        output = torch.sigmoid(output).cpu().numpy().astype(np.float16)

        adt = adt.numpy().astype(np.float16)

        all_pred.append(output)
        all_label.append(adt)

    all_pred = np.concatenate(all_pred, axis=0)
    all_label = np.concatenate(all_label, axis=0)

    R = get_R(all_pred, all_label, dim=0)[0]
    R1 = get_R(all_pred, all_label, dim=1)[0]
    print("T2",all_pred.shape,all_label.shape)

    css = get_CSS(all_pred, all_label, dim=0)
    css1 = get_CSS(all_pred, all_label, dim=1)


    def mseloss(y_true, y_pred):
        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        mse = np.mean((y_true - y_pred) ** 2)
        return mse
    All_Loss = mseloss(all_pred, all_label)

    R = np.nanmean(R)
    R1 = np.nanmean(R1)

    embedding = model.get_embedding().detach().cpu().numpy().astype(np.float32)

    adata1 = sc.AnnData(
        embedding,
        obs=adtT.obs,
    )
    # print(adata1)
    sc.pp.neighbors(adata1, use_rep='X')
    sc.tl.umap(adata1)

    sc.tl.louvain(adata1)
    # adata1.obs['louvain_res'] = find_res_label(adata1, len(np.unique(adata1.obs[label_key])))

    if label_key is not None:
        ari = ARI(adata1.obs['louvain'], adata1.obs[label_key])
        nmi = NMI(adata1.obs['louvain'], adata1.obs[label_key])
        ca = cluster_acc(adata1.obs[label_key].to_numpy(), adata1.obs['louvain'].values.to_numpy())
        fmi = FMI(adata1.obs['louvain'], adata1.obs[label_key])
        sci = SC(adata1.X, adata1.obs['louvain'].values.reshape(-1, 1))
        lsi = label_scores(embedding, adata1.obs[label_key])
    else:
        ari = 0.
        nmi = 0.
        ca = 0.
        fmi = 0.
        sci = SC(adata1.X, adata1.obs['louvain'].values.reshape(-1, 1))
        lsi = 0.

    # ari_res = ARI(adata1.obs['louvain_res'], adata1.obs[label_key])
    # nmi_res = NMI(adata1.obs['louvain_res'], adata1.obs[label_key])
    # ca_res = cluster_acc(adata1.obs[label_key].to_numpy(), adata1.obs['louvain_res'].values.to_numpy())
    # fmi_res = FMI(adata1.obs['louvain_res'], adata1.obs[label_key])
    # sci_res = SC(adata1.X, adata1.obs['louvain_res'].values.reshape(-1, 1))
    sci_res = None

    if epoch is not None:
        print('ARI: ' + str(ari) + ', NMI: ' + str(nmi) + ', CA: ' + str(ca) + ', FMI: ', str(fmi))
        print('SCI: ' + str(sci) + ', LSI: ' + str(lsi) + ', css: ', str(css) + ', css1: ', str(css1) + ', All_loss', All_Loss)
        # print('ARI: ' + str(ari_res) + ', NMI: ' + str(nmi_res) + ', CA: ' + str(ca_res) + ', FMI', str(fmi_res) + ', SCI', str(sci_res))

        writer.add_scalar('ARI', ari, global_step=epoch)
        writer.add_scalar('NMI', nmi, global_step=epoch)
        writer.add_scalar('CA', ca, global_step=epoch)
        writer.add_scalar('FMI', fmi, global_step=epoch)
        writer.add_scalar('SC', sci, global_step=epoch)
        writer.add_scalar('PCC0', R, global_step=epoch)
        writer.add_scalar('PCC1', R1, global_step=epoch)
        writer.add_scalar('lsi', lsi, global_step=epoch)
        writer.add_scalar('css', css, global_step=epoch)
        writer.add_scalar('css1', css1, global_step=epoch)
        writer.add_scalar('All_Loss', All_Loss, global_step=epoch)

    return R, R1, sci, sci_res, embedding


def split_dataset(length, tr, va):
    seq = np.random.permutation(np.arange(length))
    trs = seq[:int(length * tr)]
    vas = seq[int(length * tr) : int(length * (tr + va))]
    tes = seq[int(length * (tr + va)):]

    return trs, vas, tes


def plot_ump(adata,label_key,clust_way="louvain"):
    sc.pl.umap(adata, color=label_key)
    sc.pl.umap(adata, color='louvain')

def clusters_val(adata,label_key,clust_way="louvain"):
    print("="*100)
    print("="*100)
    print("="*10,clust_way,"="*10)
    print("ARI:",ARI(adata.obs[clust_way], adata.obs[label_key]))
    print("NMI",NMI(adata.obs[clust_way], adata.obs[label_key]))
    print("CA:",cluster_acc(adata.obs[label_key].to_numpy(), adata.obs[clust_way].values.to_numpy()))
    print("FNI:",FMI(adata.obs[clust_way], adata.obs[label_key]))
    print()
 
 
    if issparse(adata.X):
        dense_X = adata.X.toarray()
    else:
        dense_X = adata.X
    print("Label_Score:",label_scores(dense_X,adata.obs[label_key]))
    print("SC",SC(adata.X, adata.obs['louvain']))
    print()


### Step2:超参数设置

In [None]:
seq_len = 1344
batch_size = 4
num_workers = 1
z_dim = 256 
lr = 0.01
max_epoch = 20
batch = None # ['P1', 'P2']
seed = 3407
# seed = 0

In [None]:
set_seed(seed, force_deterministic=True)
outdir = make_directory('./output')
device = torch.device("cuda:7") if torch.cuda.is_available() else torch.device("cpu")

### Step3:设置数据

In [None]:
def run(data_name,adtT,batch_size,num_workers,lr=1e-3,max_epoch=100,label_key=None):
    train_loader = DataLoader(
        adtT,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=4
    )

    model = scbasset.scBasset(n_cells=adtT.X.shape[1], hidden_size=z_dim, seq_len=seq_len, batch_ids=adtT.batche_ids).to(device)

    load = False
    if not load:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

        criterion = nn.MSELoss()
        scaler = GradScaler()

        best_sci = 0
        best_embedding = None

        max_epoch = max_epoch
        for epoch in range(max_epoch):
            pool = [np.nan for _ in range(10)]
            pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{max_epoch}")
            epoch_loss = 0;num_batches=0
            model.train()
            for it, (seq, adt) in enumerate(pbar):
                seq, adt = seq.to(device), adt.to(device)

                optimizer.zero_grad()
                with autocast():
                    output = model(seq)[0]
                    # print(output[0], adt[0])
                    loss = criterion(output, adt)
                    # print(loss)

                epoch_loss += loss.item()
                num_batches += 1
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                pool[it % 10] = loss.item()

                lr = optimizer.param_groups[-1]["lr"]
                pbar.set_postfix_str(f"loss/lr={np.nanmean(pool):.4f}/{lr:.3e}")
            
            avg_loss = epoch_loss / num_batches
            writer.add_scalar('Loss/train', avg_loss, epoch)
            
            if epoch % 1 == 0: 
                pcc0, pcc1, sci, sci_res, embedding = test_model(model, train_loader, device, epoch)
                
                if sci > best_sci:
                    best_sci = sci
                    best_epoch = epoch
                    best_embedding = embedding


    embedding = best_embedding
    adata1 = sc.AnnData(
        embedding,
        obs=adtT.obs,
    )
    sc.pp.neighbors(adata1, use_rep='X')
    sc.tl.louvain(adata1)
    # sc.tl.louvain(adata1, random_state=seed)6
    # adata1.obs['louvain_res'] = find_res_label(adata1, len(np.unique(adata1.obs[label_key])))
    adata1.write(f"/home/chenjn/rna2adt/A_run_test/new_emb_sc/{data_name}.h5ad")

    clusters_val(adata1,label_key,clust_way="louvain")
    print(f"Best epoch{best_epoch}")
    print("="*60)

    sc.tl.umap(adata1)
    plot_ump(adata1,label_key,clust_way="louvain")


### Step4:Run(reap2 训练数据 early_stop机制)

In [None]:
data_names = ["pbmc"] 
label_keys = ['celltype.l2']

for i in range(len(data_names)):
    data = data_names[i]
    label_key = label_keys[i]
    adt_data = '/home/chenjn/rna2adt/data/' + data + '/ADT.h5ad'
    ref_data = '/home/chenjn/rna2adt/data/pbmc/CCND.csv'

    tem_data = sc.read_h5ad(adt_data)

    print("="*30,f"P{i}:{data}","="*30)
    print(f"Cell:{tem_data.shape[0]}    ADT:{tem_data.shape[1]}")
    # print(tem_data)
    print("-"*60)

    
    adtT = dataset.SingleCellDataset(
        data=dataset.load_adata(adt_data, hvg=False, log1p=False, nor=False), 
        seq_ref=dataset.load_csv(ref_data),
        seq_len=seq_len, 
        batch=batch,
    )

    run(data,adtT,batch_size,num_workers,lr=lr,max_epoch=max_epoch,label_key=label_key)

writer.close()