In [None]:
# Now universal MACE finetuning on T2
#!/usr/bin/env python3
import os
import subprocess
import sys

def main():
    # ——— Environment setup ———
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    cmd = [
        "mace_run_train",
        # === General settings ===
        "--name",              "mace_T2_including_replay_w2",
        "--model",             "MACE",
        "--num_interactions",  "2",
        "--foundation_model",  "/home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model",
        "--foundation_model_readout",
    # --- MP replay (pretraining head) ---
        "--pt_train_file","/home/phanim/harshitrawat/summer/replay_data/mp_finetuning-mace_T2_mp_replay_run-42.xyz",              # <- MP replay shortcut
        "--atomic_numbers","[3,8,40,57]",    # Li, O, Zr, La
        "--multiheads_finetuning","True",

        "--train_file","/home/phanim/harshitrawat/summer/T1_T2_T3_data/T3_chgnet_labeled.extxyz",
        "--valid_file","/home/phanim/harshitrawat/summer/T1_T2_T3_data/T2_chgnet_labeled.extxyz",

        "--batch_size",        "2",
        "--valid_batch_size",  "1",

        "--device",            "cuda",

        # === Loss function weights ===
        "--forces_weight",     "0",         # Increased force weight to balance energy better
        "--energy_weight",     "10",   
        "--stress_weight", "0",             # Reduced from 100 → avoid dominance + stabilize energy RMSE

        # === Learning setup ===
        "--lr",                "0.006",      # Explicit learning rate (0.0001 is too low → stagnation)
        "--scheduler_patience","4",          # Reduce LR if val loss doesn’t improve in 3 epochs
        "--clip_grad",         "1",        # Avoid exploding gradients — essential when energy_weight is high
        "--weight_decay",      "1e-8",       # Mild regularization to prevent overfitting

        # === EMA helps smooth loss curve ===
        #"--ema_decay",         "0.999",     # Smooths validation loss and helps final convergence

        # === Domain + training settings ===
        "--r_max",             "5.0",
        "--max_num_epochs",    "130",
        "--E0s",               "{3: -1.2302615750354944, 8: -23.049110738413006, 40: 23.367646191010394, 57: 15.192898072498549}",    # Still allowed — could optionally be replaced by manual E0s
        "--seed",              "84",
        "--patience",     "8",

        "--restart_latest",                   # Resumes from checkpoint if available
    ]

    print("Running:", " \\\n    ".join(cmd), file=sys.stderr)
    subprocess.run(cmd, check=True)

if __name__ == "__main__":
    main()


Running: mace_run_train \
    --name \
    mace_T2_including_replay_w2 \
    --model \
    MACE \
    --num_interactions \
    2 \
    --foundation_model \
    /home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model \
    --foundation_model_readout \
    --pt_train_file \
    /home/phanim/harshitrawat/summer/replay_data/mp_finetuning-mace_T2_mp_replay_run-42.xyz \
    --atomic_numbers \
    [3,8,40,57] \
    --multiheads_finetuning \
    True \
    --train_file \
    /home/phanim/harshitrawat/summer/T1_T2_T3_data/T3_chgnet_labeled.extxyz \
    --valid_file \
    /home/phanim/harshitrawat/summer/T1_T2_T3_data/T2_chgnet_labeled.extxyz \
    --batch_size \
    2 \
    --valid_batch_size \
    1 \
    --device \
    cuda \
    --forces_weight \
    0 \
    --energy_weight \
    10 \
    --stress_weight \
    0 \
    --lr \
    0.006 \
    --scheduler_patience \
    4 \
    --clip_grad \
    1 \
    --weight_decay \
    1e-8 \
    --r_max \
    5.0 \
  

