Skip to content

Commit

Permalink
Merge pull request #113 from luigibonati/checkpoint
Browse files Browse the repository at this point in the history
Fix #103: CVs cannot be loaded from checkpoint
  • Loading branch information
luigibonati committed Dec 22, 2023
2 parents 9e35e07 + 59988fc commit c6381fd
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 1 deletion.
5 changes: 4 additions & 1 deletion mlcolvar/cvs/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ def __init__(
"""
super().__init__(*args, **kwargs)
self.save_hyperparameters()

# The parent class sets in_features and out_features based on their own
# init arguments so we don't need to save them here (see #103).
self.save_hyperparameters(ignore=['in_features', 'out_features'])

# MODEL
self.initialize_blocks()
Expand Down
100 changes: 100 additions & 0 deletions mlcolvar/tests/test_cvs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python


# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Shared tests for the objects and functions in the mlcolvar.cvs package.
"""


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

import os
import tempfile

import lightning
import pytest
import torch

import mlcolvar.cvs
from mlcolvar.data import DictDataset, DictModule


# =============================================================================
# GLOBAL VARIABLES
# =============================================================================

N_STATES = 2
N_DESCRIPTORS = 15
LAYERS = [N_DESCRIPTORS, 5, 5, N_STATES-1]

# =============================================================================
# FIXTURES
# =============================================================================

@pytest.fixture(scope="module")
def dataset():
"""Dataset with all fields required by all CV types."""
n_samples = 10

# Weights should be optional so we don't add them.
data = {
"data": torch.randn((n_samples, N_DESCRIPTORS)),
"data_lag": torch.randn((n_samples, N_DESCRIPTORS)),
"target": torch.randn(n_samples),
"weights": torch.rand(n_samples),
"weights_lag": torch.rand(n_samples),
}

# With sequential sampling, this make sure that all labels are represented
# in the validation and training set so that LDA/TDA don't complain.
labels = torch.arange(N_STATES, dtype=torch.get_default_dtype())
data["labels"] = labels.repeat(n_samples // N_STATES + 1)[:n_samples]

return DictDataset(data)


# =============================================================================
# TESTS
# =============================================================================

@pytest.mark.parametrize("cv_model", [
mlcolvar.cvs.DeepLDA(layers=LAYERS, n_states=N_STATES),
mlcolvar.cvs.DeepTDA(n_states=N_STATES, n_cvs=1, target_centers=[-1., 1.], target_sigmas=[0.1, 0.1], layers=LAYERS),
mlcolvar.cvs.RegressionCV(layers=LAYERS),
mlcolvar.cvs.DeepTICA(layers=LAYERS, n_cvs=1),
mlcolvar.cvs.AutoEncoderCV(encoder_layers=LAYERS),
mlcolvar.cvs.VariationalAutoEncoderCV(n_cvs=1, encoder_layers=LAYERS[:-1]),
])
def test_resume_from_checkpoint(cv_model, dataset):
"""CVs correctly resume from a checkpoint."""
datamodule = DictModule(dataset, lengths=[1.0,0.], batch_size=len(dataset))

# Run a few steps of training in a temporary directory.
with tempfile.TemporaryDirectory() as tmp_dir_path:
# Simulate a couple of epochs of training.
trainer = lightning.Trainer(
max_epochs=2,
enable_checkpointing=True,
logger=False,
enable_progress_bar=False,
enable_model_summary=False,
default_root_dir=tmp_dir_path,
)
trainer.fit(cv_model, datamodule)

# Now load from checkpoint.
file_name = 'epoch={}-step={}.ckpt'.format(trainer.current_epoch-1, trainer.global_step)
checkpoint_file_path = os.path.join(tmp_dir_path, 'checkpoints', file_name)
cv_model2 = cv_model.__class__.load_from_checkpoint(checkpoint_file_path)

# Check that state is the same.
x = dataset['data']
cv_model.eval()
cv_model2.eval()
assert torch.allclose(cv_model(x), cv_model2(x))

0 comments on commit c6381fd

Please sign in to comment.