In [None]:
# Now universal MACE finetuning on T2
#!/usr/bin/env python3
import os
import subprocess
import sys

def main():
    # ——— Environment setup ———
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    cmd = [
        "mace_run_train",
        # === General settings ===
        "--name",              "mace_T2_including_replay_w2",
        "--model",             "MACE",
        "--num_interactions",  "2",
        "--foundation_model",  "/home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model",
        "--foundation_model_readout",
    # --- MP replay (pretraining head) ---
        "--pt_train_file","/home/phanim/harshitrawat/summer/replay_data/mp_finetuning-mace_T2_mp_replay_run-42.xyz",              # <- MP replay shortcut
        "--atomic_numbers","[3,8,40,57]",    # Li, O, Zr, La
        "--multiheads_finetuning","True",

        "--train_file","/home/phanim/harshitrawat/summer/T1_T2_T3_data/T3_chgnet_labeled.extxyz",
        "--valid_file","/home/phanim/harshitrawat/summer/T1_T2_T3_data/T2_chgnet_labeled.extxyz",

        "--batch_size",        "2",
        "--valid_batch_size",  "1",

        "--device",            "cuda",

        # === Loss function weights ===
        "--forces_weight",     "0",         # Increased force weight to balance energy better
        "--energy_weight",     "10",   
        "--stress_weight", "0",             # Reduced from 100 → avoid dominance + stabilize energy RMSE

        # === Learning setup ===
        "--lr",                "0.006",      # Explicit learning rate (0.0001 is too low → stagnation)
        "--scheduler_patience","4",          # Reduce LR if val loss doesn’t improve in 3 epochs
        "--clip_grad",         "1",        # Avoid exploding gradients — essential when energy_weight is high
        "--weight_decay",      "1e-8",       # Mild regularization to prevent overfitting

        # === EMA helps smooth loss curve ===
        #"--ema_decay",         "0.999",     # Smooths validation loss and helps final convergence

        # === Domain + training settings ===
        "--r_max",             "5.0",
        "--max_num_epochs",    "130",
        "--E0s",               "{3: -1.2302615750354944, 8: -23.049110738413006, 40: 23.367646191010394, 57: 15.192898072498549}",    # Still allowed — could optionally be replaced by manual E0s
        "--seed",              "84",
        "--patience",     "8",

        "--restart_latest",                   # Resumes from checkpoint if available
    ]

    print("Running:", " \\\n    ".join(cmd), file=sys.stderr)
    subprocess.run(cmd, check=True)

if __name__ == "__main__":
    main()


Running: mace_run_train \
    --name \
    mace_T2_including_replay_w2 \
    --model \
    MACE \
    --num_interactions \
    2 \
    --foundation_model \
    /home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model \
    --foundation_model_readout \
    --pt_train_file \
    /home/phanim/harshitrawat/summer/replay_data/mp_finetuning-mace_T2_mp_replay_run-42.xyz \
    --atomic_numbers \
    [3,8,40,57] \
    --multiheads_finetuning \
    True \
    --train_file \
    /home/phanim/harshitrawat/summer/T1_T2_T3_data/T3_chgnet_labeled.extxyz \
    --valid_file \
    /home/phanim/harshitrawat/summer/T1_T2_T3_data/T2_chgnet_labeled.extxyz \
    --batch_size \
    2 \
    --valid_batch_size \
    1 \
    --device \
    cuda \
    --forces_weight \
    0 \
    --energy_weight \
    10 \
    --stress_weight \
    0 \
    --lr \
    0.006 \
    --scheduler_patience \
    4 \
    --clip_grad \
    1 \
    --weight_decay \
    1e-8 \
    --r_max \
    5.0 \
  

2025-08-21 10:16:49.591 INFO: MACE version: 0.3.14
2025-08-21 10:16:50.231 INFO: CUDA version: 12.6, CUDA device: 0


  model_foundation = torch.load(


2025-08-21 10:16:50.719 INFO: Using foundation model /home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model as initial checkpoint.
2025-08-21 10:16:50.721 INFO: Multihead finetuning mode, setting learning rate to 0.0001 and EMA to True. To use a different learning rate, set --force_mh_ft_lr=True.
2025-08-21 10:16:50.721 INFO: Using multiheads finetuning mode, setting learning rate to 0.0001 and EMA to True
2025-08-21 10:16:50.721 INFO: Using heads: ['Default', 'pt_head']
2025-08-21 10:16:50.721 INFO: Using the key specifications to parse data:
2025-08-21 10:16:50.721 INFO: Default: KeySpecification(info_keys={'energy': 'REF_energy', 'stress': 'REF_stress', 'virials': 'REF_virials', 'dipole': 'dipole', 'head': 'head', 'elec_temp': 'elec_temp', 'total_charge': 'total_charge', 'polarizability': 'polarizability', 'total_spin': 'total_spin'}, arrays_keys={'forces': 'REF_forces', 'charges': 'REF_charges'})
2025-08-21 10:16:50.721 INFO: pt_head: KeySpeci

  torch.load(f=checkpoint_info.path, map_location=device),


2025-08-21 10:17:47.714 INFO: Total number of parameters: 5556810
2025-08-21 10:17:47.714 INFO: 
2025-08-21 10:17:47.714 INFO: Using ADAM as parameter optimizer
2025-08-21 10:17:47.714 INFO: Batch size: 2
2025-08-21 10:17:47.714 INFO: Using Exponential Moving Average with decay: 0.99999
2025-08-21 10:17:47.714 INFO: Number of gradient updates: 689780
2025-08-21 10:17:47.714 INFO: Learning rate: 0.0001, weight decay: 1e-08
2025-08-21 10:17:47.714 INFO: UniversalLoss(energy_weight=10.000, forces_weight=0.000, stress_weight=0.000)
2025-08-21 10:17:47.726 INFO: Loading checkpoint: ./checkpoints/mace_T2_including_replay_w2_run-84_epoch-59.pt
2025-08-21 10:17:47.817 INFO: Using gradient clipping with tolerance=1.000
2025-08-21 10:17:47.817 INFO: 
2025-08-21 10:17:47.817 INFO: Started training, reporting errors on validation set
2025-08-21 10:17:47.817 INFO: Loss metrics on validation set


✅ Keys renamed to REF_* and saved to:
/home/phanim/harshitrawat/summer/replay_data/mp_finetuning-mace_T2_mp_replay_run-42_REFkeys.xyz


✅ Header updated:
→ /home/phanim/harshitrawat/summer/replay_data/mp_finetuning-mace_T2_mp_replay_run-42_REFkeys.xyz
