Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
## New functionality

* Added `metrics/kbet_pg` and `metrics/kbet_pg_label` components (PR #52).

* Added `methods/stacas` new method (PR #58).
- Add non-supervised version of STACAS tool for integration of single-cell transcriptomics data. This functionality enables correction of batch effects while preserving biological variability without requiring prior cell type annotations.
* Added `method/drvi` component (PR #61).

* Added `methods/drvi` component (PR #61).

* Added `ARI_batch` and `NMI_batch` to `metrics/clustering_overlap` (PR #68).

* Added `methods/bbknn_ts` component (PR #84).

* Added `metrics/cilisi` new metric component (PR #57).
- ciLISI measures batch mixing in a cell type-aware manner by computing iLISI within each cell type and normalizing
the scores between 0 and 1. Unlike iLISI, ciLISI preserves sensitivity to biological variance and avoids favoring
Expand Down
54 changes: 54 additions & 0 deletions src/methods/bbknn_ts/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
__merge__: ../../api/comp_method.yaml

name: bbknn_ts
label: BBKNN (TS)
summary: "A combination of ComBat and BBKNN discovered and implemented by Gemini."
description: |
"The BBKNN (TS) solution applies standard scRNA-seq preprocessing steps, including total count normalization, log-transformation, and scaling of gene expression data. Batch effect correction is performed using scanpy.pp.combat directly on the gene expression matrix (before dimensionality reduction). Dimensionality reduction is then applied using PCA on the ComBat-corrected data, and this PCA embedding (adata.obsm['X_pca']) is designated as the integrated embedding (adata.obsm['X_emb']). A custom batch-aware nearest neighbors graph is constructed based on this integrated embedding; for each cell, neighbors are independently identified within its own batch and other batches, up to n_neighbors_per_batch. These candidate neighbors are merged, keeping the minimum distance for duplicate entries, and the top total_k_neighbors are selected for each cell. Finally, a symmetric sparse distance matrix and a binary connectivities matrix are generated to represent the integrated neighborhood graph. This code was entirely written by the AI system described in the associated publication."
references:
bibtex: |
@misc{GoogleScienceAI2025,
title={An {AI} system to help scientists write expert-level empirical software},
author={Eser Ayg\"un and Anastasiya Belyaeva and Gheorghe Comanici and Marc Coram and Hao Cui and Jake Garrison and Renee Johnston and Anton Kast and Cory Y. McLean and Peter Norgaard and Zahra Shamsi and David Smalling and James Thompson and Subhashini Venugopalan and Brian P. Williams and Chujun He and Sarah Martinson and Martyna Plomecka and Lai Wei and Yuchen Zhou and Qian-Ze Zhu and Matthew Abraham and Erica Brand and Anna Bulanova and Jeffrey A. Cardille and Chris Co and Scott Ellsworth and Grace Joseph and Malcolm Kane and Ryan Krueger and Johan Kartiwa and Dan Liebling and Jan-Matthis Lueckmann and Paul Raccuglia and Xuefei (Julie) Wang and Katherine Chou and James Manyika and Yossi Matias and John C. Platt and Lizzie Dorfman and Shibl Mourad and Michael P. Brenner},
year={2025},
eprint={2509.06503},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2509.06503}
}
links:
documentation: https://google-research.github.io/score/
repository: https://github.com/google-research/score

info:
method_types: [embedding]
preferred_normalization: log_cp10k

arguments:
- name: "--n_pca_components"
type: "integer"
default: 100
description: "Number of PCA components."
- name: "--n_neighbors_per_batch"
type: "integer"
default: 10
description: "Number of neighbors to use within each batch."
- name: "--total_k_neighbors"
type: "integer"
default: 50
description: "Total number of nearest neighbors to retain for the final graph."

resources:
- type: python_script
path: script.py
- path: /src/utils/read_anndata_partial.py

engines:
- type: docker
image: openproblems/base_python:1.0.0

runners:
- type: executable
- type: nextflow
directives:
label: [midtime,midmem,midcpu]
248 changes: 248 additions & 0 deletions src/methods/bbknn_ts/script.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on a quick glance, this code will likely work as the developer intended, however it does not make use of the pre-computed processing step and recomputes everything by itself insead. This in itself won't cause the code to fail or give wrong results, however it won't follow the benchmark setup as intended

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate a little on the comment that "it won't follow the benchmark setup as intended"? Is that because we indicate this is an embedding method with preferred normalization of log_cp10k but use the "layers/counts" as input rather than "layers/normalized"?

Would it be considered to follow the benchmark setup as intended if we indicated it is a [feature] method and set adata_integrated.layers['corrected_counts'] = adata_integrated.X? IIUC that could only improve its overall score on the v2.0.0 benchmark as then the HVG metric would also be computed.

I'd rather not edit the code to do that since right now it's purely LLM implemented, just want to make sure we're not missing something more fundamental. Thanks!

Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
"""BBKNN (TS) code for removing batch effects.

This is the top-performing batch correction implementation discovered by the AI
system described in https://arxiv.org/abs/2509.06503.
"""
## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
par = {
'input': 'resources_test/.../input.h5ad',
'output': 'output.h5ad'
}
meta = {
'name': 'bbknn_ts'
}
## VIASH END

################################################################################
####### LLM-written implementation of batch correction code starts here. #######
################################################################################
from typing import Any
from sklearn.decomposition import TruncatedSVD
from sklearn.neighbors import NearestNeighbors
from scipy.sparse import lil_matrix, csr_matrix
import numpy as np
import scanpy as sc
import anndata as ad
import heapq # For efficiently getting top K elements from merged lists

# Define parameters for the config.
# These values are chosen to balance computational cost and integration performance
# for datasets with up to ~300k cells and 2k genes.
config = {
'n_pca_components': 100, # Number of PCA components. Recommended: 50-200.
# Captures sufficient variance while reducing dimensionality.
'n_neighbors_per_batch': 10, # Number of neighbors to find within each batch. Recommended: 5-15.
# This defines the local batch context for each cell.
'total_k_neighbors': 50, # Total number of nearest neighbors to retain for the final graph. Recommended: 15-100.
# This forms the global batch-integrated graph.
}


def eliminate_batch_effect_fn(
adata: ad.AnnData, config: dict[str, Any]
) -> ad.AnnData:
# Create a copy to ensure the original input adata remains unchanged.
adata_integrated = adata.copy()

# --- Preprocessing: Normalize, log-transform, scale ---
# These are standard initial steps for scRNA-seq data.
# Use adata.X which contains raw counts.
sc.pp.normalize_total(adata_integrated, target_sum=1e4)
sc.pp.log1p(adata_integrated)
sc.pp.scale(adata_integrated, max_value=10) # Clip values to avoid extreme outliers

# --- Batch Correction: ComBat on the gene expression matrix ---
# This step applies a more robust linear model-based batch correction
# directly on the gene expression data before dimensionality reduction.
# ComBat modifies adata_integrated.X in place.
sc.pp.combat(adata_integrated, key='batch')

# --- Dimensionality Reduction: PCA on the ComBat-corrected data ---
# n_comps cannot exceed min(n_obs - 1, n_vars). Robustly handle small datasets.
n_pca_components = config.get('n_pca_components', 100)
actual_n_pca_components = min(n_pca_components, adata_integrated.n_vars, adata_integrated.n_obs - 1)

# Handle edge cases for PCA and graph construction where data is too small.
# If PCA cannot be run meaningfully, return a minimal AnnData object to avoid errors.
if actual_n_pca_components <= 0 or adata_integrated.n_obs <= 1:
print(f"Warning: Too few observations ({adata_integrated.n_obs}) or dimensions ({adata_integrated.n_vars}) for PCA/graph construction. Returning trivial embedding.")
# Provide a placeholder embedding and empty graph structure.
adata_integrated.obsm['X_emb'] = np.zeros((adata_integrated.n_obs, 1))
adata_integrated.obsp['connectivities'] = csr_matrix((adata_integrated.n_obs, adata_integrated.n_obs))
adata_integrated.obsp['distances'] = csr_matrix((adata_integrated.n_obs, adata_integrated.n_obs))
adata_integrated.uns['neighbors'] = {
'params': {
'n_neighbors': 0,
'method': 'degenerate',
'n_pcs': 0,
'n_neighbors_per_batch': 0,
'pca_batch_correction': 'none',
},
'connectivities_key': 'connectivities',
'distances_key': 'distances',
}
return adata_integrated

sc.tl.pca(adata_integrated, n_comps=actual_n_pca_components, svd_solver='arpack')

# Set the ComBat-corrected PCA embedding as the integrated output embedding.
# This 'X_emb' will be directly evaluated by metrics like ASW, LISI, PCR.
adata_integrated.obsm['X_emb'] = adata_integrated.obsm['X_pca']


# --- Custom Batch-Aware Nearest Neighbors Graph Construction ---
# This implements the expert advice: find neighbors independently within batches, then merge.
# This part of the code remains largely the same, but now operates on the
# ComBat-corrected PCA embedding (adata_integrated.obsm['X_emb']).
k_batch_neighbors = config.get('n_neighbors_per_batch', 10)
total_k_neighbors = config.get('total_k_neighbors', 50)

# A list of dictionaries to store unique neighbors and their minimum distances for each cell.
# Using dictionaries allows efficient updating if a cell is found as a neighbor from multiple batches.
merged_neighbors_per_cell = [{} for _ in range(adata_integrated.n_obs)]

# Group cell indices by batch for efficient querying.
batches = adata_integrated.obs['batch'].values
unique_batches = np.unique(batches)
batch_to_indices = {b: np.where(batches == b)[0] for b in unique_batches}

# Pre-fit NearestNeighbors models for each batch's data using the corrected PCA embedding.
# This avoids refitting the model for every query.
batch_nn_models = {}
for b_id in unique_batches:
batch_cell_indices = batch_to_indices[b_id]
# Ensure there are enough cells to fit a NearestNeighbors model (at least k_batch_neighbors + 1 for self-exclusion, or just > 0 for min k=1)
if len(batch_cell_indices) > 0:
# Fit with a k that is at most the batch size to avoid errors if k_batch_neighbors is too high for a small batch.
k_fit_effective = min(k_batch_neighbors + 1, len(batch_cell_indices)) # +1 to ensure self-loop can be found and excluded
if k_fit_effective > 0: # Only fit if there are points available
nn_model = NearestNeighbors(n_neighbors=k_fit_effective, metric='euclidean', algorithm='auto')
nn_model.fit(adata_integrated.obsm['X_emb'][batch_cell_indices])
batch_nn_models[b_id] = nn_model

# Iterate through all possible query batches and target batches to find neighbors.
for query_batch_id in unique_batches:
query_global_indices = batch_to_indices[query_batch_id]
if len(query_global_indices) == 0:
continue # Skip empty query batches

query_data = adata_integrated.obsm['X_emb'][query_global_indices]

for target_batch_id in unique_batches:
if target_batch_id not in batch_nn_models:
continue # Skip target batches that were too small to fit an NN model

nn_model = batch_nn_models[target_batch_id]
target_global_indices = batch_to_indices[target_batch_id]

# Ensure n_neighbors does not exceed the number of points in the target batch.
k_for_query = min(k_batch_neighbors, len(target_global_indices) -1) # -1 to avoid finding self as neighbor if batch is query batch
if k_for_query <= 0: # No valid neighbors can be found in this target batch
continue

# Query neighbors for all cells in the current query batch against the target batch's data.
distances, indices_in_target_batch = nn_model.kneighbors(query_data, n_neighbors=k_for_query, return_distance=True)

for i_query_local in range(len(query_global_indices)):
current_cell_global_idx = query_global_indices[i_query_local]

dists_for_cell = distances[i_query_local]
global_neighbors_for_cell = target_global_indices[indices_in_target_batch[i_query_local]]

for k_idx in range(len(global_neighbors_for_cell)):
neighbor_global_idx = global_neighbors_for_cell[k_idx]
dist = dists_for_cell[k_idx]

# Exclude self-loops: a cell should not be its own neighbor in graph construction.
if neighbor_global_idx == current_cell_global_idx:
continue

# Store neighbor and its distance. If already present, keep the minimum distance (closest connection).
if (neighbor_global_idx not in merged_neighbors_per_cell[current_cell_global_idx] or
dist < merged_neighbors_per_cell[current_cell_global_idx][neighbor_global_idx]):
merged_neighbors_per_cell[current_cell_global_idx][neighbor_global_idx] = dist

# Convert collected neighbors and distances into sparse matrices.
rows = []
cols = []
data_distances = []

for i in range(adata_integrated.n_obs):
# Retrieve all candidate neighbors for cell 'i', sort by distance, and take the top 'total_k_neighbors'.
current_cell_candidates = list(merged_neighbors_per_cell[i].items())

if not current_cell_candidates: # If a cell has no valid neighbors after all filtering
continue

# Use heapq for efficient selection of the smallest distances.
selected_neighbors = heapq.nsmallest(total_k_neighbors, current_cell_candidates, key=lambda item: item[1])

for neighbor_idx, dist in selected_neighbors:
rows.append(i)
cols.append(neighbor_idx)
data_distances.append(dist)

# Create distance matrix. Handle case with no neighbors found at all for the entire dataset.
if not rows:
distances_matrix = csr_matrix((adata_integrated.n_obs, adata_integrated.n_obs))
else:
distances_matrix = csr_matrix((data_distances, (rows, cols)), shape=(adata_integrated.n_obs, adata_integrated.n_obs))

# Symmetrize the distance matrix: if A is a neighbor of B, then B is also a neighbor of A,
# with the distance being the maximum of the two observed distances (ensures undirected graph).
distances_matrix = distances_matrix.maximum(distances_matrix.T)
distances_matrix.eliminate_zeros() # Remove any explicit zeros created by max operation

# Create connectivities matrix (binary representation of connections).
connectivities_matrix = distances_matrix.copy()
connectivities_matrix.data[:] = 1.0 # All non-zero entries become 1.0 (connected).
connectivities_matrix.eliminate_zeros()
connectivities_matrix = connectivities_matrix.astype(float)

# Store the custom graph in adata.obsp. These keys are used by scib metrics.
adata_integrated.obsp['connectivities'] = connectivities_matrix
adata_integrated.obsp['distances'] = distances_matrix

# Store parameters in adata.uns['neighbors'] for completeness and scanpy/scib compatibility.
adata_integrated.uns['neighbors'] = {
'params': {
'n_neighbors': total_k_neighbors,
'method': 'custom_batch_aware_combat_pca', # Reflects the integration strategy
'metric': 'euclidean',
'n_pcs': actual_n_pca_components,
'n_neighbors_per_batch': k_batch_neighbors,
'pca_batch_correction': 'combat', # Indicates ComBat was applied before PCA
},
'connectivities_key': 'connectivities',
'distances_key': 'distances',
}

return adata_integrated

################################################################################
######## LLM-written implementation of batch correction code ends here. ########
################################################################################

################################################################################
# Start of human-written code. #################################################
################################################################################
# This is just boilerplate to satisfy the OpenProblems-codebase-specific setup
# for running evaluation.
import sys
sys.path.append(meta["resources_dir"])
from read_anndata_partial import read_anndata

print('Read input', flush=True)
input_adata = read_anndata(
par['input'],
X='layers/counts',
obs='obs',
var='var',
uns='uns'
)

output = eliminate_batch_effect_fn(input_adata, config=config)
output.uns['method_id'] = 'bbknn_ts'
output.write_h5ad(par['output'], compression='gzip')
1 change: 1 addition & 0 deletions src/workflows/run_benchmark/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ dependencies:
- name: methods/batchelor_fastmnn
- name: methods/batchelor_mnn_correct
- name: methods/bbknn
- name: methods/bbknn_ts
- name: methods/combat
- name: methods/geneformer
- name: methods/harmony
Expand Down
1 change: 1 addition & 0 deletions src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ methods = [
batchelor_fastmnn,
batchelor_mnn_correct,
bbknn,
bbknn_ts,
combat,
geneformer,
harmony,
Expand Down