Skip to content

kuleshov-group/PlantCaduceus

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

72 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Static Badge GitHub Repo stars GitHub Issues or Pull Requests DOI Hugging Face Hugging Face Downloads

logo

Table of Contents

PlantCaduceus, with its short name of PlantCAD, is a plant DNA LM based on the Caduceus architecture, which extends the efficient Mamba linear-time sequence modeling framework to incorporate bi-directionality and reverse complement equivariance, specifically designed for DNA sequences. PlantCAD is pre-trained on a curated dataset of 16 Angiosperm genomes. PlantCAD showed state-of-the-art cross species performance in predicting TIS, TTS, Splice Donor and Splice Acceptor. The zero-shot of PlantCAD enables identifying genome-wide deleterious mutations and known causal variants in Arabidopsis, Sorghum and Maize.

Quick Start

New to PlantCAD? Try our Google Colab demo - no installation required!

For local usage: See installation instructions below, then use notebooks/examples.ipynb to get started.

Model summary

Pre-trained PlantCAD models have been uploaded to HuggingFace πŸ€—. Here's the summary of four PlantCAD models with different parameter sizes.

Model Sequence Length Model Size Embedding Size
PlantCaduceus_l20 512bp 20M 384
PlantCaduceus_l24 512bp 40M 512
PlantCaduceus_l28 512bp 128M 768
PlantCaduceus_l32 512bp 225M 1024

Model Selection Guide:

  • PlantCaduceus_l20: Good for testing and quick analysis
  • PlantCaduceus_l32: Recommended for research and production (best performance)

Prerequisites and System Requirements

For Google Colab: Just a Google account - GPU runtime recommended (free tier available)

For Local Installation: GPU recommended for reasonable performance. Dependencies will be installed automatically during setup.

Installation

Option 1: Google Colab (Recommended for beginners)

No installation required! Just open our PlantCAD Google Colab notebook and start analyzing your data.

Setup steps:

  1. Open the Colab link
  2. Important: Set runtime to GPU (Runtime β†’ Change runtime type β†’ Hardware accelerator: GPU)
  3. Run the cells to install dependencies
  4. Upload your data or use the provided examples

Option 2: Local installation

Step 1: Create conda environment

# Clone the repository (if you haven't already)
git clone https://github.com/kuleshov-group/PlantCaduceus.git
cd PlantCaduceus

# Create and activate environment
conda env create -f env/environment.yml
conda activate PlantCAD

Step 2: Install Python packages

pip install -r env/requirements.txt --no-build-isolation

Step 3: Verify installation

# Test core dependencies
import torch
from mamba_ssm import Mamba
from transformers import AutoTokenizer, AutoModel

# Test PlantCAD model loading
tokenizer = AutoTokenizer.from_pretrained('kuleshov-group/PlantCaduceus_l20')
print("βœ… Installation successful!")

Alternative: pip-only installation If you prefer pip-only installation, see issue #10 for community solutions.

Troubleshooting installation

mamba_ssm issues (most common):

# If mamba_ssm import fails, reinstall with:
pip uninstall mamba-ssm
pip install mamba-ssm==2.2.0 --no-build-isolation

CUDA/GPU issues:

  • Verify CUDA installation: nvidia-smi
  • Check PyTorch CUDA support: python -c "import torch; print(torch.cuda.is_available())"
  • For CPU-only usage: Models will work but be significantly slower

Basic Usage

Exploring model inputs and outputs

The easiest way to start is with our example notebook: notebooks/examples.ipynb

Quick example - Get sequence embeddings:

from transformers import AutoTokenizer, AutoModel
import torch

# Load model and tokenizer
model_name = 'kuleshov-group/PlantCaduceus_l20'  # Start with smaller model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Example plant DNA sequence (512bp max)
sequence = "ATGCGATCGATCGATC..."  # Your sequence here

# Get embeddings
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
    outputs = model(**inputs)
    embeddings = outputs.last_hidden_state

print(f"Sequence length: {len(sequence)}")
print(f"Embedding shape: {embeddings.shape}")  # [batch_size, seq_len, embedding_dim]

Zero-shot mutation effect scoring

Estimate the functional impact of genetic variants using PlantCAD's log-likelihood scores.

Input format options:

  1. VCF files (recommended): Standard variant format with reference genome
  2. TSV files: Pre-processed sequences with variant information

Basic usage with VCF:

# Download example reference genome
wget https://download.maizegdb.org/Zm-B73-REFERENCE-NAM-5.0/Zm-B73-REFERENCE-NAM-5.0.fa.gz
gunzip Zm-B73-REFERENCE-NAM-5.0.fa.gz

# Run zero-shot scoring
python src/zero_shot_score.py \
    -input-vcf examples/example_maize_snp.vcf \
    -input-fasta Zm-B73-REFERENCE-NAM-5.0.fa \
    -output scored_variants.vcf \
    -model 'kuleshov-group/PlantCaduceus_l32' \
    -device 'cuda:0'

Expected output:

  • Scored VCF file with PlantCAD scores in the INFO field
  • Scores represent log-likelihood ratios between reference and alternative allelesLow negative scores indicate more likely deleterious mutations

Convert VCF to table format (optional, for easier processing):

bash src/format_VCF.sh \
    examples/example_maize_snp.vcf \
    Zm-B73-REFERENCE-NAM-5.0.fa \
    formatted_variants.tsv

Use table format directly:

python src/zero_shot_score.py \
    -input-table formatted_variants.tsv \
    -output results.tsv \
    -model 'kuleshov-group/PlantCaduceus_l32' \
    -device 'cuda:0' \
    -outBED  # Optional: output in BED format

