-
Notifications
You must be signed in to change notification settings - Fork 12
Add BBKNN (TS) method. #84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
01c061a
8de3ea1
c048c0f
b04bdf7
a117d68
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
__merge__: ../../api/comp_method.yaml | ||
|
||
name: bbknn_ts | ||
label: BBKNN (TS) | ||
summary: "A combination of ComBat and BBKNN discovered and implemented by Gemini." | ||
description: | | ||
"The BBKNN (TS) solution applies standard scRNA-seq preprocessing steps, including total count normalization, log-transformation, and scaling of gene expression data. Batch effect correction is performed using scanpy.pp.combat directly on the gene expression matrix (before dimensionality reduction). Dimensionality reduction is then applied using PCA on the ComBat-corrected data, and this PCA embedding (adata.obsm['X_pca']) is designated as the integrated embedding (adata.obsm['X_emb']). A custom batch-aware nearest neighbors graph is constructed based on this integrated embedding; for each cell, neighbors are independently identified within its own batch and other batches, up to n_neighbors_per_batch. These candidate neighbors are merged, keeping the minimum distance for duplicate entries, and the top total_k_neighbors are selected for each cell. Finally, a symmetric sparse distance matrix and a binary connectivities matrix are generated to represent the integrated neighborhood graph. This code was entirely written by the AI system described in the associated publication." | ||
cmclean marked this conversation as resolved.
Show resolved
Hide resolved
|
||
references: | ||
bibtex: | | ||
@misc{GoogleScienceAI2025, | ||
title={An {AI} system to help scientists write expert-level empirical software}, | ||
author={Eser Ayg\"un and Anastasiya Belyaeva and Gheorghe Comanici and Marc Coram and Hao Cui and Jake Garrison and Renee Johnston and Anton Kast and Cory Y. McLean and Peter Norgaard and Zahra Shamsi and David Smalling and James Thompson and Subhashini Venugopalan and Brian P. Williams and Chujun He and Sarah Martinson and Martyna Plomecka and Lai Wei and Yuchen Zhou and Qian-Ze Zhu and Matthew Abraham and Erica Brand and Anna Bulanova and Jeffrey A. Cardille and Chris Co and Scott Ellsworth and Grace Joseph and Malcolm Kane and Ryan Krueger and Johan Kartiwa and Dan Liebling and Jan-Matthis Lueckmann and Paul Raccuglia and Xuefei (Julie) Wang and Katherine Chou and James Manyika and Yossi Matias and John C. Platt and Lizzie Dorfman and Shibl Mourad and Michael P. Brenner}, | ||
year={2025}, | ||
eprint={2509.06503}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.AI}, | ||
url={https://arxiv.org/abs/2509.06503} | ||
} | ||
links: | ||
documentation: https://google-research.github.io/score/ | ||
repository: https://github.com/google-research/score | ||
|
||
info: | ||
method_types: [embedding] | ||
preferred_normalization: log_cp10k | ||
|
||
arguments: | ||
- name: "--n_pca_components" | ||
type: "integer" | ||
default: 100 | ||
description: "Number of PCA components." | ||
- name: "--n_neighbors_per_batch" | ||
type: "integer" | ||
default: 10 | ||
description: "Number of neighbors to use within each batch." | ||
- name: "--total_k_neighbors" | ||
type: "integer" | ||
default: 50 | ||
description: "Total number of nearest neighbors to retain for the final graph." | ||
|
||
resources: | ||
- type: python_script | ||
path: script.py | ||
- path: /src/utils/read_anndata_partial.py | ||
|
||
engines: | ||
- type: docker | ||
image: openproblems/base_python:1.0.0 | ||
|
||
runners: | ||
- type: executable | ||
- type: nextflow | ||
directives: | ||
label: [midtime,midmem,midcpu] |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on a quick glance, this code will likely work as the developer intended, however it does not make use of the pre-computed processing step and recomputes everything by itself insead. This in itself won't cause the code to fail or give wrong results, however it won't follow the benchmark setup as intended There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you elaborate a little on the comment that "it won't follow the benchmark setup as intended"? Is that because we indicate this is an embedding method with preferred normalization of Would it be considered to follow the benchmark setup as intended if we indicated it is a I'd rather not edit the code to do that since right now it's purely LLM implemented, just want to make sure we're not missing something more fundamental. Thanks! |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
"""BBKNN (TS) code for removing batch effects. | ||
|
||
This is the top-performing batch correction implementation discovered by the AI | ||
system described in https://arxiv.org/abs/2509.06503. | ||
""" | ||
## VIASH START | ||
# Note: this section is auto-generated by viash at runtime. To edit it, make changes | ||
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. | ||
par = { | ||
'input': 'resources_test/.../input.h5ad', | ||
'output': 'output.h5ad' | ||
} | ||
meta = { | ||
'name': 'bbknn_ts' | ||
} | ||
## VIASH END | ||
|
||
################################################################################ | ||
####### LLM-written implementation of batch correction code starts here. ####### | ||
################################################################################ | ||
from typing import Any | ||
from sklearn.decomposition import TruncatedSVD | ||
from sklearn.neighbors import NearestNeighbors | ||
from scipy.sparse import lil_matrix, csr_matrix | ||
import numpy as np | ||
import scanpy as sc | ||
import anndata as ad | ||
import heapq # For efficiently getting top K elements from merged lists | ||
|
||
# Define parameters for the config. | ||
# These values are chosen to balance computational cost and integration performance | ||
# for datasets with up to ~300k cells and 2k genes. | ||
config = { | ||
'n_pca_components': 100, # Number of PCA components. Recommended: 50-200. | ||
# Captures sufficient variance while reducing dimensionality. | ||
'n_neighbors_per_batch': 10, # Number of neighbors to find within each batch. Recommended: 5-15. | ||
# This defines the local batch context for each cell. | ||
'total_k_neighbors': 50, # Total number of nearest neighbors to retain for the final graph. Recommended: 15-100. | ||
# This forms the global batch-integrated graph. | ||
} | ||
|
||
|
||
def eliminate_batch_effect_fn( | ||
adata: ad.AnnData, config: dict[str, Any] | ||
) -> ad.AnnData: | ||
# Create a copy to ensure the original input adata remains unchanged. | ||
adata_integrated = adata.copy() | ||
|
||
# --- Preprocessing: Normalize, log-transform, scale --- | ||
# These are standard initial steps for scRNA-seq data. | ||
# Use adata.X which contains raw counts. | ||
sc.pp.normalize_total(adata_integrated, target_sum=1e4) | ||
sc.pp.log1p(adata_integrated) | ||
sc.pp.scale(adata_integrated, max_value=10) # Clip values to avoid extreme outliers | ||
|
||
# --- Batch Correction: ComBat on the gene expression matrix --- | ||
# This step applies a more robust linear model-based batch correction | ||
# directly on the gene expression data before dimensionality reduction. | ||
# ComBat modifies adata_integrated.X in place. | ||
sc.pp.combat(adata_integrated, key='batch') | ||
|
||
# --- Dimensionality Reduction: PCA on the ComBat-corrected data --- | ||
# n_comps cannot exceed min(n_obs - 1, n_vars). Robustly handle small datasets. | ||
n_pca_components = config.get('n_pca_components', 100) | ||
actual_n_pca_components = min(n_pca_components, adata_integrated.n_vars, adata_integrated.n_obs - 1) | ||
|
||
# Handle edge cases for PCA and graph construction where data is too small. | ||
# If PCA cannot be run meaningfully, return a minimal AnnData object to avoid errors. | ||
if actual_n_pca_components <= 0 or adata_integrated.n_obs <= 1: | ||
print(f"Warning: Too few observations ({adata_integrated.n_obs}) or dimensions ({adata_integrated.n_vars}) for PCA/graph construction. Returning trivial embedding.") | ||
# Provide a placeholder embedding and empty graph structure. | ||
adata_integrated.obsm['X_emb'] = np.zeros((adata_integrated.n_obs, 1)) | ||
adata_integrated.obsp['connectivities'] = csr_matrix((adata_integrated.n_obs, adata_integrated.n_obs)) | ||
adata_integrated.obsp['distances'] = csr_matrix((adata_integrated.n_obs, adata_integrated.n_obs)) | ||
adata_integrated.uns['neighbors'] = { | ||
'params': { | ||
'n_neighbors': 0, | ||
'method': 'degenerate', | ||
'n_pcs': 0, | ||
'n_neighbors_per_batch': 0, | ||
'pca_batch_correction': 'none', | ||
}, | ||
'connectivities_key': 'connectivities', | ||
'distances_key': 'distances', | ||
} | ||
return adata_integrated | ||
|
||
sc.tl.pca(adata_integrated, n_comps=actual_n_pca_components, svd_solver='arpack') | ||
|
||
# Set the ComBat-corrected PCA embedding as the integrated output embedding. | ||
# This 'X_emb' will be directly evaluated by metrics like ASW, LISI, PCR. | ||
adata_integrated.obsm['X_emb'] = adata_integrated.obsm['X_pca'] | ||
|
||
|
||
# --- Custom Batch-Aware Nearest Neighbors Graph Construction --- | ||
# This implements the expert advice: find neighbors independently within batches, then merge. | ||
# This part of the code remains largely the same, but now operates on the | ||
# ComBat-corrected PCA embedding (adata_integrated.obsm['X_emb']). | ||
k_batch_neighbors = config.get('n_neighbors_per_batch', 10) | ||
total_k_neighbors = config.get('total_k_neighbors', 50) | ||
|
||
# A list of dictionaries to store unique neighbors and their minimum distances for each cell. | ||
# Using dictionaries allows efficient updating if a cell is found as a neighbor from multiple batches. | ||
merged_neighbors_per_cell = [{} for _ in range(adata_integrated.n_obs)] | ||
|
||
# Group cell indices by batch for efficient querying. | ||
batches = adata_integrated.obs['batch'].values | ||
unique_batches = np.unique(batches) | ||
batch_to_indices = {b: np.where(batches == b)[0] for b in unique_batches} | ||
|
||
# Pre-fit NearestNeighbors models for each batch's data using the corrected PCA embedding. | ||
# This avoids refitting the model for every query. | ||
batch_nn_models = {} | ||
for b_id in unique_batches: | ||
batch_cell_indices = batch_to_indices[b_id] | ||
# Ensure there are enough cells to fit a NearestNeighbors model (at least k_batch_neighbors + 1 for self-exclusion, or just > 0 for min k=1) | ||
if len(batch_cell_indices) > 0: | ||
# Fit with a k that is at most the batch size to avoid errors if k_batch_neighbors is too high for a small batch. | ||
k_fit_effective = min(k_batch_neighbors + 1, len(batch_cell_indices)) # +1 to ensure self-loop can be found and excluded | ||
if k_fit_effective > 0: # Only fit if there are points available | ||
nn_model = NearestNeighbors(n_neighbors=k_fit_effective, metric='euclidean', algorithm='auto') | ||
nn_model.fit(adata_integrated.obsm['X_emb'][batch_cell_indices]) | ||
batch_nn_models[b_id] = nn_model | ||
|
||
# Iterate through all possible query batches and target batches to find neighbors. | ||
for query_batch_id in unique_batches: | ||
query_global_indices = batch_to_indices[query_batch_id] | ||
if len(query_global_indices) == 0: | ||
continue # Skip empty query batches | ||
|
||
query_data = adata_integrated.obsm['X_emb'][query_global_indices] | ||
|
||
for target_batch_id in unique_batches: | ||
if target_batch_id not in batch_nn_models: | ||
continue # Skip target batches that were too small to fit an NN model | ||
|
||
nn_model = batch_nn_models[target_batch_id] | ||
target_global_indices = batch_to_indices[target_batch_id] | ||
|
||
# Ensure n_neighbors does not exceed the number of points in the target batch. | ||
k_for_query = min(k_batch_neighbors, len(target_global_indices) -1) # -1 to avoid finding self as neighbor if batch is query batch | ||
if k_for_query <= 0: # No valid neighbors can be found in this target batch | ||
continue | ||
|
||
# Query neighbors for all cells in the current query batch against the target batch's data. | ||
distances, indices_in_target_batch = nn_model.kneighbors(query_data, n_neighbors=k_for_query, return_distance=True) | ||
|
||
for i_query_local in range(len(query_global_indices)): | ||
current_cell_global_idx = query_global_indices[i_query_local] | ||
|
||
dists_for_cell = distances[i_query_local] | ||
global_neighbors_for_cell = target_global_indices[indices_in_target_batch[i_query_local]] | ||
|
||
for k_idx in range(len(global_neighbors_for_cell)): | ||
neighbor_global_idx = global_neighbors_for_cell[k_idx] | ||
dist = dists_for_cell[k_idx] | ||
|
||
# Exclude self-loops: a cell should not be its own neighbor in graph construction. | ||
if neighbor_global_idx == current_cell_global_idx: | ||
continue | ||
|
||
# Store neighbor and its distance. If already present, keep the minimum distance (closest connection). | ||
if (neighbor_global_idx not in merged_neighbors_per_cell[current_cell_global_idx] or | ||
dist < merged_neighbors_per_cell[current_cell_global_idx][neighbor_global_idx]): | ||
merged_neighbors_per_cell[current_cell_global_idx][neighbor_global_idx] = dist | ||
|
||
# Convert collected neighbors and distances into sparse matrices. | ||
rows = [] | ||
cols = [] | ||
data_distances = [] | ||
|
||
for i in range(adata_integrated.n_obs): | ||
# Retrieve all candidate neighbors for cell 'i', sort by distance, and take the top 'total_k_neighbors'. | ||
current_cell_candidates = list(merged_neighbors_per_cell[i].items()) | ||
|
||
if not current_cell_candidates: # If a cell has no valid neighbors after all filtering | ||
continue | ||
|
||
# Use heapq for efficient selection of the smallest distances. | ||
selected_neighbors = heapq.nsmallest(total_k_neighbors, current_cell_candidates, key=lambda item: item[1]) | ||
|
||
for neighbor_idx, dist in selected_neighbors: | ||
rows.append(i) | ||
cols.append(neighbor_idx) | ||
data_distances.append(dist) | ||
|
||
# Create distance matrix. Handle case with no neighbors found at all for the entire dataset. | ||
if not rows: | ||
distances_matrix = csr_matrix((adata_integrated.n_obs, adata_integrated.n_obs)) | ||
else: | ||
distances_matrix = csr_matrix((data_distances, (rows, cols)), shape=(adata_integrated.n_obs, adata_integrated.n_obs)) | ||
|
||
# Symmetrize the distance matrix: if A is a neighbor of B, then B is also a neighbor of A, | ||
# with the distance being the maximum of the two observed distances (ensures undirected graph). | ||
distances_matrix = distances_matrix.maximum(distances_matrix.T) | ||
distances_matrix.eliminate_zeros() # Remove any explicit zeros created by max operation | ||
|
||
# Create connectivities matrix (binary representation of connections). | ||
connectivities_matrix = distances_matrix.copy() | ||
connectivities_matrix.data[:] = 1.0 # All non-zero entries become 1.0 (connected). | ||
connectivities_matrix.eliminate_zeros() | ||
connectivities_matrix = connectivities_matrix.astype(float) | ||
|
||
# Store the custom graph in adata.obsp. These keys are used by scib metrics. | ||
adata_integrated.obsp['connectivities'] = connectivities_matrix | ||
adata_integrated.obsp['distances'] = distances_matrix | ||
|
||
# Store parameters in adata.uns['neighbors'] for completeness and scanpy/scib compatibility. | ||
adata_integrated.uns['neighbors'] = { | ||
'params': { | ||
'n_neighbors': total_k_neighbors, | ||
'method': 'custom_batch_aware_combat_pca', # Reflects the integration strategy | ||
'metric': 'euclidean', | ||
'n_pcs': actual_n_pca_components, | ||
'n_neighbors_per_batch': k_batch_neighbors, | ||
'pca_batch_correction': 'combat', # Indicates ComBat was applied before PCA | ||
}, | ||
'connectivities_key': 'connectivities', | ||
'distances_key': 'distances', | ||
} | ||
|
||
return adata_integrated | ||
|
||
################################################################################ | ||
######## LLM-written implementation of batch correction code ends here. ######## | ||
################################################################################ | ||
|
||
################################################################################ | ||
# Start of human-written code. ################################################# | ||
################################################################################ | ||
# This is just boilerplate to satisfy the OpenProblems-codebase-specific setup | ||
# for running evaluation. | ||
import sys | ||
sys.path.append(meta["resources_dir"]) | ||
from read_anndata_partial import read_anndata | ||
|
||
print('Read input', flush=True) | ||
input_adata = read_anndata( | ||
par['input'], | ||
X='layers/counts', | ||
obs='obs', | ||
var='var', | ||
uns='uns' | ||
) | ||
|
||
output = eliminate_batch_effect_fn(input_adata, config=config) | ||
output.uns['method_id'] = 'bbknn_ts' | ||
output.write_h5ad(par['output'], compression='gzip') |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ methods = [ | |
batchelor_fastmnn, | ||
batchelor_mnn_correct, | ||
bbknn, | ||
bbknn_ts, | ||
combat, | ||
geneformer, | ||
harmony, | ||
|
Uh oh!
There was an error while loading. Please reload this page.