In [9]:
import os
import shutil

from equinet.utils import (
    load_args,
    load_checkpoint,
    load_scalers,
    save_checkpoint,
)

# üîÅ Set the version you want to retro-tag with
TARGET_VERSION = "0.2.0"

# üìÇ Base directory where your .pt files live
BASE_DIR = r"C:\Users\alamz\Desktop\Servers\zenodofiles\Models\EquiNet\ANTOINE-NRTL\Without Self Activity"  # <-- change this


def update_one_checkpoint(path: str, dry_run: bool = False):
    print(f"\n=== Updating checkpoint: {path} ===")

    # 1) Load components
    args = load_args(path)
    model = load_checkpoint(path, device=None)  # adjust device arg if needed
    scalers = load_scalers(path)  # should return list of 6 scalers

    # 2) Set versions on args and model
    old_args_version = getattr(args, "version", None)
    old_model_version = getattr(model, "version", None)

    args.version = TARGET_VERSION
    setattr(model, "version", TARGET_VERSION)

    print(f"  args.version:  {old_args_version!r} -> {args.version!r}")
    print(f"  model.version: {old_model_version!r} -> {model.version!r}")

    if dry_run:
        print("  [DRY RUN] Not writing changes.")
        return

    # 3) Make a backup
    backup_path = path + ".bak"
    if not os.path.exists(backup_path):
        shutil.copy2(path, backup_path)
        print(f"  Backup written to {backup_path}")
    else:
        print(f"  Backup already exists at {backup_path}, not overwriting.")

    # 4) Save updated checkpoint
    # save_checkpoint(path, model, scaler1, scaler2, ..., args)
    # scalers is expected to be a list of 6 scalers.
    save_checkpoint(path, model, *scalers, args)
    print("  ‚úÖ Checkpoint updated.")


def main():
    # üîç 1) FIRST: test on a single file you don‚Äôt care about
    # For safety, set SINGLE_PATH to a specific file and run once.
    SINGLE_PATH = None  # e.g. "/home/.../model_0/test_model.pt"

    if SINGLE_PATH is not None:
        update_one_checkpoint(SINGLE_PATH, dry_run=False)
        return

    # üîÅ 2) Once you're confident, walk a tree of checkpoints
    for root, _, files in os.walk(BASE_DIR):
        for fname in files:
            if not fname.endswith(".pt"):
                continue
            full_path = os.path.join(root, fname)
            update_one_checkpoint(full_path, dry_run=False)


if __name__ == "__main__":
    main()



=== Updating checkpoint: C:\Users\alamz\Desktop\Servers\zenodofiles\Models\EquiNet\ANTOINE-NRTL\Without Self Activity\model.pt ===
Loading pretrained parameter "encoder.encoder.0.cached_zero_vector".
Loading pretrained parameter "encoder.encoder.0.W_i.weight".
Loading pretrained parameter "encoder.encoder.0.W_h.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.bias".
Loading pretrained parameter "encoder.encoder.1.cached_zero_vector".
Loading pretrained parameter "encoder.encoder.1.W_i.weight".
Loading pretrained parameter "encoder.encoder.1.W_h.weight".
Loading pretrained parameter "encoder.encoder.1.W_o.weight".
Loading pretrained parameter "encoder.encoder.1.W_o.bias".
Loading pretrained parameter "readout.1.weight".
Loading pretrained parameter "readout.1.bias".
Loading pretrained parameter "readout.4.weight".
Loading pretrained parameter "readout.4.bias".
Loading pretrained parameter "readout.7.weight".
Loadi

In [12]:
import torch
from equinet.utils import load_args, load_checkpoint

path = r"C:\Users\alamz\Desktop\Servers\zenodofiles\test\model.pt"
args = load_args(path)
model = load_checkpoint(path, device=None)

print("args.version:", getattr(args, "version", None))
print("model.version:", getattr(model, "version", None))


Loading pretrained parameter "encoder.encoder.0.cached_zero_vector".
Loading pretrained parameter "encoder.encoder.0.W_i.weight".
Loading pretrained parameter "encoder.encoder.0.W_h.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.bias".
Loading pretrained parameter "encoder.encoder.1.cached_zero_vector".
Loading pretrained parameter "encoder.encoder.1.W_i.weight".
Loading pretrained parameter "encoder.encoder.1.W_h.weight".
Loading pretrained parameter "encoder.encoder.1.W_o.weight".
Loading pretrained parameter "encoder.encoder.1.W_o.bias".
Loading pretrained parameter "readout.1.weight".
Loading pretrained parameter "readout.1.bias".
Loading pretrained parameter "readout.4.weight".
Loading pretrained parameter "readout.4.bias".
Loading pretrained parameter "intrinsic_vp.1.weight".
Loading pretrained parameter "intrinsic_vp.1.bias".
Loading pretrained parameter "intrinsic_vp.4.weight".
Loading pretrained paramet

In [13]:
import os
from equinet.utils import load_args, load_checkpoint, load_scalers, save_checkpoint

# <-- put the exact path you used in your tester here
ckpt_path = r"C:\Users\alamz\Desktop\Servers\zenodofiles\test\model.pt"

print("Loading from:", ckpt_path)
args = load_args(ckpt_path)
model = load_checkpoint(ckpt_path, device=None)
scalers = load_scalers(ckpt_path)

print("BEFORE:")
print("  args.version  :", getattr(args, "version", "<no attr>"))
print("  model.version :", getattr(model, "version", "<no attr>"))

# Manually set version
if not hasattr(args, "version"):
    args.version = "0.2.0"
if not hasattr(model, "version"):
    model.version = "0.2.0"

if args.version is None:
    args.version = "0.2.0"
if model.version is None:
    model.version = "0.2.0"

print("AFTER setting:")
print("  args.version  :", args.version)
print("  model.version :", model.version)

# Re-save
save_checkpoint(
    path=ckpt_path,
    model=model,
    scaler=scalers[0] if isinstance(scalers, (list, tuple)) else scalers,
    features_scaler=None,
    atom_descriptor_scaler=None,
    bond_descriptor_scaler=None,
    phase_features_scaler=None,
    args=args,
)
print("‚úÖ Checkpoint updated.")


Loading from: C:\Users\alamz\Desktop\Servers\zenodofiles\test\model.pt
Loading pretrained parameter "encoder.encoder.0.cached_zero_vector".
Loading pretrained parameter "encoder.encoder.0.W_i.weight".
Loading pretrained parameter "encoder.encoder.0.W_h.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.bias".
Loading pretrained parameter "encoder.encoder.1.cached_zero_vector".
Loading pretrained parameter "encoder.encoder.1.W_i.weight".
Loading pretrained parameter "encoder.encoder.1.W_h.weight".
Loading pretrained parameter "encoder.encoder.1.W_o.weight".
Loading pretrained parameter "encoder.encoder.1.W_o.bias".
Loading pretrained parameter "readout.1.weight".
Loading pretrained parameter "readout.1.bias".
Loading pretrained parameter "readout.4.weight".
Loading pretrained parameter "readout.4.bias".
Loading pretrained parameter "intrinsic_vp.1.weight".
Loading pretrained parameter "intrinsic_vp.1.bias".
Loading p

TypeError: save_checkpoint() got an unexpected keyword argument 'phase_features_scaler'