- PlantCAD overview
- Quick Start
- Model summary
- Prerequisites and system requirements
- Installation
- Basic Usage
- Advanced Usage
- Development and Training
- Citation
PlantCaduceus, with its short name of PlantCAD, is a plant DNA LM based on the Caduceus architecture, which extends the efficient Mamba linear-time sequence modeling framework to incorporate bi-directionality and reverse complement equivariance, specifically designed for DNA sequences. PlantCAD is pre-trained on a curated dataset of 16 Angiosperm genomes. PlantCAD showed state-of-the-art cross species performance in predicting TIS, TTS, Splice Donor and Splice Acceptor. The zero-shot of PlantCAD enables identifying genome-wide deleterious mutations and known causal variants in Arabidopsis, Sorghum and Maize.
New to PlantCAD? Try our Google Colab demo - no installation required!
For local usage: See installation instructions below, then use notebooks/examples.ipynb
to get started.
Pre-trained PlantCAD models have been uploaded to HuggingFace π€. Here's the summary of four PlantCAD models with different parameter sizes.
Model | Sequence Length | Model Size | Embedding Size |
---|---|---|---|
PlantCaduceus_l20 | 512bp | 20M | 384 |
PlantCaduceus_l24 | 512bp | 40M | 512 |
PlantCaduceus_l28 | 512bp | 128M | 768 |
PlantCaduceus_l32 | 512bp | 225M | 1024 |
Model Selection Guide:
- PlantCaduceus_l20: Good for testing and quick analysis
- PlantCaduceus_l32: Recommended for research and production (best performance)
For Google Colab: Just a Google account - GPU runtime recommended (free tier available)
For Local Installation: GPU recommended for reasonable performance. Dependencies will be installed automatically during setup.
No installation required! Just open our PlantCAD Google Colab notebook and start analyzing your data.
Setup steps:
- Open the Colab link
- Important: Set runtime to GPU (
Runtime
βChange runtime type
βHardware accelerator: GPU
) - Run the cells to install dependencies
- Upload your data or use the provided examples
Step 1: Create conda environment
# Clone the repository (if you haven't already)
git clone https://github.com/kuleshov-group/PlantCaduceus.git
cd PlantCaduceus
# Create and activate environment
conda env create -f env/environment.yml
conda activate PlantCAD
Step 2: Install Python packages
pip install -r env/requirements.txt --no-build-isolation
Step 3: Verify installation
# Test core dependencies
import torch
from mamba_ssm import Mamba
from transformers import AutoTokenizer, AutoModel
# Test PlantCAD model loading
tokenizer = AutoTokenizer.from_pretrained('kuleshov-group/PlantCaduceus_l20')
print("β
Installation successful!")
Alternative: pip-only installation If you prefer pip-only installation, see issue #10 for community solutions.
mamba_ssm issues (most common):
# If mamba_ssm import fails, reinstall with:
pip uninstall mamba-ssm
pip install mamba-ssm==2.2.0 --no-build-isolation
CUDA/GPU issues:
- Verify CUDA installation:
nvidia-smi
- Check PyTorch CUDA support:
python -c "import torch; print(torch.cuda.is_available())"
- For CPU-only usage: Models will work but be significantly slower
The easiest way to start is with our example notebook: notebooks/examples.ipynb
Quick example - Get sequence embeddings:
from transformers import AutoTokenizer, AutoModel
import torch
# Load model and tokenizer
model_name = 'kuleshov-group/PlantCaduceus_l20' # Start with smaller model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# Example plant DNA sequence (512bp max)
sequence = "ATGCGATCGATCGATC..." # Your sequence here
# Get embeddings
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state
print(f"Sequence length: {len(sequence)}")
print(f"Embedding shape: {embeddings.shape}") # [batch_size, seq_len, embedding_dim]
Estimate the functional impact of genetic variants using PlantCAD's log-likelihood scores.
Input format options:
- VCF files (recommended): Standard variant format with reference genome
- TSV files: Pre-processed sequences with variant information
Basic usage with VCF:
# Download example reference genome
wget https://download.maizegdb.org/Zm-B73-REFERENCE-NAM-5.0/Zm-B73-REFERENCE-NAM-5.0.fa.gz
gunzip Zm-B73-REFERENCE-NAM-5.0.fa.gz
# Run zero-shot scoring
python src/zero_shot_score.py \
-input-vcf examples/example_maize_snp.vcf \
-input-fasta Zm-B73-REFERENCE-NAM-5.0.fa \
-output scored_variants.vcf \
-model 'kuleshov-group/PlantCaduceus_l32' \
-device 'cuda:0'
Expected output:
- Scored VCF file with PlantCAD scores in the INFO field
- Scores represent log-likelihood ratios between reference and alternative allelesLow negative scores indicate more likely deleterious mutations
Convert VCF to table format (optional, for easier processing):
bash src/format_VCF.sh \
examples/example_maize_snp.vcf \
Zm-B73-REFERENCE-NAM-5.0.fa \
formatted_variants.tsv
Use table format directly:
python src/zero_shot_score.py \
-input-table formatted_variants.tsv \
-output results.tsv \
-model 'kuleshov-group/PlantCaduceus_l32' \
-device 'cuda:0' \
-outBED # Optional: output in BED format
For large-scale simulation and analysis of genetic variants, we provide a comprehensive in-silico mutagenesis pipeline. See pipelines/in-silico-mutagenesis/README.md for detailed instructions.
Train custom classifiers on top of PlantCAD embeddings for specific annotation tasks (e.g., TIS, TTS, splice sites).
Purpose: Fine-tune prediction performance for specific annotation tasks using supervised learning.
Data format: Training data should follow the format used in our cross-species annotation dataset.
python src/train_XGBoost.py \
-train train.tsv \
-valid valid.tsv \
-test test_rice.tsv \
-model 'kuleshov-group/PlantCaduceus_l20' \
-output ./output \
-device 'cuda:0'
Expected outputs:
- Trained XGBoost classifier (
.json
file) - Performance metrics on validation/test sets
- Feature importance analysis
We provide pre-trained XGBoost classifiers for common annotation tasks in the classifiers
directory.
Available classifiers:
- TIS (Translation Initiation Sites)
- TTS (Translation Termination Sites)
- Splice donor/acceptor sites
python src/predict_XGBoost.py \
-test test_rice.tsv \
-model 'kuleshov-group/PlantCaduceus_l20' \
-classifier classifiers/PlantCaduceus_l20/TIS_XGBoost.json \
-device 'cuda:0' \
-output ./output
Expected output: Predictions with confidence scores for each sequence in your test data.
For advanced users who want to pre-train PlantCAD models from scratch or fine-tune on custom datasets.
Requirements:
- Large computational resources (multi-GPU recommended)
- WandB account for experiment tracking
- Custom genomic dataset in HuggingFace format
Basic pre-training command:
WANDB_PROJECT=PlantCAD python src/HF_pre_train.py \
--do_train \
--report_to wandb \
--prediction_loss_only True \
--remove_unused_columns False \
--dataset_name 'kuleshov-group/Angiosperm_16_genomes' \
--soft_masked_loss_weight_train 0.1 \
--soft_masked_loss_weight_evaluation 0.0 \
--weight_decay 0.01 \
--optim adamw_torch \
--dataloader_num_workers 16 \
--preprocessing_num_workers 16 \
--seed 32 \
--save_strategy steps \
--save_steps 1000 \
--evaluation_strategy steps \
--eval_steps 1000 \
--logging_steps 10 \
--max_steps 120000 \
--warmup_steps 1000 \
--save_total_limit 20 \
--learning_rate 2E-4 \
--lr_scheduler_type constant_with_warmup \
--run_name test \
--overwrite_output_dir \
--output_dir "PlantCaduceus_train_1" \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--gradient_accumulation_steps 4 \
--tokenizer_name 'kuleshov-group/PlantCaduceus_l20' \
--config_name 'kuleshov-group/PlantCaduceus_l20'
Key parameters:
dataset_name
: Your custom dataset or use our Angiosperm datasetmax_steps
: Total training steps (adjust based on dataset size)learning_rate
: 2E-4 works well for most cases- Batch sizes: Adjust based on your GPU memory
The inference speed is highly dependent on the model size and GPU type. Performance with 5,000 SNPs:
Model | H100 | A100 | A6000 | 3090 | A5000 | A40 | 2080 |
---|---|---|---|---|---|---|---|
PlantCaduceus_l20 | 16s | 19s | 24s | 25s | 25s | 26s | 44s |
PlantCaduceus_l24 | 21s | 27s | 35s | 37s | 42s | 38s | 71s |
PlantCaduceus_l28 | 31s | 43s | 62s | 69s | 77s | 67s | 137s |
PlantCaduceus_l32 | 47s | 66s | 94s | 116s | 130s | 107s | 232s |
WANDB_PROJECT=PlantCAD python src/HF_pre_train.py --do_train
--report_to wandb --prediction_loss_only True --remove_unused_columns False --dataset_name 'kuleshov-group/Angiosperm_16_genomes' --soft_masked_loss_weight_train 0.1 --soft_masked_loss_weight_evaluation 0.0 \
--weight_decay 0.01 --optim adamw_torch \
--dataloader_num_workers 16 --preprocessing_num_workers 16 --seed 32 \
--save_strategy steps --save_steps 1000 --evaluation_strategy steps --eval_steps 1000 --logging_steps 10 \
--max_steps 120000 --warmup_steps 1000 \
--save_total_limit 20 --learning_rate 2E-4 --lr_scheduler_type constant_with_warmup \
--run_name test --overwrite_output_dir \
--output_dir "PlantCaduceus_train_1" --per_device_train_batch_size 32 --per_device_eval_batch_size 32 --gradient_accumulation_steps 4 --tokenizer_name 'kuleshov-group/PlantCaduceus_l20' --config_name 'kuleshov-group/PlantCaduceus_l20'
Zhai, J., Gokaslan, A., Schiff, Y., Berthel, A., Liu, Z. Y., Lai, W. L., Miller, Z. R., Scheben, A., Stitzer, M. C., Romay, M. C., Buckler, E. S., & Kuleshov, V. (2025). Cross-species modeling of plant genomes at single nucleotide resolution using a pretrained DNA language model. Proceedings of the National Academy of Sciences, 122(24), e2421738122. https://doi.org/10.1073/pnas.2421738122