Skip to content

Commit

Permalink
New backend_kwargs keyword argument for plot_atomic_features.
Browse files Browse the repository at this point in the history
  • Loading branch information
muammar committed Feb 20, 2020
1 parent 000aba9 commit 8a04838
Showing 1 changed file with 34 additions and 9 deletions.
43 changes: 34 additions & 9 deletions ml4chem/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def read_log(logfile, metric="loss", refresh=None, data_only=False):
columns = ["epochs", "loss", "training", "test"]
arr = [epochs, loss, training, test]

if metric != combined:
if metric != "combined":
columns.pop(-1)
arr.pop(-1)

Expand All @@ -218,7 +218,13 @@ def read_log(logfile, metric="loss", refresh=None, data_only=False):


def plot_atomic_features(
latent_space, method="PCA", dimensions=2, backend="seaborn", **kwargs
latent_space,
method="PCA",
dimensions=2,
backend="seaborn",
data_only=False,
backend_kwargs=None,
**kwargs
):
"""Plot high dimensional atomic feature vectors
Expand All @@ -240,7 +246,22 @@ def plot_atomic_features(
backend : str, optional
Select the backend to plot features. Supported are "plotly" and
"seaborn", by default "plotly".
backend_kwargs : dict
Dictionary with extra keyword arguments to extend functionality of
backends that cannot be set with the defaults keyword arguments of
the plot_atomic_features function.
For more information see:
- https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html
- https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
data_only : bool
If set to True, this function returns only data in a dataframe with
the following structure:
"""
if backend_kwargs == None:
backend_kwargs = {}

method = method.lower()
backend = backend.lower()
dot_size = kwargs.get("dot_size", 2)
Expand Down Expand Up @@ -288,7 +309,7 @@ def plot_atomic_features(
from sklearn.decomposition import PCA

labels = {str(axis[i]): "PCA-{}".format(i + 1) for i in range(len(axis))}
pca = PCA(n_components=dimensions)
pca = PCA(n_components=dimensions, **backend_kwargs)
pca_result = pca.fit_transform(full_ls)

to_pandas = []
Expand Down Expand Up @@ -327,7 +348,7 @@ def plot_atomic_features(

labels = {str(axis[i]): "t-SNE-{}".format(i + 1) for i in range(len(axis))}

tsne = manifold.TSNE(n_components=dimensions, perplexity=5)
tsne = manifold.TSNE(n_components=dimensions, **backend_kwargs)

tsne_result = tsne.fit_transform(full_ls)

Expand Down Expand Up @@ -362,9 +383,13 @@ def plot_atomic_features(
elif dimensions == 2 and backend == "seaborn":
sns.scatterplot(**labels, data=df, hue="Symbol")

try:
plt.show()
except:
pass
if data_only:
return df

else:
try:
plt.show()
except:
pass

return plt, df
return plt, df

0 comments on commit 8a04838

Please sign in to comment.