Skip to content

Commit

Permalink
First draft trainer framework
Browse files Browse the repository at this point in the history
  • Loading branch information
ipcamit committed Mar 4, 2024
1 parent 86ab5d2 commit 2f57f56
Show file tree
Hide file tree
Showing 11 changed files with 1,163 additions and 181 deletions.
13 changes: 13 additions & 0 deletions kliff/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
This module contains exceptions to be raised in kliff modules, along with details on
where they are raised.
"""


class TrainerError(Exception):

Check warning on line 7 in kliff/_exceptions.py

View check run for this annotation

Codecov / codecov/patch

kliff/_exceptions.py#L7

Added line #L7 was not covered by tests
"""
Exceptions to be raised in Trainer and associated classes.
"""

def __init__(self, message):
super().__init__(message)

Check warning on line 13 in kliff/_exceptions.py

View check run for this annotation

Codecov / codecov/patch

kliff/_exceptions.py#L12-L13

Added lines #L12 - L13 were not covered by tests
229 changes: 203 additions & 26 deletions kliff/dataset/dataset.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions kliff/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .kim_trainer import KIMTrainer
from .kliff_trainer import Trainer

Check warning on line 2 in kliff/trainer/__init__.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/__init__.py#L1-L2

Added lines #L1 - L2 were not covered by tests
19 changes: 19 additions & 0 deletions kliff/trainer/kim_residuals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Any, Dict

Check warning on line 1 in kliff/trainer/kim_residuals.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/kim_residuals.py#L1

Added line #L1 was not covered by tests

import numpy as np

Check warning on line 3 in kliff/trainer/kim_residuals.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/kim_residuals.py#L3

Added line #L3 was not covered by tests


def MSE_residuals(

Check warning on line 6 in kliff/trainer/kim_residuals.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/kim_residuals.py#L6

Added line #L6 was not covered by tests
predictions: np.ndarray,
targets: np.ndarray,
) -> np.ndarray:
r"""
Compute the mean squared error (MSE) of the residuals.
Args:
Returns:
The MSE of the residuals.
"""
residuals = predictions - targets
return np.mean(residuals**2)

Check warning on line 19 in kliff/trainer/kim_residuals.py

View check run for this annotation

Codecov / codecov/patch

kliff/trainer/kim_residuals.py#L18-L19

Added lines #L18 - L19 were not covered by tests
Loading

0 comments on commit 2f57f56

Please sign in to comment.