A neural network framework for accelerating Density Functional Theory (DFT) calculations through machine learning. This repository implements state-of-the-art neural network architectures for predicting molecular properties, Hamiltonian matrices, and wavefunctions.
- Features
- Architecture
- Installation
- Quick Start
- Training
- Configuration
- Models
- Datasets
- Scripts
- Project Structure
- Citation
- License
- Multiple Model Architectures: Support for QHNet, Equiformer, ViSNet, and LSR variants
- Hamiltonian Prediction: Direct prediction of molecular Hamiltonian matrices
- Multi-task Learning: Simultaneous prediction of energies, forces, and Hamiltonian matrices
- Symmetry-aware Models: Incorporation of physical symmetries (SO(2), SO(3))
- Distributed Training: Support for multi-GPU training with PyTorch Lightning
- Flexible Configuration: Hydra-based configuration management
- Experiment Tracking: Integration with Weights & Biases (wandb)
The framework supports several neural network architectures:
- QHNet: Quantum Hamiltonian Network for molecular property prediction
- Equiformer: Equivariant transformer architecture with attention mechanisms
- ViSNet: Vector-Scalar Interaction networks
- LSR Models: Large-scale representation models with improved efficiency
- Representation Models: Extract molecular features from atomic coordinates
- Output Modules: Task-specific heads for energy, forces, and Hamiltonian prediction
- Hamiltonian Heads: Specialized modules for predicting quantum mechanical Hamiltonians
- CUDA driver version ≥ 520 (for CUDA Toolkit 12 support)
- Python 3.10+
- Conda/Mamba package manager
Create a virtual environment with all necessary packages:
# Clone the repository
git clone https://github.com/your-repo/Wavefunction-Alignment-Net.git
cd Wavefunction-Alignment-Net
# Create conda environment
conda env create -n madft_nn -f environment.yaml
# Activate environment
conda activate madft_nn
# Install the custom CUDA DFT wheel (if needed)
pip install cudft-0.2.6-cp310-cp310-linux_x86_64.whl- PyTorch 2.1.0 with CUDA 12.1
- PyTorch Lightning
- e3nn (for equivariant neural networks)
- PyTorch Geometric
- PySCF 2.4.0
- Weights & Biases
- Hydra configuration framework
Train a Hamiltonian prediction model on QH9 dataset:
python pipelines/train.py \
    --config-name=config.yaml \
    --wandb=True \
    --wandb-group="test" \
    --dataset-path="/path/to/QH9_new.db" \
    --ngpus=4 \
    --lr=0.0005 \
    --enable-hami=True \
    --hami-weight=1 \
    --batch-size=32 \
    --lr-warmup-steps=1000 \
    --lr-schedule="polynomial" \
    --max-steps=300000 \
    --used-cache=True \
    --train-ratio=0.9 \
    --val-ratio=0.06 \
    --test-ratio=0.04 \
    --gradient-clip-val=5.0 \
    --dataset-size=100000The training pipeline (pipelines/train.py) supports:
- Multi-GPU distributed training
- Automatic mixed precision
- Gradient clipping
- Learning rate scheduling (cosine, polynomial, reduce-on-plateau)
- Early stopping
- Model checkpointing
Key training parameters:
- model_backbone: Choose from QHNet, Equiformerv2, LSR_QHNet, ViSNet
- enable_hami: Enable Hamiltonian prediction
- enable_energy: Enable energy prediction
- enable_forces: Enable force prediction
- hami_weight,- energy_weight,- forces_weight: Loss function weights
- lr_schedule: Learning rate scheduler type
- max_steps: Maximum training steps
- batch_size: Training batch size
Enable multiple prediction tasks simultaneously:
python pipelines/train.py \
    --enable-hami=True \
    --enable-energy=True \
    --enable-forces=True \
    --hami-weight=1.0 \
    --energy-weight=0.1 \
    --forces-weight=0.01The project uses Hydra for configuration management. Configuration files are located in config/:
- config.yaml: Default configuration
- config_equiformer.yaml: Equiformer-specific settings
- config_lsrqhnet.yaml: LSR-QHNet configuration
- config_lsrm.yaml: LSRM model configuration
- model/: Model-specific configurations
- schedule/: Learning rate schedule configurations
Create custom configurations by:
- Creating a new YAML file in config/
- Overriding parameters via command line
- Using Hydra's composition feature
Example custom config:
defaults:
  - model: equiformerv2
  - schedule: polynomial
  - _self_
