SOLO 完整版

In [None]:
# Install necessary packages (run in Colab)
!pip install -U scvi-tools
!pip install scanpy

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Import required libraries
import os
import time
import anndata
import scvi
import pandas as pd
import numpy as np
import scanpy as sc
from sklearn.metrics import roc_auc_score, average_precision_score, confusion_matrix

# Set file paths for models and data
scvi_model_path = "/content/drive/MyDrive/models/scvi_model"
solo_model_path = "/content/drive/MyDrive/models/solo_model"
data_path = "/content/drive/MyDrive/h5ad_data/cline-ch.h5ad"

# Load the data and register batch information
adata = anndata.read_h5ad(data_path)
adata.obs["batch"] = "batch1"  # Assume all cells belong to the same batch
scvi.model.SCVI.setup_anndata(adata, batch_key="batch")

#################################
# SCVI Model: Load if exists or train a new one
#################################
if os.path.exists(scvi_model_path):
    vae = scvi.model.SCVI.load(scvi_model_path, adata)
    print(f"Loaded SCVI model from: {scvi_model_path}")
else:
    vae = scvi.model.SCVI(adata)
    t0 = time.time()
    vae.train(accelerator="gpu", devices=1)
    print(f"SCVI training time: {time.time() - t0:.2f} seconds")
    vae.save(scvi_model_path)
    print(f"SCVI model saved to: {scvi_model_path}")

#################################
# SOLO Model: Load if exists or train a new one
#################################
if os.path.exists(solo_model_path):
    solo_model = scvi.external.SOLO.load(solo_model_path, adata)
    print(f"Loaded SOLO model from: {solo_model_path}")
else:
    solo_model = scvi.external.SOLO.from_scvi_model(vae)
    t1 = time.time()
    solo_model.train(accelerator="gpu", devices=1)
    print(f"SOLO training time: {time.time() - t1:.2f} seconds")
    solo_model.save(solo_model_path)
    print(f"SOLO model saved to: {solo_model_path}")

#################################
# Obtain predictions and save results
#################################
# Assume that solo_model.predict() returns a DataFrame containing a 'doublet' column
predictions = solo_model.predict()  # Model's prediction interface
predictions_df = predictions[['doublet']].copy()
predictions_df.columns = ["predicted_doublet"]

# Combine the prediction results with the observation metadata in adata.obs
adata.obs = pd.concat([adata.obs, predictions_df], axis=1)
print("SOLO predictions completed.")

# Define file paths for saving results
solo_scores_path = "/content/drive/MyDrive/models/solo_doublet_scores.csv"      # Save only predicted scores
solo_obs_path = "/content/drive/MyDrive/models/solo_obs_results.csv"              # Save full observation metadata

# Save predicted doublet scores and full obs metadata as CSV files
predictions_df.to_csv(solo_scores_path, index=True)
adata.obs.to_csv(solo_obs_path, index=True)
print("Predictions and observation metadata have been saved to:")
print(f"  - {solo_scores_path}")
print(f"  - {solo_obs_path}")

#################################
# Performance evaluation (if ground-truth labels are available)
#################################
if "doublet_label" in adata.obs.columns:
    # Convert ground-truth labels to integers
    y_true = adata.obs["doublet_label"].astype(int).to_numpy()
    # Get predicted scores from the predictions DataFrame
    y_score = predictions_df["predicted_doublet"].to_numpy()

    # Compute overall performance metrics: AUROC and AUPRC
    overall_auroc = roc_auc_score(y_true, y_score)
    overall_auprc = average_precision_score(y_true, y_score)

    performance_df = pd.DataFrame({
        "AUROC": [overall_auroc],
        "AUPRC": [overall_auprc]
    })
    performance_file = "/content/drive/MyDrive/models/solo_performance_overall.csv"
    performance_df.to_csv(performance_file, index=False)
    print(f"Overall performance metrics saved to {performance_file}")

    # Threshold analysis: calculate sensitivity and specificity across 101 threshold points from 0 to 1
    thresholds = np.linspace(0, 1, 101)
    thresh_results = []
    for t in thresholds:
        y_pred = (y_score > t).astype(int)
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        thresh_results.append({
            "threshold": t,
            "sensitivity": sensitivity,
            "specificity": specificity
        })
    thresh_df = pd.DataFrame(thresh_results)
    thresh_file = "/content/drive/MyDrive/models/solo_threshold_analysis.csv"
    thresh_df.to_csv(thresh_file, index=False)
    print(f"Threshold analysis saved to {thresh_file}")
