<a href="https://colab.research.google.com/github/crhysc/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/flowmm_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Tutorial**: FlowMM & FlowLLM



**Authors**: Charles "Rhys" Campbell (crc00042@mix.wvu.edu)

# TABLE OF CONTENTS

- Background and Central Goal
- Installation, Configuration, and Dependencies
- Dataset ETL
- Training
  - Manifolds
  - Unconditional Training
  - Conditional Training
  - FlowLLM
- Inference
  - De Novo Generation / Unconditional Evalation
  - Reconstruction / Conditional Evaluation
- Next Steps & References

# (1) BACKGROUND AND CENTRAL GOAL


# Background
### FlowMM
**FlowMM** uses Riemannian flow matching to learn how to transform simple base noise into full periodic crystal structures by jointly modeling fractional atomic coordinates and lattice parameters on the manifold defined by crystal symmetries. It tackles both **Crystal Structure Prediction** (finding the stable arrangement for a known composition) and **De Novo Generation** (proposing entirely new materials), doing so with about three times fewer integration steps than comparable diffusion-based approaches.  

### FlowLLM
**FlowLLM** builds on FlowMM by swapping out the simple analytic noise prior for samples from a pretrained CrystalLLM (a LLaMA‐style model fine-tuned on crystal data). You generate initial “noisy” structures with the LLM, then use the same Riemannian flow-matching steps to refine those into accurate crystal geometries.


# Central Goal
Show viewers how to install, train, and use FlowMM and FlowLLM.
  


# (2) INSTALLATION, CONFIGURATION, AND DEPENDENCIES


# Install Conda

In [None]:
!pip install -q condacolab
import condacolab, os, sys
condacolab.install()
print("Done")

**Note**: Colab and FlowMM have hard pins for different Python and CUDA versions. To bypass this, the "!conda run" command will be used to run most code in this notebook. This bypasses the hard pinned Colab Python version by spinning up a conda subprocess that runs its own Python kernel with the correct version required by FlowMM.

# Install FlowMM

In [None]:
import os
%cd /content
if not os.path.exists('flowmm'):
  !git clone https://github.com/crhysc/flowmm.git
print("Done")

# Load FlowMM submodules

In [None]:
%%bash
cd /content/flowmm
cat .gitmodules
sed -i 's|git@github.com:jiaor17/DiffCSP.git|https://github.com/jiaor17/DiffCSP.git|' .gitmodules
sed -i 's|git@github.com:crhysc/cdvae.git|https://github.com/crhysc/cdvae.git|' .gitmodules
sed -i 's|git@github.com:crhysc/riemannian-fm.git|https://github.com/crhysc/riemannian-fm.git|' .gitmodules
git submodule sync
git submodule update --init --recursive
echo "Done"

# Switch Colab Runtime to GPU
At the top menu by the Colab logo, select **Runtime** -> **Change runtime type** -> **Any GPU**    

It is not necessary to run on GPU, but the code will complete faster.



# Create conda environment for FlowMM
Making the conda environment takes 20 minutes


In [None]:
%%time
%cd /content/flowmm
!mamba env create -p /usr/local/envs/flowmm_env -f environment.yml
!conda run -p /usr/local/envs/flowmm_env --live-stream\
    pip install uv
!conda run -p /usr/local/envs/flowmm_env --live-stream\
    uv pip install "jarvis-tools>=2024.5" "pymatgen>=2024.1" pandas numpy tqdm
!conda run -p /usr/local/envs/flowmm_env --live-stream\
    uv pip install -e . \
                   -e remote/riemannian-fm \
                   -e remote/cdvae \
                   -e remote/DiffCSP-official
print("Done")

Add __ init __.py to manifm and reinstall

In [None]:
%cd /content/flowmm/
import os
if not os.path.exists('remote/riemannian-fm/manifm/__init.py__'):
    !wget -q https://raw.githubusercontent.com/crhysc/utilities/refs/heads/main/__init__.py
    !mv __init__.py /content/flowmm/remote/riemannian-fm/manifm/
!conda run -p /usr/local/envs/flowmm_env --live-stream\
    pip install -e /content/flowmm/remote/riemannian-fm/i
!conda run -p /usr/local/envs/flowmm_env --live-stream\
    python -c "import manifm; print('manifm version:', manifm.__version__)"

# Install Other dependencies


# (3) DATASET ETL (Extract-Transform-Load)


# Download data pre-processor