In-silico mutagenesis pipeline

For large-scale simulation and analysis of genetic variants, we provide a comprehensive in-silico mutagenesis pipeline. See pipelines/in-silico-mutagenesis/README.md for detailed instructions.

Advanced Usage

Training XGBoost classifiers

Train custom classifiers on top of PlantCAD embeddings for specific annotation tasks (e.g., TIS, TTS, splice sites).

Purpose: Fine-tune prediction performance for specific annotation tasks using supervised learning.

Data format: Training data should follow the format used in our cross-species annotation dataset.

python src/train_XGBoost.py \
    -train train.tsv \
    -valid valid.tsv \
    -test test_rice.tsv \
    -model 'kuleshov-group/PlantCaduceus_l20' \
    -output ./output \
    -device 'cuda:0'

Expected outputs:

  • Trained XGBoost classifier (.json file)
  • Performance metrics on validation/test sets
  • Feature importance analysis

Using pre-trained XGBoost classifiers

We provide pre-trained XGBoost classifiers for common annotation tasks in the classifiers directory.

Available classifiers:

  • TIS (Translation Initiation Sites)
  • TTS (Translation Termination Sites)
  • Splice donor/acceptor sites
python src/predict_XGBoost.py \
    -test test_rice.tsv \
    -model 'kuleshov-group/PlantCaduceus_l20' \
    -classifier classifiers/PlantCaduceus_l20/TIS_XGBoost.json \
    -device 'cuda:0' \
    -output ./output

Expected output: Predictions with confidence scores for each sequence in your test data.

Development and Training

Pre-training PlantCAD

For advanced users who want to pre-train PlantCAD models from scratch or fine-tune on custom datasets.

Requirements:

  • Large computational resources (multi-GPU recommended)
  • WandB account for experiment tracking
  • Custom genomic dataset in HuggingFace format

Basic pre-training command:

WANDB_PROJECT=PlantCAD python src/HF_pre_train.py \
    --do_train \
    --report_to wandb \
    --prediction_loss_only True \
    --remove_unused_columns False \
    --dataset_name 'kuleshov-group/Angiosperm_16_genomes' \
    --soft_masked_loss_weight_train 0.1 \
    --soft_masked_loss_weight_evaluation 0.0 \
    --weight_decay 0.01 \
    --optim adamw_torch \
    --dataloader_num_workers 16 \
    --preprocessing_num_workers 16 \
    --seed 32 \
    --save_strategy steps \
    --save_steps 1000 \
    --evaluation_strategy steps \
    --eval_steps 1000 \
    --logging_steps 10 \
    --max_steps 120000 \
    --warmup_steps 1000 \
    --save_total_limit 20 \
    --learning_rate 2E-4 \
    --lr_scheduler_type constant_with_warmup \
    --run_name test \
    --overwrite_output_dir \
    --output_dir "PlantCaduceus_train_1" \
    --per_device_train_batch_size 32 \
    --per_device_eval_batch_size 32 \
    --gradient_accumulation_steps 4 \
    --tokenizer_name 'kuleshov-group/PlantCaduceus_l20' \
    --config_name 'kuleshov-group/PlantCaduceus_l20'

Key parameters:

  • dataset_name: Your custom dataset or use our Angiosperm dataset
  • max_steps: Total training steps (adjust based on dataset size)
  • learning_rate: 2E-4 works well for most cases
  • Batch sizes: Adjust based on your GPU memory

Performance benchmarks

The inference speed is highly dependent on the model size and GPU type. Performance with 5,000 SNPs:

Model H100 A100 A6000 3090 A5000 A40 2080
PlantCaduceus_l20 16s 19s 24s 25s 25s 26s 44s
PlantCaduceus_l24 21s 27s 35s 37s 42s 38s 71s
PlantCaduceus_l28 31s 43s 62s 69s 77s 67s 137s
PlantCaduceus_l32 47s 66s 94s 116s 130s 107s 232s

Pre-train PlantCAD with huggingface

WANDB_PROJECT=PlantCAD python src/HF_pre_train.py --do_train 
    --report_to wandb --prediction_loss_only True --remove_unused_columns False --dataset_name 'kuleshov-group/Angiosperm_16_genomes' --soft_masked_loss_weight_train 0.1 --soft_masked_loss_weight_evaluation 0.0 \
    --weight_decay 0.01 --optim adamw_torch \
    --dataloader_num_workers 16 --preprocessing_num_workers 16 --seed 32 \
    --save_strategy steps --save_steps 1000 --evaluation_strategy steps --eval_steps 1000 --logging_steps 10 \
    --max_steps 120000 --warmup_steps 1000 \
    --save_total_limit 20 --learning_rate 2E-4 --lr_scheduler_type constant_with_warmup \
    --run_name test --overwrite_output_dir \
    --output_dir "PlantCaduceus_train_1" --per_device_train_batch_size 32 --per_device_eval_batch_size 32 --gradient_accumulation_steps 4 --tokenizer_name 'kuleshov-group/PlantCaduceus_l20' --config_name 'kuleshov-group/PlantCaduceus_l20'

Citation

Zhai, J., Gokaslan, A., Schiff, Y., Berthel, A., Liu, Z. Y., Lai, W. L., Miller, Z. R., Scheben, A., Stitzer, M. C., Romay, M. C., Buckler, E. S., & Kuleshov, V. (2025). Cross-species modeling of plant genomes at single nucleotide resolution using a pretrained DNA language model. Proceedings of the National Academy of Sciences, 122(24), e2421738122. https://doi.org/10.1073/pnas.2421738122

About

Cross-species modeling of plant genomes

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •