In [None]:
%load_ext autoreload
%autoreload 2

import argparse
import os
os.environ['CUBLAS_WORKSPACE_CONFIG']=':4096:8'
import sys
sys.path.append('/home/chenjn/rna2adt')

%config InlineBackend.figure_format = 'retina'
import warnings
warnings.filterwarnings('ignore')


In [None]:
import numpy as np
from tqdm import tqdm
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics import average_precision_score, roc_auc_score
from sklearn.metrics import adjusted_rand_score as ARI
from sklearn.metrics import normalized_mutual_info_score as NMI
from sklearn.metrics import fowlkes_mallows_score as FMI
from sklearn.metrics import silhouette_score as SC
import dataset
import dataloaders
import scanpy as sc
import scbasset_ori as scbasset
import sklearn
from utils import get_R
from biock import make_directory, make_logger, get_run_info
from biock.pytorch import model_summary, set_seed
from biock import HG19_FASTA_H5, HG38_FASTA_H5

from torch.utils.tensorboard import SummaryWriter

from scipy.optimize import linear_sum_assignment
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import NearestNeighbors

from utils import find_res_label

In [None]:
def cluster_acc(y_true, y_pred):
    """
    Calculate clustering accuracy. Require scikit-learn installed

    # Arguments
        y: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`

    # Return
        accuracy, in [0,1]
    """
    assert y_pred.size == y_true.size
    
    encoder = LabelEncoder()
    encoder = encoder.fit(np.unique(y_true))
    y_true = encoder.transform(y_true).astype(np.int64)
    y_pred = y_pred.astype(np.int64)

    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)

    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1

    # ind = linear_assignment(w.max() - w)
    ind = linear_sum_assignment(w.max() - w)
    ind = np.array((ind[0], ind[1])).T

    return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size

In [None]:
def label_scores(embeddings, labels):
    nn_ = NearestNeighbors(n_neighbors=20)
    nn_.fit(embeddings)
    knns = nn_.kneighbors(embeddings, return_distance=False)

    res = 0
    for i in range(len(embeddings)):
        num = 0
        for j in range(len(knns[i])):
            if labels[i] == labels[knns[i][j]]:
                num += 1
        res += num / len(knns[i])

    return res / len(embeddings)

In [None]:
def test_model(model, loader, device, epoch):
    model.eval()
    all_label = list()
    all_pred = list()

    for it, (seq, adt) in enumerate(tqdm(loader)):
        seq = seq.to(device)
        output = model(seq)[0].detach()
        output = torch.sigmoid(output).cpu().numpy().astype(np.float16)

        adt = adt.numpy().astype(np.float16)

        all_pred.append(output)
        all_label.append(adt)

    all_pred = np.concatenate(all_pred, axis=0)
    all_label = np.concatenate(all_label, axis=0)

    R = get_R(all_pred, all_label, dim=0)[0]
    R1 = get_R(all_pred, all_label, dim=1)[0]

    R = np.nanmean(R)
    R1 = np.nanmean(R1)

    embedding = model.get_embedding().detach().cpu().numpy().astype(np.float32)

    adata1 = sc.AnnData(
        embedding,
        obs=adtT.obs,
    )
    sc.pp.neighbors(adata1, use_rep='X')
    sc.tl.umap(adata1)

    sc.tl.louvain(adata1)

    
    sci = SC(adata1.X, adata1.obs['louvain'].values.reshape(-1, 1))

    if epoch is not None:
        print('SCI', str(sci))
        print("="*100)

        writer.add_scalar('SC', sci, global_step=epoch)
        writer.add_scalar('PCC0', R, global_step=epoch)
        writer.add_scalar('PCC1', R1, global_step=epoch)
    sci_res = None

    return R, R1, sci, sci_res, embedding

In [None]:
def split_dataset(length, tr, va):
    seq = np.random.permutation(np.arange(length))
    trs = seq[:int(length * tr)]
    vas = seq[int(length * tr) : int(length * (tr + va))]
    tes = seq[int(length * (tr + va)):]

    return trs, vas, tes

In [None]:
seq_len = 1344
batch_size = 4
num_workers = 1
z_dim = 256 
lr = 0.001
max_epoch = 10000
batch=None
seed = 3407
# seed = 0

In [None]:
set_seed(seed, force_deterministic=True)

In [None]:
outdir = make_directory('./output')
# logger = make_logger(title="", filename=os.path.join(outdir, "train.log"))

In [None]:
data = 'ECCITE'

In [None]:
adt_data = '/data/user/liwb/project/rna2adt/A_run_test/downstream/clinical_proteomic_application/data/ECCITE_seq_processed.h5ad'
ref_data = '/data/user/liwb/project/rna2adt/A_run_test/data/pbmc/CCND.csv'
label_key = 'CellType'

