Skip to content

Commit

Permalink
read_log() plots test error.
Browse files Browse the repository at this point in the history
  • Loading branch information
muammar committed Feb 17, 2020
1 parent fb5885c commit 31ca2d3
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 24 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
[![Downloads](https://img.shields.io/github/downloads/muammar/ml4chem/total.svg?maxAge=2592000?style=flat-square)](https://github.com/muammar/ml4chem/releases)
![PyPI - Downloads](https://img.shields.io/pypi/dm/ml4chem)
[![GitHub release](https://img.shields.io/github/release/muammar/ml4chem.svg)](https://github.com/muammar/ml4chem/releases/latest)
[![GitHub release](https://readthedocs.org/projects/ml4chem/badge/?version=latest)](https://ml4chem.dev)
[![Slack channel](https://img.shields.io/badge/slack-ml4chem-yellow.svg?logo=slack)](https://ml4chem.slack.com/)


Expand Down
4 changes: 3 additions & 1 deletion bin/ml4chem
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ from ml4chem.visualization import read_log, plot_atomic_features
)
def main(**args):
"""ML4Chem command line tool"""
training_plots = ["rmse", "loss"]
training_plots = ["training", "loss", "test", "combined"]
dim_visualization = ["pca", "tsne"]

_file = args["file"]
Expand All @@ -40,6 +40,8 @@ def main(**args):
backend = args["backend"]
plot_atomic_features(_file, method=method, backend=backend)

else:
raise NotImplementedError

if __name__ == "__main__":
main()
62 changes: 40 additions & 22 deletions ml4chem/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@ def read_log(logfile, metric="loss", refresh=None):
logfile : str
Path to logfile.
metric : str
Metric to plot. Supported are loss and rmse.
The keys,values of the dictionary are:
- "loss": Loss function values.
- "training": Training error.
- "test": Test error.
- "combined": training + test errors in same plot.
refresh : float
Interval in seconds before refreshing log file plot.
"""
Expand All @@ -88,7 +94,8 @@ def read_log(logfile, metric="loss", refresh=None):
start = False
epochs = []
loss = []
rmse = []
training = []
test = []

initiliazed = False
while refresh is not None:
Expand All @@ -101,36 +108,43 @@ def read_log(logfile, metric="loss", refresh=None):
line = line.split()
epochs.append(int(line[0]))
loss.append(float(line[3]))
rmse.append(float(line[4]))
training.append(float(line[4]))
test.append(float(line[6]))
except ValueError:
pass

if initiliazed is False:
if metric == "loss":
(fig,) = plt.plot(epochs, loss, label="loss")
(fig,) = plt.plot(epochs, loss, label="Loss")

elif metric == "rmse":
(fig,) = plt.plot(epochs, rmse, label="rmse")
elif metric == "training":
(fig,) = plt.plot(epochs, training, label="Training")

else:
(fig,) = plt.plot(epochs, loss, label="loss")
(fig,) = plt.plot(epochs, rmse, label="rmse")
elif metric == "test":
(fig,) = plt.plot(epochs, test, label="test")

elif metric == "combined":
(fig,) = plt.plot(epochs, training, label="Training")
(fig,) = plt.plot(epochs, test, label="Test")
else:
if metric == "loss":
fig.set_data(epochs, loss)

elif metric == "rmse":
fig.set_data(epochs, rmse)
elif metric == "training":
fig.set_data(epochs, training)

else:
fig.set_data(epochs, loss)
fig.set_data(epochs, rmse)
elif metric == "test":
fig.set_data(epochs, test)

elif metric == "combined":
fig.set_data(epochs, training)
fig.set_data(epochs, test)

# Updating annotation
if metric == "loss":
values = loss
elif metric == "rmse":
values = rmse
else:
values = training

reported = values[-1]
x = int(epochs[-1] * 0.9)
Expand All @@ -156,19 +170,23 @@ def read_log(logfile, metric="loss", refresh=None):
line = line.split()
epochs.append(int(line[0]))
loss.append(float(line[3]))
rmse.append(float(line[4]))
training.append(float(line[4]))
test.append(float(line[6]))
except ValueError:
pass

if metric == "loss":
(fig,) = plt.plot(epochs, loss, label="loss")

elif metric == "rmse":
(fig,) = plt.plot(epochs, rmse, label="rmse")
elif metric == "training":
(fig,) = plt.plot(epochs, training, label="training")

else:
(fig,) = plt.plot(epochs, loss, label="loss")
(fig,) = plt.plot(epochs, rmse, label="rmse")
elif metric == "test":
(fig,) = plt.plot(epochs, test, label="training")

elif metric == "combined":
(fig,) = plt.plot(epochs, training, label="training")
(fig,) = plt.plot(epochs, test, label="test")

plt.show(block=True)

Expand Down
2 changes: 1 addition & 1 deletion readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ python:
version: 3.7
install:
- requirements: docs/requirements.txt
# pip_install: true
pip_install: true

0 comments on commit 31ca2d3

Please sign in to comment.