In [2]:
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

from pathlib import Path

from jmp.configs.pretrain.jmp_l import jmp_l_pt_config_
from jmp.tasks.pretrain import PretrainConfig, PretrainModel
from jmp.tasks.pretrain.module import (
    NormalizationConfig,
    PretrainDatasetConfig,
    TaskConfig,
)


# Let's make the config
def jmp_l_config():
    config = PretrainConfig.draft()

    jmp_l_pt_config_(config)

    # Set data config
    config.batch_size = 4
    config.num_workers = 8

    # Set the tasks
    config.tasks = [
        TaskConfig(
            name="oc20",
            train_dataset=PretrainDatasetConfig(
                src=Path("/datasets/s2ef/2M/train/"),
                metadata_path=Path("/datasets/s2ef/2M/train_metadata.npz"),
            ),
            val_dataset=PretrainDatasetConfig(
                src=Path("/datasets/s2ef/all/val_id/"),
                metadata_path=Path("/datasets/s2ef/all/val_id_metadata.npz"),
            ),
            energy_loss_scale=1.0,
            force_loss_scale=73.0,
            normalization={
                "y": NormalizationConfig(mean=0.0, std=24.901469505465872),
                "force": NormalizationConfig(mean=0.0, std=0.5111534595489502),
            },
        ),
        TaskConfig(
            name="oc22",
            train_dataset=PretrainDatasetConfig(
                src=Path("/shared/pre-training-datasets/oc22/s2ef-total/train/"),
            ),
            val_dataset=PretrainDatasetConfig(
                src=Path("/shared/pre-training-datasets/oc22/s2ef-total/val_id/"),
            ),
            energy_loss_scale=1.0,
            force_loss_scale=80.0,
            normalization={
                "y": NormalizationConfig(mean=0.0, std=25.229595396538468),
                "force": NormalizationConfig(mean=0.0, std=0.25678861141204834),
            },
        ),
        TaskConfig(
            name="ani1x",
            train_dataset=PretrainDatasetConfig(
                src=Path("/shared/pre-training-datasets/ani1x/train/"),
            ),
            val_dataset=PretrainDatasetConfig(
                src=Path("/shared/pre-training-datasets/ani1x/val/"),
            ),
            energy_loss_scale=1.0,
            force_loss_scale=15.0,
            normalization={
                "y": NormalizationConfig(mean=0.0, std=2.8700712783472118),
                "force": NormalizationConfig(mean=0.0, std=2.131422996520996),
            },
        ),
        TaskConfig(
            name="transition1x",
            train_dataset=PretrainDatasetConfig(
                src=Path("/shared/pre-training-datasets/trans1x/train/"),
            ),
            val_dataset=PretrainDatasetConfig(
                src=Path("/shared/pre-training-datasets/trans1x/val/"),
            ),
            energy_loss_scale=1.0,
            force_loss_scale=14.0,
            normalization={
                "y": NormalizationConfig(mean=0.0, std=1.787466168382901),
                "force": NormalizationConfig(mean=0.0, std=0.3591422140598297),
            },
        ),
    ]

    return config.finalize()


config = jmp_l_config()
print(config)

configs: list[tuple[PretrainConfig, type[PretrainModel]]] = []
configs.append((config, PretrainModel))

id='fy0th22x' trainer=TrainerConfig(optimizer=OptimizerConfig(log_grad_norm=True, gradient_clipping=GradientClippingConfig(value=1.0)), supports_skip_batch_exception=False, supports_parameter_hooks=False, set_float32_matmul_precision='medium', precision='16-mixed', use_distributed_sampler=False) optimizer=AdamWConfig(lr=0.0003, weight_decay=0.1, betas=(0.9, 0.95)) lr_scheduler=LinearWarmupCosineAnnealingSchedulerConfig(warmup_steps=2000, max_epochs=2, warmup_start_lr_factor=0.2, min_lr_factor=0.1) edge_dropout=0.1 backbone=BackboneConfig(num_spherical=7, num_radial=128, num_blocks=4, emb_size_atom=256, emb_size_edge=512, emb_size_trip_in=64, emb_size_trip_out=64, emb_size_quad_in=32, emb_size_quad_out=32, emb_size_aint_in=64, emb_size_aint_out=64, emb_size_rbf=16, emb_size_cbf=16, emb_size_sbf=32, num_before_skip=2, num_after_skip=2, num_concat=1, num_atom=3, num_output_afteratom=3, num_atom_emb_layers=2, direct_forces=True, sbf={'name': 'legendre_outer'}, quad_interaction=True, atom_e



In [3]:
from jmp.lightning import Runner, Trainer


def run(config: PretrainConfig, model_cls: type[PretrainModel]) -> None:
    model = model_cls(config)
    trainer = Trainer(config)
    trainer.fit(model)


runner = Runner(run)
runner.fast_dev_run(configs, n_batches=16)

Fast dev run:   0%|          | 0/1 [00:00<?, ?it/s]

Failed to import rich. Falling back to default Python logging.
CRITICAL:ll.trainer.trainer:Setting config.trainer.default_root_dir='/workspaces/repositories/fm/config/lightning_logs/ns2igr3i'.
Seed set to 0
CRITICAL:ll.util.seed:Set global seed to 0.
CRITICAL:ll.runner:Auto-wrapping run in Trainer context


Unrecognized arguments:  dict_keys(['learnable_rbf', 'learnable_rbf_stds', 'unique_basis_per_layer', 'dropout', 'edge_dropout'])


CRITICAL:ll.trainer.trainer:Disabling loggers because fast_dev_run is enabled.
CRITICAL:ll.trainer.trainer:Setting num_nodes to 1 (no SLURM detected).
CRITICAL:ll.trainer.trainer:LightningTrainer.__init__ with args=() and kwargs={'accelerator': 'auto', 'strategy': 'auto', 'devices': 'auto', 'num_nodes': 1, 'precision': '16-mixed', 'logger': None, 'fast_dev_run': 16, 'max_epochs': None, 'min_epochs': None, 'max_steps': -1, 'min_steps': None, 'max_time': None, 'limit_train_batches': None, 'limit_val_batches': None, 'limit_test_batches': None, 'limit_predict_batches': None, 'overfit_batches': 0.0, 'val_check_interval': None, 'check_val_every_n_epoch': 1, 'num_sanity_val_steps': None, 'log_every_n_steps': 50, 'enable_checkpointing': None, 'enable_progress_bar': None, 'enable_model_summary': None, 'accumulate_grad_batches': 1, 'deterministic': None, 'benchmark': None, 'inference_mode': True, 'use_distributed_sampler': False, 'detect_anomaly': False, 'barebones': False, 'plugins': [], 'sync_

Training: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=16` reached.
CRITICAL:ll.trainer.trainer:Ran 1 finalizers for Trainer cleanup.
Seed set to 0
CRITICAL:ll.util.seed:Reset global seed.


[None]