Skip to content

Commit

Permalink
3D interactive plotting added to visualization module
Browse files Browse the repository at this point in the history
- This new change helps visualizing 3D latent spaces.
- README file was updated.
- Added more information about usage of plot_atomic_features() function,
  and an example image.
  • Loading branch information
muammar committed Oct 23, 2019
1 parent 5fd13d5 commit 459146d
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 47 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
![alt text](https://raw.githubusercontent.com/muammar/ml4chem/master/docs/source/_static/ml4chem.png "Logo")

--------------------------------------------------------------------------------

## About
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/Django.svg)](https://github.com/muammar/mkchromecast/)
[![Build Status](https://travis-ci.com/muammar/ml4chem.svg?branch=master)](https://travis-ci.com/muammar/ml4chem)
[![License](https://img.shields.io/badge/license-BSD-green)](https://github.com/muammar/ml4chem/blob/master/LICENSE)
[![Downloads](https://img.shields.io/github/downloads/muammar/ml4chem/total.svg?maxAge=2592000?style=flat-square)](https://github.com/muammar/ml4chem/releases)
[![GitHub release](https://img.shields.io/github/release/muammar/ml4chem.svg)](https://github.com/muammar/ml4chem/releases/latest)

## About


This package is written in Python 3, and intends to offer modern and rich
features to perform machine learning workflows for chemical physics.
Expand Down
49 changes: 30 additions & 19 deletions bin/ml4chem
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,44 @@ import time
import click
import os
import sys
path = os.path.dirname(os.path.abspath(__file__)).strip('bin')

path = os.path.dirname(os.path.abspath(__file__)).strip("bin")
sys.path.append(path)
from ml4chem.data.visualization import read_log, plot_atomic_features

@click.command()
@click.option('--plot', default=None, help='Plot information from file.')
@click.option('--file', default=None, help='Path to log file or database.')
@click.option('--refresh', default=None, type=float,
help='Useful for sleeping before reading log files.')

@click.command()
@click.option("--plot", default=None, help="Plot information from file.")
@click.option(
"--backend",
default="seaborn",
help='Select backed to plot, supported "plotly", and "seaborn". Default is "searborn".',
)
@click.option("--file", default=None, help="Path to log file or database.")
@click.option(
"--refresh",
default=None,
type=float,
help="Useful for sleeping before reading log files.",
)
def main(**args):
"""ML4Chem command line tool"""
training_plots = ['rmse', 'loss']
dim_visualization = ['pca', 'tsne']
training_plots = ["rmse", "loss"]
dim_visualization = ["pca", "tsne"]

_file = args["file"]

_file = args['file']
if args["plot"] is not None:
if args["plot"] in training_plots:
metric = args["plot"]
refresh = args["refresh"]
plt = read_log(_file, metric=metric, refresh=refresh)

if args['plot'] is not None:
if args['plot'] in training_plots:
metric = args['plot']
refresh = args['refresh']
plt = read_log(_file, metric=metric,
refresh=refresh)
elif args["plot"].lower() in dim_visualization:
method = args["plot"]
backend = args["backend"]
plt = plot_atomic_features(_file, method=method, backend=backend)

elif args['plot'].lower() in dim_visualization:
method = args['plot']
plt = plot_atomic_features(_file, method=method)

if __name__ == '__main__':
if __name__ == "__main__":
main()
Binary file added docs/source/_static/tsne_visual.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 19 additions & 3 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ For more information please refer to :mod:`ml4chem.data.handler`.
Visualization
===================

We also offer a :mod:`ml4chem.data.visualization` module to plot some
interesting graphics about your model, or even monitor the progress of the
loss function minimization.
We also offer a :mod:`ml4chem.data.visualization` module to plot interesting
graphics about your model, features, or even monitor the progress of the loss
function and error minimization.

Two backends are supported to plot in ML4Chem: Seaborn and Plotly.

An example is shown below::

from ml4chem.data.visualization import plot_atomic_features
plot_atomic_features("latent_space.db",
method="tsne",
dimensions=3,
backend="plotly")

This will produce an interactive plot with plotly where dimensionality was
reduced using T-SNE.

.. image:: https://raw.githubusercontent.com/muammar/ml4chem/master/docs/source/_static/tsne_visual.png

8 changes: 0 additions & 8 deletions docs/source/ml4chem.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@ Subpackages
Submodules
----------

ml4chem.active module
---------------------

.. automodule:: ml4chem.active
:members:
:undoc-members:
:show-inheritance:

ml4chem.metrics module
----------------------

Expand Down
91 changes: 75 additions & 16 deletions ml4chem/data/visualization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from ml4chem.data.serialization import load
import time
Expand All @@ -12,19 +13,19 @@ def parity(predictions, true, scores=False, filename=None, **kwargs):
Parameters
----------
predictions : list or numpy.array
predictions : list or ndarray
Model predictions in a list.
true : list or numpy.array
true : list or ndarray
Targets or true values.
scores : bool
Print scores in parity plot.
filename : str
A name to save the plot to a file. If filename is non exisntent, we
A name to save the plot to a file. If filename is non existent, we
call plt.show().
Notes
-----
kargs accepts all valid keyword arguments for matplotlib.pyplot.savefig.
kwargs accepts all valid keyword arguments for matplotlib.pyplot.savefig.
"""

min_val = min(true)
Expand Down Expand Up @@ -173,7 +174,7 @@ def read_log(logfile, metric="loss", refresh=None):
plt.show(block=True)


def plot_atomic_features(latent_space, method="PCA", dimensions=2):
def plot_atomic_features(latent_space, method="PCA", dimensions=3, backend="seaborn"):
"""Plot high dimensional atomic feature vectors
This function can take a feature space dictionary, or a database file
Expand All @@ -191,9 +192,24 @@ def plot_atomic_features(latent_space, method="PCA", dimensions=2):
dimensions : int, optional
Number of dimensions to reduce the high dimensional atomic feature
vectors, by default 2.
backend : str, optional
Select the backend to plot features. Supported are "plotly" and
"seaborn", by default "plotly".
"""

method = method.lower()
backend = backend.lower()

if backend == "seaborn":
# This hack is needed because it seems plotly import overwrite everything.
import matplotlib.pyplot as plt

axis = ["x", "y", "z"]

if dimensions > 3:
raise NotImplementedError
elif dimensions == 2:
axis.pop(-1)

if isinstance(latent_space, str):
latent_space = load(latent_space)

Expand All @@ -220,37 +236,80 @@ def plot_atomic_features(latent_space, method="PCA", dimensions=2):
if method == "pca":
from sklearn.decomposition import PCA

labels = {"x": "PCA-1", "y": "PCA-2"}
labels = {str(axis[i]): "PCA-{}".format(i + 1) for i in range(len(axis))}
pca = PCA(n_components=dimensions)
pca_result = pca.fit_transform(full_ls)

to_pandas = []

entry = []
for i, element in enumerate(pca_result):
to_pandas.append([full_symbols[i], element[0], element[1]])
entry = [full_symbols[i]]
for d in range(dimensions):
entry.append(element[d])
to_pandas.append(entry)

columns = ["Symbol"]
args = {}

columns = ["Symbol", "PCA-1", "PCA-2"]
for key in axis:
columns.append(labels[key])
args[key] = labels[key]

df = pd.DataFrame(to_pandas, columns=columns)
sns.scatterplot(**labels, data=df, hue="Symbol")

if dimensions == 3 and backend == "plotly":
args["color"] = "Symbol"
plt = px.scatter_3d(df, **args)
plt.update_traces(marker=dict(size=4))
elif dimensions == 2 and backend == "plotly":
args["color"] = "Symbol"
plt = px.scatter(df, **args)
elif dimensions == 3 and backend == "seaborn":
raise ("This backend is for 2D visualization")
elif dimensions == 2 and backend == "seaborn":
sns.scatterplot(**labels, data=df, hue="Symbol")

elif method == "tsne":
from sklearn import manifold

labels = {"x": "t-SNE-1", "y": "t-SNE-2"}
labels = {str(axis[i]): "t-SNE-{}".format(i + 1) for i in range(len(axis))}

tsne = manifold.TSNE(n_components=dimensions)

tsne_result = tsne.fit_transform(full_ls)

to_pandas = []

entry = []
for i, element in enumerate(tsne_result):
to_pandas.append([full_symbols[i], element[0], element[1]])
entry = [full_symbols[i]]
for d in range(dimensions):
entry.append(element[d])
to_pandas.append(entry)

columns = ["Symbol", "t-SNE-1", "t-SNE-2"]
columns = ["Symbol"]
args = {}

for key in axis:
columns.append(labels[key])
args[key] = labels[key]

df = pd.DataFrame(to_pandas, columns=columns)
sns.scatterplot(**labels, data=df, hue="Symbol")

plt.show()
if dimensions == 3 and backend == "plotly":
args["color"] = "Symbol"
plt = px.scatter_3d(df, **args)
plt.update_traces(marker=dict(size=4))
elif dimensions == 2 and backend == "plotly":
args["color"] = "Symbol"
plt = px.scatter(df, **args)
elif dimensions == 3 and backend == "seaborn":
raise ("This backend is for 2D visualization")
elif dimensions == 2 and backend == "seaborn":
sns.scatterplot(**labels, data=df, hue="Symbol")

try:
plt.show()
except:
pass

0 comments on commit 459146d

Please sign in to comment.