In [1]:
!pip install -U --quiet torchtext==0.1.1
!pip install --quiet scvi-tools[tutorials]
!pip install --quiet scarches
!pip install --quiet scrublet
!pip install --quiet xgboost

In [2]:
import time
from sklearn.metrics import f1_score

import os
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

import gc
import scvi
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context
import numpy as np
import gdown
import anndata as ad
import pandas as pd
from scipy.io import mmread
import loompy
import seaborn as sns
from sklearn.svm import LinearSVC as SVC
import xgboost as xgb
from sklearn.ensemble import RandomForestClassifier as RFC
from sklearn.linear_model import LogisticRegression as LR
from sklearn.model_selection import train_test_split
from sklearn.calibration import CalibratedClassifierCV

In [3]:
def preproccess(filePath):
    adata = sc.read(filePath)
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(
        adata,
        n_top_genes=2000,
        subset=True)
    return adata

In [4]:
def run_scanvi(adata, labels_key='Cell_type', num_gen=-1, num_cell=-1):
  if num_cell==-1:
    num_cell = adata.shape[0]
  if num_gen==-1:
    num_gen = adata.shape[1]
  adata = adata[np.random.choice(adata.shape[0], num_cell, replace=False), np.random.choice(adata.shape[1], num_gen, replace=False)].copy()
  scvi.data.setup_anndata(adata, labels_key = labels_key)
  #sca.models.SCVI.setup_anndata(adata, labels_key='Broad cell type (numbers)')
  vae = sca.models.SCVI(
      adata,
      n_layers=2,
      encode_covariates=True,
      deeply_inject_covariates=False,
      use_layer_norm="both",
      use_batch_norm="none",
  )

  sca_start_time=time.time()
  vae.train()
  scanvae = sca.models.SCANVI(adata, "Unknown", vae)
  scanvae.train()
  reference_latent = sc.AnnData(vae.get_latent_representation())
  reference_latent.obs["cell_type"] = adata.obs[labels_key].tolist()
  reference_latent.obs['predictions'] = scanvae.predict()
  sca_time = time.time() - sca_start_time
  sca_acc = np.mean(reference_latent.obs.predictions == reference_latent.obs.cell_type)
  sca_f1=f1_score(reference_latent.obs.cell_type, reference_latent.obs.predictions, average='weighted')
  sca_label_rate = sum(reference_latent.obs.predictions.isin(reference_latent.obs.cell_type))/len(reference_latent.obs.cell_type)
  return sca_acc, sca_f1, sca_label_rate, sca_time

In [5]:
def run_svm(adata, labels_key='Cell_type', num_gen=-1, num_cell=-1):
  if num_cell==-1:
    num_cell = adata.shape[0]
  if num_gen==-1:
    num_gen = adata.shape[1]
  adata = adata.copy()
  adata = adata[np.random.choice(adata.shape[0], num_cell, replace=False), np.random.choice(adata.shape[1], num_gen, replace=False)]
  X_train, X_test, y_train, y_test = train_test_split(adata.X, adata.obs[labels_key], test_size=0.33)
  clf = SVC()
  svc_start_time = time.time()
  clf.fit(X_train, y_train)
  adata.obs["cell_type"] = adata.obs[labels_key].tolist()
  predicted = clf.predict(X_test)
  svc_time = time.time() - svc_start_time
  svc_acc = sum(y_test==predicted)/len(y_test)
  svc_f1 = f1_score(y_test, predicted, average='weighted')
  svc_label_rate = sum(np.isin(predicted, adata.obs["cell_type"]))/len(predicted)
  return svc_acc, svc_f1, svc_label_rate, svc_time

In [None]:
def run_svm_rejection(adata, labels_key='Cell_type', num_gen=-1, num_cell=-1, threshold = 0.7):
  if num_cell==-1:
    num_cell = adata.shape[0]
  if num_gen==-1:
    num_gen = adata.shape[1]
  adata = adata.copy()
  adata = adata[np.random.choice(adata.shape[0], num_cell, replace=False), np.random.choice(adata.shape[1], num_gen, replace=False)]
  X_train, X_test, y_train, y_test = train_test_split(adata.X, adata.obs[labels_key], test_size=0.33)
  Classifier = SVC()
  clf = CalibratedClassifierCV(Classifier)
  svc_start_time = time.time()
  clf.fit(X_train, y_train)
  adata.obs["cell_type"] = adata.obs[labels_key].tolist()
  predicted = clf.predict(X_test)
  prob = np.max(clf.predict_proba(X_test), axis = 1)
  unlabeled = np.where(prob < threshold)
  predicted[unlabeled] = 'Unknown'
  svc_time = time.time() - svc_start_time
  svc_acc = sum(y_test==predicted)/len(y_test)
  svc_f1 = f1_score(y_test, predicted, average='weighted')
  svc_label_rate = sum(np.isin(predicted, adata.obs["cell_type"]))/len(predicted)
  return svc_acc, svc_f1, svc_label_rate, svc_time

