<a href="https://colab.research.google.com/github/crhysc/jarvis-tools-notebooks/blob/master/cdvae_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inverse Design of Next-Generation Superconductors Using Data-Driven Deep Generative Models

# Tutorial: CDVAE, Crystal Diffusion Variational AutoEncoder



[Reference DOI](https://pubs.acs.org/doi/10.1021/acs.jpclett.3c01260)

Authors: Charles "Rhys" Campbell (crc00042@mix.wvu.edu), Kamal Choudhary (kamal.choudhary@nist.gov),

# (1) INTRODUCTION AND MOTIVATION


# (2) INSTALLATION, CONFIGURATION, AND DEPENDENCIES


# Install Conda

In [None]:
!pip install -q condacolab
import condacolab, os, sys
condacolab.install()
print("Done")

# Install CDVAE

In [None]:
import os
%cd /content
if not os.path.exists('cdvae'):
  !git clone https://github.com/txie-93/cdvae.git
print("Done")

# Switch Colab Runtime to GPU
At the top menu by the Colab logo, select **Runtime** -> **Change runtime type** -> **Any GPU**    

If this works, create GPU-based conda environment.  

If this fails due to usage limits, make the CPU-based conda environment.  



# Create **GPU**-based conda environment for CDVAE

#### Creating the **GPU** legacy env takes 7 minutes


In [None]:
%%time
%cd /content/cdvae
!mamba env create -p /usr/local/envs/cdvae_legacy -f env.yml
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    mamba install -c conda-forge "torchmetrics<0.8" --yes
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    mamba install mkl=2024.0 --yes
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install "monty==2022.9.9"
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    mamba install -c conda-forge "pymatgen>=2022.0.8,<2023" --yes
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install pandas jarvis-tools
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install --upgrade torch_geometric==1.7.0
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install -e .
print("Done")

In [46]:
!conda run -p /usr/local/envs/cdvae_legacy python -c "import sys; print(sys.version)"
# proves that conda is running python 3.8.*

3.8.20 | packaged by conda-forge | (default, Sep 30 2024, 17:52:49) 
[GCC 13.3.0]



# Create **CPU**-based conda environment for CDVAE

#### Creating the **CPU** legacy env takes 10 minutes


In [None]:
%%time
%cd /content/cdvae
!mamba env create -p /usr/local/envs/cdvae_legacy -f env.cpu.yml
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    mamba install -c conda-forge "torchmetrics<0.8" --yes
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    mamba install mkl=2024.0 --yes
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install "monty==2022.9.9"
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    mamba install -c conda-forge "pymatgen>=2022.0.8,<2023" --yes
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install pandas jarvis-tools
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install --upgrade torch_geometric==1.7.0
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install -e .
print("Done")

In [50]:
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install --upgrade torch_geometric==1.7.0

Collecting torch_geometric==1.7.0
  Downloading torch_geometric-1.7.0.tar.gz (212 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting rdflib (from torch_geometric==1.7.0)
  Downloading rdflib-7.1.4-py3-none-any.whl.metadata (11 kB)
Collecting h5py (from torch_geometric==1.7.0)
  Downloading h5py-3.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)
Collecting isodate<1.0.0,>=0.7.2 (from rdflib->torch_geometric==1.7.0)
  Downloading isodate-0.7.2-py3-none-any.whl.metadata (11 kB)
Downloading h5py-3.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.3/5.3 MB[0m [31m68.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rdflib-7.1.4-py3-none-any.whl (565 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m565.1/565.1 kB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading isodate-0.7.2-py3-none-any.whl (22 kB)
Building wheels for collected 

# Install Dataset ETL dependencies


In [None]:
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install pandas jarvis-tools

# (3) DATASET ETL (Extract-Transform-Load)


# Download data pre-processor

Data was generated using this [script](https://github.com/JARVIS-Materials-Design/cdvae/blob/main/scripts/generate_data_cdvae.py). It lives in the JARVIS Materials design repository, and it compiles a set of around 1000 structures and their superconducting critical temperatures into the format required for CDVAE training.

In [None]:
!wget https://raw.githubusercontent.com/JARVIS-Materials-Design/cdvae/refs/heads/main/scripts/generate_data_cdvae.py

# Run data pre-processor

In [None]:
!conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python generate_data_cdvae.py
print("Done")

# Move train/test/val data to the correct spot

In [None]:
%cd /content
%mkdir /content/cdvae/data/supercon
%mv /content/cdvae/scripts/train.csv /content/cdvae/data/supercon/
%mv /content/cdvae/scripts/val.csv /content/cdvae/data/supercon/
%mv /content/cdvae/scripts/test.csv /content/cdvae/data/supercon/
print("Done")

# Pull the supercon Hydra config YAML from JARVIS

In [47]:
%cd /content/cdvae/conf/data/
!wget https://raw.githubusercontent.com/JARVIS-Materials-Design/cdvae/refs/heads/main/conf/data/supercon.yaml

[Errno 2] No such file or directory: '/content/cdvae/conf/data/ && wget https://raw.githubusercontent.com/JARVIS-Materials-Design/cdvae/refs/heads/main/conf/data/supercon.yaml'
/content/cdvae


# (4) TRAIN WITHOUT PROPERTY PREDICTOR

# If using **GPU**

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 HYDRA_FULL_ERROR=1 \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 WANDB_ANONYMOUS=allow \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python -u -m cdvae.run data=supercon expname=supercon \
    model.num_noise_level=2 \
    data.train_max_epochs=2

# If using **CPU**

In [45]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 HYDRA_FULL_ERROR=1 \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 WANDB_ANONYMOUS=allow \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python -u -m cdvae.run data=supercon expname=supercon \
    model.num_noise_level=2 \
    data.train_max_epochs=2 \
    train.pl_trainer.gpus=0

[2025-05-27 18:03:45,062][hydra.utils][INFO] - Instantiating <cdvae.pl_data.datamodule.CrystDataModule>
100% 846/846 [01:03<00:00, 13.33it/s]
  X = torch.tensor(X, dtype=torch.float)
[2025-05-27 18:04:51,364][hydra.utils][INFO] - Instantiating <cdvae.pl_modules.model.CDVAE>
[2025-05-27 18:05:07,867][hydra.utils][INFO] - Passing scaler from datamodule to model <StandardScalerTorch(means: 3.714592695236206, stds: 4.966126441955566)>
[2025-05-27 18:05:07,869][hydra.utils][INFO] - Adding callback <LearningRateMonitor>
[2025-05-27 18:05:07,869][hydra.utils][INFO] - Adding callback <EarlyStopping>
[2025-05-27 18:05:07,871][hydra.utils][INFO] - Adding callback <ModelCheckpoint>
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
[2025-05-27 18:05:07,875][hydra.utils][INFO] - Instantiating <WandbLogger>
[2025-05-27 18:05:07,876][hydra.utils][INFO] - W&B is now watching <{cfg.logging.wandb_watch.log}>!
[34m[1mwandb[0m: Currently logged in as: [33manony-moose-9110992

# (5) TRAIN WITH PROPERTY PREDICTOR

# If using **GPU**

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 HYDRA_FULL_ERROR=1 \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 WANDB_ANONYMOUS=allow \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python -u -m cdvae.run data=supercon expname=supercon \
    model.num_noise_level=2 \
    data.train_max_epochs=2 \
    train.pl_trainer.gpus=0 \
    model.predict_property=True

# If using **CPU**

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 HYDRA_FULL_ERROR=1 \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 WANDB_ANONYMOUS=allow \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python -u -m cdvae.run data=supercon expname=supercon \
    data.train_max_epochs=2 \
    model.num_noise_level=2 \
    model.predict_property=True \
    train.pl_trainer.gpus=0

# (6) INFERENCE

# Reconstruction

In [51]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python /content/cdvae/scripts/evaluate.py \
    --model_path /content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon \
    --tasks recon

100% 105/105 [00:08<00:00, 12.59it/s]
Evaluate model on the reconstruction task.
batch 0 in 5
  X = torch.tensor(X, dtype=torch.float)
100% 2/2 [01:07<00:00, 33.79s/it]
batch 1 in 5
100% 2/2 [00:58<00:00, 29.29s/it]
batch 2 in 5
100% 2/2 [01:19<00:00, 39.77s/it]
batch 3 in 5
100% 2/2 [00:58<00:00, 29.32s/it]
batch 4 in 5
100% 2/2 [00:14<00:00,  7.29s/it]


# Generation

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python /content/cdvae/scripts/evaluate.py \
    --model_path /content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon \
    --tasks gen

# Optimization

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python /content/cdvae/scripts/evaluate.py \
    --model_path /content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon \
    --tasks opt

# (7) NEXT STEPS & REFERENCES

## Next Steps

1. **Hyperparameter exploration**  
   - Try different numbers of noise levels (`model.num_noise_level`) and training epochs to improve sample quality.  
   - Use Hydra’s multirun (`hydra -m`) to sweep learning rates, batch sizes, and noise schedules in parallel.

2. **Property-conditioned generation**  
   - Re-enable the property predictor (`model.predict_property=True`) and train with longer schedules to learn accurate Tₙₑ predictions.  
   - After training, sample structures by specifying a target critical temperature and evaluate via DFT or empirical models.

3. **Dataset expansion**  
   - Incorporate additional superconducting materials from JARVIS or other open databases to grow beyond the ~1 000-structure subset.  
   - Preprocess new data with `generate_data_cdvae.py` (adjust train/val/test splits as needed).

4. **Model evaluation and validation**  
   - Use the `scripts/evaluate.py` tasks `recon`, `gen`, and `opt` to quantify reconstruction error, chemical validity, and property matching.  
   - Compute domain-specific metrics: e.g. composition novelty, structural clustering in latent space, and MAE on Tc.

5. **Downstream integration**  
   - Wrap inference in a simple web demo (Gradio or Streamlit) so collaborators can interactively explore generated candidates.  
   - Integrate into a computational pipeline (e.g. DVC + GitHub Actions) for reproducible end-to-end workflows.

6. **Advanced research directions**  
   - Extend the diffusion VAE approach to other material classes (e.g. thermoelectrics, battery cathodes).  
   - Investigate Lorentzian or graph-based diffusion processes to directly handle crystal symmetry constraints.

---

## References

- **Original CDVAE paper:**  
  Li _et al._, “Crystal Diffusion Variational Autoencoder for Inverse Materials Design,” _J. Phys. Chem. Lett._ 2023, DOI: [10.1021/acs.jpclett.3c01260](https://pubs.acs.org/doi/10.1021/acs.jpclett.3c01260)

- **CDVAE GitHub repo:**  
  https://github.com/txie-93/cdvae

- **JARVIS-Materials-Design:**  
  https://github.com/JARVIS-Materials-Design/jarvis

- **Hydra configuration framework:**  
  https://hydra.cc

- **PyTorch Lightning:**  
  https://www.pytorchlightning.ai

- **condacolab:**  
  https://github.com/conda-incubator/condacolab

- **Mamba (fast conda):**  
  https://github.com/mamba-org/mamba

- **Jarvis-tools (data ETL):**  
  https://github.com/JARVIS-Materials-Design/jarvis-tools
