In [None]:
# Now universal MACE finetuning on T2
#!/usr/bin/env python3
import os
import subprocess
import sys

def main():
    # ——— Environment setup ———
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    os.environ["PYTHONPATH"] = "/home/phanim/harshitrawat/mace/mace"

    cmd = [
        "mace_run_train",
        "--name",              "mace_T3_finetune_h200_cn10",
        "--model",             "MACE",
        "--num_interactions",  "2",
        "--foundation_model",  "/home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model",
        "--foundation_model_readout",

        "--train_file",        "/home/phanim/harshitrawat/summer/T1_T2_T3_data/T3_chgnet_labeled.extxyz",
        "--train_file",        "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_test.extxyz",
        
        "--valid_file",        "/home/phanim/harshitrawat/summer/T1_T2_T3_data/T2_chgnet_labeled.extxyz",
        "--valid_file",        "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_val.extxyz",

        "--batch_size",        "2",
        "--valid_batch_size",  "1",
        "--valid_fraction",    "0.1",

        "--device",            "cuda",

        # === Loss function weights ===
        "--forces_weight",     "50",         # Increased force weight to balance energy better
        "--energy_weight",     "75",         # Reduced from 100 → avoid dominance + stabilize energy RMSE

        # === Learning setup ===
        "--lr",                "0.001",      # Explicit learning rate (0.0001 is too low → stagnation)
        "--scheduler_patience","4",          # Reduce LR if val loss doesn’t improve in 3 epochs
        #"--clip_grad",         "10",        # Avoid exploding gradients — essential when energy_weight is high
        #"--weight_decay",      "1e-8",       # Mild regularization to prevent overfitting

        # === EMA helps smooth loss curve ===
        #"--ema_decay",         "0.999",     # Smooths validation loss and helps final convergence

        # === Domain + training settings ===
        "--r_max",             "5.0",
        "--max_num_epochs",    "130",
        "--E0s",               "average",    # Still allowed — could optionally be replaced by manual E0s
        "--seed",              "21",
        "--patience",     "5",

        "--restart_latest",                   # Resumes from checkpoint if available
    ]

    print("Running:", " \\\n    ".join(cmd), file=sys.stderr)
    subprocess.run(cmd, check=True)

if __name__ == "__main__":
    main()


Running: mace_run_train \
    --name \
    mace_T3_finetune_h200_cn10 \
    --model \
    MACE \
    --num_interactions \
    2 \
    --foundation_model \
    /home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model \
    --foundation_model_readout \
    --train_file \
    /home/phanim/harshitrawat/summer/T1_T2_T3_data/T3_chgnet_labeled.extxyz \
    --train_file \
    /home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_test.extxyz \
    --valid_file \
    /home/phanim/harshitrawat/summer/T1_T2_T3_data/T2_chgnet_labeled.extxyz \
    --valid_file \
    /home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_val.extxyz \
    --batch_size \
    2 \
    --valid_batch_size \
    1 \
    --valid_fraction \
    0.1 \
    --device \
    cuda \
    --forces_weight \
    50 \
    --energy_weight \
    75 \
    --lr \
    0.001 \
    --scheduler_patience \
    4 \
    --r_max \
    5.0 \
    --max_num_epochs \
    130 \
    --E0s \
    average \
    -

