In [10]:
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

from pathlib import Path

from jmp.configs.finetune.jmp_l import jmp_l_ft_config_
from jmp.configs.finetune.md22 import jmp_l_md22_config_
from jmp.tasks.finetune.base import FinetuneConfigBase, FinetuneModelBase
from jmp.tasks.finetune.md22 import MD22Config, MD22Model

ckpt_path = Path("/data/shared/ishan_stuff/jmp-s.pt")
base_path = Path("/home/sanjeevr/MLFF-distill/data/md22")

# We create a list of all configurations that we want to run.
configs: list[tuple[FinetuneConfigBase, type[FinetuneModelBase]]] = []

config = MD22Config.draft()
jmp_l_ft_config_(config, ckpt_path)  # This loads the base JMP-L fine-tuning config
# This loads the rMD17-specific configuration
jmp_l_md22_config_(config, "Ac-Ala3-NHMe", base_path)
config = config.finalize()  # Actually construct the config object
print(config)

configs.append((config, MD22Model))

id='vjri0m7m' trainer=TrainerConfig(optimizer=OptimizerConfig(log_grad_norm=True, gradient_clipping=GradientClippingConfig(value=1.0, algorithm='value')), supports_skip_batch_exception=False, supports_parameter_hooks=False, set_float32_matmul_precision='medium', precision='16-mixed', max_epochs=500, max_time='07:00:00:00', inference_mode=False) meta={'ckpt_path': PosixPath('/data/shared/ishan_stuff/jmp-s.pt'), 'ema_backbone': True} train_dataset=FinetuneLmdbDatasetConfig(src=PosixPath('/home/sanjeevr/MLFF-distill/data/md22/lmdb/Ac-Ala3-NHMe/train'), metadata_path=PosixPath('/home/sanjeevr/MLFF-distill/data/md22/lmdb/Ac-Ala3-NHMe/train/metadata.npz')) val_dataset=FinetuneLmdbDatasetConfig(src=PosixPath('/home/sanjeevr/MLFF-distill/data/md22/lmdb/Ac-Ala3-NHMe/val'), metadata_path=PosixPath('/home/sanjeevr/MLFF-distill/data/md22/lmdb/Ac-Ala3-NHMe/val/metadata.npz')) test_dataset=FinetuneLmdbDatasetConfig(src=PosixPath('/home/sanjeevr/MLFF-distill/data/md22/lmdb/Ac-Ala3-NHMe/test'), metada

In [11]:
from jmp.lightning import Runner, Trainer
from jmp.utils.finetune_state_dict import (
    filter_state_dict,
    retreive_state_dict_for_finetuning,
)


def run(config: FinetuneConfigBase, model_cls: type[FinetuneModelBase]) -> None:
    if (ckpt_path := config.meta.get("ckpt_path")) is None:
        raise ValueError("No checkpoint path provided")

    model = model_cls(config)

    # Load the checkpoint
    state_dict = retreive_state_dict_for_finetuning(
        ckpt_path, load_emas=config.meta.get("ema_backbone", False)
    )
    embedding = filter_state_dict(state_dict, "embedding.atom_embedding.")
    backbone = filter_state_dict(state_dict, "backbone.")
    model.load_backbone_state_dict(backbone=backbone, embedding=embedding, strict=True)

    trainer = Trainer(config)
    trainer.fit(model)


runner = Runner(run)
runner.fast_dev_run(configs)

Fast dev run:   0%|          | 0/1 [00:00<?, ?it/s]

Failed to import lovely-tensors. Ignoring pretty PyTorch tensor formatting


Failed to import rich. Falling back to default Python logging.
CRITICAL:jmp.lightning.trainer.trainer:Setting config.trainer.default_root_dir='/home/sanjeevr/MLFF-distill/JMP/config/lightning_logs/l1gfnmct'.
Seed set to 0
CRITICAL:jmp.lightning.util.seed:Set global seed to 0.
CRITICAL:jmp.lightning.runner:Auto-wrapping run in Trainer context
CRITICAL:jmp.tasks.finetune.base:Using regular backbone


Unrecognized arguments:  dict_keys(['learnable_rbf', 'learnable_rbf_stds', 'unique_basis_per_layer', 'dropout', 'edge_dropout'])


CRITICAL:jmp.tasks.finetune.base:Freezing 0 parameters (0.00%) out of 160,874,752 total parameters (160,874,752 trainable)
CRITICAL:jmp.utils.finetune_state_dict:Loaded 405 EMA parameters
CRITICAL:jmp.utils.finetune_state_dict:Loaded state dict from /data/shared/ishan_stuff/jmp-s.pt
CRITICAL:jmp.utils.state_dict:pattern='out_blocks.0.scale_rbf_F.*' matched keys ['out_blocks.0.scale_rbf_F.scale_factor'], which were ignored during loading.
CRITICAL:jmp.utils.state_dict:pattern='out_blocks.0.seq_forces.*' matched keys ['out_blocks.0.seq_forces.0.dense_mlp.0.linear.weight', 'out_blocks.0.seq_forces.0.dense_mlp.1.linear.weight', 'out_blocks.0.seq_forces.1.dense_mlp.0.linear.weight', 'out_blocks.0.seq_forces.1.dense_mlp.1.linear.weight', 'out_blocks.0.seq_forces.2.dense_mlp.0.linear.weight', 'out_blocks.0.seq_forces.2.dense_mlp.1.linear.weight'], which were ignored during loading.
CRITICAL:jmp.utils.state_dict:pattern='out_blocks.0.dense_rbf_F.*' matched keys ['out_blocks.0.dense_rbf_F.linea

RuntimeError: Error(s) in loading state_dict for GemNetOCBackbone:
	size mismatch for bases.mlp_rbf_qint.linear.weight: copying a param with shape torch.Size([16, 128]) from checkpoint, the shape in current model is torch.Size([32, 128]).
	size mismatch for bases.mlp_sbf_qint.weight: copying a param with shape torch.Size([128, 49, 32]) from checkpoint, the shape in current model is torch.Size([128, 49, 64]).
	size mismatch for bases.mlp_rbf_aeint.linear.weight: copying a param with shape torch.Size([16, 128]) from checkpoint, the shape in current model is torch.Size([32, 128]).
	size mismatch for bases.mlp_rbf_eaint.linear.weight: copying a param with shape torch.Size([16, 128]) from checkpoint, the shape in current model is torch.Size([32, 128]).
	size mismatch for bases.mlp_rbf_aint.weight: copying a param with shape torch.Size([16, 128]) from checkpoint, the shape in current model is torch.Size([32, 128]).
	size mismatch for bases.mlp_rbf_tint.linear.weight: copying a param with shape torch.Size([16, 128]) from checkpoint, the shape in current model is torch.Size([32, 128]).
	size mismatch for bases.mlp_rbf_h.linear.weight: copying a param with shape torch.Size([16, 128]) from checkpoint, the shape in current model is torch.Size([32, 128]).
	size mismatch for bases.mlp_rbf_out.linear.weight: copying a param with shape torch.Size([16, 128]) from checkpoint, the shape in current model is torch.Size([32, 128]).
	size mismatch for bases.edge_emb.dense.linear.weight: copying a param with shape torch.Size([512, 640]) from checkpoint, the shape in current model is torch.Size([1024, 640]).
	size mismatch for int_blocks.0.dense_ca.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.0.trip_interaction.dense_ba.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.0.trip_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.0.trip_interaction.mlp_cbf.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([128, 1024]).
	size mismatch for int_blocks.0.trip_interaction.down_projection.linear.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 1024]).
	size mismatch for int_blocks.0.trip_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.0.trip_interaction.up_projection_ac.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.0.quad_interaction.dense_db.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.0.quad_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.0.quad_interaction.mlp_cbf.linear.weight: copying a param with shape torch.Size([32, 16]) from checkpoint, the shape in current model is torch.Size([64, 16]).
	size mismatch for int_blocks.0.quad_interaction.mlp_sbf.bilinear.linear.weight: copying a param with shape torch.Size([32, 1024]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for int_blocks.0.quad_interaction.down_projection.linear.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 1024]).
	size mismatch for int_blocks.0.quad_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.0.quad_interaction.up_projection_ac.linear.weight: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.0.atom_edge_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([256, 16]) from checkpoint, the shape in current model is torch.Size([256, 32]).
	size mismatch for int_blocks.0.atom_edge_interaction.mlp_cbf.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([128, 1024]).
	size mismatch for int_blocks.0.atom_edge_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.0.atom_edge_interaction.up_projection_ac.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.0.edge_atom_interaction.dense_ba.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.0.edge_atom_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.0.edge_atom_interaction.mlp_cbf.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([128, 1024]).
	size mismatch for int_blocks.0.edge_atom_interaction.down_projection.linear.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 1024]).
	size mismatch for int_blocks.0.edge_atom_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([256, 128]).
	size mismatch for int_blocks.0.atom_interaction.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([64, 2048]).
	size mismatch for int_blocks.0.layers_before_skip.0.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.0.layers_before_skip.0.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.0.layers_before_skip.1.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.0.layers_before_skip.1.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.0.layers_after_skip.0.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.0.layers_after_skip.0.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.0.layers_after_skip.1.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.0.layers_after_skip.1.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.0.atom_update.dense_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.0.atom_update.layers.0.linear.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for int_blocks.0.concat_layer.dense.linear.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([1024, 1536]).
	size mismatch for int_blocks.0.residual_m.0.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.0.residual_m.0.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.1.dense_ca.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.1.trip_interaction.dense_ba.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.1.trip_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.1.trip_interaction.mlp_cbf.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([128, 1024]).
	size mismatch for int_blocks.1.trip_interaction.down_projection.linear.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 1024]).
	size mismatch for int_blocks.1.trip_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.1.trip_interaction.up_projection_ac.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.1.quad_interaction.dense_db.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.1.quad_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.1.quad_interaction.mlp_cbf.linear.weight: copying a param with shape torch.Size([32, 16]) from checkpoint, the shape in current model is torch.Size([64, 16]).
	size mismatch for int_blocks.1.quad_interaction.mlp_sbf.bilinear.linear.weight: copying a param with shape torch.Size([32, 1024]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for int_blocks.1.quad_interaction.down_projection.linear.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 1024]).
	size mismatch for int_blocks.1.quad_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.1.quad_interaction.up_projection_ac.linear.weight: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.1.atom_edge_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([256, 16]) from checkpoint, the shape in current model is torch.Size([256, 32]).
	size mismatch for int_blocks.1.atom_edge_interaction.mlp_cbf.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([128, 1024]).
	size mismatch for int_blocks.1.atom_edge_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.1.atom_edge_interaction.up_projection_ac.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.1.edge_atom_interaction.dense_ba.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.1.edge_atom_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.1.edge_atom_interaction.mlp_cbf.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([128, 1024]).
	size mismatch for int_blocks.1.edge_atom_interaction.down_projection.linear.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 1024]).
	size mismatch for int_blocks.1.edge_atom_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([256, 128]).
	size mismatch for int_blocks.1.atom_interaction.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([64, 2048]).
	size mismatch for int_blocks.1.layers_before_skip.0.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.1.layers_before_skip.0.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.1.layers_before_skip.1.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.1.layers_before_skip.1.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.1.layers_after_skip.0.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.1.layers_after_skip.0.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.1.layers_after_skip.1.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.1.layers_after_skip.1.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.1.atom_update.dense_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.1.atom_update.layers.0.linear.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for int_blocks.1.concat_layer.dense.linear.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([1024, 1536]).
	size mismatch for int_blocks.1.residual_m.0.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.1.residual_m.0.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.2.dense_ca.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.2.trip_interaction.dense_ba.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.2.trip_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.2.trip_interaction.mlp_cbf.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([128, 1024]).
	size mismatch for int_blocks.2.trip_interaction.down_projection.linear.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 1024]).
	size mismatch for int_blocks.2.trip_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.2.trip_interaction.up_projection_ac.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.2.quad_interaction.dense_db.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.2.quad_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.2.quad_interaction.mlp_cbf.linear.weight: copying a param with shape torch.Size([32, 16]) from checkpoint, the shape in current model is torch.Size([64, 16]).
	size mismatch for int_blocks.2.quad_interaction.mlp_sbf.bilinear.linear.weight: copying a param with shape torch.Size([32, 1024]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for int_blocks.2.quad_interaction.down_projection.linear.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 1024]).
	size mismatch for int_blocks.2.quad_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.2.quad_interaction.up_projection_ac.linear.weight: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.2.atom_edge_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([256, 16]) from checkpoint, the shape in current model is torch.Size([256, 32]).
	size mismatch for int_blocks.2.atom_edge_interaction.mlp_cbf.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([128, 1024]).
	size mismatch for int_blocks.2.atom_edge_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.2.atom_edge_interaction.up_projection_ac.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.2.edge_atom_interaction.dense_ba.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.2.edge_atom_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.2.edge_atom_interaction.mlp_cbf.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([128, 1024]).
	size mismatch for int_blocks.2.edge_atom_interaction.down_projection.linear.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 1024]).
	size mismatch for int_blocks.2.edge_atom_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([256, 128]).
	size mismatch for int_blocks.2.atom_interaction.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([64, 2048]).
	size mismatch for int_blocks.2.layers_before_skip.0.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.2.layers_before_skip.0.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.2.layers_before_skip.1.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.2.layers_before_skip.1.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.2.layers_after_skip.0.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.2.layers_after_skip.0.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.2.layers_after_skip.1.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.2.layers_after_skip.1.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.2.atom_update.dense_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.2.atom_update.layers.0.linear.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for int_blocks.2.concat_layer.dense.linear.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([1024, 1536]).
	size mismatch for int_blocks.2.residual_m.0.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.2.residual_m.0.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.3.dense_ca.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.3.trip_interaction.dense_ba.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.3.trip_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.3.trip_interaction.mlp_cbf.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([128, 1024]).
	size mismatch for int_blocks.3.trip_interaction.down_projection.linear.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 1024]).
	size mismatch for int_blocks.3.trip_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.3.trip_interaction.up_projection_ac.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.3.quad_interaction.dense_db.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.3.quad_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.3.quad_interaction.mlp_cbf.linear.weight: copying a param with shape torch.Size([32, 16]) from checkpoint, the shape in current model is torch.Size([64, 16]).
	size mismatch for int_blocks.3.quad_interaction.mlp_sbf.bilinear.linear.weight: copying a param with shape torch.Size([32, 1024]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for int_blocks.3.quad_interaction.down_projection.linear.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 1024]).
	size mismatch for int_blocks.3.quad_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.3.quad_interaction.up_projection_ac.linear.weight: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.3.atom_edge_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([256, 16]) from checkpoint, the shape in current model is torch.Size([256, 32]).
	size mismatch for int_blocks.3.atom_edge_interaction.mlp_cbf.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([128, 1024]).
	size mismatch for int_blocks.3.atom_edge_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.3.atom_edge_interaction.up_projection_ac.linear.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([1024, 128]).
	size mismatch for int_blocks.3.edge_atom_interaction.dense_ba.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.3.edge_atom_interaction.mlp_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.3.edge_atom_interaction.mlp_cbf.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([128, 1024]).
	size mismatch for int_blocks.3.edge_atom_interaction.down_projection.linear.weight: copying a param with shape torch.Size([64, 512]) from checkpoint, the shape in current model is torch.Size([64, 1024]).
	size mismatch for int_blocks.3.edge_atom_interaction.up_projection_ca.linear.weight: copying a param with shape torch.Size([256, 64]) from checkpoint, the shape in current model is torch.Size([256, 128]).
	size mismatch for int_blocks.3.atom_interaction.bilinear.linear.weight: copying a param with shape torch.Size([64, 1024]) from checkpoint, the shape in current model is torch.Size([64, 2048]).
	size mismatch for int_blocks.3.layers_before_skip.0.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.3.layers_before_skip.0.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.3.layers_before_skip.1.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.3.layers_before_skip.1.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.3.layers_after_skip.0.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.3.layers_after_skip.0.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.3.layers_after_skip.1.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.3.layers_after_skip.1.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.3.atom_update.dense_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for int_blocks.3.atom_update.layers.0.linear.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for int_blocks.3.concat_layer.dense.linear.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([1024, 1536]).
	size mismatch for int_blocks.3.residual_m.0.dense_mlp.0.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for int_blocks.3.residual_m.0.dense_mlp.1.linear.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
	size mismatch for out_blocks.0.dense_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for out_blocks.0.layers.0.linear.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for out_blocks.0.seq_energy_pre.0.linear.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for out_blocks.1.dense_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for out_blocks.1.layers.0.linear.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for out_blocks.1.seq_energy_pre.0.linear.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for out_blocks.2.dense_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for out_blocks.2.layers.0.linear.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for out_blocks.2.seq_energy_pre.0.linear.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for out_blocks.3.dense_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for out_blocks.3.layers.0.linear.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for out_blocks.3.seq_energy_pre.0.linear.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for out_blocks.4.dense_rbf.linear.weight: copying a param with shape torch.Size([512, 16]) from checkpoint, the shape in current model is torch.Size([1024, 32]).
	size mismatch for out_blocks.4.layers.0.linear.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for out_blocks.4.seq_energy_pre.0.linear.weight: copying a param with shape torch.Size([256, 512]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for out_mlp_E.out_mlp.0.linear.weight: copying a param with shape torch.Size([256, 1280]) from checkpoint, the shape in current model is torch.Size([256, 1792]).