2025-08-21 10:16:49.591 INFO: MACE version: 0.3.14
2025-08-21 10:16:50.231 INFO: CUDA version: 12.6, CUDA device: 0


  model_foundation = torch.load(


2025-08-21 10:16:50.719 INFO: Using foundation model /home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model as initial checkpoint.
2025-08-21 10:16:50.721 INFO: Multihead finetuning mode, setting learning rate to 0.0001 and EMA to True. To use a different learning rate, set --force_mh_ft_lr=True.
2025-08-21 10:16:50.721 INFO: Using multiheads finetuning mode, setting learning rate to 0.0001 and EMA to True
2025-08-21 10:16:50.721 INFO: Using heads: ['Default', 'pt_head']
2025-08-21 10:16:50.721 INFO: Using the key specifications to parse data:
2025-08-21 10:16:50.721 INFO: Default: KeySpecification(info_keys={'energy': 'REF_energy', 'stress': 'REF_stress', 'virials': 'REF_virials', 'dipole': 'dipole', 'head': 'head', 'elec_temp': 'elec_temp', 'total_charge': 'total_charge', 'polarizability': 'polarizability', 'total_spin': 'total_spin'}, arrays_keys={'forces': 'REF_forces', 'charges': 'REF_charges'})
2025-08-21 10:16:50.721 INFO: pt_head: KeySpeci

  torch.load(f=checkpoint_info.path, map_location=device),


2025-08-21 10:17:47.714 INFO: Total number of parameters: 5556810
2025-08-21 10:17:47.714 INFO: 
2025-08-21 10:17:47.714 INFO: Using ADAM as parameter optimizer
2025-08-21 10:17:47.714 INFO: Batch size: 2
2025-08-21 10:17:47.714 INFO: Using Exponential Moving Average with decay: 0.99999
2025-08-21 10:17:47.714 INFO: Number of gradient updates: 689780
2025-08-21 10:17:47.714 INFO: Learning rate: 0.0001, weight decay: 1e-08
2025-08-21 10:17:47.714 INFO: UniversalLoss(energy_weight=10.000, forces_weight=0.000, stress_weight=0.000)
2025-08-21 10:17:47.726 INFO: Loading checkpoint: ./checkpoints/mace_T2_including_replay_w2_run-84_epoch-59.pt
2025-08-21 10:17:47.817 INFO: Using gradient clipping with tolerance=1.000
2025-08-21 10:17:47.817 INFO: 
2025-08-21 10:17:47.817 INFO: Started training, reporting errors on validation set
2025-08-21 10:17:47.817 INFO: Loss metrics on validation set


✅ Keys renamed to REF_* and saved to:
/home/phanim/harshitrawat/summer/replay_data/mp_finetuning-mace_T2_mp_replay_run-42_REFkeys.xyz


In [4]:
# Now universal MACE finetuning on T2
#!/usr/bin/env python3
import os
import subprocess
import sys

def main():
    # ——— Environment setup ———
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    cmd = [
        "mace_run_train",
        # === General settings ===
        "--name",              "mace_T2_including_replay_w2_a",
        "--model",             "MACE",
        "--num_interactions",  "2",
        "--foundation_model",  "/home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model",
        "--foundation_model_readout",
    # --- MP replay (pretraining head) ---
        "--pt_train_file","/home/phanim/harshitrawat/summer/replay_data/mp_finetuning-just_to_get_file_combinations_run-84.xyz",              # <- MP replay shortcut
        "--atomic_numbers","[3,8,40,57]",    # Li, O, Zr, La
        "--multiheads_finetuning","True",

        "--train_file","/home/phanim/harshitrawat/summer/T1_T2_T3_data/T3_chgnet_labeled.extxyz",
        "--valid_file","/home/phanim/harshitrawat/summer/T1_T2_T3_data/T2_chgnet_labeled.extxyz",

        "--batch_size",        "2",
        "--valid_batch_size",  "1",

        "--device",            "cuda",

        # === Loss function weights ===
        "--forces_weight",     "0",         # Increased force weight to balance energy better
        "--energy_weight",     "10",   
        "--stress_weight", "0",             # Reduced from 100 → avoid dominance + stabilize energy RMSE

        # === Learning setup ===
        "--lr",                "0.006",      # Explicit learning rate (0.0001 is too low → stagnation)
        "--scheduler_patience","4",          # Reduce LR if val loss doesn’t improve in 3 epochs
        "--clip_grad",         "1",        # Avoid exploding gradients — essential when energy_weight is high
        "--weight_decay",      "1e-8",       # Mild regularization to prevent overfitting

        # === EMA helps smooth loss curve ===
        "--ema_decay",         "0.999",     # Smooths validation loss and helps final convergence

        # === Domain + training settings ===
        "--r_max",             "5.0",
        "--max_num_epochs",    "33",
        "--E0s",               "{3: -1.2302615750354944, 8: -23.049110738413006, 40: 23.367646191010394, 57: 15.192898072498549}",    # Still allowed — could optionally be replaced by manual E0s
        "--seed",              "84",
        "--patience",     "8",

        "--restart_latest",                   # Resumes from checkpoint if available
    ]

    print("Running:", " \\\n    ".join(cmd), file=sys.stderr)
    subprocess.run(cmd, check=True)

if __name__ == "__main__":
    main()


Running: mace_run_train \
    --name \
    mace_T2_including_replay_w2_a \
    --model \
    MACE \
    --num_interactions \
    2 \
    --foundation_model \
    /home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model \
    --foundation_model_readout \
    --pt_train_file \
    /home/phanim/harshitrawat/summer/replay_data/mp_finetuning-just_to_get_file_combinations_run-84.xyz \
    --atomic_numbers \
    [3,8,40,57] \
    --multiheads_finetuning \
    True \
    --train_file \
    /home/phanim/harshitrawat/summer/T1_T2_T3_data/T3_chgnet_labeled.extxyz \
    --valid_file \
    /home/phanim/harshitrawat/summer/T1_T2_T3_data/T2_chgnet_labeled.extxyz \
    --batch_size \
    2 \
    --valid_batch_size \
    1 \
    --device \
    cuda \
    --forces_weight \
    0 \
    --energy_weight \
    10 \
    --stress_weight \
    0 \
    --lr \
    0.006 \
    --scheduler_patience \
    4 \
    --clip_grad \
    1 \
    --weight_decay \
    1e-8 \
    --ema_de

2025-08-26 13:32:26.203 INFO: MACE version: 0.3.14
2025-08-26 13:32:26.842 INFO: CUDA version: 12.6, CUDA device: 0


  model_foundation = torch.load(


2025-08-26 13:32:27.331 INFO: Using foundation model /home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model as initial checkpoint.
2025-08-26 13:32:27.339 INFO: Multihead finetuning mode, setting learning rate to 0.0001 and EMA to True. To use a different learning rate, set --force_mh_ft_lr=True.
2025-08-26 13:32:27.339 INFO: Using multiheads finetuning mode, setting learning rate to 0.0001 and EMA to True
2025-08-26 13:32:27.339 INFO: Using heads: ['Default', 'pt_head']
2025-08-26 13:32:27.339 INFO: Using the key specifications to parse data:
2025-08-26 13:32:27.339 INFO: Default: KeySpecification(info_keys={'energy': 'REF_energy', 'stress': 'REF_stress', 'virials': 'REF_virials', 'dipole': 'dipole', 'head': 'head', 'elec_temp': 'elec_temp', 'total_charge': 'total_charge', 'polarizability': 'polarizability', 'total_spin': 'total_spin'}, arrays_keys={'forces': 'REF_forces', 'charges': 'REF_charges'})
2025-08-26 13:32:27.339 INFO: pt_head: KeySpeci

  torch.load(f=checkpoint_info.path, map_location=device),


2025-08-26 13:33:20.212 INFO: Total number of parameters: 896586
2025-08-26 13:33:20.212 INFO: 
2025-08-26 13:33:20.212 INFO: Using ADAM as parameter optimizer
2025-08-26 13:33:20.212 INFO: Batch size: 2
2025-08-26 13:33:20.212 INFO: Using Exponential Moving Average with decay: 0.99999
2025-08-26 13:33:20.212 INFO: Number of gradient updates: 175098
2025-08-26 13:33:20.212 INFO: Learning rate: 0.0001, weight decay: 1e-08
2025-08-26 13:33:20.212 INFO: UniversalLoss(energy_weight=10.000, forces_weight=0.000, stress_weight=0.000)
2025-08-26 13:33:20.228 INFO: Loading checkpoint: ./checkpoints/mace_T2_including_replay_w2_a_run-84_epoch-31.pt
2025-08-26 13:33:20.692 INFO: Using gradient clipping with tolerance=1.000
2025-08-26 13:33:20.692 INFO: 
2025-08-26 13:33:20.692 INFO: Started training, reporting errors on validation set
2025-08-26 13:33:20.692 INFO: Loss metrics on validation set
2025-08-26 13:33:59.401 INFO: Initial: head: pt_head, loss=0.00000368, RMSE_E_per_atom=    0.86 meV, RMS

In [5]:
from mace.calculators import MACECalculator
mace_calc = MACECalculator(model_paths=["/home/phanim/harshitrawat/summer/iteration_3/mace_T2_including_replay_w2_a_compiled.model"], device="cuda")  # or "cpu"
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.core import Structure
adaptor = AseAtomsAdaptor()

pmg_structure = Structure.from_file("/home/phanim/harshitrawat/summer/formation_energy/cifs/Li.cif")  # e.g. for Li
ase_atoms = adaptor.get_atoms(pmg_structure)
ase_atoms.calc = mace_calc
total_energy = ase_atoms.get_potential_energy()
mu_model_Li = total_energy / len(ase_atoms)
print(f"Li: μ_model = {mu_model_Li:.6f} eV/atom")
# Let us do this for La, Zr, and O as well
pmg_structure = Structure.from_file("/home/phanim/harshitrawat/summer/formation_energy/cifs/La.cif")
ase_atoms = adaptor.get_atoms(pmg_structure)
ase_atoms.calc = mace_calc
total_energy = ase_atoms.get_potential_energy()
mu_model_La = total_energy / len(ase_atoms)
print(f"La: μ_model = {mu_model_La:.6f} eV/atom")
pmg_structure = Structure.from_file("/home/phanim/harshitrawat/summer/formation_energy/cifs/Zr.cif")
ase_atoms = adaptor.get_atoms(pmg_structure)
ase_atoms.calc = mace_calc
total_energy = ase_atoms.get_potential_energy()
mu_model_Zr = total_energy / len(ase_atoms)
print(f"Zr: μ_model = {mu_model_Zr:.6f} eV/atom")
pmg_structure = Structure.from_file("/home/phanim/harshitrawat/summer/formation_energy/cifs/O2.cif")  # Needs to be a periodic solid O2 structure
ase_atoms = adaptor.get_atoms(pmg_structure)
ase_atoms.calc = mace_calc
total_energy = ase_atoms.get_potential_energy()
mu_model_O = total_energy / len(ase_atoms)
print(f"O: μ_model = {mu_model_O:.6f} eV/atom")

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))
  torch.load(f=model_path, map_location=device)


