Skip to content

Commit

Permalink
Improved memory management in Gaussian class.
Browse files Browse the repository at this point in the history
- `dask.persist` is used instead of `dask.compute` and this makes
   a huge improvement in memory consumption when building Gaussian
   features.
- Black beautification.
- Changed way convergence criteria is checked on merger module.
  • Loading branch information
muammar committed Sep 28, 2019
1 parent 10b859f commit 6b98ffa
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 14 deletions.
4 changes: 4 additions & 0 deletions ml4chem/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def fit(self, stacked_features, scheduler):
Scaled features using requested preprocessor.
"""

logger.info("Scaling features...")
if isinstance(stacked_features, np.ndarray):
# The Normalizer() is not supported by dask_ml.
self.preprocessor.fit(stacked_features)
Expand All @@ -139,6 +140,9 @@ def fit(self, stacked_features, scheduler):
stacked_features.compute(scheduler=scheduler)
)

logger.info("Finished scaling features.")
logger.info("")

return scaled_features

def transform(self, raw_features):
Expand Down
33 changes: 24 additions & 9 deletions ml4chem/fingerprints/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def calculate_features(self, images=None, purpose="training", data=None, svm=Fal

# We start populating computations to get atomic fingerprints.
logger.info("")
logger.info("Adding atomic feature calculations to scheduler...")
logger.info("Adding atomic feature calculations to computational graph...")

ini = end = 0

Expand Down Expand Up @@ -280,17 +280,30 @@ def calculate_features(self, images=None, purpose="training", data=None, svm=Fal
"... finished in {} hours {} minutes {:.2f}" " seconds.".format(h, m, s)
)

# In this block we compute the fingerprints.
logger.info("")
logger.info("Computing fingerprints...")
# In this block we compute the fingerprints.

stacked_features = dask.persist(*computations, scheduler=self.scheduler)

stacked_features = dask.compute(*computations, scheduler=self.scheduler)
# dask.distributed.wait(stacked_features)

if self.preprocessor is not None:
stacked_features = np.array(stacked_features)
logger.info("Adding Dask array construction to computational graph...")
symbol = data.unique_element_symbols[purpose][0]
sample = np.zeros(len(self.GP[symbol]))
dim = (len(stacked_features), len(sample))

stacked_features = [
dask.array.from_delayed(lazy, dtype=float, shape=sample.shape)
for lazy in stacked_features
]

stacked_features = (
dask.array.stack(stacked_features, axis=0).reshape(dim).rechunk(dim)
)

# Clean
del computations
# del computations

if purpose == "training":
# To take advantage of dask_ml we need to convert our numpy array
Expand All @@ -299,8 +312,7 @@ def calculate_features(self, images=None, purpose="training", data=None, svm=Fal

if self.preprocessor is not None:
scaled_feature_space = []
dim = stacked_features.shape
stacked_features = dask.array.from_array(stacked_features, chunks=dim)

stacked_features = preprocessor.fit(
stacked_features, scheduler=self.scheduler
)
Expand All @@ -324,6 +336,8 @@ def calculate_features(self, images=None, purpose="training", data=None, svm=Fal
)
feature_space.append(features)

# Clean
del computations
del stacked_features
computations = []

Expand Down Expand Up @@ -376,6 +390,7 @@ def calculate_features(self, images=None, purpose="training", data=None, svm=Fal
fp_time = time.time() - initial_time

h, m, s = convert_elapsed_time(fp_time)

logger.info(
"Fingerprinting finished in {} hours {} minutes {:.2f}"
" seconds.".format(h, m, s)
Expand Down Expand Up @@ -671,7 +686,7 @@ def get_atomic_fingerprint(
)
return symbol, fingerprint
else:
return fingerprint
return np.array(fingerprint)

def make_symmetry_functions(self, symbols, custom=None, angular_type="G3"):
"""Function to make symmetry functions
Expand Down
14 changes: 11 additions & 3 deletions ml4chem/models/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,18 @@ def train(

self.epochs = epochs

if isinstance(convergence["rmse"], float) or isinstance(convergence["rmse"], int):
convergence["rmse"] = np.array([convergence["rmse"] for model in range(len(self.models))])
# Convergence criterion
if isinstance(convergence["rmse"], float) or isinstance(
convergence["rmse"], int
):
convergence["rmse"] = np.array(
[convergence["rmse"] for model in range(len(self.models))]
)
elif isinstance(convergence["rmse"], list):
if len(convergence["rmse"]) != len(self.models):
raise("Your convergence list is not the same length of the number of models")
raise (
"Your convergence list is not the same length of the number of models"
)
convergence["rmse"] = np.array(convergence["rmse"])

logger.info(" ")
Expand All @@ -164,6 +171,7 @@ def train(
l.__name__, self.loss_weights[index]
)
)
logging.info("Convergence criterion: {}.".format(convergence))

# If no batch_size provided then the whole training set length is the batch.
if batch_size is None:
Expand Down
4 changes: 3 additions & 1 deletion ml4chem/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,9 @@ def calculate(self, atoms, properties, system_changes):
except:
raise ("This is not a database...")

energy = self.model.get_potential_energy(fingerprints, reference_space, purpose=purpose)
energy = self.model.get_potential_energy(
fingerprints, reference_space, purpose=purpose
)
else:
input_dimension = len(list(fingerprints.values())[0][0][-1])
model = copy.deepcopy(self.model)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
url="https://github.com/muammar/ml4chem",
packages=setuptools.find_packages(),
scripts=["bin/ml4chem"],
data_files = [("", ["LICENSE"])],
data_files=[("", ["LICENSE"])],
classifiers=[
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
Expand Down

0 comments on commit 6b98ffa

Please sign in to comment.