-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #89 from nlesc-nano/dev
Add state to store the training results
- Loading branch information
Showing
27 changed files
with
343 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,7 @@ swan_models | |
|
||
# data files | ||
*.hdf5 | ||
swan_state.h5 | ||
|
||
# Pytorch | ||
*.pt | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,11 @@ | ||
"""Swan API.""" | ||
from .__version__ import __version__ | ||
|
||
from .modeller import Modeller, SKModeller | ||
from swan.dataset import TorchGeometricGraphData, FingerprintsData, DGLGraphData | ||
from .modeller.models import FingerprintFullyConnected, MPNN, SE3Transformer | ||
|
||
__all__ = [ | ||
"__version__", "Modeller", "SKModeller", | ||
"TorchGeometricGraphData", "FingerprintsData", "DGLGraphData", | ||
"FingerprintFullyConnected", "MPNN", "SE3Transformer"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
|
||
import abc | ||
from typing import Generic, Optional, Tuple, TypeVar, Union | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from ..dataset.swan_data_base import SwanDataBase | ||
from ..state import StateH5 | ||
from ..type_hints import PathLike | ||
|
||
# `bound` preserves all sub-type information, which might be useful | ||
T_co = TypeVar('T_co', bound=Union[np.ndarray, torch.Tensor], covariant=True) | ||
|
||
|
||
class BaseModeller(Generic[T_co]): | ||
"""Base class for the modellers.""" | ||
|
||
def __init__(self, data: SwanDataBase, replace_state: bool) -> None: | ||
self.state = StateH5(replace_state=replace_state) | ||
self.smiles = data.dataframe.smiles.to_numpy() | ||
|
||
@abc.abstractmethod | ||
def train_model(self, frac: Tuple[float, float] = (0.8, 0.2), **kwargs): | ||
"""Train the model using the given data. | ||
Parameters | ||
---------- | ||
frac | ||
fraction to divide the dataset, by default [0.8, 0.2] | ||
""" | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
def validate_model(self) -> Tuple[T_co, T_co]: | ||
"""compute the output of the model on the validation set | ||
Returns | ||
------- | ||
output of the network, ground truth of the data | ||
""" | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
def predict(self, inp_data: T_co) -> T_co: | ||
"""compute output of the model for a given input | ||
Parameters | ||
---------- | ||
inp_data | ||
input data of the network | ||
Returns | ||
------- | ||
Tensor | ||
output of the network | ||
""" | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
def load_model(self, path_model: Optional[PathLike]) -> None: | ||
"""Load the model from the Network file.""" | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
def save_model(self, *args, **kwargs): | ||
"""Store the trained model.""" | ||
raise NotImplementedError |
Oops, something went wrong.