Using head Default out of ['pt_head', 'Default']
No dtype selected, switching to float64 to match model dtype.
Li: μ_model = -1.775718 eV/atom


  struct = parser.parse_structures(primitive=primitive)[0]


La: μ_model = 18.380837 eV/atom
Zr: μ_model = 26.286409 eV/atom
O: μ_model = -20.756856 eV/atom


In [6]:
import os
# ---- 1. Load the MACE model -------------------------------------------
calculator = MACECalculator(model_paths=["/home/phanim/harshitrawat/summer/iteration_3/mace_T2_including_replay_w2_a_compiled.model"], device="cuda")  # or "cpu"

# ---- 2. Reference μ_model from MACE -----------------------------------

mu_mace = {
    "Li": -1.775718,
    "La": 18.380837,
    "Zr": 26.286409,
    "O":  -20.756856,
}

# ---- 3. CIF files ------------------------------------------------------
cif_dir = "/home/phanim/harshitrawat/summer/formation_energy/cifs"
compounds = {
    "mp-841.cif": "Li2O2",
    "mp-1960.cif": "Li2O",
    "mp-942733.cif": "Li7La3Zr2O12",
    "mp-2858.cif": "ZrO2",
    "mp-1968.cif": "La2O3",
}

# ---- 4. Predict formation energy per atom -----------------------------
for fname, label in compounds.items():
    struct = Structure.from_file(os.path.join(cif_dir, fname))
    comp = struct.composition
    n_atoms = comp.num_atoms

    # Convert to ASE
    ase_atoms = AseAtomsAdaptor.get_atoms(struct)

    # Assign calculator and predict energy
    ase_atoms.calc = calculator
    energy_total = ase_atoms.get_potential_energy()  # eV (total)

    # Reference energy from MACE chemical potentials
    ref_total = sum(comp[el] * mu_mace[el.symbol] for el in comp.elements)

    # Formation energy per atom
    e_form = (energy_total - ref_total) / n_atoms

    print(f"{label:15s}:  E_form (MACE_T2_w2_it3) = {e_form: .6f} eV/atom")

  torch.load(f=model_path, map_location=device)


