In [None]:
#scgpt similarity search prototype learning (zero-shot)

In [1]:
# install scgpt
!pip install scgpt

# install wandb
!pip install wandb louvain faiss-cpu

# install scanpy (single cell analysis)
!pip install --upgrade scanpy

# install Google Drive Public File Downloader
!pip install -q -U gdown

Collecting scgpt
  Downloading scgpt-0.2.1-py3-none-any.whl (829 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/829.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.2/829.2 kB[0m [31m2.6 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m819.2/829.2 kB[0m [31m12.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m829.2/829.2 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting cell-gears<0.0.3 (from scgpt)
  Downloading cell-gears-0.0.2.tar.gz (25 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting datasets<3.0.0,>=2.3.0 (from scgpt)
  Downloading datasets-2.19.1-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m35.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting leidenalg>=0.8.10 (from scgpt)
  D

In [2]:
import os
import sys
import gdown
import anndata

from pathlib import Path
import numpy as np
from scipy.stats import mode
import scanpy as sc
import warnings
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd
import sys
import matplotlib.pyplot as plt

sys.path.insert(0, "../")

import scgpt as scg

# extra dependency for similarity search
try:
    import faiss

    faiss_imported = True
except ImportError:
    faiss_imported = False
    print(
        "faiss not installed! We highly recommend installing it for fast similarity search."
    )
    print("To install it, see https://github.com/facebookresearch/faiss/wiki/Installing-Faiss")

warnings.filterwarnings("ignore", category=ResourceWarning)




Load model

In [3]:
# current path
cwd = os.getcwd()

# load model
model_dir = os.path.join(cwd, "scGPT")
if not os.path.exists(model_dir):
    !mkdir -p $model_dir

    # only blood
    gdown.download_folder("https://drive.google.com/drive/folders/1kkug5C7NjvXIwQGGaGoqXTk_Lb_pDrBU?usp=sharing",
                          output=model_dir,)

  and should_run_async(code)
Retrieving folder contents
Retrieving folder contents completed
Building directory structure
Building directory structure completed


Processing file 1y4UJVflGl-b2qm-fvpxIoQ3XcC2umjj0 args.json
Processing file 1MJaavaG0ZZkC_yPO4giGRnuCe3F1zt30 best_model.pt
Processing file 127FdcUyY1EM7rQfAS0YI4ms6LwjmnT9J vocab.json


Downloading...
From: https://drive.google.com/uc?id=1y4UJVflGl-b2qm-fvpxIoQ3XcC2umjj0
To: /content/scGPT/args.json
100%|██████████| 902/902 [00:00<00:00, 2.13MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1MJaavaG0ZZkC_yPO4giGRnuCe3F1zt30
From (redirected): https://drive.google.com/uc?id=1MJaavaG0ZZkC_yPO4giGRnuCe3F1zt30&confirm=t&uuid=57db2063-ea7f-407c-9046-fda09f9a5067
To: /content/scGPT/best_model.pt
100%|██████████| 156M/156M [00:03<00:00, 40.3MB/s]
Downloading...
From: https://drive.google.com/uc?id=127FdcUyY1EM7rQfAS0YI4ms6LwjmnT9J
To: /content/scGPT/vocab.json
100%|██████████| 761k/761k [00:00<00:00, 165MB/s]
Download completed


Prepare parameters

In [14]:
# connect to google drive
from google.colab import drive
drive.mount('/content/drive')

# choose dataset
#dataset_name = 'yeg'
dataset_name = 'scott'

# global parameter
cell_type_key = "Type"
gene_col = "gene_name"

# chose the method
#method = 'origdata' # this method examines the SSLP with the whole data
#method = 'transfer' # this method examines the SSLP by changing training data to only COVID-19 and only healthy samples (transferability)
method = 'sampling' # this method examines the SSLP by chosing different ratios of training data (data efficiency)

if method == 'origdata':
  print('U R good to go')

if method == 'transfer':
  print('please set train_test')
  train_test = 'CC'
  #train_test = 'CH'
  #train_test = 'HC'
  #train_test = 'HH'

if method == 'sampling':
  print('please set p')
  # sampling ratio
  p = 0.01

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
please set p


Functions for each method

In [6]:
# this function calculate embeddings for original data
def origdata(adata, train_id, test_id, umap_original, umap_embedding):

  #add gene name
  adata.var["gene_name"] = adata.var.index

  # batch id train = 0
  adata.obs["batch_id"]  = 0

  # batch id test = 1
  adata.obs.loc[adata.obs['orig.ident'].isin(test_id), "batch_id"] = 1

  # make test and train
  adata_test = adata[adata.obs["batch_id"] == 1]
  adata_train = adata[adata.obs["batch_id"] == 0]


  # umap before scgpt
  if umap_original:
    sc.pp.neighbors(adata, use_rep='X')
    sc.tl.umap(adata)
    sc.pl.umap(adata,
              color=[cell_type_key],
              wspace=0.4, frameon=False, ncols=1)

  # extarct embedding from scgpt
  ref_embed_adata = scg.tasks.embed_data(
      adata_train,
      model_dir,
      gene_col=gene_col,
      batch_size=16,
  )

  test_embed_adata = scg.tasks.embed_data(
      adata_test,
      model_dir,
      gene_col=gene_col,
      batch_size=16,
  )

  # check embeddings shape
  print('ref data', ref_embed_adata.shape)
  print('test data', test_embed_adata.shape)

  #umap after scgpt
  if umap_embedding:
    sc.pp.neighbors(ref_embed_adata, use_rep="X_scGPT")
    sc.tl.umap(ref_embed_adata)
    sc.pl.umap(ref_embed_adata, color= cell_type_key, frameon=False)

  return (ref_embed_adata, test_embed_adata, adata_train, adata_test)

In [7]:
# this function calculate embeddings for transferability
def transfer(adata, train_covid_id, train_healthy_id, test_covid_id, test_healthy_id, train_test):

  # add gene name
  adata.var["gene_name"] = adata.var.index
  adata.obs["batch_id"]  = 0

  # change batch id
  adata.obs.loc[adata.obs['orig.ident'].isin(train_healthy_id), "batch_id"] = 1
  adata.obs.loc[adata.obs['orig.ident'].isin(test_covid_id), "batch_id"] = 2
  adata.obs.loc[adata.obs['orig.ident'].isin(test_healthy_id), "batch_id"] = 3

  # separate train and test data
  adata_train_covid = adata[adata.obs["batch_id"] == 0]
  adata_train_healthy = adata[adata.obs["batch_id"] == 1]

  adata_test_covid = adata[adata.obs["batch_id"] == 2]
  adata_test_healthy = adata[adata.obs["batch_id"] == 3]

  print('train covid', adata_train_covid.shape)
  print('train healthy',adata_train_healthy.shape)
  print('test covid',adata_test_covid.shape)
  print('test healthy',adata_test_healthy.shape)

  # define train data
  if train_test == 'CC':
    adata_train = adata_train_covid
    adata_test = adata_test_covid

  if train_test == 'HC':
    adata_train = adata_train_healthy
    adata_test = adata_test_covid

  if train_test == 'CH':
    adata_train = adata_train_covid
    adata_test = adata_test_healthy

  if train_test == 'HH':
    adata_train = adata_train_healthy
    adata_test = adata_test_healthy

  print('train', adata_train.shape)
  print('test', adata_test.shape)

  # extarct embedding from scgpt
  ref_embed_adata = scg.tasks.embed_data(
      adata_train,
      model_dir,
      gene_col=gene_col,
      batch_size=16,
  )

  test_embed_adata = scg.tasks.embed_data(
      adata_test,
      model_dir,
      gene_col=gene_col,
      batch_size=16,
  )

  print('ref data', ref_embed_adata.shape)
  print('test healthy data', test_embed_adata.shape)

  return (ref_embed_adata, test_embed_adata, adata_train, adata_test)

In [8]:
def sampling(adata, train_id, test_id, p):

  # add gene name
  adata.var["gene_name"] = adata.var.index

  # batch id train = 0
  adata.obs["batch_id"]  = 0

  # batch id test = 1
  adata.obs.loc[adata.obs['orig.ident'].isin(test_id), "batch_id"] = 1

  # make test and train
  adata_test = adata[adata.obs["batch_id"] == 1]
  adata_train = adata[adata.obs["batch_id"] == 0]

  # reset index
  adata_train.obs.reset_index(drop=True, inplace=True)
  adata_test.obs.reset_index(drop=True, inplace=True)

  print('train', adata_train.shape)
  print('test', adata_test.shape)

  # find all cell types and their sizes
  type_counts = adata_train.obs['Type'].value_counts()
  sample_sizes = (type_counts * p).astype(int)

  # sample from train data with ratio p from each cell type
  ind = []
  for cell_type, sample_size in sample_sizes.items():
      type_ind = adata_train.obs[adata_train.obs['Type'] == cell_type].index
      sampled_ind = np.random.choice(type_ind, size=sample_size, replace=False)
      ind.extend(sampled_ind)

  adata_train_sample = adata_train[ind]

  # extarct embedding from scgpt
  ref_embed_adata = scg.tasks.embed_data(
      adata_train_sample,
      model_dir,
      gene_col=gene_col,
      batch_size=16,
  )

  test_embed_adata = scg.tasks.embed_data(
      adata_test,
      model_dir,
      gene_col=gene_col,
      batch_size=16,
  )
  print('ref data', ref_embed_adata.shape)
  print('test data', test_embed_adata.shape)

  return (ref_embed_adata, test_embed_adata, adata_train, adata_test)


Functions for SSLP

In [9]:
# functions for SSLP

# This functions is only used when faiss is not installed
def l2_sim(a, b):
    sims = -np.linalg.norm(a - b, axis=1)
    return sims

# This functions is only used when faiss is not installed
def get_similar_vectors(vector, ref, top_k=10):
        # sims = cos_sim(vector, ref)
        sims = l2_sim(vector, ref)

        top_k_idx = np.argsort(sims)[::-1][:top_k]
        return top_k_idx, sims[top_k_idx]

# This function calculates result dictionary
def SSLP(ref_embed_adata, test_embed_adata, adata_test):

  # concatenate the two datasets
  adata_concat = test_embed_adata.concatenate(ref_embed_adata, batch_key="dataset")

  # mark the reference vs. query dataset
  adata_concat.obs["is_ref"] = ["Query"] * len(test_embed_adata) + ["Reference"] * len(ref_embed_adata)
  adata_concat.obs["is_ref"] = adata_concat.obs["is_ref"].astype("category")

  # mask the query dataset cell types
  adata_concat.obs[cell_type_key] = adata_concat.obs[cell_type_key].astype("category")
  adata_concat.obs[cell_type_key] = adata_concat.obs[cell_type_key].cat.add_categories(["To be predicted"])
  adata_concat.obs[cell_type_key][: len(test_embed_adata)] = "To be predicted"

  # save scgpt embeddings
  ref_cell_embeddings = ref_embed_adata.obsm["X_scGPT"]
  test_embed = test_embed_adata.obsm["X_scGPT"]

  print('ref embedding', ref_cell_embeddings.shape)
  print('test embedding', test_embed.shape)

  k = 10  # number of neighbors

  index = faiss.IndexFlatL2(ref_cell_embeddings.shape[1])
  index.add(ref_cell_embeddings)

  # Query dataset, k - number of closest elements (returns 2 numpy arrays)
  distances, labels = index.search(test_embed, k)

  idx_list=[i for i in range(test_embed.shape[0])]
  preds = []
  sim_list = distances

  for k in idx_list:
      if faiss_imported:
          idx = labels[k]
      else:
          idx, sim = get_similar_vectors(test_embed[k][np.newaxis, ...], ref_cell_embeddings, k)
      pred = ref_embed_adata.obs[cell_type_key][idx].value_counts()
      preds.append(pred.index[0])

  gt = adata_test.obs[cell_type_key].to_numpy()

  res_dict = {
      "accuracy": accuracy_score(gt, preds),
      "precision": precision_score(gt, preds, average="macro"),
      "recall": recall_score(gt, preds, average="macro"),
      "macro_f1": f1_score(gt, preds, average="macro"),
  }

  return res_dict

Load data

In [15]:
if dataset_name == 'scott':
  # load scott dataset
  data_dir = '/content/drive/My Drive/scgpt/data/Scott/'
  meta = pd.read_csv(data_dir + 'GSE155673_metadata_metedata_cell_coordinate_tsne.csv')
  adata = sc.read_text(data_dir + 'GSE155673_rawdata_raw_counts.txt.gz').T
  adata.obs = meta



In [None]:
if dataset_name == 'yeg':
  # load yeg dataset
  data_dir = '/content/drive/My Drive/scgpt/data/Yeg/'
  meta = pd.read_csv(data_dir + 'GSE166992_metadata_metedata_cell_coordinate_tsne.csv')
  adata = sc.read_text(data_dir + 'GSE166992_rawdata_raw_counts.txt.gz').T
  adata.obs = meta

Apply SSLP based on method and dataset

In [18]:
if dataset_name == 'scott':

  if method == 'origdata':

    # umap visualization (if 0, no umap visualization)
    umap_original = 1
    umap_embedding = 1

    # split train and test
    train_id = ['GSM4712885', 'GSM4712887', 'GSM4712889', 'GSM4712891',
              'GSM4712893', 'GSM4712895', 'GSM4712897', 'GSM4712907']
    test_id = ['GSM4712899', 'GSM4712901', 'GSM4712903', 'GSM4712905']

    # call origdata function for embeddings and then calculate SSLP results
    ref_embed_adata, test_embed_adata, adata_train, adata_test = origdata(adata, train_id, test_id, umap_original, umap_embedding)
    result = SSLP(ref_embed_adata, test_embed_adata, adata_test)
    print(result)

  if method == 'transfer':

    # split train and test
    # make separate ids
    train_covid_id = ['GSM4712885', 'GSM4712887', 'GSM4712889', 'GSM4712891'] # batch id = 0
    train_healthy_id = ['GSM4712893', 'GSM4712895', 'GSM4712897', 'GSM4712907'] # batch id = 1
    test_covid_id = ['GSM4712899', 'GSM4712901', 'GSM4712903'] # batch id = 2
    test_healthy_id = ['GSM4712905'] # batch id = 3

    # call transfer function for embeddings and then calculate SSLP results
    ref_embed_adata, test_embed_adata, adata_train, adata_test = transfer(adata, train_covid_id, train_healthy_id, test_covid_id, test_healthy_id, train_test)
    result = SSLP(ref_embed_adata, test_embed_adata, adata_test)
    print(result)

  if method == 'sampling':

    # split train and test
    train_id = ['GSM4712885', 'GSM4712887', 'GSM4712889', 'GSM4712891',
              'GSM4712893', 'GSM4712895', 'GSM4712897', 'GSM4712907']
    test_id = ['GSM4712899', 'GSM4712901', 'GSM4712903', 'GSM4712905']
    #test_id = ['GSM4712899', 'GSM4712901', 'GSM4712903'] # covid

    # call sampling function for embeddings and then calculate SSLP results
    ref_embed_adata, test_embed_adata, adata_train, adata_test = sampling(adata, train_id, test_id, p)
    result = SSLP(ref_embed_adata, test_embed_adata, adata_test)
    print(result)

train (25439, 20370)
test (19312, 20370)
scGPT - INFO - match 16373/20370 genes in vocabulary of size 36574.


  adata.var["id_in_vocab"] = [
  self.pid = os.fork()
Embedding cells: 100%|██████████| 16/16 [00:05<00:00,  3.16it/s]
  adata.obsm["X_scGPT"] = cell_embeddings
  adata.var["id_in_vocab"] = [


scGPT - INFO - match 16373/20370 genes in vocabulary of size 36574.


  self.pid = os.fork()
Embedding cells: 100%|██████████| 1207/1207 [04:48<00:00,  4.19it/s]
  adata.obsm["X_scGPT"] = cell_embeddings


ref data (251, 16373)
test data (19312, 16373)


  adata_concat = test_embed_adata.concatenate(ref_embed_adata, batch_key="dataset")


ref embedding (251, 512)
test embedding (19312, 512)


  _warn_prf(average, modifier, msg_start, len(result))


{'accuracy': 0.9063794531897266, 'precision': 0.6535804331051592, 'recall': 0.6742825334763339, 'macro_f1': 0.6612239403890874}


In [None]:
if dataset_name == 'yeg':

  if method == 'origdata':

    # umap visualization (if 0, no umap visualization)
    umap_original = 1
    umap_embedding = 1

    # split train and test
    train_id = ['GSM5090446', 'GSM5090448',
                'GSM5090447', 'GSM5090449']
    test_id = ['GSM5090454', 'GSM5090453', 'GSM5090451']

    # call origdata function for embeddings and then calculate SSLP results
    ref_embed_adata, test_embed_adata, adata_train, adata_test = origdata(adata, train_id, test_id, umap_original, umap_embedding)
    result = SSLP(ref_embed_adata, test_embed_adata, adata_test)
    print(result)

  if method == 'transfer':

    # split train and test
    # make separate ids
    train_covid_id = ['GSM5090447', 'GSM5090449'] # batch id = 0
    train_healthy_id = ['GSM5090446', 'GSM5090448'] # batch id = 1
    test_covid_id = ['GSM5090453', 'GSM5090451'] # batch id = 2
    test_healthy_id = ['GSM5090454'] # batch id = 3

    # call transfer function for embeddings and then calculate SSLP results
    ref_embed_adata, test_embed_adata, adata_train, adata_test = transfer(adata, train_covid_id, train_healthy_id, test_covid_id, test_healthy_id, train_test)
    result = SSLP(ref_embed_adata, test_embed_adata, adata_test)
    print(result)

  if method == 'sampling':

    # split train and test
    train_id = ['GSM5090446', 'GSM5090448',
              'GSM5090447', 'GSM5090449']
    test_id = ['GSM5090454', 'GSM5090453', 'GSM5090451']
    #test_id = ['GSM5090453', 'GSM5090451'] # only covid

    # call sampling function for embeddings and then calculate SSLP results
    ref_embed_adata, test_embed_adata, adata_train, adata_test = sampling(adata, train_id, test_id, p)
    result = SSLP(ref_embed_adata, test_embed_adata, adata_test)
    print(result)