In [3]:
# Now universal MACE finetuning on T2
#!/usr/bin/env python3
import os
import subprocess
import sys

def main():
    # ——— Environment setup ———
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    os.environ["PYTHONPATH"] = "/home/phanim/harshitrawat/mace/mace"

    cmd = [
        "mace_run_train",
        "--name",              "mace_T3_finetune_h200_cn10",
        "--model",             "MACE",
        "--num_interactions",  "2",
        "--foundation_model",  "/home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model",
        "--foundation_model_readout",

        "--train_file",        "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_test.extxyz",
        
        "--valid_file",        "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_val.extxyz",

        "--batch_size",        "2",
        "--valid_batch_size",  "1",
        "--valid_fraction",    "0.1",

        "--device",            "cuda",

        # === Loss function weights ===
        "--forces_weight",     "50",         # Increased force weight to balance energy better
        "--energy_weight",     "25",         # Reduced from 100 → avoid dominance + stabilize energy RMSE

        # === Learning setup ===
        "--lr",                "0.001",      # Explicit learning rate (0.0001 is too low → stagnation)
        "--scheduler_patience",    "8",          # Reduce LR if val loss doesn’t improve in 3 epochs
        "--clip_grad",         "10",        # Avoid exploding gradients — essential when energy_weight is high
        "--weight_decay",      "1e-8",       # Mild regularization to prevent overfitting

        # === EMA helps smooth loss curve ===
        #"--ema_decay",         "0.999",     # Smooths validation loss and helps final convergence

        # === Domain + training settings ===
        "--r_max",             "5.0",
        "--max_num_epochs",    "130",
        "--E0s",               "average",    # Still allowed — could optionally be replaced by manual E0s
        "--seed",              "21",
        "--patience",     "10",

        "--restart_latest",                   # Resumes from checkpoint if available
    ]

    print("Running:", " \\\n    ".join(cmd), file=sys.stderr)
    subprocess.run(cmd, check=True)

if __name__ == "__main__":
    main()


Running: mace_run_train \
    --name \
    mace_T3_finetune_h200_cn10 \
    --model \
    MACE \
    --num_interactions \
    2 \
    --foundation_model \
    /home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model \
    --foundation_model_readout \
    --train_file \
    /home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_test.extxyz \
    --valid_file \
    /home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_val.extxyz \
    --batch_size \
    2 \
    --valid_batch_size \
    1 \
    --valid_fraction \
    0.1 \
    --device \
    cuda \
    --forces_weight \
    50 \
    --energy_weight \
    25 \
    --lr \
    0.001 \
    --scheduler_patience \
    8 \
    --clip_grad \
    10 \
    --weight_decay \
    1e-8 \
    --r_max \
    5.0 \
    --max_num_epochs \
    130 \
    --E0s \
    average \
    --seed \
    21 \
    --patience \
    10 \
    --restart_latest
  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirnam

