Skip to content

Commit

Permalink
Now plots of features have a return.
Browse files Browse the repository at this point in the history
  • Loading branch information
muammar committed Oct 28, 2019
1 parent 3fadbcc commit a09839e
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions ml4chem/data/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def parity(predictions, true, scores=False, filename=None, **kwargs):
Parameters
----------
predictions : list or ndarray
predictions : list or ndarray
Model predictions in a list.
true : list or ndarray
Targets or true values.
Expand Down Expand Up @@ -198,6 +198,12 @@ def plot_atomic_features(latent_space, method="PCA", dimensions=2, backend="seab
"""
method = method.lower()
backend = backend.lower()
dot_size = 2.0

supported_methods = ["pca", "tsne"]

if method not in supported_methods:
raise NotImplementedError

if backend == "seaborn":
# This hack is needed because it seems plotly import overwrite everything.
Expand Down Expand Up @@ -261,11 +267,11 @@ def plot_atomic_features(latent_space, method="PCA", dimensions=2, backend="seab
if dimensions == 3 and backend == "plotly":
args["color"] = "Symbol"
plt = px.scatter_3d(df, **args)
plt.update_traces(marker=dict(size=2))
plt.update_traces(marker=dict(size=dot_size))
elif dimensions == 2 and backend == "plotly":
args["color"] = "Symbol"
plt = px.scatter(df, **args)
plt.update_traces(marker=dict(size=2))
plt.update_traces(marker=dict(size=dot_size))
elif dimensions == 3 and backend == "seaborn":
raise ("This backend is for 2D visualization")
elif dimensions == 2 and backend == "seaborn":
Expand Down Expand Up @@ -301,11 +307,11 @@ def plot_atomic_features(latent_space, method="PCA", dimensions=2, backend="seab
if dimensions == 3 and backend == "plotly":
args["color"] = "Symbol"
plt = px.scatter_3d(df, **args)
plt.update_traces(marker=dict(size=2))
plt.update_traces(marker=dict(size=dot_size))
elif dimensions == 2 and backend == "plotly":
args["color"] = "Symbol"
plt = px.scatter(df, **args)
plt.update_traces(marker=dict(size=2))
plt.update_traces(marker=dict(size=dot_size))
elif dimensions == 3 and backend == "seaborn":
raise ("This backend is for 2D visualization")
elif dimensions == 2 and backend == "seaborn":
Expand All @@ -315,3 +321,5 @@ def plot_atomic_features(latent_space, method="PCA", dimensions=2, backend="seab
plt.show()
except:
pass

return plt

0 comments on commit a09839e

Please sign in to comment.