In [None]:
%load_ext autoreload
%autoreload 2

import argparse
import os
os.environ['CUBLAS_WORKSPACE_CONFIG']=':4096:8'
# import sys
# # sys.path.append('/home/chenjn/rna2adt')
# sys.path.append('./')

In [None]:
import numpy as np
from tqdm import tqdm
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics import average_precision_score, roc_auc_score
from sklearn.metrics import adjusted_rand_score as ARI
from sklearn.metrics import normalized_mutual_info_score as NMI
from sklearn.metrics import fowlkes_mallows_score as FMI
from sklearn.metrics import silhouette_score as SC
import dataset
import dataloaders
import scanpy as sc
import scbasset_ori as scbasset
import sklearn
from utils import get_R
from biock import make_directory, make_logger, get_run_info
from biock.pytorch import model_summary, set_seed
from biock import HG19_FASTA_H5, HG38_FASTA_H5

from torch.utils.tensorboard import SummaryWriter

from scipy.optimize import linear_sum_assignment
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import NearestNeighbors

from utils import find_res_label

In [None]:
def cluster_acc(y_true, y_pred):
    """
    Calculate clustering accuracy. Require scikit-learn installed

    # Arguments
        y: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`

    # Return
        accuracy, in [0,1]
    """
    assert y_pred.size == y_true.size
    
    encoder = LabelEncoder()
    encoder = encoder.fit(np.unique(y_true))
    y_true = encoder.transform(y_true).astype(np.int64)
    y_pred = y_pred.astype(np.int64)

    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)

    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1

    # ind = linear_assignment(w.max() - w)
    ind = linear_sum_assignment(w.max() - w)
    ind = np.array((ind[0], ind[1])).T

    return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size

In [None]:
def label_scores(embeddings, labels):
    nn_ = NearestNeighbors(n_neighbors=20)
    nn_.fit(embeddings)
    knns = nn_.kneighbors(embeddings, return_distance=False)

    res = 0
    for i in range(len(embeddings)):
        num = 0
        for j in range(len(knns[i])):
            if labels[i] == labels[knns[i][j]]:
                num += 1
        res += num / len(knns[i])

    return res / len(embeddings)

In [None]:
def test_model(model, loader, device, epoch):
    model.eval()
    all_label = list()
    all_pred = list()

    for it, (seq, adt) in enumerate(tqdm(loader)):
        seq = seq.to(device)
        output = model(seq)[0].detach()
        output = torch.sigmoid(output).cpu().numpy().astype(np.float16)

        adt = adt.numpy().astype(np.float16)

        all_pred.append(output)
        all_label.append(adt)

    all_pred = np.concatenate(all_pred, axis=0)
    all_label = np.concatenate(all_label, axis=0)

    R = get_R(all_pred, all_label, dim=0)[0]
    R1 = get_R(all_pred, all_label, dim=1)[0]

    R = np.nanmean(R)
    R1 = np.nanmean(R1)

    embedding = model.get_embedding().detach().cpu().numpy().astype(np.float32)

    adata1 = sc.AnnData(
        embedding,
        obs=adtT.obs,
    )
    sc.pp.neighbors(adata1, use_rep='X')
    sc.tl.umap(adata1)

    sc.tl.louvain(adata1)
    adata1.obs['louvain_res'] = find_res_label(adata1, len(np.unique(adata1.obs[label_key])))

    ari = ARI(adata1.obs['louvain'], adata1.obs[label_key])
    nmi = NMI(adata1.obs['louvain'], adata1.obs[label_key])
    ca = cluster_acc(adata1.obs[label_key].to_numpy(), adata1.obs['louvain'].values.to_numpy())
    fmi = FMI(adata1.obs['louvain'], adata1.obs[label_key])
    sci = SC(adata1.X, adata1.obs['louvain'].values.reshape(-1, 1))
    lsi = label_scores(embedding, adata1.obs[label_key])

    ari_res = ARI(adata1.obs['louvain_res'], adata1.obs[label_key])
    nmi_res = NMI(adata1.obs['louvain_res'], adata1.obs[label_key])
    ca_res = cluster_acc(adata1.obs[label_key].to_numpy(), adata1.obs['louvain_res'].values.to_numpy())
    fmi_res = FMI(adata1.obs['louvain_res'], adata1.obs[label_key])
    sci_res = SC(adata1.X, adata1.obs['louvain_res'].values.reshape(-1, 1))

    if epoch is not None:
        print('ARI: ' + str(ari) + ', NMI: ' + str(nmi) + ', CA: ' + str(ca) + ', FMI', str(fmi) + ', SCI', str(sci) + ', LSI', str(lsi))
        print('ARI: ' + str(ari_res) + ', NMI: ' + str(nmi_res) + ', CA: ' + str(ca_res) + ', FMI', str(fmi_res) + ', SCI', str(sci_res))

        # writer.add_scalar('ARI', ari, global_step=epoch)
        # writer.add_scalar('NMI', nmi, global_step=epoch)
        # writer.add_scalar('CA', ca, global_step=epoch)
        # writer.add_scalar('FMI', fmi, global_step=epoch)
        # writer.add_scalar('SC', sci, global_step=epoch)
        # writer.add_scalar('PCC0', R, global_step=epoch)
        # writer.add_scalar('PCC1', R1, global_step=epoch)

    return R, R1, sci, sci_res, embedding

