Skip to content

Commit

Permalink
Enforcing derived deep learning classes and general improvements.
Browse files Browse the repository at this point in the history
These changes are needed for ML4Chem publication.

- All deep learning models are now inheriting from DeepLearningModel
  base class.
- Visualization module moved from `data.visualization` to
 `.visualization`.
  • Loading branch information
muammar committed Jan 31, 2020
1 parent 10d913a commit 495539f
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion bin/ml4chem
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import sys

path = os.path.dirname(os.path.abspath(__file__)).strip("bin")
sys.path.append(path)
from ml4chem.data.visualization import read_log, plot_atomic_features
from ml4chem.visualization import read_log, plot_atomic_features


@click.command()
Expand Down
4 changes: 2 additions & 2 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ For more information please refer to :mod:`ml4chem.data.handler`.
Visualization
===================

We also offer a :mod:`ml4chem.data.visualization` module to plot interesting
We also offer a :mod:`ml4chem.visualization` module to plot interesting
graphics about your model, features, or even monitor the progress of the loss
function and error minimization.

Two backends are supported to plot in ML4Chem: Seaborn and Plotly.

An example is shown below::

from ml4chem.data.visualization import plot_atomic_features
from ml4chem.visualization import plot_atomic_features
fig = plot_atomic_features("latent_space.db",
method="pca",
dimensions=3,
Expand Down
4 changes: 2 additions & 2 deletions docs/source/ml4chem.data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ ml4chem.data.utils module
:undoc-members:
:show-inheritance:

ml4chem.data.visualization module
ml4chem.visualization module
---------------------------------

.. automodule:: ml4chem.data.visualization
.. automodule:: ml4chem.visualization
:members:
:undoc-members:
:show-inheritance:
Expand Down
7 changes: 5 additions & 2 deletions ml4chem/models/autoencoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from collections import OrderedDict
from ml4chem.metrics import compute_rmse
from ml4chem.models.base import DeepLearningModel
from ml4chem.models.loss import MSELoss
from ml4chem.optim.handler import get_optimizer, get_lr_scheduler
from ml4chem.utils import convert_elapsed_time, get_chunks, lod_to_list
Expand All @@ -16,7 +17,7 @@
logger = logging.getLogger()


class AutoEncoder(torch.nn.Module):
class AutoEncoder(DeepLearningModel):
"""Fully connected atomic autoencoder
Expand Down Expand Up @@ -67,7 +68,7 @@ def name(cls):
def __init__(
self, hiddenlayers=None, activation="relu", one_for_all=False, **kwargs
):
super(AutoEncoder, self).__init__()
super(DeepLearningModel, self).__init__()

self.hiddenlayers = hiddenlayers
self.activation = activation
Expand Down Expand Up @@ -956,6 +957,8 @@ def trainer(self):
if self.anneal:
annealing = annealer.update(epoch)
print(annealing)
else:
annealing = None

self.optimizer.zero_grad() # clear previous gradients

Expand Down
3 changes: 2 additions & 1 deletion ml4chem/models/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
import torch


class DeepLearningModel(ABC):
class DeepLearningModel(ABC, torch.nn.Module):
@abstractmethod
def name(cls):
"""Return name of the class"""
Expand Down
5 changes: 3 additions & 2 deletions ml4chem/models/neuralnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from collections import OrderedDict
from ml4chem.metrics import compute_rmse
from ml4chem.models.base import DeepLearningModel
from ml4chem.models.loss import AtomicMSELoss
from ml4chem.optim.handler import get_optimizer, get_lr_scheduler
from ml4chem.utils import convert_elapsed_time, get_chunks, get_number_of_parameters
Expand All @@ -17,7 +18,7 @@
logger = logging.getLogger()


class NeuralNetwork(torch.nn.Module):
class NeuralNetwork(DeepLearningModel):
"""Atom-centered Neural Network Regression with Pytorch
This model is based on Ref. 1 by Behler and Parrinello.
Expand Down Expand Up @@ -48,7 +49,7 @@ def name(cls):
return cls.NAME

def __init__(self, hiddenlayers=(3, 3), activation="relu", **kwargs):
super(NeuralNetwork, self).__init__()
super(DeepLearningModel, self).__init__()
self.hiddenlayers = hiddenlayers
self.activation = activation

Expand Down
2 changes: 1 addition & 1 deletion ml4chem/data/visualization.py → ml4chem/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def plot_atomic_features(
"""
method = method.lower()
backend = backend.lower()
dot_size = kwargs["dot_size"]
dot_size = kwargs.get("dot_size", 2)

supported_methods = ["pca", "tsne"]

Expand Down

0 comments on commit 495539f

Please sign in to comment.