Using head Default out of ['pt_head', 'Default']
No dtype selected, switching to float64 to match model dtype.


  struct = parser.parse_structures(primitive=primitive)[0]


Li2O2          :  E_form (MACE_T2_w2_it3) = -0.061559 eV/atom
Li2O           :  E_form (MACE_T2_w2_it3) = -0.104004 eV/atom
Li7La3Zr2O12   :  E_form (MACE_T2_w2_it3) = -1.424687 eV/atom
ZrO2           :  E_form (MACE_T2_w2_it3) = -2.953763 eV/atom


  struct = parser.parse_structures(primitive=primitive)[0]


La2O3          :  E_form (MACE_T2_w2_it3) = -2.602734 eV/atom


In [None]:
Li2O2          :  E_form (MACE_T2_w2_it3) = -0.061559 eV/atom
Li2O           :  E_form (MACE_T2_w2_it3) = -0.104004 eV/atom
Li7La3Zr2O12   :  E_form (MACE_T2_w2_it3) = -1.424687 eV/atom
ZrO2           :  E_form (MACE_T2_w2_it3) = -2.953763 eV/atom
/home/phanim/harshitrawat/miniconda3/envs/mace_0.3.8/lib/python3.10/site-packages/pymatgen/core/structure.py:3107: UserWarning: Issues encountered while parsing CIF: 8 fractional coordinates rounded to ideal values to avoid issues with finite precision.
  struct = parser.parse_structures(primitive=primitive)[0]
La2O3          :  E_form (MACE_T2_w2_it3) = -2.602734 eV/atom