Permalink
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
143 lines (106 sloc) 3.49 KB
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
def plot_2d_space(
model, one_hot_encoder, x, y, n=None, reverse=False,
save_to=None, dpi=256,
):
"""
For models trained on equations ala 'a + b', this will plot a scatter plot
with correct examples in green and incorrect ones in red.
"""
if n is None:
n = len(x)
order = -1 if reverse else 1
predictions = model.predict(x[:n])
correct_coords = []
incorrect_coords = []
for i, prediction in enumerate(predictions):
target = y[i]
equation_string = one_hot_to_string(x[i])[::order]
prediction_string = one_hot_to_string(prediction)[::order]
target_string = one_hot_to_string(target)[::order]
equation_plus_index = equation_string.index('+')
n1 = int(equation_string[:equation_plus_index-1])
n2 = int(equation_string[equation_plus_index+1:-1])
if prediction_string == target_string:
# Correct
correct_coords.append((n1, n2,))
else:
incorrect_coords.append((n1, n2,))
# Create plot
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, axisbg="1.0")
correct_coords = np.array(correct_coords)
ax.scatter(
correct_coords[:, 0],
correct_coords[:, 1],
alpha=0.33, c='green', edgecolors='none', s=20, label='correct'
)
incorrect_coords = np.array(incorrect_coords)
ax.scatter(
incorrect_coords[:, 0],
incorrect_coords[:, 1],
alpha=0.33, c='red', edgecolors='none', s=20, label='incorrect'
)
if save_to:
plt.savefig(save_to, dpi=dpi)
else:
plt.show()
plt.close()
def plot_error_histogram(
model, one_hot_encoder, x, y, n=None, max_d=10, reverse=False
):
"""
For models trained on equations ala 'a + b', this will plot a histogram of
error differences.
"""
if n is None:
n = len(x)
order = -1 if reverse else 1
predictions = model.predict(x[:n])
differences = [0] * max_d
for i, prediction in enumerate(predictions):
target = y[i]
prediction_string = one_hot_to_string(prediction)[::order]
target_string = one_hot_to_string(target)[::order]
prediction_int = int(prediction_string.strip(' \x00'))
target_int = int(target_string.strip(' \x00'))
difference = abs(target_int - prediction_int)
if difference >= max_d:
difference = max_d - 1
differences[difference] += 1
# Create plot
fig = plt.figure()
n, bins, patches = plt.hist(differences, normed=1, facecolor='green')
plt.show()
plt.close()
def plot_training_log(
log_file_name='log.csv', metric='acc', save_to=None, dpi=256,
):
"""
Plot a training log csv generated by CSVLogger. Will display the plot,
or if you pass a save_to filename, render to a file.
"""
log = pd.read_csv(log_file_name)
if metric == 'acc':
col_train = log.columns[1]
col_test = log.columns[3]
title = 'Accuracy ({})'.format(col_train)
elif metric == 'loss':
col_train = log.columns[2]
col_test = log.columns[4]
title = 'Loss'
else:
raise ValueError('Metric not s')
plt.plot(log[col_train])
plt.plot(log[col_test])
plt.title(title)
plt.ylabel(metric)
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='lower right')
if save_to:
plt.savefig(save_to, dpi=dpi)
else:
plt.show()
plt.close()