From 01c061ad7fe091dda0aea14691c32c58e478a394 Mon Sep 17 00:00:00 2001 From: cmclean Date: Sat, 13 Sep 2025 11:59:50 +0000 Subject: [PATCH 1/4] Added BBKNN (TS) method. --- CHANGELOG.md | 1 + src/methods/bbknn_ts/config.vsh.yaml | 78 +++++++++ src/methods/bbknn_ts/script.py | 248 +++++++++++++++++++++++++++ 3 files changed, 327 insertions(+) create mode 100644 src/methods/bbknn_ts/config.vsh.yaml create mode 100644 src/methods/bbknn_ts/script.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 859869e4..10f8596d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## New functionality +* Added `method/bbknn_ts` component. * Added `metrics/kbet_pg` and `metrics/kbet_pg_label` components (PR #52). * Added `method/drvi` component (PR #61). diff --git a/src/methods/bbknn_ts/config.vsh.yaml b/src/methods/bbknn_ts/config.vsh.yaml new file mode 100644 index 00000000..4b1e6f46 --- /dev/null +++ b/src/methods/bbknn_ts/config.vsh.yaml @@ -0,0 +1,78 @@ +# The API specifies which type of component this is. +# It contains specifications for: +# - The input/output files +# - Common parameters +# - A unit test +__merge__: ../../api/comp_method.yaml + +# A unique identifier for your component (required). +# Can contain only lowercase letters or underscores. +name: bbknn_ts +# A relatively short label, used when rendering visualisations (required) +label: BBKNN (TS) +# A one sentence summary of how this method works (required). Used when +# rendering summary tables. +summary: "A combination of ComBat and BBKNN discovered and implemented by Gemini." +# A multi-line description of how this component works (required). Used +# when rendering reference documentation. +description: | + "The BBKNN (TS) solution applies standard scRNA-seq preprocessing steps, including total count normalization, log-transformation, and scaling of gene expression data. Batch effect correction is performed using scanpy.pp.combat directly on the gene expression matrix (before dimensionality reduction). Dimensionality reduction is then applied using PCA on the ComBat-corrected data, and this PCA embedding (adata.obsm['X_pca']) is designated as the integrated embedding (adata.obsm['X_emb']). A custom batch-aware nearest neighbors graph is constructed based on this integrated embedding; for each cell, neighbors are independently identified within its own batch and other batches, up to n_neighbors_per_batch. These candidate neighbors are merged, keeping the minimum distance for duplicate entries, and the top total_k_neighbors are selected for each cell. Finally, a symmetric sparse distance matrix and a binary connectivities matrix are generated to represent the integrated neighborhood graph. This code was entirely written by the AI system described in the associated publication." +references: + bibtex: | + @misc{GoogleScienceAI2025, + title={An {AI} system to help scientists write expert-level empirical software}, + author={Eser Ayg\"un and Anastasiya Belyaeva and Gheorghe Comanici and Marc Coram and Hao Cui and Jake Garrison and Renee Johnston and Anton Kast and Cory Y. McLean and Peter Norgaard and Zahra Shamsi and David Smalling and James Thompson and Subhashini Venugopalan and Brian P. Williams and Chujun He and Sarah Martinson and Martyna Plomecka and Lai Wei and Yuchen Zhou and Qian-Ze Zhu and Matthew Abraham and Erica Brand and Anna Bulanova and Jeffrey A. Cardille and Chris Co and Scott Ellsworth and Grace Joseph and Malcolm Kane and Ryan Krueger and Johan Kartiwa and Dan Liebling and Jan-Matthis Lueckmann and Paul Raccuglia and Xuefei (Julie) Wang and Katherine Chou and James Manyika and Yossi Matias and John C. Platt and Lizzie Dorfman and Shibl Mourad and Michael P. Brenner}, + year={2025}, + eprint={2509.06503}, + archivePrefix={arXiv}, + primaryClass={cs.AI}, + url={https://arxiv.org/abs/2509.06503} + } +links: + # URL to the documentation for this method (required). + documentation: https://google-research.github.io/score/ + # URL to the code repository for this method (required). + repository: https://github.com/google-research/score + + + +# Metadata for your component +info: + method_types: [embedding] + # Which normalisation method this component prefers to use (required). + preferred_normalization: log_cp10k + +# Component-specific parameters (optional) +arguments: + - name: "--n_pca_components" + type: "integer" + default: 100 + description: Number of PCA components. + - name: "--n_neighbors_per_batch" + type: "integer" + default: 10 + description: Number of neighbors to use within each batch. + - name: "--total_k_neighbors" + type: "integer" + default: 50 + description: Total number of nearest neighbors to retain for the final graph. + +# Resources required to run the component +resources: + # The script of your component (required) + - type: python_script + path: script.py + - path: /src/utils/read_anndata_partial.py + +engines: + # Specifications for the Docker image for this component. + - type: docker + image: openproblems/base_python:1.0.0 + +runners: + # This platform allows running the component natively + - type: executable + # Allows turning the component into a Nextflow module / pipeline. + - type: nextflow + directives: + label: [midtime,midmem,midcpu] diff --git a/src/methods/bbknn_ts/script.py b/src/methods/bbknn_ts/script.py new file mode 100644 index 00000000..7046f24b --- /dev/null +++ b/src/methods/bbknn_ts/script.py @@ -0,0 +1,248 @@ +"""BBKNN (TS) code for removing batch effects. + +This is the top-performing batch correction implementation discovered by the AI +system described in https://arxiv.org/abs/2509.06503. +""" +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + 'input': 'resources_test/.../input.h5ad', + 'output': 'output.h5ad' +} +meta = { + 'name': 'bbknn_ts' +} +## VIASH END + +################################################################################ +####### LLM-written implementation of batch correction code starts here. ####### +################################################################################ +from typing import Any +from sklearn.decomposition import TruncatedSVD +from sklearn.neighbors import NearestNeighbors +from scipy.sparse import lil_matrix, csr_matrix +import numpy as np +import scanpy as sc +import anndata as ad +import heapq # For efficiently getting top K elements from merged lists + +# Define parameters for the config. +# These values are chosen to balance computational cost and integration performance +# for datasets with up to ~300k cells and 2k genes. +config = { + 'n_pca_components': 100, # Number of PCA components. Recommended: 50-200. + # Captures sufficient variance while reducing dimensionality. + 'n_neighbors_per_batch': 10, # Number of neighbors to find within each batch. Recommended: 5-15. + # This defines the local batch context for each cell. + 'total_k_neighbors': 50, # Total number of nearest neighbors to retain for the final graph. Recommended: 15-100. + # This forms the global batch-integrated graph. +} + + +def eliminate_batch_effect_fn( + adata: ad.AnnData, config: dict[str, Any] +) -> ad.AnnData: + # Create a copy to ensure the original input adata remains unchanged. + adata_integrated = adata.copy() + + # --- Preprocessing: Normalize, log-transform, scale --- + # These are standard initial steps for scRNA-seq data. + # Use adata.X which contains raw counts. + sc.pp.normalize_total(adata_integrated, target_sum=1e4) + sc.pp.log1p(adata_integrated) + sc.pp.scale(adata_integrated, max_value=10) # Clip values to avoid extreme outliers + + # --- Batch Correction: ComBat on the gene expression matrix --- + # This step applies a more robust linear model-based batch correction + # directly on the gene expression data before dimensionality reduction. + # ComBat modifies adata_integrated.X in place. + sc.pp.combat(adata_integrated, key='batch') + + # --- Dimensionality Reduction: PCA on the ComBat-corrected data --- + # n_comps cannot exceed min(n_obs - 1, n_vars). Robustly handle small datasets. + n_pca_components = config.get('n_pca_components', 100) + actual_n_pca_components = min(n_pca_components, adata_integrated.n_vars, adata_integrated.n_obs - 1) + + # Handle edge cases for PCA and graph construction where data is too small. + # If PCA cannot be run meaningfully, return a minimal AnnData object to avoid errors. + if actual_n_pca_components <= 0 or adata_integrated.n_obs <= 1: + print(f"Warning: Too few observations ({adata_integrated.n_obs}) or dimensions ({adata_integrated.n_vars}) for PCA/graph construction. Returning trivial embedding.") + # Provide a placeholder embedding and empty graph structure. + adata_integrated.obsm['X_emb'] = np.zeros((adata_integrated.n_obs, 1)) + adata_integrated.obsp['connectivities'] = csr_matrix((adata_integrated.n_obs, adata_integrated.n_obs)) + adata_integrated.obsp['distances'] = csr_matrix((adata_integrated.n_obs, adata_integrated.n_obs)) + adata_integrated.uns['neighbors'] = { + 'params': { + 'n_neighbors': 0, + 'method': 'degenerate', + 'n_pcs': 0, + 'n_neighbors_per_batch': 0, + 'pca_batch_correction': 'none', + }, + 'connectivities_key': 'connectivities', + 'distances_key': 'distances', + } + return adata_integrated + + sc.tl.pca(adata_integrated, n_comps=actual_n_pca_components, svd_solver='arpack') + + # Set the ComBat-corrected PCA embedding as the integrated output embedding. + # This 'X_emb' will be directly evaluated by metrics like ASW, LISI, PCR. + adata_integrated.obsm['X_emb'] = adata_integrated.obsm['X_pca'] + + + # --- Custom Batch-Aware Nearest Neighbors Graph Construction --- + # This implements the expert advice: find neighbors independently within batches, then merge. + # This part of the code remains largely the same, but now operates on the + # ComBat-corrected PCA embedding (adata_integrated.obsm['X_emb']). + k_batch_neighbors = config.get('n_neighbors_per_batch', 10) + total_k_neighbors = config.get('total_k_neighbors', 50) + + # A list of dictionaries to store unique neighbors and their minimum distances for each cell. + # Using dictionaries allows efficient updating if a cell is found as a neighbor from multiple batches. + merged_neighbors_per_cell = [{} for _ in range(adata_integrated.n_obs)] + + # Group cell indices by batch for efficient querying. + batches = adata_integrated.obs['batch'].values + unique_batches = np.unique(batches) + batch_to_indices = {b: np.where(batches == b)[0] for b in unique_batches} + + # Pre-fit NearestNeighbors models for each batch's data using the corrected PCA embedding. + # This avoids refitting the model for every query. + batch_nn_models = {} + for b_id in unique_batches: + batch_cell_indices = batch_to_indices[b_id] + # Ensure there are enough cells to fit a NearestNeighbors model (at least k_batch_neighbors + 1 for self-exclusion, or just > 0 for min k=1) + if len(batch_cell_indices) > 0: + # Fit with a k that is at most the batch size to avoid errors if k_batch_neighbors is too high for a small batch. + k_fit_effective = min(k_batch_neighbors + 1, len(batch_cell_indices)) # +1 to ensure self-loop can be found and excluded + if k_fit_effective > 0: # Only fit if there are points available + nn_model = NearestNeighbors(n_neighbors=k_fit_effective, metric='euclidean', algorithm='auto') + nn_model.fit(adata_integrated.obsm['X_emb'][batch_cell_indices]) + batch_nn_models[b_id] = nn_model + + # Iterate through all possible query batches and target batches to find neighbors. + for query_batch_id in unique_batches: + query_global_indices = batch_to_indices[query_batch_id] + if len(query_global_indices) == 0: + continue # Skip empty query batches + + query_data = adata_integrated.obsm['X_emb'][query_global_indices] + + for target_batch_id in unique_batches: + if target_batch_id not in batch_nn_models: + continue # Skip target batches that were too small to fit an NN model + + nn_model = batch_nn_models[target_batch_id] + target_global_indices = batch_to_indices[target_batch_id] + + # Ensure n_neighbors does not exceed the number of points in the target batch. + k_for_query = min(k_batch_neighbors, len(target_global_indices) -1) # -1 to avoid finding self as neighbor if batch is query batch + if k_for_query <= 0: # No valid neighbors can be found in this target batch + continue + + # Query neighbors for all cells in the current query batch against the target batch's data. + distances, indices_in_target_batch = nn_model.kneighbors(query_data, n_neighbors=k_for_query, return_distance=True) + + for i_query_local in range(len(query_global_indices)): + current_cell_global_idx = query_global_indices[i_query_local] + + dists_for_cell = distances[i_query_local] + global_neighbors_for_cell = target_global_indices[indices_in_target_batch[i_query_local]] + + for k_idx in range(len(global_neighbors_for_cell)): + neighbor_global_idx = global_neighbors_for_cell[k_idx] + dist = dists_for_cell[k_idx] + + # Exclude self-loops: a cell should not be its own neighbor in graph construction. + if neighbor_global_idx == current_cell_global_idx: + continue + + # Store neighbor and its distance. If already present, keep the minimum distance (closest connection). + if (neighbor_global_idx not in merged_neighbors_per_cell[current_cell_global_idx] or + dist < merged_neighbors_per_cell[current_cell_global_idx][neighbor_global_idx]): + merged_neighbors_per_cell[current_cell_global_idx][neighbor_global_idx] = dist + + # Convert collected neighbors and distances into sparse matrices. + rows = [] + cols = [] + data_distances = [] + + for i in range(adata_integrated.n_obs): + # Retrieve all candidate neighbors for cell 'i', sort by distance, and take the top 'total_k_neighbors'. + current_cell_candidates = list(merged_neighbors_per_cell[i].items()) + + if not current_cell_candidates: # If a cell has no valid neighbors after all filtering + continue + + # Use heapq for efficient selection of the smallest distances. + selected_neighbors = heapq.nsmallest(total_k_neighbors, current_cell_candidates, key=lambda item: item[1]) + + for neighbor_idx, dist in selected_neighbors: + rows.append(i) + cols.append(neighbor_idx) + data_distances.append(dist) + + # Create distance matrix. Handle case with no neighbors found at all for the entire dataset. + if not rows: + distances_matrix = csr_matrix((adata_integrated.n_obs, adata_integrated.n_obs)) + else: + distances_matrix = csr_matrix((data_distances, (rows, cols)), shape=(adata_integrated.n_obs, adata_integrated.n_obs)) + + # Symmetrize the distance matrix: if A is a neighbor of B, then B is also a neighbor of A, + # with the distance being the maximum of the two observed distances (ensures undirected graph). + distances_matrix = distances_matrix.maximum(distances_matrix.T) + distances_matrix.eliminate_zeros() # Remove any explicit zeros created by max operation + + # Create connectivities matrix (binary representation of connections). + connectivities_matrix = distances_matrix.copy() + connectivities_matrix.data[:] = 1.0 # All non-zero entries become 1.0 (connected). + connectivities_matrix.eliminate_zeros() + connectivities_matrix = connectivities_matrix.astype(float) + + # Store the custom graph in adata.obsp. These keys are used by scib metrics. + adata_integrated.obsp['connectivities'] = connectivities_matrix + adata_integrated.obsp['distances'] = distances_matrix + + # Store parameters in adata.uns['neighbors'] for completeness and scanpy/scib compatibility. + adata_integrated.uns['neighbors'] = { + 'params': { + 'n_neighbors': total_k_neighbors, + 'method': 'custom_batch_aware_combat_pca', # Reflects the integration strategy + 'metric': 'euclidean', + 'n_pcs': actual_n_pca_components, + 'n_neighbors_per_batch': k_batch_neighbors, + 'pca_batch_correction': 'combat', # Indicates ComBat was applied before PCA + }, + 'connectivities_key': 'connectivities', + 'distances_key': 'distances', + } + + return adata_integrated + +################################################################################ +######## LLM-written implementation of batch correction code ends here. ######## +################################################################################ + +################################################################################ +# Start of human-written code. ################################################# +################################################################################ +# This is just boilerplate to satisfy the OpenProblems-codebase-specific setup +# for running evaluation. +import sys +sys.path.append(meta["resources_dir"]) +from read_anndata_partial import read_anndata + +print('Read input', flush=True) +input_adata = read_anndata( + par['input'], + X='layers/counts', + obs='obs', + var='var', + uns='uns' +) + +output = eliminate_batch_effect_fn(input_adata, config=config) +output.uns['method_id'] = 'bbknn_ts' +output.write_h5ad(par['output'], compression='gzip') From 8de3ea1fa6c153041e5a6d7b1e6718685ca2c01a Mon Sep 17 00:00:00 2001 From: cmclean Date: Fri, 19 Sep 2025 19:38:09 +0000 Subject: [PATCH 2/4] Add bbknn_ts to run_benchmark methods --- src/workflows/run_benchmark/config.vsh.yaml | 1 + src/workflows/run_benchmark/main.nf | 1 + 2 files changed, 2 insertions(+) diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index 09905ad0..b9b72632 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -91,6 +91,7 @@ dependencies: - name: methods/batchelor_fastmnn - name: methods/batchelor_mnn_correct - name: methods/bbknn + - name: methods/bbknn_ts - name: methods/combat - name: methods/geneformer - name: methods/harmony diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index 6196f749..0c77ef56 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -19,6 +19,7 @@ methods = [ batchelor_fastmnn, batchelor_mnn_correct, bbknn, + bbknn_ts, combat, geneformer, harmony, From c048c0f28b3a935e6dad68702f79bc295cd50b1c Mon Sep 17 00:00:00 2001 From: cmclean Date: Fri, 19 Sep 2025 19:41:58 +0000 Subject: [PATCH 3/4] Add PR# to BBKNN (TS) method addition log --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 10f8596d..a778ae92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ ## New functionality -* Added `method/bbknn_ts` component. +* Added `method/bbknn_ts` component (PR #84). * Added `metrics/kbet_pg` and `metrics/kbet_pg_label` components (PR #52). * Added `method/drvi` component (PR #61). From b04bdf7bb2cecb53bda5c22ee6bf2bb8ac4cdfda Mon Sep 17 00:00:00 2001 From: cmclean Date: Sun, 28 Sep 2025 15:50:51 +0000 Subject: [PATCH 4/4] Clean up bbknn_ts YAML and CHANGELOG.md for readability. --- CHANGELOG.md | 4 +++- src/methods/bbknn_ts/config.vsh.yaml | 30 +++------------------------- 2 files changed, 6 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a778ae92..db349ade 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,12 +2,14 @@ ## New functionality -* Added `method/bbknn_ts` component (PR #84). * Added `metrics/kbet_pg` and `metrics/kbet_pg_label` components (PR #52). + * Added `method/drvi` component (PR #61). * Added `ARI_batch` and `NMI_batch` to `metrics/clustering_overlap` (PR #68). +* Added `method/bbknn_ts` component (PR #84). + ## Minor changes * Un-pin the scPRINT version and update parameters (PR #51) diff --git a/src/methods/bbknn_ts/config.vsh.yaml b/src/methods/bbknn_ts/config.vsh.yaml index 4b1e6f46..7ac3189a 100644 --- a/src/methods/bbknn_ts/config.vsh.yaml +++ b/src/methods/bbknn_ts/config.vsh.yaml @@ -1,20 +1,8 @@ -# The API specifies which type of component this is. -# It contains specifications for: -# - The input/output files -# - Common parameters -# - A unit test __merge__: ../../api/comp_method.yaml -# A unique identifier for your component (required). -# Can contain only lowercase letters or underscores. name: bbknn_ts -# A relatively short label, used when rendering visualisations (required) label: BBKNN (TS) -# A one sentence summary of how this method works (required). Used when -# rendering summary tables. summary: "A combination of ComBat and BBKNN discovered and implemented by Gemini." -# A multi-line description of how this component works (required). Used -# when rendering reference documentation. description: | "The BBKNN (TS) solution applies standard scRNA-seq preprocessing steps, including total count normalization, log-transformation, and scaling of gene expression data. Batch effect correction is performed using scanpy.pp.combat directly on the gene expression matrix (before dimensionality reduction). Dimensionality reduction is then applied using PCA on the ComBat-corrected data, and this PCA embedding (adata.obsm['X_pca']) is designated as the integrated embedding (adata.obsm['X_emb']). A custom batch-aware nearest neighbors graph is constructed based on this integrated embedding; for each cell, neighbors are independently identified within its own batch and other batches, up to n_neighbors_per_batch. These candidate neighbors are merged, keeping the minimum distance for duplicate entries, and the top total_k_neighbors are selected for each cell. Finally, a symmetric sparse distance matrix and a binary connectivities matrix are generated to represent the integrated neighborhood graph. This code was entirely written by the AI system described in the associated publication." references: @@ -29,50 +17,38 @@ references: url={https://arxiv.org/abs/2509.06503} } links: - # URL to the documentation for this method (required). documentation: https://google-research.github.io/score/ - # URL to the code repository for this method (required). repository: https://github.com/google-research/score - - -# Metadata for your component info: method_types: [embedding] - # Which normalisation method this component prefers to use (required). preferred_normalization: log_cp10k -# Component-specific parameters (optional) arguments: - name: "--n_pca_components" type: "integer" default: 100 - description: Number of PCA components. + description: "Number of PCA components." - name: "--n_neighbors_per_batch" type: "integer" default: 10 - description: Number of neighbors to use within each batch. + description: "Number of neighbors to use within each batch." - name: "--total_k_neighbors" type: "integer" default: 50 - description: Total number of nearest neighbors to retain for the final graph. + description: "Total number of nearest neighbors to retain for the final graph." -# Resources required to run the component resources: - # The script of your component (required) - type: python_script path: script.py - path: /src/utils/read_anndata_partial.py engines: - # Specifications for the Docker image for this component. - type: docker image: openproblems/base_python:1.0.0 runners: - # This platform allows running the component natively - type: executable - # Allows turning the component into a Nextflow module / pipeline. - type: nextflow directives: label: [midtime,midmem,midcpu]