In [None]:
def split_dataset(length, tr, va):
    seq = np.random.permutation(np.arange(length))
    trs = seq[:int(length * tr)]
    vas = seq[int(length * tr) : int(length * (tr + va))]
    tes = seq[int(length * (tr + va)):]

    return trs, vas, tes

## Haperparameters

In [None]:
seq_len = 1344
batch_size = 4
num_workers = 2
z_dim = 256 
lr = 0.01
max_epoch = 500
batch=None
seed = 3407

In [None]:
set_seed(seed, force_deterministic=True)

In [None]:
outdir = make_directory('./output')
# logger = make_logger(title="", filename=os.path.join(outdir, "train.log"))

## Load data

In [None]:
adt_data = '/home/chenjn/rna2adt/data/karen2018a_5/ADT.h5ad'
ref_data = '/home/chenjn/rna2adt/data/pbmc/CCND.csv'
label_key = 'immuneGroup_name'

In [None]:
adtT = dataset.SingleCellDataset(
    data=dataset.load_adata(adt_data, log1p=False, nor=False), 
    seq_ref=dataset.load_csv(ref_data),
    seq_len=seq_len, 
    batch=batch,
)

## Visualization of raw data

In [None]:
adata = sc.AnnData(
    adtT.X.T,
    obs=adtT.obs,
)
sc.pp.neighbors(adata, use_rep="X")
sc.tl.umap(adata)

In [None]:
sc.tl.louvain(adata, random_state=seed)

In [None]:
sc.pl.umap(adata, color='louvain')
sc.pl.umap(adata, color=label_key)

In [None]:
print(ARI(adata.obs['louvain'], adata.obs[label_key]))
print(NMI(adata.obs['louvain'], adata.obs[label_key]))
print(cluster_acc(adata.obs[label_key].to_numpy(), adata.obs['louvain'].values.to_numpy()))
print(FMI(adata.obs['louvain'], adata.obs[label_key]))
# print(SC(adata.X, adata.obs['louvain'].values.reshape(-1, 1)))

## Construct train and test set

In [None]:
train_loader = DataLoader(
    adtT,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=True,
    prefetch_factor=4
)

In [None]:
sampled = np.random.permutation(np.arange(len(adtT)))[:10]
valid_loader = DataLoader(
    Subset(adtT, sampled),
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
)

## Train the model

In [None]:
device = torch.device("cuda:6") if torch.cuda.is_available() else torch.device("cpu")
model = scbasset.scBasset(n_cells=adtT.X.shape[1], hidden_size=z_dim, seq_len=seq_len, batch_ids=adtT.batche_ids).to(device)

