# cpury/lstm-math

Fetching contributors…
Cannot retrieve contributors at this time
143 lines (106 sloc) 3.49 KB
 import numpy as np import matplotlib.pyplot as plt import pandas as pd def plot_2d_space( model, one_hot_encoder, x, y, n=None, reverse=False, save_to=None, dpi=256, ): """ For models trained on equations ala 'a + b', this will plot a scatter plot with correct examples in green and incorrect ones in red. """ if n is None: n = len(x) order = -1 if reverse else 1 predictions = model.predict(x[:n]) correct_coords = [] incorrect_coords = [] for i, prediction in enumerate(predictions): target = y[i] equation_string = one_hot_to_string(x[i])[::order] prediction_string = one_hot_to_string(prediction)[::order] target_string = one_hot_to_string(target)[::order] equation_plus_index = equation_string.index('+') n1 = int(equation_string[:equation_plus_index-1]) n2 = int(equation_string[equation_plus_index+1:-1]) if prediction_string == target_string: # Correct correct_coords.append((n1, n2,)) else: incorrect_coords.append((n1, n2,)) # Create plot fig = plt.figure() ax = fig.add_subplot(1, 1, 1, axisbg="1.0") correct_coords = np.array(correct_coords) ax.scatter( correct_coords[:, 0], correct_coords[:, 1], alpha=0.33, c='green', edgecolors='none', s=20, label='correct' ) incorrect_coords = np.array(incorrect_coords) ax.scatter( incorrect_coords[:, 0], incorrect_coords[:, 1], alpha=0.33, c='red', edgecolors='none', s=20, label='incorrect' ) if save_to: plt.savefig(save_to, dpi=dpi) else: plt.show() plt.close() def plot_error_histogram( model, one_hot_encoder, x, y, n=None, max_d=10, reverse=False ): """ For models trained on equations ala 'a + b', this will plot a histogram of error differences. """ if n is None: n = len(x) order = -1 if reverse else 1 predictions = model.predict(x[:n]) differences = [0] * max_d for i, prediction in enumerate(predictions): target = y[i] prediction_string = one_hot_to_string(prediction)[::order] target_string = one_hot_to_string(target)[::order] prediction_int = int(prediction_string.strip(' \x00')) target_int = int(target_string.strip(' \x00')) difference = abs(target_int - prediction_int) if difference >= max_d: difference = max_d - 1 differences[difference] += 1 # Create plot fig = plt.figure() n, bins, patches = plt.hist(differences, normed=1, facecolor='green') plt.show() plt.close() def plot_training_log( log_file_name='log.csv', metric='acc', save_to=None, dpi=256, ): """ Plot a training log csv generated by CSVLogger. Will display the plot, or if you pass a save_to filename, render to a file. """ log = pd.read_csv(log_file_name) if metric == 'acc': col_train = log.columns[1] col_test = log.columns[3] title = 'Accuracy ({})'.format(col_train) elif metric == 'loss': col_train = log.columns[2] col_test = log.columns[4] title = 'Loss' else: raise ValueError('Metric not s') plt.plot(log[col_train]) plt.plot(log[col_test]) plt.title(title) plt.ylabel(metric) plt.xlabel('epoch') plt.legend(['train', 'test'], loc='lower right') if save_to: plt.savefig(save_to, dpi=dpi) else: plt.show() plt.close()