else:
    print("Warning: 'doublet_label' column not found in adata.obs; skipping performance evaluation.")

)


Collecting scvi-tools
  Downloading scvi_tools-1.3.1.post1-py3-none-any.whl.metadata (22 kB)
Collecting anndata>=0.11 (from scvi-tools)
  Downloading anndata-0.11.4-py3-none-any.whl.metadata (9.3 kB)
Collecting docrep>=0.3.2 (from scvi-tools)
  Downloading docrep-0.3.2.tar.gz (33 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting lightning>=2.0 (from scvi-tools)
  Downloading lightning-2.5.1.post0-py3-none-any.whl.metadata (39 kB)
Collecting ml-collections>=0.1.1 (from scvi-tools)
  Downloading ml_collections-1.1.0-py3-none-any.whl.metadata (22 kB)
Collecting mudata>=0.1.2 (from scvi-tools)
  Downloading mudata-0.3.1-py3-none-any.whl.metadata (8.3 kB)
Collecting numpyro>=0.12.1 (from scvi-tools)
  Downloading numpyro-0.18.0-py3-none-any.whl.metadata (37 kB)
Collecting pyro-ppl>=1.6.0 (from scvi-tools)
  Downloading pyro_ppl-1.9.1-py3-none-any.whl.metadata (7.8 kB)
Collecting sparse>=0.14.0 (from scvi-tools)
  Downloading sparse-0.17.0-py2.py3-none-any.whl.metadata (5.3

scrublet完整版

In [None]:
!pip install scrublet anndata
import anndata
import scanpy as sc
import scrublet as scr
import time
import pandas as pd
from google.colab import files
import numpy as np

# 1. Load the data
adata = anndata.read_h5ad("/content/drive/MyDrive/h5ad_data/cline-ch.h5ad")

# Scrublet relies on the UMI count matrix (typically stored in adata.X)
counts_matrix = adata.X

# 2. Check if Scrublet results already exist
scrublet_results_path = "/content/drive/MyDrive/models/scrublet_results.csv"

try:
    scrublet_results = pd.read_csv(scrublet_results_path, index_col=0)
    adata.obs["scrublet_score"] = scrublet_results["scrublet_score"]
    adata.obs["scrublet_doublet_prediction"] = scrublet_results["scrublet_doublet_prediction"]
    print("Loaded Scrublet results from:", scrublet_results_path)
    # Record the processing time as 'preloaded' to indicate results were already calculated
    scrublet_time_record = "preloaded"
except FileNotFoundError:
    # If no existing Scrublet results are found, perform Scrublet doublet detection
    scrub = scr.Scrublet(counts_matrix)
    t0 = time.time()
    scrublet_scores, predicted_doublets = scrub.scrub_doublets()
    scrublet_time = time.time() - t0
    scrublet_time_record = f"{scrublet_time:.2f}"
    print("Scrublet processing time: {:.2f} seconds".format(scrublet_time))

    # 3. Store the results in the AnnData object
    adata.obs["scrublet_score"] = scrublet_scores
    adata.obs["scrublet_doublet_prediction"] = predicted_doublets

    # 4. Export the Scrublet results
    scrublet_results = adata.obs[["scrublet_score", "scrublet_doublet_prediction"]]
    scrublet_results.to_csv(scrublet_results_path)
    print("Scrublet results saved to:", scrublet_results_path)

# 5. Export the complete Scrublet metadata for future use in R (if needed)
scrublet_obs_path = "/content/drive/MyDrive/models/scrublet_obs_results.csv"
adata.obs.to_csv(scrublet_obs_path, index=True)
print("Scrublet obs data saved to:", scrublet_obs_path)

# 6. Save the Scrublet processing time information
runtime_log_path = "/content/drive/MyDrive/models/scrublet_runtime.txt"
with open(runtime_log_path, "w") as f:
    f.write("Scrublet processing time: {} seconds\n".format(scrublet_time_record))
    f.write("Scrublet results saved at: {}\n".format(scrublet_results_path))
print("Scrublet runtime info saved to:", runtime_log_path)


Loaded Scrublet results from: /content/drive/MyDrive/models/scrublet_results.csv
Scrublet obs data saved to: /content/drive/MyDrive/models/scrublet_obs_results.csv
Scrublet runtime info saved to: /content/drive/MyDrive/models/scrublet_runtime.txt
