In [1]:
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

from pathlib import Path

from jmp.configs.finetune.jmp_l import jmp_l_ft_config_
from jmp.configs.finetune.rmd17 import jmp_l_rmd17_config_
from jmp.tasks.finetune.base import FinetuneConfigBase, FinetuneModelBase
from jmp.tasks.finetune.rmd17 import RMD17Config, RMD17Model

ckpt_path = Path("/mnt/shared/checkpoints/fm_gnoc_large_2_epoch.ckpt")
base_path = Path("/mnt/shared/datasets/rmd17/")

# We create a list of all configurations that we want to run.
configs: list[tuple[FinetuneConfigBase, type[FinetuneModelBase]]] = []

config = RMD17Config.draft()
jmp_l_ft_config_(config, ckpt_path)  # This loads the base JMP-L fine-tuning config
# This loads the rMD17-specific configuration
jmp_l_rmd17_config_(config, "aspirin", base_path)
config = config.finalize()  # Actually construct the config object
print(config)

configs.append((config, RMD17Model))

No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!
Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'dgl'
Skipped loading modules with transformers dependency. No module named 'transformers'
cannot import name 'HuggingFaceModel' from 'deepchem.models.torch_models' (/opt/conda/envs/jmp/lib/python3.11/site-packages/deepchem/models/torch_models/__init__.py)
Skipped loading some Jax models, missing a dependency. No module named 'jax'
Skipped loading some PyTorch models, missing a dependency. No module named 'tensorflow'


id='mpo5imfg' trainer=TrainerConfig(optimizer=OptimizerConfig(log_grad_norm=True, gradient_clipping=GradientClippingConfig(value=1.0, algorithm='value')), supports_skip_batch_exception=False, supports_parameter_hooks=False, set_float32_matmul_precision='medium', precision='16-mixed', max_epochs=100000, max_time='07:00:00:00', inference_mode=False) meta={'ckpt_path': PosixPath('/mnt/shared/checkpoints/fm_gnoc_large_2_epoch.ckpt'), 'ema_backbone': True} train_dataset=FinetuneLmdbDatasetConfig(src=PosixPath('/mnt/shared/datasets/rmd17/lmdb/aspirin/train'), metadata_path=PosixPath('/mnt/shared/datasets/rmd17/lmdb/aspirin/train/metadata.npz')) val_dataset=FinetuneLmdbDatasetConfig(src=PosixPath('/mnt/shared/datasets/rmd17/lmdb/aspirin/val'), metadata_path=PosixPath('/mnt/shared/datasets/rmd17/lmdb/aspirin/val/metadata.npz')) test_dataset=FinetuneLmdbDatasetConfig(src=PosixPath('/mnt/shared/datasets/rmd17/lmdb/aspirin/test'), metadata_path=PosixPath('/mnt/shared/datasets/rmd17/lmdb/aspirin/t

/workspaces/repositories/fm/src/jmp/lightning/model/config.py:709: BaseConfig._rng is None. The generated IDs will not be reproducible. To fix this, call BaseConfig.set_seed(...) before generating any IDs.


In [2]:
from jmp.lightning import Runner, Trainer
from jmp.utils.finetune_state_dict import (
    filter_state_dict,
    retreive_state_dict_for_finetuning,
)


def run(config: FinetuneConfigBase, model_cls: type[FinetuneModelBase]) -> None:
    if (ckpt_path := config.meta.get("ckpt_path")) is None:
        raise ValueError("No checkpoint path provided")

    model = model_cls(config)

    # Load the checkpoint
    state_dict = retreive_state_dict_for_finetuning(
        ckpt_path, load_emas=config.meta.get("ema_backbone", False)
    )
    embedding = filter_state_dict(state_dict, "embedding.atom_embedding.")
    backbone = filter_state_dict(state_dict, "backbone.")
    model.load_backbone_state_dict(backbone=backbone, embedding=embedding, strict=True)

    trainer = Trainer(config)
    trainer.fit(model)


runner = Runner(run)
runner.fast_dev_run(configs)

Fast dev run:   0%|          | 0/1 [00:00<?, ?it/s]

Failed to import rich. Falling back to default Python logging.
CRITICAL:jmp.lightning.trainer.trainer:Setting config.trainer.default_root_dir='/workspaces/repositories/fm/config/lightning_logs/3aghcddl'.
Seed set to 0
CRITICAL:jmp.lightning.util.seed:Set global seed to 0.
CRITICAL:jmp.lightning.runner:Auto-wrapping run in Trainer context
CRITICAL:jmp.tasks.finetune.base:Using regular backbone


Unrecognized arguments:  dict_keys(['learnable_rbf', 'learnable_rbf_stds', 'unique_basis_per_layer', 'dropout', 'edge_dropout'])


CRITICAL:jmp.tasks.finetune.base:Freezing 0 parameters (0.00%) out of 160,874,752 total parameters (160,874,752 trainable)
CRITICAL:jmp.utils.finetune_state_dict:Loaded 585 EMA parameters
CRITICAL:jmp.utils.finetune_state_dict:Loaded state dict from /mnt/shared/checkpoints/fm_gnoc_large_2_epoch.ckpt
CRITICAL:jmp.utils.state_dict:pattern='out_blocks.0.scale_rbf_F.*' matched keys ['out_blocks.0.scale_rbf_F.scale_factor'], which were ignored during loading.
CRITICAL:jmp.utils.state_dict:pattern='out_blocks.0.seq_forces.*' matched keys ['out_blocks.0.seq_forces.0.dense_mlp.0.linear.weight', 'out_blocks.0.seq_forces.0.dense_mlp.1.linear.weight', 'out_blocks.0.seq_forces.1.dense_mlp.0.linear.weight', 'out_blocks.0.seq_forces.1.dense_mlp.1.linear.weight', 'out_blocks.0.seq_forces.2.dense_mlp.0.linear.weight', 'out_blocks.0.seq_forces.2.dense_mlp.1.linear.weight'], which were ignored during loading.
CRITICAL:jmp.utils.state_dict:pattern='out_blocks.0.dense_rbf_F.*' matched keys ['out_blocks.0.

Training: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.
CRITICAL:jmp.lightning.trainer.trainer:Ran 1 finalizers for Trainer cleanup.
Seed set to 0
CRITICAL:jmp.lightning.util.seed:Reset global seed.


[None]