In [None]:
def run_XGBoost(adata, labels_key='Cell_type', num_gen=-1, num_cell=-1):
  if num_cell==-1:
    num_cell = adata.shape[0]
  if num_gen==-1:
    num_gen = adata.shape[1]
  adata = adata.copy()
  adata = adata[np.random.choice(adata.shape[0], num_cell, replace=False), np.random.choice(adata.shape[1], num_gen, replace=False)]
  X_train, X_test, y_train, y_test = train_test_split(adata.X, adata.obs[labels_key], test_size=0.33)
  clf = xgb.XGBClassifier()
  svc_start_time = time.time()
  clf.fit(X_train, y_train)
  adata.obs["cell_type"] = adata.obs[labels_key].tolist()
  predicted = clf.predict(X_test)
  svc_time = time.time() - svc_start_time
  svc_acc = sum(y_test==predicted)/len(y_test)
  svc_f1 = f1_score(y_test, predicted, average='weighted')
  svc_label_rate = sum(np.isin(predicted, adata.obs["cell_type"]))/len(predicted)
  return svc_acc, svc_f1, svc_label_rate, svc_time

In [None]:
def run_logistic_regression(adata, labels_key='Cell_type', num_gen=-1, num_cell=-1):
  if num_cell==-1:
    num_cell = adata.shape[0]
  if num_gen==-1:
    num_gen = adata.shape[1]
  adata = adata.copy()
  adata = adata[np.random.choice(adata.shape[0], num_cell, replace=False), np.random.choice(adata.shape[1], num_gen, replace=False)]
  X_train, X_test, y_train, y_test = train_test_split(adata.X, adata.obs[labels_key], test_size=0.33)
  clf = LR()
  svc_start_time = time.time()
  clf.fit(X_train, y_train)
  adata.obs["cell_type"] = adata.obs[labels_key].tolist()
  predicted = clf.predict(X_test)
  svc_time = time.time() - svc_start_time
  svc_acc = sum(y_test==predicted)/len(y_test)
  svc_f1 = f1_score(y_test, predicted, average='weighted')
  svc_label_rate = sum(np.isin(predicted, adata.obs["cell_type"]))/len(predicted)
  return svc_acc, svc_f1, svc_label_rate, svc_time

In [None]:
def run_random_forest(adata, labels_key='Cell_type', num_gen=-1, num_cell=-1):
  if num_cell==-1:
    num_cell = adata.shape[0]
  if num_gen==-1:
    num_gen = adata.shape[1]
  adata = adata.copy()
  adata = adata[np.random.choice(adata.shape[0], num_cell, replace=False), np.random.choice(adata.shape[1], num_gen, replace=False)]
  X_train, X_test, y_train, y_test = train_test_split(adata.X, adata.obs[labels_key], test_size=0.33)
  clf = RFC()
  svc_start_time = time.time()
  clf.fit(X_train, y_train)
  adata.obs["cell_type"] = adata.obs[labels_key].tolist()
  predicted = clf.predict(X_test)
  svc_time = time.time() - svc_start_time
  svc_acc = sum(y_test==predicted)/len(y_test)
  svc_f1 = f1_score(y_test, predicted, average='weighted')
  svc_label_rate = sum(np.isin(predicted, adata.obs["cell_type"]))/len(predicted)
  return svc_acc, svc_f1, svc_label_rate, svc_time

