In [5]:
from pathlib import Path

import nshconfig_extra as CE
import nshtrainer as nt
import nshutils as nu
from jmp.lightning_datamodule import MPTrjAlexOMAT24DataModuleConfig
from jmp.lightning_module import Config, TargetsConfig
from jmp.models.gemnet.graph import (
    CutoffsConfig,
    GraphComputerConfig,
    MaxNeighborsConfig,
)
from jmp.nn.energy_head import EnergyTargetConfig
from jmp.nn.force_head import ForceTargetConfig
from jmp.nn.stress_head import StressTargetConfig

cwd = Path("/net/csefiles/coc-fung-cluster/nima/shared/experiment-data/")
env = {
    "HF_HOME": "/net/csefiles/coc-fung-cluster/nima/shared/cache/huggingface",
}

config = Config.draft()
config.pretrained_ckpt = CE.CachedPath(
    uri="/net/csefiles/coc-fung-cluster/nima/shared/checkpoints/jmp-s.pt"
)
config.graph_computer = GraphComputerConfig(
    cutoffs=CutoffsConfig.from_constant(8.0),
    max_neighbors=MaxNeighborsConfig(main=20, aeaint=20, aint=1000, qint=8),
    pbc=True,
    per_graph_radius_graph=True,
)
config.optimizer = nt.config.AdamWConfig(lr=5.0e-5, weight_decay=0.001)
config.lr_scheduler = nt.config.LinearWarmupCosineDecayLRSchedulerConfig(
    warmup_duration=nt.config.StepsConfig(value=5000),
    warmup_start_lr_factor=0.001,
    max_duration=nt.config.StepsConfig(value=500_000),
    min_lr_factor=0.1,
)
config.targets = TargetsConfig(
    energy=EnergyTargetConfig(max_atomic_number=120),
    force=ForceTargetConfig(),
    stress=StressTargetConfig(num_layers=5),
)
config.trainer.precision = "16-mixed-auto"
config.trainer.set_float32_matmul_precision = "medium"
config = config.finalize()
nu.display(config)

data_config = MPTrjAlexOMAT24DataModuleConfig.draft()
data_config.batch_size = 100
data_config.num_workers = 8
data_config.salex.local_path = Path("/storage/nima/salex-ocp/hf/")
data_config.omat24.local_path = Path("/storage/nima/omat24/hf/")
data_config.with_linear_reference_("mptrj-salex")
data_config = data_config.finalize()
nu.display(data_config)

In [6]:
from jmp.lightning_datamodule import MPTrjAlexOMAT24DataModule
from jmp.lightning_module import Module


def run(config: Config, data_config: MPTrjAlexOMAT24DataModuleConfig):
    module = Module(config)
    datamodule = MPTrjAlexOMAT24DataModule(data_config)
    trainer = nt.Trainer(config)
    trainer.fit(module, datamodule)

In [None]:
import nshrunner as nr

configs = [(config.fast_dev_run(256), data_config)]

runner = nr.Runner(run, nr.RunnerConfig(working_dir=".", env=env))
runner.local(configs)

  0%|          | 0/1 [00:00<?, ?it/s]

Seed set to 0
INFO:jmp.models.gemnet.layers.radial_basis_dynamic_cutoff:[RadialBasis] Using absolute cutoff of 12.0 Angstroms.
INFO:jmp.models.gemnet.layers.radial_basis_dynamic_cutoff:[RadialBasis] Using absolute cutoff of 12.0 Angstroms.
INFO:jmp.models.gemnet.layers.radial_basis_dynamic_cutoff:[RadialBasis] Using absolute cutoff of 12.0 Angstroms.
INFO:jmp.models.gemnet.layers.radial_basis_dynamic_cutoff:[RadialBasis] Using absolute cutoff of 12.0 Angstroms.


Unrecognized arguments:  dict_keys(['ln', 'dropout', 'replace_scale_factors_with_ln', 'learnable_rbf', 'learnable_rbf_stds', 'unique_basis_per_layer', 'old_gaussian_implementation', 'edge_dropout'])


INFO:jmp.models.gemnet.layers.radial_basis_dynamic_cutoff:[RadialBasis] Using absolute cutoff of 12.0 Angstroms.
INFO:jmp.models.gemnet.layers.radial_basis_dynamic_cutoff:[RadialBasis] Using absolute cutoff of 12.0 Angstroms.
INFO:jmp.models.gemnet.layers.radial_basis_dynamic_cutoff:[RadialBasis] Using absolute cutoff of 12.0 Angstroms.
CRITICAL:root:Found the following scale factors: [('int_blocks.0.trip_interaction.scale_rbf', 'int_blocks.0.trip_interaction.scale_rbf'), ('int_blocks.0.trip_interaction.scale_cbf_sum', 'int_blocks.0.trip_interaction.scale_cbf_sum'), ('int_blocks.0.quad_interaction.scale_rbf', 'int_blocks.0.quad_interaction.scale_rbf'), ('int_blocks.0.quad_interaction.scale_cbf', 'int_blocks.0.quad_interaction.scale_cbf'), ('int_blocks.0.quad_interaction.scale_sbf_sum', 'int_blocks.0.quad_interaction.scale_sbf_sum'), ('int_blocks.0.atom_edge_interaction.scale_rbf', 'int_blocks.0.atom_edge_interaction.scale_rbf'), ('int_blocks.0.atom_edge_interaction.scale_cbf_sum', 'int

Loading dataset from disk:   0%|          | 0/212 [00:00<?, ?it/s]

Loading dataset from disk:   0%|          | 0/212 [00:00<?, ?it/s]

CRITICAL:nshtrainer.callbacks.debug_flag:Fast dev run detected, setting debug flag to True.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
Loading `train_dataloader` to estimate number of stepping batches.


Loading dataset from disk:   0%|          | 0/212 [00:00<?, ?it/s]


  | Name           | Type              | Params | Mode 
-------------------------------------------------------------
0 | backbone       | GemNetOCBackbone  | 38.9 M | train
1 | energy_head    | EnergyOutputHead  | 263 K  | train
2 | force_head     | ForceOutputHead   | 1.1 M  | train
3 | stress_head    | StressOutputHead  | 2.1 M  | train
4 | graph_computer | GraphComputer     | 0      | train
5 | train_metrics  | ForceFieldMetrics | 0      | train
6 | val_metrics    | ForceFieldMetrics | 0      | train
7 | test_metrics   | ForceFieldMetrics | 0      | train
-------------------------------------------------------------
42.3 M    Trainable params
0         Non-trainable params
42.3 M    Total params
169.161   Total estimated model params size (MB)
INFO:nshtrainer.trainer.signal_connector:No auto-requeue signals found. Reverting to default Lightning behavior.


Loading dataset from disk:   0%|          | 0/212 [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]