2025-08-14 03:33:08.799 INFO: MACE version: 0.3.14
2025-08-14 03:33:09.350 INFO: CUDA version: 12.6, CUDA device: 0


  model_foundation = torch.load(


2025-08-14 03:33:09.877 INFO: Using foundation model /home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model as initial checkpoint.
2025-08-14 03:33:09.878 INFO: Using heads: ['Default']
2025-08-14 03:33:09.878 INFO: Using the key specifications to parse data:
2025-08-14 03:33:09.878 INFO: Default: KeySpecification(info_keys={'energy': 'REF_energy', 'stress': 'REF_stress', 'virials': 'REF_virials', 'dipole': 'dipole', 'head': 'head', 'elec_temp': 'elec_temp', 'total_charge': 'total_charge', 'polarizability': 'polarizability', 'total_spin': 'total_spin'}, arrays_keys={'forces': 'REF_forces', 'charges': 'REF_charges'})
2025-08-14 03:33:10.107 INFO: Training set 1/1 [energy: 1200, stress: 0, virials: 0, dipole components: 0, head: 1200, elec_temp: 0, total_charge: 0, polarizability: 0, total_spin: 0, forces: 1200, charges: 0]
2025-08-14 03:33:10.110 INFO: Total Training set [energy: 1200, stress: 0, virials: 0, dipole components: 0, head: 1200, elec_t

  torch.load(f=checkpoint_info.path, map_location=device),


2025-08-14 03:33:13.842 INFO: Total number of parameters: 894362
2025-08-14 03:33:13.842 INFO: 
2025-08-14 03:33:13.842 INFO: Using ADAM as parameter optimizer
2025-08-14 03:33:13.842 INFO: Batch size: 2
2025-08-14 03:33:13.842 INFO: Number of gradient updates: 78000
2025-08-14 03:33:13.842 INFO: Learning rate: 0.001, weight decay: 5e-07
2025-08-14 03:33:13.843 INFO: WeightedEnergyForcesLoss(energy_weight=75.000, forces_weight=50.000)
2025-08-14 03:33:13.851 INFO: Loading checkpoint: ./checkpoints/mace_T3_finetune_h200_cn10_run-21_epoch-107.pt
2025-08-14 03:33:14.381 INFO: Using gradient clipping with tolerance=10.000
2025-08-14 03:33:14.381 INFO: 
2025-08-14 03:33:14.381 INFO: Started training, reporting errors on validation set
2025-08-14 03:33:14.381 INFO: Loss metrics on validation set


In [45]:
from ase.io import read, write
import numpy as np
import json, os

LABEL_PATH = "/home/phanim/harshitrawat/summer/md/mdlabels_it2.jsonl"
CIF_DIR = "/home/phanim/harshitrawat/summer/md/mdcifs_it2"
OUT_XYZ = "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem.extxyz"

# Load labels from JSONL using only basename for matching
with open(LABEL_PATH) as f:
    label_data = {
        os.path.basename(entry["snapshot_file"]): entry
        for entry in map(json.loads, f)
    }

frames = []
skipped = 0

for fname in sorted(os.listdir(CIF_DIR)):
    if not fname.endswith(".cif"):
        continue
    if fname not in label_data:
        print(f"[WARN] No label found for: {fname}")
        skipped += 1
        continue

    cif_path = os.path.join(CIF_DIR, fname)
    try:
        atoms = read(cif_path)
    except Exception as e:
        print(f"[ERROR] Failed to read {fname}: {e}")
        continue

    meta = label_data[fname]

    # Set energy and forces
    atoms.info["REF_energy"] = meta["energy_eV"]
    atoms.arrays["REF_forces"] = np.array(meta["forces_per_atom_eV_per_A"])
    atoms.info["snapshot_file"] = fname  # traceability

    frames.append(atoms)

# Write the final .extxyz file
write(OUT_XYZ, frames)
print(f"\n[✓] Saved {len(frames)} structures to {OUT_XYZ}")
if skipped > 0:
    print(f"[i] Skipped {skipped} CIFs due to missing labels.")

# Optional: Check for label entries with no matching CIFs
label_fnames = set(label_data.keys())
cif_fnames = set(os.listdir(CIF_DIR))
unused_labels = label_fnames - cif_fnames
if unused_labels:
    print(f"[INFO] {len(unused_labels)} labels had no matching CIF file:")
    for name in sorted(unused_labels)[:5]:
        print("  ", name)
    if len(unused_labels) > 5:
        print("  ...")



[✓] Saved 3000 structures to /home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem.extxyz


In [46]:
# Preview a labeled structure
from ase.io import read
atoms = read(OUT_XYZ, index=0)
print("File:", atoms.info["snapshot_file"])
print("Energy:", atoms.info["REF_energy"])
print("Forces shape:", atoms.arrays["REF_forces"].shape)
print("Formula:", atoms.get_chemical_formula())


File: 47f7ac20-5f9d-577f-87ae-0d21207606bf__T360K__step1000.cif
Energy: 10.264266014099121
Forces shape: (3, 3)
Formula: Li2O


In [47]:
from ase.io import read, write
import numpy as np

IN_FILE = "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem.extxyz"
OUT_VAL = "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_val.extxyz"
OUT_TEST = "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_test.extxyz"

# Load all structures
frames = read(IN_FILE, ":")

# Shuffle with fixed seed
np.random.seed(42)
indices = np.random.permutation(len(frames))

# 90% val, 10% test
split_at = int(0.6 * len(frames))
val_indices = indices[:split_at]
test_indices = indices[split_at:]

val_frames = [frames[i] for i in val_indices]
test_frames = [frames[i] for i in test_indices]

# Save them
write(OUT_VAL, val_frames, format="extxyz", write_info=True, write_results=True)
write(OUT_TEST, test_frames, format="extxyz", write_info=True, write_results=True)

print(f"[✓] Saved {len(val_frames)} to binary_elem_val.extxyz")
print(f"[✓] Saved {len(test_frames)} to binary_elem_test.extxyz")


[✓] Saved 1800 to binary_elem_val.extxyz
[✓] Saved 1200 to binary_elem_test.extxyz