In [6]:
def draw(adata, labels_key='Cell_type', funs=[run_svm, run_XGBoost], num_gen = [400, 800, 1200, 1600, 2000], num_cell = [5000, 10000, 15000, 20000, 25000, 30000]):
    
    fig, ax = plt.subplots(4,2)
    
    ax[0,0].grid(True, alpha=0.25)
    ax[0,0].set_xlabel('Number of genes')
    ax[0,0].set_ylabel('Accuracy')
    #ax[0,0].set_title('untitled')
    
    ax[0,1].grid(True, alpha=0.25)
    ax[0,1].set_xlabel('Number of genes')
    ax[0,1].set_ylabel('F1 score')
    #ax[0,1].set_title('untitled')
    
    ax[1,0].grid(True, alpha=0.25)
    ax[1,0].set_xlabel('Number of genes')
    ax[1,0].set_ylabel('Labeled (%)')
    #ax[1,0].set_title('untitled')
        
    ax[1,1].grid(True, alpha=0.25)
    ax[1,1].set_xlabel('Number of genes')
    ax[1,1].set_ylabel('Time (s)')
    #ax[1,1].set_title('untitled')
    
    ax[2,0].grid(True, alpha=0.25)
    ax[2,0].set_xlabel('Number of cells')
    ax[2,0].set_ylabel('Accuracy')
    #ax[2,0].set_title('untitled')
    
    ax[2,1].grid(True, alpha=0.25)
    ax[2,1].set_xlabel('Number of cells')
    ax[2,1].set_ylabel('F1 score')
    #ax[2,1].set_title('untitled')
    
    ax[3,0].grid(True, alpha=0.25)
    ax[3,0].set_xlabel('Number of cells')
    ax[3,0].set_ylabel('Labeled (%)')
    #ax[3,0].set_title('untitled')

    ax[3,1].grid(True, alpha=0.25)
    ax[3,1].set_xlabel('Number of cells')
    ax[3,1].set_ylabel('Time (s)')
    #ax[3,1].set_title('untitled')
    
    count = 0
    for fun in funs:
        count+=1
        
        acc_1 = []
        f1_1 = []
        label_rate_1 = []
        t_1 = []

        for i in num_gen:
          print(f"Running: num_gen={i}, fun={fun}")
          temp = fun(adata, labels_key, num_gen=i, num_cell=num_cell[-1])
          acc_1.append(temp[0])
          f1_1.append(temp[1])
          label_rate_1.append(temp[2])
          t_1.append(temp[3])
          del temp
          gc.collect()

        acc_2 = []
        f1_2 = []
        label_rate_2 = []
        t_2 = []

        for i in num_cell:
          print(f"Running: num_cell={i}, fun={fun}")
          temp = fun(adata, labels_key, num_gen=num_gen[-1], num_cell=i)
          acc_2.append(temp[0])
          f1_2.append(temp[1])
          label_rate_2.append(temp[2])
          t_2.append(temp[3])
          del temp
          gc.collect()
        
        ax[0,0].plot(num_gen, acc_1, label=count, marker='D')
        ax[0,1].plot(num_gen, f1_1, label=count, marker='D')
        ax[1,0].plot(num_gen, label_rate_1, label=count, marker='D')
        ax[1,1].plot(num_gen, t_1, label=count, marker='D')
        ax[2,0].plot(num_cell, acc_2, label=count, marker='D')
        ax[2,1].plot(num_cell, f1_2, label=count, marker='D')
        ax[3,0].plot(num_cell, label_rate_2, label=count, marker='D')
        ax[3,1].plot(num_cell, t_2, label=count, marker='D')
       
    fig.set_size_inches(24, 24)
    #ax[0,0].legend(loc='upper right')
    #ax[0,1].legend(loc='upper right')
    #ax[1,0].legend(loc='upper right')
    ax[1,1].legend(loc='upper right', bbox_to_anchor=(1, 0.5))
    #ax[2,0].legend(loc='upper right')
    #ax[2,1].legend(loc='upper right')
    #ax[3,0].legend(loc='upper right')
    ax[3,1].legend(loc='upper right', bbox_to_anchor=(1, 0.5))

    plt.show()

In [None]:
def draw_heatmap(models, datasets, values):
    fig, ax = plt.subplots()
    im = ax.imshow(values)
    ax.grid(False)
    cbar = ax.figure.colorbar(im, ax=ax)
    cbar.ax.set_ylabel('Accuracy', rotation=-90, va="bottom")

    ax.set_xticks(np.arange(len(models)))
    ax.set_yticks(np.arange(len(datasets)))
    ax.set_xticklabels(models)
    ax.set_yticklabels(datasets)

    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    for i in range(len(datasets)):
        for j in range(len(models)):
            text = ax.text(j, i, values[i, j],
                           ha="center", va="center", color="w")

    ax.set_title("Classification performance across datasets")
    fig.set_size_inches(6, 6)
    fig.tight_layout()
    plt.show()