Data was generated using this [script](https://github.com/crhysc/utilities/blob/main/supercon_preprocess.py). It compiles a set of around 1000 structures and their superconducting critical temperatures into the format required for FlowMM training.

In [None]:
%cd /content/flowmm
import os
if not os.path.exists('supercon_preprocess.py'):
  !wget -q https://raw.githubusercontent.com/crhysc/utilities/refs/heads/main/supercon_preprocess.py
%cat supercon_preprocess.py

# Run data pre-processor

In [None]:
%cd /content/flowmm
!conda run -p /usr/local/envs/flowmm_env --live-stream \
    python supercon_preprocess.py \
        --dataset dft_3d \
        --id-key jid \
        --target Tc_supercon \
        --train-ratio 0.8 --val-ratio 0.1 --test-ratio 0.1 \
        --seed 123 \
        --max-size 25
print("Done")

# Move train/test/val data to the correct spot

In [None]:
%cd /content
%mkdir /content/flowmm/data/supercon
%mv /content/flowmm/train.csv /content/flowmm/data/supercon/
%mv /content/flowmm/val.csv /content/flowmm/data/supercon/
%mv /content/flowmm/test.csv /content/flowmm/data/supercon/
print("Done")

# Pull the supercon Hydra config YAML from GitHub

In [None]:
%cd /content/flowmm/scripts_model/conf/data/
!wget https://raw.githubusercontent.com/crhysc/utilities/refs/heads/main/supercon.yaml
%cat supercon.yaml

# Modify FlowMM hardcode to accept our supercon dataset

First, open **Files** in the left sidebar and navigate to **/Content/flowmm/src/flowmm/**. Click **cfg_utils.py**, and on line 15, add "supercon" to the *dataset_options* literal and delete all other strings in the tuple.

Next, open **Files** again and navigate to /Content/flowmm/src/flowmm/rfm/manifolds/. Click **spd.py**, and then navigate to the "if __ name __ = __ main __" block. Uncomment lines 449 through 466 (we are turning on "compute_stats". Next, on line 468, set "compute_stats = True". Next, on line 489, set "compute_stats = True" again. Next, on line 461, change ""std": std.cpu().tolist()" to ""logmap_std": std.cpu().tolist(),". Next, on line 236, change the "std" string to "logmap_std". Next, on line 431, in the ".std()" function, add "unbiased=False" in between the parentheses so that the whole line reads "std_coefs.append((log_noise_samples.std(unbiased=False) ** (3 / 2)) / n)
".

Finally, open Files again and navigate to /Content/flowmm/src/flowmm/rfm/manifolds/. Click **spd.py**, and then replace all code including and after line 531, which is a comment saying "# do some testing for SPDNonIsotropicRandom"

    pL_stats = OmegaConf.load(Path(__file__).parent / "spd_pLTL_stats.yaml")  # ← new line

    for dataset in tqdm(list(dataset_options.__args__)):
          mean_vec = torch.tensor(pL_stats[dataset]["mean"])           # now using pL_stats
          std_vec  = torch.tensor(pL_stats[dataset]["logmap_std"])     # correct key name

          # optional sanity check
          if mean_vec.ndim == 0:
              raise ValueError(
                  f"Loaded mean for {dataset} is scalar—wrong YAML? shape {mean_vec.shape}"
              )

          s = manifm_SPD(Riem_geodesic=True, Riem_norm=True)
          spd = SPDNonIsotropicRandom(mean_vec, std_vec)
          r   = spd.random_base(10, mean_vec.size(-1))
          lp  = spd.base_logprob(r)
          print(r, lp)

          r  = spd.random_base(3, 10, mean_vec.size(-1))
          lp = spd.base_logprob(r)
          print(r, lp)

# Generate necessary YAML files for training

In [None]:
%rm /content/flowmm/src/flowmm/rfm/manifolds/atom_density.yaml
%rm /content/flowmm/src/flowmm/rfm/manifolds/spd_pLTL_stats.yaml
%rm /content/flowmm/src/flowmm/rfm/manifolds/spd_std_coef.yaml

In [None]:
%cd /content/flowmm
!bash create_env_file.sh && \
 echo "successfully ran create_env_file.sh" && \
 HYDRA_FULL_ERROR=1 \
 FLOWMM_DEBUG=1 \
 conda run -p /usr/local/envs/flowmm_env --live-stream \
    python -u -m flowmm.rfm.manifolds.spd

# Create lattice_params_stats.yaml

In [None]:
!rm /content/flowmm/src/flowmm/rfm/manifolds/lattice_params_stats.yaml

In [None]:
%cd /content/flowmm
!bash create_env_file.sh && \
 echo "successfully ran create_env_file.sh" && \
 HYDRA_FULL_ERROR=1 \
 conda run -p /usr/local/envs/flowmm_env --live-stream \
    python -u -m flowmm.rfm.manifolds.lattice_params

# Create the required affine stats YAML for the dataset

In [None]:
%rm /content/flowmm/src/flowmm/model/stats_supercon*

In [None]:
%cd /content/flowmm
!bash create_env_file.sh && \
 echo "successfully ran create_env_file.sh" && \
 HYDRA_FULL_ERROR=1 \
 conda run -p /usr/local/envs/flowmm_env --live-stream \
    python -u -m flowmm.model.standardize \
                 data=supercon

# (4) TRAINING
# Manifolds


- FlowMM allows the user to select a variety of manifolds via the keyword argument   
`model={atom_type_manifold}_{lattice_manifold}`  
when using `scripts_model/run.py`.  

- Atom type manifolds and lattice type manifolds can be found in `scripts_model/conf/model`.

# Unconditional Training

In [None]:
%pwd

In [None]:
%cd /content/flowmm
!bash create_env_file.sh && \
  HYDRA_FULL_ERROR=1 \
  WANDB_MODE=disabled \
  conda run -p /usr/local/envs/flowmm_env \
    python -u -m scripts_model.run \
      data=supercon \
      model=abits_params \
      train.pl_trainer.accelerator=cpu \
      train.pl_trainer.devices=1 \
      train.model_checkpoints.save_last=True \
      logging.val_check_interval=1  \
      train.pl_trainer.max_epochs=1

# Conditional Training

In [None]:
%cd /content/flowmm
!bash create_env_file.sh && \
 echo "successfully ran create_env_file.sh" && \
 HYDRA_FULL_ERROR=1 \
 conda run -p /usr/local/envs/flowmm_env --live-stream \
    python -u -m scripts_model.run data=supercon model=null_params

# FlowLLM Training

In [None]:
%cd /content/flowmm
!bash create_env_file.sh && \
 echo "successfully ran create_env_file.sh" && \
 HYDRA_FULL_ERROR=1 \
 conda run -p /usr/local/envs/flowmm_env --live-stream \
    python -u -m scripts_model.run data=mp20_llama model=null_params \
      base_distribution_from_data=True

# (5) INFERENCE
# Unconditional Evaluation - De Novo Generation



In [None]:
!bash create_env_file.sh && \
 echo "successfully ran create_env_file.sh" && \
 ckpt=PATH_TO_CHECKPOINT \
 subdir=NAME_OF_SUBDIRECTORY_AT_CHECKPOINT \
 slope=SLOPE_OF_INFERENCE_ANTI_ANNEALING \
 conda run -p /usr/local/envs/flowmm_env --live-stream \
    python scripts_model/evaluate.py generate ${ckpt} --subdir ${subdir} \
      --inference_anneal_slope ${slope} --stage test && \
    python scripts_model/evaluate.py consolidate ${ckpt} --subdir ${subdir} && \
    python scripts_model/evaluate.py old_eval_metrics ${ckpt} --subdir ${subdir} \
      --stage test && \
    python scripts_model/evaluate.py lattice_metrics ${ckpt} --subdir ${subdir} \
      --stage test

In [None]:
import torch
from pprint import pprint
path = "/content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon/eval_recon.pt"
data = torch.load(path, map_location="cpu", weights_only=False)
pprint(data, width=120, indent=2)

# Conditional Evaluation - Crystal Structure Prediction - Reconstruction

In [None]:
!bash create_env_file.sh && \
 echo "successfully ran create_env_file.sh" && \
 ckpt=PATH_TO_CHECKPOINT \
 subdir=NAME_OF_SUBDIRECTORY_AT_CHECKPOINT \
 slope=SLOPE_OF_INFERENCE_ANTI_ANNEALING \
 conda run -p /usr/local/envs/flowmm_env --live-stream \
    python scripts_model/evaluate.py reconstruct ${ckpt} --subdir ${subdir} \
      --inference_anneal_slope ${slope} --stage test && \
    python scripts_model/evaluate.py consolidate ${ckpt} --subdir ${subdir} && \
    python scripts_model/evaluate.py old_eval_metrics ${ckpt} --subdir ${subdir} \
      --stage test && \
    python scripts_model/evaluate.py lattice_metrics ${ckpt} --subdir ${subdir} \
      --stage test

In [None]:
import torch
from pprint import pprint
path = "/content/cdvae/hydra_outputs/singlerun/2025-05-27/supercon/eval_recon.pt"
data = torch.load(path, map_location="cpu", weights_only=False)
pprint(data, width=120, indent=2)