In [None]:
temp = sc.read_h5ad('/data/user/liwb/project/rna2adt/A_run_test/downstream/clinical_proteomic_application/data/ECCITE_seq_processed.h5ad')

In [None]:
adtT = dataset.SingleCellDataset(
    data=dataset.load_adata(adt_data, log1p=True, nor=True), 
    seq_ref=dataset.load_csv(ref_data),
    seq_len=seq_len, 
    batch=batch,
)

In [None]:
adata = sc.AnnData(
    adtT.X.T,
    obs=adtT.obs,
)
sc.pp.neighbors(adata, use_rep="X")
sc.tl.umap(adata)

In [None]:
train_loader = DataLoader(
    adtT,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=True,
    prefetch_factor=4
)

In [None]:
sampled = np.random.permutation(np.arange(len(adtT)))[:10]
valid_loader = DataLoader(
    Subset(adtT, sampled),
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
)

In [None]:
device = torch.device("cuda:5") if torch.cuda.is_available() else torch.device("cpu")
model = scbasset.scBasset(n_cells=adtT.X.shape[1], hidden_size=z_dim, seq_len=seq_len, batch_ids=adtT.batche_ids).to(device)
# logger.info("model parameters: {} {} {} {} ".format(str(seq_len), str(z_dim), str(lr * 1000), str(device)))
writer = SummaryWriter('/data/user/liwb/project/rna2adt/A_run_test/downstream/clinical_proteomic_application/log/'+ str(batch) + '_' + str(seq_len) + '_' + str(z_dim))

In [None]:
load = 0
best_epoch = -1

if not load:
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    criterion = nn.MSELoss()
    scaler = GradScaler()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer=optimizer,
        mode="max",
        factor=0.95,
        patience=2,
        min_lr=1e-7
    )

    best_score = 0
    wait = 0
    patience = 15

    best_sci = 0
    best_embedding = None

    max_epoch = max_epoch
    for epoch in range(max_epoch):
        pool = [np.nan for _ in range(10)]
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{max_epoch}")
        model.train()
        for it, (seq, adt) in enumerate(pbar):
            seq, adt = seq.to(device), adt.to(device)

            optimizer.zero_grad()
            with autocast():
                output = model(seq)[0]
                # print(output[0], adt[0])
                loss = criterion(output, adt)
                # print(loss)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            pool[it % 10] = loss.item()

            lr = optimizer.param_groups[-1]["lr"]
            pbar.set_postfix_str(f"loss/lr={np.nanmean(pool):.4f}/{lr:.3e}")
        
        
        if (epoch+1) % 100 == 0: 
            pcc0, pcc1, sci, sci_res, embedding = test_model(model, train_loader, device, epoch)
            
            if sci > best_sci:
                best_sci = sci
                best_embedding = embedding
                best_epoch = epoch+1


            # val_score = sci

            # # scheduler.step(val_score)

            # if val_score > best_score: 
            #     best_score = val_score
            #     wait = 0
            #     # torch.save(model.state_dict(), "{}/best_scb_ori_{}_{}_{}_{}_{}_{}.pt".format(outdir, str(batch), str(seq_len), str(z_dim), str(lr * 1000), str(device), str(seed)))

            # else:
            #     wait += 1
            #     if wait <= patience / 2:
            #         embedding = model.get_embedding().detach().cpu().numpy().astype(np.float32)
            #         # sc.AnnData(embedding, obs=adtT.obs).write_h5ad("{}/best_scb_ori_emb_{}_{}_{}_{}_{}_{}.h5ad".format(outdir, str(batch), str(seq_len), str(z_dim), str(lr * 1000), str(device), str(seed)))
            #         print(f"Epoch {epoch+1}: early stopping patience {wait}/{patience}, embedding saved\n")
            #     else:
            #         print(f"Epoch {epoch+1}: early stopping patience {wait}/{patience}\n")
            #     if wait >= patience:
            #         print(f"Epoch {epoch+1}: early stopping")
            #         break

In [None]:
print("Best_epoch:",best_epoch)
print("Best_sci:",best_sci)

embedding = best_embedding

adata1 = sc.AnnData(
    embedding,
    obs=adtT.obs,
)

adata1.write_h5ad("/data/user/liwb/project/rna2adt/A_run_test/downstream/clinical_proteomic_application/emb/ECCITE_step1.h5ad")

sc.pp.neighbors(adata1, use_rep='X')
sc.tl.umap(adata1)
sc.tl.louvain(adata1, random_state=seed)
sc.pl.umap(adata1, color='donor_type')
sc.pl.umap(adata1, color='louvain')

In [None]:
adata1