<a href="https://colab.research.google.com/github/crhysc/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/flowmm_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Tutorial**: FlowMM & FlowLLM



**Authors**: Charles "Rhys" Campbell (crc00042@mix.wvu.edu)

# TABLE OF CONTENTS

- Background and Central Goal
- Installation, Configuration, and Dependencies
- Dataset ETL
- Training
  - Manifolds
  - Unconditional Training
  - Conditional Training
- Inference
  - De Novo Generation / Unconditional Evalation
  - Reconstruction / Conditional Evaluation
- Prerelaxation
- Prepare DFT
- Compute E above hull
- Compute corrected E above hull
- Compute Stable, Unique, and Novel (SUN) structures
- Next Steps & References

# (1) BACKGROUND AND CENTRAL GOAL


# Background
### FlowMM
**FlowMM** uses Riemannian flow matching to learn how to transform simple base noise into full periodic crystal structures by jointly modeling fractional atomic coordinates and lattice parameters on the manifold defined by crystal symmetries. It tackles both **Crystal Structure Prediction** (finding the stable arrangement for a known composition) and **De Novo Generation** (proposing entirely new materials), doing so with about three times fewer integration steps than comparable diffusion-based approaches.  

### FlowLLM
**FlowLLM** builds on FlowMM by swapping out the simple analytic noise prior for samples from a pretrained CrystalLLM (a LLaMA‐style model fine-tuned on crystal data). You generate initial “noisy” structures with the LLM, then use the same Riemannian flow-matching steps to refine those into accurate crystal geometries.


# Central Goal
Show viewers how to install, train, and use FlowMM and FlowLLM.
  


# (2) INSTALLATION, CONFIGURATION, AND DEPENDENCIES


# Install Conda

In [None]:
!pip install -q condacolab
import condacolab, os, sys
condacolab.install()
print("Done")

**Note**: Colab and FlowMM have hard pins for different Python and CUDA versions. To bypass this, the "!conda run" command will be used to run most code in this notebook. This bypasses the hard pinned Colab Python version by spinning up a conda subprocess that runs its own Python kernel with the correct version required by FlowMM.

# Install FlowMM

In [None]:
import os
%cd /content
if not os.path.exists('flowmm'):
  !git clone https://github.com/facebookresearch/flowmm.git
print("Done")

# Load FlowMM submodules

In [None]:
%%bash
cd /content/flowmm
sed -i 's|git@github.com:bkmi/DiffCSP-official.git|https://github.com/bkmi/DiffCSP-official.git|' .gitmodules
sed -i 's|git@github.com:bkmi/cdvae.git|https://github.com/bkmi/cdvae.git|' .gitmodules
sed -i 's|git@github.com:facebookresearch/riemannian-fm.git|https://github.com/facebookresearch/riemannian-fm.git|' .gitmodules
git submodule sync
git submodule update --init --recursive
echo "Done"

# Switch Colab Runtime to GPU
At the top menu by the Colab logo, select **Runtime** -> **Change runtime type** -> **Any GPU**    

It is not necessary to run on GPU, but the code will complete faster.



# Create conda environment for FlowMM
Making the conda environment takes 20 minutes


In [None]:
%%time
%cd /content/flowmm
!mamba env create -p /usr/local/envs/flowmm_env -f environment.yml
!conda run -p /usr/local/envs/flowmm_env --live-stream\
    pip install uv
!conda run -p /usr/local/envs/flowmm_env --live-stream\
    uv pip install "jarvis-tools>=2024.5" "pymatgen>=2024.1" pandas numpy tqdm
!conda run -p /usr/local/envs/flowmm_env --live-stream\
    uv pip install -e . \
                   -e remote/riemannian-fm \
                   -e remote/cdvae \
                   -e remote/DiffCSP-official
print("Done")

In [None]:
!conda run -p /usr/local/envs/flowmm_env --live-stream\
    pip install -e . \
                   -e remote/riemannian-fm \
                   -e remote/cdvae \
                   -e remote/DiffCSP-official

