Skip to content

Commit

Permalink
read_log() return a dataframe instead of tuple when data_only=True.
Browse files Browse the repository at this point in the history
  • Loading branch information
muammar committed Feb 19, 2020
1 parent 61f8e21 commit 000aba9
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions ml4chem/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import plotly.express as px
import seaborn as sns
import matplotlib.pyplot as plt
from collections import OrderedDict
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from ml4chem.data.serialization import load

Expand Down Expand Up @@ -71,10 +72,18 @@ def read_log(logfile, metric="loss", refresh=None, data_only=False):
refresh : float
Interval in seconds before refreshing log file plot.
data_only : bool
If set to True, this function returns only data in a tuple with the
following structure: (epochs, loss, training, test).
If set to True, this function returns only data in a dataframe with
the following structure:
>>> df.head()
epochs loss training test
0 1 33779.46 815.6884 793.3943
Returns
-------
pandas.DataFrame or matplotlib.pyplot object
If data_only is true we return dataframe, otherwise a figure.
"""

if refresh is not None:
Expand Down Expand Up @@ -193,7 +202,17 @@ def read_log(logfile, metric="loss", refresh=None, data_only=False):
(fig,) = plt.plot(epochs, test, label="Test")

if data_only:
return epochs, loss, training, test
data = OrderedDict()
columns = ["epochs", "loss", "training", "test"]
arr = [epochs, loss, training, test]

if metric != combined:
columns.pop(-1)
arr.pop(-1)

for i, column in enumerate(columns):
data[column] = arr[i]
return pd.DataFrame.from_dict(data)
else:
plt.show(block=True)

Expand Down

0 comments on commit 000aba9

Please sign in to comment.