Skip to content

Commit

Permalink
Merge pull request #89 from nlesc-nano/dev
Browse files Browse the repository at this point in the history
Add state to store the training results
  • Loading branch information
felipeZ committed Jun 1, 2021
2 parents c3f6c7c + 26f72a8 commit 0dbb0a8
Show file tree
Hide file tree
Showing 27 changed files with 343 additions and 68 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: install conda dependencies
env:
TORCH: "1.8"
run: conda install scipy rdkit pytorch==${TORCH} torchvision cpuonly dgl -c pytorch -c conda-forge -c dglteam
run: conda install h5py scipy rdkit pytorch==${TORCH} torchvision cpuonly dgl -c pytorch -c conda-forge -c dglteam

- name: install torch-geometric dependencies
env:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ swan_models

# data files
*.hdf5
swan_state.h5

# Pytorch
*.pt
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# 0.6.0 [Unreleased]
## New
* Add interface to scikit regressors (#85)
* Add interface to HDF5 to store the training results (#88)

## Changed
* Fix prediction functionality (#81)
Expand Down
19 changes: 8 additions & 11 deletions scripts/plot_means.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
#!/usr/bin/env python

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pandas.core.algorithms import mode
import seaborn as sns
import json
from pathlib import Path
Expand All @@ -16,19 +14,18 @@


def read_data():
ground_true = pd.read_csv(PATH_GROUND_TRUE, index_col=0)
ground_true.drop("smiles", axis=1, inplace=True)
results = {}
for m in MODELS:
results[m] = {}
for n in NSAMPLES:
df = pd.read_json(f"means_{m}_{n}.json")
new = (df - ground_true) ** 2
mse = new.sum() / len(df)
results[m][n] = mse.to_dict()
results[m][n] = df.sum().to_dict()

with open(MSE_FILE, 'w') as f:
json.dump(results, f, indent=4)
data = [pd.DataFrame(transpose_data(results[m])) for m in MODELS]
for df in data:
df.sort_index(inplace=True)
for df, model in zip(data, MODELS):
df.to_csv(f"{model}.csv")


def plot_data(model: str):
Expand Down Expand Up @@ -63,8 +60,8 @@ def transpose_data(data):


def main():
# read_data()
plot_data(MODELS[0])
read_data()
# plot_data(MODELS[0])


if __name__ == "__main__":
Expand Down
8 changes: 6 additions & 2 deletions scripts/run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import torch
from swan.dataset import TorchGeometricGraphData, FingerprintsData, DGLGraphData
from swan.modeller import Modeller
from swan.modeller.models import FingerprintFullyConnected, MPNN, InvariantPolynomial
from swan.modeller.models.se3_transformer import TFN, SE3Transformer
from swan.modeller.models import FingerprintFullyConnected, MPNN, SE3Transformer
from swan.utils.log_config import configure_logger
from swan.utils.plot import create_scatter_plot

Expand Down Expand Up @@ -77,8 +76,13 @@
researcher.data.scale_labels()
trained_data = researcher.train_model(nepoch=nepoch, batch_size=batch_size)
predicted_train, expected_train = [x.cpu().detach().numpy() for x in trained_data]
print("train regression")
create_scatter_plot(predicted_train, expected_train, properties, "trained_scatterplot")

# Print validation scatterplot
print("validation regression")
predicted_validation, expected_validation = [x.cpu().detach().numpy() for x in researcher.validate_model()]
create_scatter_plot(predicted_validation, expected_validation, properties, "validation_scatterplot")

print("properties stored in the HDF5")
researcher.state.show()
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ addopts = --cov --cov-report xml --cov-report term --cov-report html
source-dir = docs
build-dir = docs/_build
all_files = 1
builder = html
builder = html

[mypy]
plugins = numpy.typing.mypy_plugin
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
install_requires=[
'e3nn@git+https://github.com/e3nn/e3nn@main',
'equivariant_attention@git+https://github.com/nlesc-nano/se3-transformer-public@dev',
'mendeleev', 'numpy', 'pandas', 'pyyaml', 'scikit-learn',
'h5py', 'mendeleev', 'numpy', 'pandas', 'pyyaml', 'scikit-learn',
'scipy', 'seaborn', 'schema',
'torch-geometric'],

Expand Down
9 changes: 9 additions & 0 deletions swan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,11 @@
"""Swan API."""
from .__version__ import __version__

from .modeller import Modeller, SKModeller
from swan.dataset import TorchGeometricGraphData, FingerprintsData, DGLGraphData
from .modeller.models import FingerprintFullyConnected, MPNN, SE3Transformer

__all__ = [
"__version__", "Modeller", "SKModeller",
"TorchGeometricGraphData", "FingerprintsData", "DGLGraphData",
"FingerprintFullyConnected", "MPNN", "SE3Transformer"]
5 changes: 1 addition & 4 deletions swan/dataset/data_graph_base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
"""Base class for the Graph data representation."""

from pathlib import Path
from typing import List, Optional, Union


from .geometry import guess_positions
from .swan_data_base import SwanDataBase
from ..type_hints import PathLike


__all__ = ["SwanGraphData"]


PathLike = Union[str, Path]


class SwanGraphData(SwanDataBase):
"""Base class for the Data represented as graphs."""

Expand Down
5 changes: 1 addition & 4 deletions swan/dataset/dgl_graph_data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Interface to build a Dataset for DGL. see: https://www.dgl.ai/"""
from pathlib import Path
from typing import List, Optional, Tuple, Union

import torch
from torch.utils.data import DataLoader, Dataset

from .data_graph_base import SwanGraphData
from .graph.molecular_graph import create_molecular_dgl_graph
from ..type_hints import PathLike

try:
import dgl
Expand All @@ -16,9 +16,6 @@
__all__ = ["DGLGraphData"]


PathLike = Union[str, Path]


def collate_fn(samples):
"""Aggregate graphs."""
graphs, y = map(list, zip(*samples))
Expand Down
4 changes: 1 addition & 3 deletions swan/dataset/fingerprints_data.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""Module to process dataset."""
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union

import torch
from torch.utils.data import Dataset

from .features.featurizer import generate_fingerprints
from .swan_data_base import SwanDataBase

PathLike = Union[str, Path]
from ..type_hints import PathLike

__all__ = ["FingerprintsData"]

Expand Down
5 changes: 2 additions & 3 deletions swan/dataset/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import json
import multiprocessing
from functools import partial
from pathlib import Path
from typing import Collection, List, Tuple, Union
from typing import Collection, List, Tuple

import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem

PathLike = Union[str, Path]
from ..type_hints import PathLike


def read_geometries_from_files(file_geometries: PathLike) -> Tuple[List[Chem.rdchem.Mol], List[np.ndarray]]:
Expand Down
19 changes: 11 additions & 8 deletions swan/dataset/swan_data_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
import torch
from rdkit.Chem import PandasTools
from sklearn.preprocessing import RobustScaler
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.data import DataLoader, Dataset, Subset

from ..type_hints import PathLike
from .geometry import read_geometries_from_files
from .sanitize_data import sanitize_data

PathLike = Union[str, Path]

__all__ = ["SwanDataBase"]


Expand Down Expand Up @@ -114,8 +113,8 @@ def clean_dataframe(self, sanitize: bool) -> None:

def create_data_loader(self,
frac: Tuple[float, float] = (0.8, 0.2),
batch_size: int = 64) -> None:
"""create the train/valid data loaders
batch_size: int = 64) -> Tuple[np.ndarray, np.ndarray]:
"""create the train/valid data loaders using non-overlapping datasets.
Parameters
----------
Expand All @@ -126,17 +125,21 @@ def create_data_loader(self,
"""
ntotal = len(self.dataset)
ntrain = int(frac[0] * ntotal)
nvalid = ntotal - ntrain

self.train_dataset, self.valid_dataset = random_split(
self.dataset, [ntrain, nvalid])
indices = np.arange(ntotal)
np.random.shuffle(indices)

self.train_dataset = Subset(self.dataset, indices[:ntrain])
self.valid_dataset = Subset(self.dataset, indices[ntrain:])

self.train_loader = self.data_loader_fun(dataset=self.train_dataset,
batch_size=batch_size)

self.valid_loader = self.data_loader_fun(dataset=self.valid_dataset,
batch_size=batch_size)

return indices[:ntrain], indices[ntrain:]

def scale_labels(self) -> None:
"""Create a new column with the transformed target."""
self.labels = self.transformer.fit_transform(self.labels)
Expand Down
6 changes: 2 additions & 4 deletions swan/dataset/torch_geometric_graph_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@
.. autoclass:: TorchGeometricGraphData
"""
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union

import torch
import torch_geometric as tg
from torch_geometric.data import Data

from .graph.molecular_graph import create_molecular_torch_geometric_graph
from ..type_hints import PathLike
from .data_graph_base import SwanGraphData

PathLike = Union[str, Path]
from .graph.molecular_graph import create_molecular_torch_geometric_graph


class TorchGeometricGraphData(SwanGraphData):
Expand Down
68 changes: 68 additions & 0 deletions swan/modeller/base_modeller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@

import abc
from typing import Generic, Optional, Tuple, TypeVar, Union

import numpy as np
import torch

from ..dataset.swan_data_base import SwanDataBase
from ..state import StateH5
from ..type_hints import PathLike

# `bound` preserves all sub-type information, which might be useful
T_co = TypeVar('T_co', bound=Union[np.ndarray, torch.Tensor], covariant=True)


class BaseModeller(Generic[T_co]):
"""Base class for the modellers."""

def __init__(self, data: SwanDataBase, replace_state: bool) -> None:
self.state = StateH5(replace_state=replace_state)
self.smiles = data.dataframe.smiles.to_numpy()

@abc.abstractmethod
def train_model(self, frac: Tuple[float, float] = (0.8, 0.2), **kwargs):
"""Train the model using the given data.
Parameters
----------
frac
fraction to divide the dataset, by default [0.8, 0.2]
"""
raise NotImplementedError

@abc.abstractmethod
def validate_model(self) -> Tuple[T_co, T_co]:
"""compute the output of the model on the validation set
Returns
-------
output of the network, ground truth of the data
"""
raise NotImplementedError

@abc.abstractmethod
def predict(self, inp_data: T_co) -> T_co:
"""compute output of the model for a given input
Parameters
----------
inp_data
input data of the network
Returns
-------
Tensor
output of the network
"""
raise NotImplementedError

@abc.abstractmethod
def load_model(self, path_model: Optional[PathLike]) -> None:
"""Load the model from the Network file."""
raise NotImplementedError

@abc.abstractmethod
def save_model(self, *args, **kwargs):
"""Store the trained model."""
raise NotImplementedError

0 comments on commit 0dbb0a8

Please sign in to comment.