In [None]:
%cd /content/flowmm/
import os
if not os.path.exists('remote/riemannian-fm/manifm/__init.py__'):
    !wget -q https://raw.githubusercontent.com/crhysc/utilities/refs/heads/main/__init__.py
    !mv __init__.py /content/flowmm/remote/riemannian-fm/manifm/
!conda run -p /usr/local/envs/flowmm_env --live-stream\
    pip install -e /content/flowmm/remote/riemannian-fm/
!conda run -p /usr/local/envs/flowmm_env --live-stream\
    python -c "import manifm; print('manifm version:', manifm.__version__)"

# Install Other dependencies


# (3) DATASET ETL (Extract-Transform-Load)


# Download data pre-processor

Data was generated using this [script](https://github.com/crhysc/utilities/blob/main/supercon_preprocess.py). It compiles a set of around 1000 structures and their superconducting critical temperatures into the format required for FlowMM training.

In [None]:
!rm supercon_preprocess.py

In [None]:
%cd /content/flowmm
import os
if not os.path.exists('supercon_preprocess.py'):
  !wget -q https://raw.githubusercontent.com/crhysc/utilities/refs/heads/main/supercon_preprocess.py
%cat supercon_preprocess.py

# Run data pre-processor

In [None]:
!rm -rf /content/flowmm/data/supercon

In [None]:
%cd /content/flowmm
!conda run -p /usr/local/envs/flowmm_env --live-stream \
    python supercon_preprocess.py \
        --dataset dft_3d \
        --id-key jid \
        --target Tc_supercon \
        --train-ratio 0.8 --val-ratio 0.1 --test-ratio 0.1 \
        --seed 123 \
        --max-size 25
print("Done")

# Move train/test/val data to the correct spot

In [None]:
%cd /content
%mkdir /content/flowmm/data/supercon
%mv /content/flowmm/train.csv /content/flowmm/data/supercon/
%mv /content/flowmm/val.csv /content/flowmm/data/supercon/
%mv /content/flowmm/test.csv /content/flowmm/data/supercon/
print("Done")

# Pull the supercon Hydra config YAML from GitHub

In [None]:
%cd /content/flowmm/scripts_model/conf/data/
!wget https://raw.githubusercontent.com/crhysc/utilities/refs/heads/main/supercon.yaml
%cat supercon.yaml

# Modify FlowMM hardcode to accept our supercon dataset

First, open **Files** in the left sidebar and navigate to **/Content/flowmm/src/flowmm/**. Click **cfg_utils.py**, and on line 15, add "supercon" to the *dataset_options* literal and delete all other strings in the tuple. Once you have done that, run the following code to generate the necessary affine stats YAML.

# Generate SPD stats

In [132]:
%cd /content/flowmm
!bash create_env_file.sh && \
 echo "successfully ran create_env_file.sh" && \
 HYDRA_FULL_ERROR=1 \
 conda run -p /usr/local/envs/flowmm_env --live-stream \
    python -u -m flowmm.rfm.manifolds.spd

/content/flowmm
successfully ran create_env_file.sh
calculate the stats of p(L | N) for each dataset
dataset='supercon':   0% 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/usr/local/envs/flowmm_env/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/envs/flowmm_env/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/content/flowmm/src/flowmm/rfm/manifolds/spd.py", line 498, in <module>
    atom_density, _ = get_atom_density(dataset)
  File "/content/flowmm/src/flowmm/rfm/manifolds/spd.py", line 366, in get_atom_density
    return get_spd_data(dataset, path, "atom_density")
  File "/content/flowmm/src/flowmm/rfm/manifolds/spd.py", line 341, in get_spd_data
    mean = torch.tensor(stats[dataset]["mean"])
  File "/usr/local/envs/flowmm_env/lib/python3.9/site-packages/omegaconf/dictconfig.py", line 375, in __getitem__
    self._format_and_raise(key=key, value=None, cau

# Create the required affine stats YAML for the dataset

In [None]:
%cd /content/flowmm
!bash create_env_file.sh && \
 echo "successfully ran create_env_file.sh" && \
 HYDRA_FULL_ERROR=1 \
 conda run -p /usr/local/envs/flowmm_env --live-stream \
    python -u -m flowmm.model.standardize \
                 data=supercon

# (4) TRAINING
# Manifolds


- FlowMM allows the user to select a variety of manifolds via the keyword argument   
`model={atom_type_manifold}_{lattice_manifold}`  
when using `scripts_model/run.py`.  

- Atom type manifolds and lattice type manifolds can be found in `scripts_model/conf/model`.

# Unconditional Training

In [None]:
%cd /content/flowmm
!bash create_env_file.sh && \
 echo "successfully ran create_env_file.sh" && \
 HYDRA_FULL_ERROR=1 \
 conda run -p /usr/local/envs/flowmm_env --live-stream \
    python -u -m scripts_model.run data=supercon model=abits_params \
    data.datamodule.batch_size.train=64 \
    data.datamodule.batch_size.val=64 \
    data.datamodule.batch_size.test=64

# Conditional Training

In [None]:
%cd /content/flowmm
!bash create_env_file.sh && \
 echo "successfully ran create_env_file.sh" && \
 HYDRA_FULL_ERROR=1 \
 conda run -p /usr/local/envs/flowmm_env --live-stream \
    python -u -m scripts_model.run data=supercon model=null_params \
    data.datamodule.batch_size.train=64 \
    data.datamodule.batch_size.val=64 \
    data.datamodule.batch_size.test=64

# (5) INFERENCE
# De Novo Generation / Unconditional Evalation

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python /content/cdvae/scripts/evaluate.py \
    --model_path /content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon \
    --tasks recon

In [None]:
import torch
from pprint import pprint
path = "/content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon/eval_recon.pt"
data = torch.load(path, map_location="cpu", weights_only=False)
pprint(data, width=120, indent=2)

# Generation

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python /content/cdvae/scripts/evaluate.py \
    --model_path /content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon \
    --tasks gen

In [None]:
import torch
from pprint import pprint
path = "/content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon/eval_gen.pt"
data = torch.load(path, map_location="cpu", weights_only=False)
pprint(data, width=120, indent=2)

# Optimization

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python /content/cdvae/scripts/evaluate.py \
    --model_path /content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon \
    --tasks opt

In [None]:
import torch
from pprint import pprint
path = "/content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon/eval_opt.pt"
data = torch.load(path, map_location="cpu", weights_only=False)
pprint(data, width=120, indent=2)

# (7) EVALUATION

In [None]:
!PROJECT_ROOT=/content/cdvae \
 HYDRA_JOBS=/content/cdvae/hydra_outputs \
 WABDB_DIR=/content/cdvae/wandb_outputs \
 conda run -p /usr/local/envs/cdvae_legacy --live-stream \
    python scripts/compute_metrics.py \
    --root_path /content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon \
    --tasks recon gen opt

# (8) NEXT STEPS & REFERENCES

## Next Steps

1. **Hyperparameter exploration**  
   - Try different numbers of noise levels (`model.num_noise_level`) and training epochs to improve sample quality.

2. **Property-conditioned generation**  
   - Re-enable the property predictor (`model.predict_property=True`) and train with longer schedules to improve prediction accuracy.
   - After training, sample structures by specifying a target critical temperature and evaluate via DFT or empirical models.


---

## References

- **Original CDVAE paper:**  
  Li _et al._, “Crystal Diffusion Variational Autoencoder for Inverse Materials Design,” _J. Phys. Chem. Lett._ 2023, DOI: [10.1021/acs.jpclett.3c01260](https://pubs.acs.org/doi/10.1021/acs.jpclett.3c01260)

- **CDVAE GitHub repo:**  
  https://github.com/txie-93/cdvae

- **JARVIS-Materials-Design:**  
  https://github.com/JARVIS-Materials-Design/jarvis

- **Hydra configuration framework:**  
  https://hydra.cc

- **PyTorch Lightning:**  
  https://www.pytorchlightning.ai

- **condacolab:**  
  https://github.com/conda-incubator/condacolab

- **Mamba (fast conda):**  
  https://github.com/mamba-org/mamba

- **Jarvis-tools (data ETL):**  
  https://github.com/JARVIS-Materials-Design/jarvis-tools