model_backbone: Equiformerv2SO2
batch_size: 64
lr: 1e-3
max_steps: 500000- 
QHNet Variants: - QHNet_backbone: Basic QHNet architecture
- QHNet_backbone_No: QHNet without specific features
- QHNetBackBoneSO2: SO(2) symmetric QHNet
 
- 
Equiformer Variants: - Equiformerv2: Equivariant transformer v2
- Equiformerv2SO2: SO(2) symmetric Equiformer
- GraphAttentionTransformer: Graph attention-based transformer
 
- 
LSR Models: - LSR_QHNet_backbone: Large-scale representation QHNet
- LSR_Equiformerv2SO2: LSR variant of Equiformer
- LSRM: Standalone LSR model
 
- 
Other Models: - ViSNet: Vision-inspired molecular network
- spherical_visnet: Spherical harmonics-based ViSNet
 
- Representation Models: Extract invariant/equivariant features
- Output Modules:
- EquivariantScalar_viaTP: Tensor product-based scalar output
- Energy/force prediction heads
 
- Hamiltonian Heads:
- HamiHead: Basic Hamiltonian prediction
- HamiHeadSymmetry: Symmetry-aware Hamiltonian prediction
 
- QH9: Quantum Hamiltonian dataset with ~130k molecules
- Water Clusters: Various water cluster configurations (W3, W6, W15, etc.)
- PubChem: Subset of PubChem molecules
- Malondialdehyde: Conformational dataset
Datasets are stored in SQLite or LMDB format with the following structure:
- Atomic coordinates
- Atomic numbers
- Hamiltonian matrices (if available)
- Energy values
- Force vectors
- Additional quantum mechanical properties
Use the data preparation utilities in src/madftnn/data_prepare/ to:
- Convert raw data to database format
- Generate train/validation/test splits
- Apply data augmentation
Located in scripts/:
- run_qh9.sh: Train on QH9 dataset
- run_water.sh: Train on water clusters
- run_pubchem.sh: Train on PubChem molecules
- run_malondialdehyde.sh: Train on malondialdehyde conformations
- test_script.sh: Testing and validation scripts
- job.sh: HPC job submission script
Wavefunction-Alignment-Net/
├── config/                 # Configuration files
│   ├── model/             # Model-specific configs
│   └── schedule/          # LR schedule configs
├── src/madftnn/           # Main source code
│   ├── models/            # Neural network models
│   ├── training/          # Training utilities
│   ├── dataset/           # Dataset handling
│   ├── utility/           # Helper functions
│   └── analysis/          # Analysis tools
├── pipelines/             # Training and testing pipelines
├── scripts/               # Utility scripts
├── dataset/               # Dataset storage
├── outputs/               # Training outputs
├── wandb/                 # Weights & Biases logs
├── local_files/           # Local configurations (git-ignored)
└── environment.yaml       # Conda environment specification
To implement a custom model:
- Create a new model class in src/madftnn/models/
- Inherit from appropriate base class
- Register the model in model.py
- Create corresponding configuration
For multi-node training:
python pipelines/train.py \
    --ngpus=8 \
    --num-nodes=4 \
    --strategy=ddpAll experiments are tracked using Weights & Biases:
- Set wandb.open=Trueto enable tracking
- Configure project and group names
- Monitor training progress in real-time
If you use this code in your research, please cite:
@inproceedings{
li2025enhancing,
title={Enhancing the Scalability and Applicability of Kohn-Sham Hamiltonians for Molecular Systems},
author={Yunyang Li and Zaishuo Xia and Lin Huang and Xinran Wei and Samuel Harshe and Han Yang and Erpai Luo and Zun Wang and Jia Zhang and Chang Liu and Bin Shao and Mark Gerstein},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=twEvvkQqPS}
}This project is licensed under the MIT License - see the LICENSE file for details.
- Built on PyTorch and PyTorch Lightning
- Uses e3nn for equivariant neural networks
- Incorporates ideas from QHNet, Equiformer, and ViSNet architectures
For questions and support, please open an issue on GitHub or contact the maintainers.