-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #86 from nlesc-nano/dev
add interface to scikit regressors
- Loading branch information
Showing
7 changed files
with
164 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = '0.5.0' | ||
__version__ = '0.6.0' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
from .modeller import Modeller | ||
_all__ = ["Modeller"] | ||
from .scikit_modeller import SKModeller | ||
|
||
_all__ = ["Modeller", "SKModeller"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
"""Module to create statistical models using scikit learn.""" | ||
|
||
import logging | ||
import pickle | ||
from pathlib import Path | ||
from typing import Optional, Tuple, Union | ||
|
||
import numpy as np | ||
from sklearn import gaussian_process, svm, tree | ||
|
||
from ..dataset.fingerprints_data import FingerprintsData | ||
|
||
PathLike = Union[str, Path] | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class SKModeller: | ||
"""Create statistical models using the scikit learn library.""" | ||
|
||
def __init__(self, data: FingerprintsData, name: str, **kwargs): | ||
"""Class constructor. | ||
Parameters | ||
---------- | ||
data | ||
FingerprintsData object containing the dataset | ||
name | ||
scikit learn model to use | ||
""" | ||
self.fingerprints = data.fingerprints.numpy() | ||
self.labels = data.dataset.labels.numpy() | ||
self.path_model = "swan_skmodeller.pkl" | ||
|
||
supported_models = { | ||
"decisiontree": tree.DecisionTreeRegressor, | ||
"svm": svm.SVR, | ||
"gaussianprocess": gaussian_process.GaussianProcessRegressor | ||
} | ||
|
||
if name.lower() in supported_models: | ||
self.model = supported_models[name.lower()](**kwargs) | ||
else: | ||
raise RuntimeError(f"There is not model name: {name}") | ||
|
||
LOGGER.info(f"Created {name} model") | ||
|
||
def split_data(self, frac: Tuple[float, float]): | ||
"""Split the data into a training and validation set. | ||
Parameters | ||
---------- | ||
frac | ||
fraction to divide the dataset, by default [0.8, 0.2] | ||
""" | ||
# Generate random indices to train and validate the model | ||
size = len(self.fingerprints) | ||
indices = np.arange(size) | ||
np.random.shuffle(indices) | ||
|
||
ntrain = int(size * frac[0]) | ||
self.features_trainset = self.fingerprints[indices[:ntrain]] | ||
self.features_validset = self.fingerprints[indices[ntrain:]] | ||
self.labels_trainset = self.labels[indices[:ntrain]] | ||
self.labels_validset = self.labels[indices[ntrain:]] | ||
|
||
def train_model(self, frac: Tuple[float, float] = (0.8, 0.2)): | ||
"""Train the model using the given data. | ||
Parameters | ||
---------- | ||
frac | ||
fraction to divide the dataset, by default [0.8, 0.2] | ||
""" | ||
self.split_data(frac) | ||
self.model.fit(self.features_trainset, self.labels_trainset) | ||
self.save_model() | ||
|
||
def save_model(self): | ||
"""Store the trained model.""" | ||
with open(self.path_model, 'wb') as handler: | ||
pickle.dump(self.model, handler) | ||
|
||
def validate_model(self) -> Tuple[np.ndarray, np.ndarray]: | ||
"""Check the model prediction power.""" | ||
predicted = self.model.predict(self.features_validset) | ||
expected = self.labels_validset | ||
score = self.model.score(self.features_validset, expected) | ||
LOGGER.info(f"Validation R^2 score: {score}") | ||
return predicted, expected | ||
|
||
def load_model(self, path_model: Optional[PathLike]) -> None: | ||
"""Load the model from the state file.""" | ||
path_model = self.path_model if path_model is None else path_model | ||
with open(path_model, 'rb') as handler: | ||
self.model = pickle.load(handler) | ||
|
||
def predict(self, inp_data: np.ndarray) -> np.ndarray: | ||
"""Used the previously trained model to predict properties. | ||
Parameters | ||
---------- | ||
inp_data | ||
Matrix containing a given fingerprint for each row | ||
Returns | ||
------- | ||
Array containing the predicted results | ||
""" | ||
return self.model.predict(inp_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import numpy as np | ||
from scipy import stats | ||
from sklearn.gaussian_process.kernels import ConstantKernel | ||
|
||
from swan.dataset import FingerprintsData | ||
from swan.modeller import SKModeller | ||
|
||
from .utils_test import PATH_TEST | ||
|
||
DATA = FingerprintsData(PATH_TEST / "thousand.csv", properties=["gammas"], sanitize=False) | ||
DATA.scale_labels() | ||
|
||
|
||
def run_test(model: str, **kwargs): | ||
"""Run the training and validation step for the given model.""" | ||
modeller = SKModeller(DATA, model) | ||
modeller.train_model() | ||
predicted, expected = modeller.validate_model() | ||
reg = stats.linregress(predicted.flatten(), expected.flatten()) | ||
assert not np.isnan(reg.rvalue) | ||
|
||
|
||
def test_decision_tree(): | ||
"""Check the interface to the Decisiontree class.""" | ||
run_test("decisiontree") | ||
|
||
|
||
def test_svm(): | ||
"""Check the interface to the support vector machine.""" | ||
run_test("svm") | ||
|
||
|
||
def test_gaussian_process(): | ||
"""Check the interface to the support vector machine.""" | ||
kernel = ConstantKernel(constant_value=10) | ||
run_test("gaussianprocess", kernel=kernel) |