In [None]:
load = 0

# 100
if not load:
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    criterion = nn.MSELoss()
    scaler = GradScaler()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer=optimizer,
        mode="max",
        factor=0.95,
        patience=2,
        min_lr=1e-7
    )

    best_score = 0
    wait = 0
    patience = 15

    best_sci = 0
    best_embedding = None

    max_epoch = max_epoch
    for epoch in range(max_epoch):
        pool = [np.nan for _ in range(10)]
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{max_epoch}")
        model.train()
        for it, (seq, adt) in enumerate(pbar):
            seq, adt = seq.to(device), adt.to(device)
            # print(it,seq.shape,adt.shape);print()

            optimizer.zero_grad()
            with autocast():
                output = model(seq)[0]
                # print(output.shape, adt.shape)
                loss = criterion(output, adt)
                # print(loss)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            pool[it % 10] = loss.item()

            lr = optimizer.param_groups[-1]["lr"]
            pbar.set_postfix_str(f"loss/lr={np.nanmean(pool):.4f}/{lr:.3e}")
            # break
        
        
        if epoch % 10 == 0: 
            pcc0, pcc1, sci, sci_res, embedding = test_model(model, valid_loader, device, epoch)
            # pcc0, pcc1, sci, sci_res, embedding = test_model(model, train_loader, device, epoch)
            
            if sci > best_sci:
                best_sci = sci
                best_embedding = embedding

        #     logger.info("Validation{} PCC0={:.4f} PCC1={:.4f} SC={:.4f}".format((epoch + 1), pcc0, pcc1, sci))

            # val_score = sci

            # # scheduler.step(val_score)

            # if val_score > best_score:
            #     best_score = val_score
            #     wait = 0
            #     torch.save(model.state_dict(), "{}/best_scb_ori_{}_{}_{}_{}_{}_{}.pt".format(outdir, str(batch), str(seq_len), str(z_dim), str(lr * 1000), str(device), str(seed)))
            #     logger.info(f"Epoch {epoch+1}: best model saved\n")
            # else:
            #     wait += 1
            #     if wait <= patience / 2:
            #         embedding = model.get_embedding().detach().cpu().numpy().astype(np.float32)
            #         sc.AnnData(embedding, obs=adtT.obs).write_h5ad("{}/best_scb_ori_emb_{}_{}_{}_{}_{}_{}.h5ad".format(outdir, str(batch), str(seq_len), str(z_dim), str(lr * 1000), str(device), str(seed)))

            #         logger.info(f"Epoch {epoch+1}: early stopping patience {wait}/{patience}, embedding saved\n")
            #     else:
            #         logger.info(f"Epoch {epoch+1}: early stopping patience {wait}/{patience}\n")
            #     if wait >= patience:
            #         logger.info(f"Epoch {epoch+1}: early stopping")
            #         break

## Visualizations

In [None]:
embedding = best_embedding

adata1 = sc.AnnData(
    embedding,
    obs=adtT.obs,
)

sc.pp.neighbors(adata1, use_rep='X')
sc.tl.umap(adata1)

In [None]:
sc.tl.louvain(adata1, random_state=seed)

adata1.obs['louvain_res'] = find_res_label(adata1, len(np.unique(adata1.obs[label_key])))

In [None]:
sc.pl.umap(adata1, color='louvain')
sc.pl.umap(adata1, color=label_key)

In [None]:
print(ARI(adata1.obs['louvain'], adata1.obs[label_key]))
print(NMI(adata1.obs['louvain'], adata1.obs[label_key]))
print(cluster_acc(adata1.obs[label_key].to_numpy(), adata1.obs['louvain'].values.to_numpy()))
print(FMI(adata1.obs['louvain'], adata1.obs[label_key]))
# print(SC(adata1.X, adata1.obs['louvain'].values.reshape(-1, 1)))