<a href="https://colab.research.google.com/github/crhysc/jarvis-tools-notebooks/blob/master/cdvae_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inverse Design of Next-Generation Superconductors Using Data-Driven Deep Generative Models

# Tutorial: CDVAE, Crystal Diffusion Variational AutoEncoder



[Reference DOI](https://pubs.acs.org/doi/10.1021/acs.jpclett.3c01260)

Authors: Charles "Rhys" Campbell (crc00042@mix.wvu.edu), Kamal Choudhary (kamal.choudhary@nist.gov),

# (1) INTRODUCTION AND MOTIVATION


# (2) INSTALLATION, CONFIGURATION, AND DEPENDENCIES


# Install Conda

In [None]:
!pip install -q condacolab
import condacolab, os, sys
condacolab.install()
print("Done")

# Install CDVAE

In [None]:
import os
%cd /content
if not os.path.exists('cdvae'):
  !git clone https://github.com/txie-93/cdvae.git
print("Done")

# Switch Colab Runtime to GPU
At the top menu by the Colab logo, select **Runtime** -> **Change runtime type** -> **Any GPU**    

If this works, create GPU-based conda environment.  

If this fails due to usage limits, make the CPU-based conda environment.  



# Create **GPU**-based conda environment for CDVAE

#### Creating the **GPU** legacy env takes 7 minutes


In [None]:
%%time
%cd /content/cdvae
!mamba env create -p /usr/local/envs/cdvae_legacy -f env.yml
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    mamba install -c conda-forge "torchmetrics<0.8" --yes
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    mamba install mkl=2024.0 --yes
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install "monty==2022.9.9"
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    mamba install -c conda-forge "pymatgen>=2022.0.8,<2023" --yes
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install pandas jarvis-tools
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install --upgrade torch_geometric==1.7.0
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install -e .
print("Done")

In [None]:
!conda run -p /usr/local/envs/cdvae_legacy python -c "import sys; print(sys.version)"
# proves that conda is running python 3.8.*

# Create **CPU**-based conda environment for CDVAE

#### Creating the **CPU** legacy env takes 10 minutes


In [None]:
%%time
%cd /content/cdvae
!mamba env create -p /usr/local/envs/cdvae_legacy -f env.cpu.yml
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    mamba install -c conda-forge "torchmetrics<0.8" --yes
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    mamba install mkl=2024.0 --yes
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install "monty==2022.9.9"
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    mamba install -c conda-forge "pymatgen>=2022.0.8,<2023" --yes
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install pandas jarvis-tools
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install --upgrade torch_geometric==1.7.0
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install -e .
print("Done")

In [None]:
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install --upgrade torch_geometric==1.7.0

# Install Other dependencies


In [None]:
!pip install torch-geometric
!conda run -p /usr/local/envs/cdvae_legacy --live-stream\
    pip install pandas jarvis-tools

# (3) DATASET ETL (Extract-Transform-Load)


# Download data pre-processor

Data was generated using this [script](https://github.com/JARVIS-Materials-Design/cdvae/blob/main/scripts/generate_data_cdvae.py). It lives in the JARVIS Materials design repository, and it compiles a set of around 1000 structures and their superconducting critical temperatures into the format required for CDVAE training.

In [None]:
!wget https://raw.githubusercontent.com/JARVIS-Materials-Design/cdvae/refs/heads/main/scripts/generate_data_cdvae.py

# Run data pre-processor

In [None]:
!conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python generate_data_cdvae.py
print("Done")

# Move train/test/val data to the correct spot

In [None]:
%cd /content
%mkdir /content/cdvae/data/supercon
%mv /content/cdvae/scripts/train.csv /content/cdvae/data/supercon/
%mv /content/cdvae/scripts/val.csv /content/cdvae/data/supercon/
%mv /content/cdvae/scripts/test.csv /content/cdvae/data/supercon/
print("Done")

# Pull the supercon Hydra config YAML from JARVIS

In [None]:
%cd /content/cdvae/conf/data/
!wget https://raw.githubusercontent.com/JARVIS-Materials-Design/cdvae/refs/heads/main/conf/data/supercon.yaml

# (4) TRAIN WITHOUT PROPERTY PREDICTOR

# If using **GPU**

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 HYDRA_FULL_ERROR=1 \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 WANDB_ANONYMOUS=allow \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python -u -m cdvae.run data=supercon expname=supercon \
    model.num_noise_level=2 \
    data.train_max_epochs=2

# If using **CPU**

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 HYDRA_FULL_ERROR=1 \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 WANDB_ANONYMOUS=allow \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python -u -m cdvae.run data=supercon expname=supercon \
    model.num_noise_level=2 \
    data.train_max_epochs=2 \
    train.pl_trainer.gpus=0

# (5) TRAIN WITH PROPERTY PREDICTOR

# If using **GPU**

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 HYDRA_FULL_ERROR=1 \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 WANDB_ANONYMOUS=allow \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python -u -m cdvae.run data=supercon expname=supercon \
    model.num_noise_level=2 \
    data.train_max_epochs=2 \
    train.pl_trainer.gpus=0 \
    model.predict_property=True

# If using **CPU**

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 HYDRA_FULL_ERROR=1 \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 WANDB_ANONYMOUS=allow \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python -u -m cdvae.run data=supercon expname=supercon \
    data.train_max_epochs=2 \
    model.num_noise_level=2 \
    model.predict_property=True \
    train.pl_trainer.gpus=0

# (6) INFERENCE

# Reconstruction

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python /content/cdvae/scripts/evaluate.py \
    --model_path /content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon \
    --tasks recon

In [72]:
import torch
from pprint import pprint
path = "/content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon/eval_recon.pt"
data = torch.load(path, map_location="cpu", weights_only=False)
pprint(data, width=120, indent=2)

{ 'all_atom_types_stack': [],
  'all_frac_coords_stack': [],
  'angles': tensor([[[88.4881, 85.0363, 90.2158],
         [91.6090, 84.9120, 90.4209],
         [91.8586, 87.8223, 91.1908],
         [83.4981, 82.6112, 92.8762],
         [80.8399, 82.4497, 77.8793],
         [88.7957, 86.7069, 88.3294],
         [83.0859, 78.7776, 84.8059],
         [85.0294, 84.5192, 89.9592],
         [85.9765, 85.6731, 87.3198],
         [93.1717, 90.7440, 96.3543],
         [93.6946, 85.4302, 94.4052],
         [87.3645, 82.1149, 88.5198],
         [88.9543, 82.4099, 85.8510],
         [90.2550, 84.2351, 90.6315],
         [83.6130, 83.1166, 86.3351],
         [91.4460, 86.0063, 90.3574],
         [80.8413, 81.3563, 80.6149],
         [81.7868, 78.1572, 83.4049],
         [88.5204, 85.6616, 93.6883],
         [83.7781, 80.4717, 81.8280],
         [88.2029, 84.8290, 92.5654],
         [84.5454, 79.5651, 84.8938],
         [79.4835, 80.0969, 83.5648],
         [89.1225, 88.0191, 91.8006],
         [88.78

# Generation

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python /content/cdvae/scripts/evaluate.py \
    --model_path /content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon \
    --tasks gen

In [73]:
import torch
from pprint import pprint
path = "/content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon/eval_gen.pt"
data = torch.load(path, map_location="cpu", weights_only=False)
pprint(data, width=120, indent=2)

{ 'all_atom_types_stack': [],
  'all_frac_coords_stack': [],
  'angles': tensor([[[88.4881, 85.0363, 90.2158],
         [91.6090, 84.9120, 90.4209],
         [91.8586, 87.8223, 91.1908],
         [83.4981, 82.6112, 92.8762],
         [80.8399, 82.4497, 77.8793],
         [88.7957, 86.7069, 88.3294],
         [83.0859, 78.7776, 84.8059],
         [85.0294, 84.5192, 89.9592],
         [85.9765, 85.6731, 87.3198],
         [93.1717, 90.7440, 96.3543],
         [93.6946, 85.4302, 94.4052],
         [87.3645, 82.1149, 88.5198],
         [88.9543, 82.4099, 85.8510],
         [90.2550, 84.2351, 90.6315],
         [83.6130, 83.1166, 86.3351],
         [91.4460, 86.0063, 90.3574],
         [80.8413, 81.3563, 80.6149],
         [81.7868, 78.1572, 83.4049],
         [88.5204, 85.6616, 93.6883],
         [83.7781, 80.4717, 81.8280],
         [88.2029, 84.8290, 92.5654],
         [84.5454, 79.5651, 84.8938],
         [79.4835, 80.0969, 83.5648],
         [89.1225, 88.0191, 91.8006],
         [88.78

# Optimization

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python /content/cdvae/scripts/evaluate.py \
    --model_path /content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon \
    --tasks opt

In [None]:
import torch
from pprint import pprint
path = "/content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon/eval_recon.pt"
data = torch.load(path, map_location="cpu", weights_only=False)
pprint(data, width=120, indent=2)

# (7) NEXT STEPS & REFERENCES

## Next Steps

1. **Hyperparameter exploration**  
   - Try different numbers of noise levels (`model.num_noise_level`) and training epochs to improve sample quality.  
   - Use Hydra’s multirun (`hydra -m`) to sweep learning rates, batch sizes, and noise schedules in parallel.

2. **Property-conditioned generation**  
   - Re-enable the property predictor (`model.predict_property=True`) and train with longer schedules to learn accurate Tₙₑ predictions.  
   - After training, sample structures by specifying a target critical temperature and evaluate via DFT or empirical models.

3. **Dataset expansion**  
   - Incorporate additional superconducting materials from JARVIS or other open databases to grow beyond the ~1 000-structure subset.  
   - Preprocess new data with `generate_data_cdvae.py` (adjust train/val/test splits as needed).

4. **Model evaluation and validation**  
   - Use the `scripts/evaluate.py` tasks `recon`, `gen`, and `opt` to quantify reconstruction error, chemical validity, and property matching.  
   - Compute domain-specific metrics: e.g. composition novelty, structural clustering in latent space, and MAE on Tc.

5. **Downstream integration**  
   - Wrap inference in a simple web demo (Gradio or Streamlit) so collaborators can interactively explore generated candidates.  
   - Integrate into a computational pipeline (e.g. DVC + GitHub Actions) for reproducible end-to-end workflows.

6. **Advanced research directions**  
   - Extend the diffusion VAE approach to other material classes (e.g. thermoelectrics, battery cathodes).  
   - Investigate Lorentzian or graph-based diffusion processes to directly handle crystal symmetry constraints.

---

## References

- **Original CDVAE paper:**  
  Li _et al._, “Crystal Diffusion Variational Autoencoder for Inverse Materials Design,” _J. Phys. Chem. Lett._ 2023, DOI: [10.1021/acs.jpclett.3c01260](https://pubs.acs.org/doi/10.1021/acs.jpclett.3c01260)

- **CDVAE GitHub repo:**  
  https://github.com/txie-93/cdvae

- **JARVIS-Materials-Design:**  
  https://github.com/JARVIS-Materials-Design/jarvis

- **Hydra configuration framework:**  
  https://hydra.cc

- **PyTorch Lightning:**  
  https://www.pytorchlightning.ai

- **condacolab:**  
  https://github.com/conda-incubator/condacolab

- **Mamba (fast conda):**  
  https://github.com/mamba-org/mamba

- **Jarvis-tools (data ETL):**  
  https://github.com/JARVIS-Materials-Design/jarvis-tools