2025-08-14 21:42:50.954 INFO: MACE version: 0.3.14
2025-08-14 21:42:51.661 INFO: CUDA version: 12.6, CUDA device: 0


  model_foundation = torch.load(


2025-08-14 21:42:52.339 INFO: Using foundation model /home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model as initial checkpoint.
2025-08-14 21:42:52.340 INFO: Using heads: ['Default']
2025-08-14 21:42:52.340 INFO: Using the key specifications to parse data:
2025-08-14 21:42:52.340 INFO: Default: KeySpecification(info_keys={'energy': 'REF_energy', 'stress': 'REF_stress', 'virials': 'REF_virials', 'dipole': 'dipole', 'head': 'head', 'elec_temp': 'elec_temp', 'total_charge': 'total_charge', 'polarizability': 'polarizability', 'total_spin': 'total_spin'}, arrays_keys={'forces': 'REF_forces', 'charges': 'REF_charges'})
2025-08-14 21:42:52.598 INFO: Training set 1/1 [energy: 1200, stress: 0, virials: 0, dipole components: 0, head: 1200, elec_temp: 0, total_charge: 0, polarizability: 0, total_spin: 0, forces: 1200, charges: 0]
2025-08-14 21:42:52.602 INFO: Total Training set [energy: 1200, stress: 0, virials: 0, dipole components: 0, head: 1200, elec_t



2025-08-14 21:42:56.352 INFO: Total number of parameters: 894362
2025-08-14 21:42:56.352 INFO: 
2025-08-14 21:42:56.352 INFO: Using ADAM as parameter optimizer
2025-08-14 21:42:56.352 INFO: Batch size: 2
2025-08-14 21:42:56.352 INFO: Number of gradient updates: 78000
2025-08-14 21:42:56.352 INFO: Learning rate: 0.001, weight decay: 1e-08
2025-08-14 21:42:56.352 INFO: WeightedEnergyForcesLoss(energy_weight=25.000, forces_weight=50.000)
2025-08-14 21:42:56.509 INFO: Loading checkpoint: ./checkpoints/mace_T3_finetune_h200_cn10_run-21_epoch-112.pt
2025-08-14 21:42:56.543 INFO: Using gradient clipping with tolerance=10.000
2025-08-14 21:42:56.543 INFO: 
2025-08-14 21:42:56.543 INFO: Started training, reporting errors on validation set
2025-08-14 21:42:56.543 INFO: Loss metrics on validation set


  torch.load(f=checkpoint_info.path, map_location=device),


2025-08-14 21:44:14.802 INFO: Initial: head: Default, loss=15211.38385703, RMSE_E_per_atom=  763.95 meV, RMSE_F=20340.81 meV / A
2025-08-14 21:46:00.373 INFO: Epoch 112: head: Default, loss=16184.11430121, RMSE_E_per_atom=  842.42 meV, RMSE_F=21603.70 meV / A
2025-08-14 21:47:43.576 INFO: Epoch 113: head: Default, loss=32196.48533347, RMSE_E_per_atom=  661.55 meV, RMSE_F=34615.29 meV / A
2025-08-14 21:49:26.379 INFO: Epoch 114: head: Default, loss=24416.95996169, RMSE_E_per_atom=  776.69 meV, RMSE_F=29142.52 meV / A
2025-08-14 21:51:09.169 INFO: Epoch 115: head: Default, loss=82011.67640299, RMSE_E_per_atom=  985.35 meV, RMSE_F=34935.34 meV / A


Traceback (most recent call last):
  File [35m"/home/phanim/harshitrawat/miniconda3/bin/mace_run_train"[0m, line [35m8[0m, in [35m<module>[0m
    sys.exit([31mmain[0m[1;31m()[0m)
             [31m~~~~[0m[1;31m^^[0m
  File [35m"/home/phanim/harshitrawat/miniconda3/lib/python3.13/site-packages/mace/cli/run_train.py"[0m, line [35m77[0m, in [35mmain[0m
    [31mrun[0m[1;31m(args)[0m
    [31m~~~[0m[1;31m^^^^^^[0m
  File [35m"/home/phanim/harshitrawat/miniconda3/lib/python3.13/site-packages/mace/cli/run_train.py"[0m, line [35m837[0m, in [35mrun[0m
    [31mtools.train[0m[1;31m([0m
    [31m~~~~~~~~~~~[0m[1;31m^[0m
        [1;31mmodel=model,[0m
        [1;31m^^^^^^^^^^^^[0m
    ...<23 lines>...
        [1;31mrank=rank,[0m
        [1;31m^^^^^^^^^^[0m
    [1;31m)[0m
    [1;31m^[0m
  File [35m"/home/phanim/harshitrawat/miniconda3/lib/python3.13/site-packages/mace/tools/train.py"[0m, line [35m261[0m, in [35mtrain[0m
    valid_loss_head, eva

KeyboardInterrupt: 

In [None]:
# Now universal MACE finetuning on T2
#!/usr/bin/env python3
import os
import subprocess
import sys

def main():
    # ——— Environment setup ———
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    os.environ["PYTHONPATH"] = "/home/phanim/harshitrawat/mace/mace"

    cmd = [
        "mace_run_train",
        "--name",              "mace_T3_finetune_h200_cn10",
        "--model",             "MACE",
        "--num_interactions",  "2",
        "--foundation_model",  "/home/phanim/harshitrawat/summer/mace_models/universal/2024-01-07-mace-128-L2_epoch-199.model",
        "--foundation_model_readout",

        "--train_file",        "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_test.extxyz",
        
        "--valid_file",        "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_val.extxyz",

        "--batch_size",        "2",
        "--valid_batch_size",  "1",
        "--valid_fraction",    "0.1",

        "--device",            "cuda",

        # === Loss function weights ===
        "--forces_weight",     "50",         # Increased force weight to balance energy better
        "--energy_weight",     "25",         # Reduced from 100 → avoid dominance + stabilize energy RMSE

        # === Learning setup ===
        "--lr",                "0.001",      # Explicit learning rate (0.0001 is too low → stagnation)
        "--scheduler_patience",    "8",          # Reduce LR if val loss doesn’t improve in 3 epochs
        "--clip_grad",         "10",        # Avoid exploding gradients — essential when energy_weight is high
        "--weight_decay",      "1e-8",       # Mild regularization to prevent overfitting

        # === EMA helps smooth loss curve ===
        #"--ema_decay",         "0.999",     # Smooths validation loss and helps final convergence

        # === Domain + training settings ===
        "--r_max",             "5.0",
        "--max_num_epochs",    "130",
        "--E0s",               "average",    # Still allowed — could optionally be replaced by manual E0s
        "--seed",              "21",
        "--patience",     "10",

        "--restart_latest",                   # Resumes from checkpoint if available
    ]

    print("Running:", " \\\n    ".join(cmd), file=sys.stderr)
    subprocess.run(cmd, check=True)

if __name__ == "__main__":
    main()


In [45]:
from ase.io import read, write
import numpy as np
import json, os

LABEL_PATH = "/home/phanim/harshitrawat/summer/md/mdlabels_it2.jsonl"
CIF_DIR = "/home/phanim/harshitrawat/summer/md/mdcifs_it2"
OUT_XYZ = "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem.extxyz"

# Load labels from JSONL using only basename for matching
with open(LABEL_PATH) as f:
    label_data = {
        os.path.basename(entry["snapshot_file"]): entry
        for entry in map(json.loads, f)
    }

frames = []
skipped = 0

for fname in sorted(os.listdir(CIF_DIR)):
    if not fname.endswith(".cif"):
        continue
    if fname not in label_data:
        print(f"[WARN] No label found for: {fname}")
        skipped += 1
        continue

    cif_path = os.path.join(CIF_DIR, fname)
    try:
        atoms = read(cif_path)
    except Exception as e:
        print(f"[ERROR] Failed to read {fname}: {e}")
        continue

    meta = label_data[fname]

    # Set energy and forces
    atoms.info["REF_energy"] = meta["energy_eV"]
    atoms.arrays["REF_forces"] = np.array(meta["forces_per_atom_eV_per_A"])
    atoms.info["snapshot_file"] = fname  # traceability

    frames.append(atoms)

# Write the final .extxyz file
write(OUT_XYZ, frames)
print(f"\n[✓] Saved {len(frames)} structures to {OUT_XYZ}")
if skipped > 0:
    print(f"[i] Skipped {skipped} CIFs due to missing labels.")

# Optional: Check for label entries with no matching CIFs
label_fnames = set(label_data.keys())
cif_fnames = set(os.listdir(CIF_DIR))
unused_labels = label_fnames - cif_fnames
if unused_labels:
    print(f"[INFO] {len(unused_labels)} labels had no matching CIF file:")
    for name in sorted(unused_labels)[:5]:
        print("  ", name)
    if len(unused_labels) > 5:
        print("  ...")



[✓] Saved 3000 structures to /home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem.extxyz


In [46]:
# Preview a labeled structure
from ase.io import read
atoms = read(OUT_XYZ, index=0)
print("File:", atoms.info["snapshot_file"])
print("Energy:", atoms.info["REF_energy"])
print("Forces shape:", atoms.arrays["REF_forces"].shape)
print("Formula:", atoms.get_chemical_formula())


File: 47f7ac20-5f9d-577f-87ae-0d21207606bf__T360K__step1000.cif
Energy: 10.264266014099121
Forces shape: (3, 3)
Formula: Li2O


In [47]:
from ase.io import read, write
import numpy as np

IN_FILE = "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem.extxyz"
OUT_VAL = "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_val.extxyz"
OUT_TEST = "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_test.extxyz"

# Load all structures
frames = read(IN_FILE, ":")

# Shuffle with fixed seed
np.random.seed(42)
indices = np.random.permutation(len(frames))

# 90% val, 10% test
split_at = int(0.6 * len(frames))
val_indices = indices[:split_at]
test_indices = indices[split_at:]

val_frames = [frames[i] for i in val_indices]
test_frames = [frames[i] for i in test_indices]

# Save them
write(OUT_VAL, val_frames, format="extxyz", write_info=True, write_results=True)
write(OUT_TEST, test_frames, format="extxyz", write_info=True, write_results=True)

print(f"[✓] Saved {len(val_frames)} to binary_elem_val.extxyz")
print(f"[✓] Saved {len(test_frames)} to binary_elem_test.extxyz")


[✓] Saved 1800 to binary_elem_val.extxyz
[✓] Saved 1200 to binary_elem_test.extxyz


In [50]:
from ase.io import read, write

# Input files
files_to_merge = [
    "/home/phanim/harshitrawat/summer/T1_T2_T3_data/T3_chgnet_labeled.extxyz",
    "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_test.extxyz"
]

# Output file
output_file = "/home/phanim/harshitrawat/summer/T1_T2_T3_data/all_train.extxyz"

# Read all frames
all_frames = []
for f in files_to_merge:
    all_frames.extend(read(f, index=":"))  # ":" reads all frames

# Write merged file
write(output_file, all_frames)
print(f"[✓] Merged {sum(len(read(f, index=':')) for f in files_to_merge)} frames into {output_file}")


[✓] Merged 2812 frames into /home/phanim/harshitrawat/summer/T1_T2_T3_data/all_train.extxyz


In [54]:
from ase.io import read, write

# === TRAIN MERGE ===
train_files = [
    "/home/phanim/harshitrawat/summer/T1_T2_T3_data/T3_chgnet_labeled.extxyz",
    "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_test.extxyz"
]
train_output = "/home/phanim/harshitrawat/summer/T1_T2_T3_data/T2_it_2.extxyz"

train_frames = []
for f in train_files:
    frames = read(f, index=":")
    train_frames.extend(frames)
    print(f"[✓] Read {len(frames)} frames from {f}")
write(train_output, train_frames)
print(f"[✅] Wrote {len(train_frames)} total frames to {train_output}\n")


# === VALIDATION MERGE ===
valid_files = [
    "/home/phanim/harshitrawat/summer/T1_T2_T3_data/T2_chgnet_labeled.extxyz",
    "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_val.extxyz"
]
valid_output = "/home/phanim/harshitrawat/summer/T1_T2_T3_data/val_it2.extxyz"

valid_frames = []
for f in valid_files:
    frames = read(f, index=":")
    valid_frames.extend(frames)
    print(f"[✓] Read {len(frames)} frames from {f}")
write(valid_output, valid_frames)
print(f"[✅] Wrote {len(valid_frames)} total frames to {valid_output}")


[✓] Read 1612 frames from /home/phanim/harshitrawat/summer/T1_T2_T3_data/T3_chgnet_labeled.extxyz
[✓] Read 1200 frames from /home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_test.extxyz
[✅] Wrote 2812 total frames to /home/phanim/harshitrawat/summer/T1_T2_T3_data/T2_it_2.extxyz

[✓] Read 705 frames from /home/phanim/harshitrawat/summer/T1_T2_T3_data/T2_chgnet_labeled.extxyz
[✓] Read 1800 frames from /home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_val.extxyz
[✅] Wrote 2505 total frames to /home/phanim/harshitrawat/summer/T1_T2_T3_data/val_it2.extxyz


In [55]:
from ase.io import read, write

# === TRAIN MERGE ===
train_files = [
    "/home/phanim/harshitrawat/summer/T1_T2_T3_data/T1_chgnet_labeled.extxyz",
    "/home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_test.extxyz"
]
train_output = "/home/phanim/harshitrawat/summer/T1_T2_T3_data/T1_it_2.extxyz"

train_frames = []
for f in train_files:
    frames = read(f, index=":")
    train_frames.extend(frames)
    print(f"[✓] Read {len(frames)} frames from {f}")
write(train_output, train_frames)
print(f"[✅] Wrote {len(train_frames)} total frames to {train_output}\n")


[✓] Read 6337 frames from /home/phanim/harshitrawat/summer/T1_T2_T3_data/T1_chgnet_labeled.extxyz
[✓] Read 1200 frames from /home/phanim/harshitrawat/summer/T1_T2_T3_data/binary_elem_test.extxyz
[✅] Wrote 7537 total frames to /home/phanim/harshitrawat/summer/T1_T2_T3_data/T1_it_2.extxyz



In [11]:
from chgnet.model.model import CHGNet
from chgnet.model.dynamics import CHGNetCalculator
from pymatgen.core import Element, Structure
from pymatgen.io.ase import AseAtomsAdaptor
import torch
import os
import json

# === CHGNet GPU-safe loading ===
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CHGNet.load(use_device="cpu", verbose=True)  # Safe load
model = model.to(device)
calc = CHGNetCalculator(model=model, use_device=device)

adaptor = AseAtomsAdaptor()

# === CIFs to evaluate ===
elements = {
    "Li": "Li.cif",
    "La": "La.cif",
    "Zr": "Zr.cif",
    "O":  "O.cif"
}

e0s = {}
for el, fname in elements.items():
    if not os.path.exists(fname):
        print(f"❌ Missing CIF: {fname}")
        continue
    struct = Structure.from_file(fname)
    atoms = adaptor.get_atoms(struct)
    atoms.calc = calc
    energy = atoms.get_potential_energy()
    mu = energy / len(atoms)
    e0s[el] = round(mu, 6)
    print(f"{el}: μ_model = {mu:.6f} eV/atom")

# Format for MACE
e0s_z = {Element(k).Z: v for k, v in e0s.items()}
print("\n✅ CHGNet E₀s (use with --E0s in MACE):")
print(json.dumps(e0s_z, indent=2))


  state = torch.load(path, map_location=torch.device("cpu"))


CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on cpu
CHGNet will run on cuda:0
Li: μ_model = -1.882089 eV/atom
La: μ_model = -4.894711 eV/atom
Zr: μ_model = -8.509101 eV/atom


  struct = parser.parse_structures(primitive=primitive)[0]


O: μ_model = -4.913321 eV/atom

✅ CHGNet E₀s (use with --E0s in MACE):
{
  "3": -1.882089,
  "57": -4.894711,
  "40": -8.509101,
  "8": -4.913321
}


In [10]:
from pymatgen.ext.matproj import MPRester
from pymatgen.core import Element
from pymatgen.analysis.phase_diagram import PhaseDiagram

# 🔑 Insert your Materials Project API key here
API_KEY = "j3J85pX4nLw6asHG9E2lbbCHEKDKgrjc"  # <-- paste your key as a string

ELEMENTS = ["Li", "La", "Zr", "O"]

with MPRester(API_KEY) as mpr:
    print("🔍 Fetching elemental entries...")
    entries = mpr.get_entries_in_chemsys(ELEMENTS, inc_structure="final")
    el_entries = [e for e in entries if len(e.composition) == 1]

    pd = PhaseDiagram(el_entries)
    stable = {}

    for el in ELEMENTS:
        try:
            entry = pd.get_stable_entry(Element(el))
            stable[el] = entry
        except:
            print(f"⚠️ Could not find stable entry for {el}")

    for el, entry in stable.items():
        structure = entry.structure
        filename = f"{el}.cif"
        structure.to(fmt="cif", filename=filename)
        print(f"✅ Saved {el}: {filename}")


🔍 Fetching elemental entries...


MPRestError: HTTPSConnectionPool(host='api.materialsproject.org', port=443): Max retries exceeded with url: /materials/thermo/?_fields=entries&chemsys=Li,La,Zr,O,La-Li,Li-Zr,Li-O,La-Zr,La-O,O-Zr,La-Li-Zr,La-Li-O,Li-O-Zr,La-O-Zr,La-Li-O-Zr&_per_page=1000&_page=1 (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x7fef2072e290>: Failed to resolve 'api.materialsproject.org' ([Errno -3] Temporary failure in name resolution)"))

In [8]:
!export MP_API_KEY="j3J85pX4nLw6asHG9E2lbbCHEKDKgrjc"
