diff --git a/deepreplay/callbacks.py b/deepreplay/callbacks.py index 87d87cd..6d20ca2 100644 --- a/deepreplay/callbacks.py +++ b/deepreplay/callbacks.py @@ -2,6 +2,7 @@ import os import numpy as np import h5py +import keras.backend as K from keras.callbacks import Callback class ReplayData(Callback): @@ -30,18 +31,36 @@ class ReplayData(Callback): Group inside the HDF5 file where the information is to be saved. If the informed group name already exists, it will throw an exception. + model: Keras Model, optional + If provided, it will set the model directly to the callback + instance and execute `on_train_begin` method to initialize + all variables and create the corresponding group in the HDF5 + file. + This is intended to be used for analyzing the initial conditions + of the model without ever calling its `fit` function, where + the callback is usually called. """ - def __init__(self, inputs, targets, filename, group_name): + def __init__(self, inputs, targets, filename, group_name, model=None): super(ReplayData, self).__init__() self.handler = h5py.File('{}'.format(filename), 'a') self.inputs = inputs - self.targets = targets.reshape(-1, 1) + self.targets = targets.reshape(len(targets), -1) self.filepath = os.path.split(filename)[0] self.filename = filename self.group = None self.group_name = group_name self.current_epoch = -1 self.n_epochs = 0 + if model is not None: + self.set_model(model) + self.set_params({ + 'epochs': 0, + 'samples': len(self.inputs), + 'batch_size': len(self.inputs), + }) + self.group_name = group_name + '_init' + self.on_train_begin() + self.group_name = group_name return def _append_weights(self): @@ -52,6 +71,13 @@ def _append_weights(self): for j, weights in enumerate(layer_weights): self.group['layer{}'.format(i)]['weights{}'.format(j)][self.current_epoch + 1] = weights + def get_lr(self): + optimizer = self.model.optimizer + return K.function(inputs=[], + outputs=[optimizer.lr * + (1. / (1. + optimizer.decay * K.cast(optimizer.iterations, + K.dtype(optimizer.decay))))])(inputs=[])[0] + def on_train_begin(self, logs={}): self.model.save(os.path.join(self.filepath, '{}_model.h5'.format(self.group_name))) self.n_epochs = self.params['epochs'] @@ -59,7 +85,8 @@ def on_train_begin(self, logs={}): self.group = self.handler.create_group(self.group_name) self.group.attrs['samples'] = self.params['samples'] self.group.attrs['batch_size'] = self.params['batch_size'] - self.group.attrs['n_batches'] = np.ceil(self.params['samples'] / self.params['batch_size']).astype(np.int) + self.group.attrs['n_batches'] = (self.params['samples'] + self.params['batch_size'] - 1) // \ + self.params['batch_size'] self.group.attrs['n_epochs'] = self.n_epochs self.group.attrs['n_layers'] = len(self.model.layers) try: @@ -81,6 +108,8 @@ def on_train_begin(self, logs={}): for metric in self.model.metrics: self.group.create_dataset(metric, shape=(self.n_epochs,), dtype='f') + self.group.create_dataset('lr', shape=(self.n_epochs,), dtype='f') + for i, layer in enumerate(self.model.layers): layer_grp = self.group.create_group('layer{}'.format(i)) layer_weights = layer.get_weights() @@ -97,6 +126,7 @@ def on_train_end(self, logs={}): def on_epoch_begin(self, epoch, logs={}): self.current_epoch = epoch + self.group['lr'][epoch] = self.get_lr() return def on_epoch_end(self, epoch, logs={}): diff --git a/deepreplay/datasets/ball.py b/deepreplay/datasets/ball.py new file mode 100644 index 0000000..a64ba14 --- /dev/null +++ b/deepreplay/datasets/ball.py @@ -0,0 +1,55 @@ +import numpy as np + +def load_data(n_dims=10, n_points=1000, classif_radius_fraction=0.5, only_sphere=False, shuffle=True, seed=13): + """ + + Parameters + ---------- + n_dims: int, optional + Number of dimensions of the n-ball. Default is 10. + n_points: int, optional + Number of points in each parabola. Default is 1,000. + classif_radius_fraction: float, optional + Points farther away from the center than + `classification_radius_fraction * ball radius` are + considered to be positive cases. The remaining + points are the negative cases. + only_sphere: boolean + If True, generates a n-sphere, that is, a hollow n-ball. + Default is False. + shuffle: boolean, optional + If True, the points are shuffled. Default is True. + seed: int, optional + Random seed. Default is 13. + + Returns + ------- + X, y: tuple of ndarray + X is an array of shape (n_points, n_dims) containing the + points in the n-ball. + y is an array of shape (n_points, 1) containing the + classes of the samples. + """ + radius = np.sqrt(n_dims) + points = np.random.normal(size=(n_points, n_dims)) + sphere = radius * points / np.linalg.norm(points, axis=1).reshape(-1, 1) + if only_sphere: + X = sphere + else: + X = sphere * np.random.uniform(size=(n_points, 1))**(1 / n_dims) + + adjustment = 1 / np.std(X) + radius *= adjustment + X *= adjustment + + y = (np.abs(np.sum(X, axis=1)) > (radius * classif_radius_fraction)).astype(np.int) + + # But we must not feed the network with neatly organized inputs... + # so let's randomize them + if shuffle: + np.random.seed(seed) + shuffled = np.random.permutation(range(X.shape[0])) + X = X[shuffled] + y = y[shuffled].reshape(-1, 1) + + return (X, y) diff --git a/deepreplay/datasets/hypercube.py b/deepreplay/datasets/hypercube.py new file mode 100644 index 0000000..ee5710e --- /dev/null +++ b/deepreplay/datasets/hypercube.py @@ -0,0 +1,37 @@ +import itertools +import numpy as np + +def load_data(n_dims=10, vertices=(-1., 1.), shuffle=True, seed=13): + """ + + Parameters + ---------- + n_dims: int, optional + Number of dimensions of the hypercube. Default is 10. + edge: tuple of floats, optional + Two vertices of an edge. Default is (-1., 1.). + shuffle: boolean, optional + If True, the points are shuffled. Default is True. + seed: int, optional + Random seed. Default is 13. + + Returns + ------- + X, y: tuple of ndarray + X is an array of shape (2 ** n_dims, n_dims) containing the + vertices coordinates of the hypercube. + y is an array of shape (2 ** n_dims, 1) containing the + classes of the samples. + """ + X = np.array(list(itertools.product(vertices, repeat=n_dims))) + y = (np.sum(np.clip(X, a_min=0, a_max=1), axis=1) >= (n_dims / 2.0)).astype(np.int) + + # But we must not feed the network with neatly organized inputs... + # so let's randomize them + if shuffle: + np.random.seed(seed) + shuffled = np.random.permutation(range(X.shape[0])) + X = X[shuffled] + y = y[shuffled].reshape(-1, 1) + + return (X, y) diff --git a/deepreplay/plot.py b/deepreplay/plot.py index a6af21e..381a53a 100644 --- a/deepreplay/plot.py +++ b/deepreplay/plot.py @@ -3,10 +3,11 @@ import matplotlib import matplotlib.pyplot as plt import matplotlib.ticker as ticker +import pandas as pd import seaborn as sns from collections import namedtuple from matplotlib import animation -matplotlib.rcParams['animation.writer'] = 'ffmpeg' +matplotlib.rcParams['animation.writer'] = 'avconv' sns.set_style('white') FeatureSpaceData = namedtuple('FeatureSpaceData', ['line', 'bent_line', 'prediction', 'target']) @@ -14,6 +15,7 @@ LossAndMetricData = namedtuple('LossAndMetricData', ['loss', 'metric', 'metric_name']) ProbHistogramData = namedtuple('ProbHistogramData', ['prob', 'target']) LossHistogramData = namedtuple('LossHistogramData', ['loss']) +LayerViolinsData = namedtuple('LayerViolinsData', ['names', 'values', 'layers', 'selected_layers']) def build_2d_grid(xlim, ylim, n_lines=11, n_points=1000): """Returns a 2D grid of boundaries given by `xlim` and `ylim`, @@ -588,4 +590,45 @@ def _update(i, lh, epoch_start=0): lh.ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f')) lh.ax.locator_params(tight=True, nbins=4) - return lh.line \ No newline at end of file + return lh.line + +class LayerViolins(Basic): + def __init__(self, ax, title): + super(LayerViolins, self).__init__(ax) + self.values = None + self.names = None + self._title = title + + def load_data(self, layer_violins_data): + self.values = layer_violins_data.values + self.names = layer_violins_data.names + self.layers = ['inputs'] + layer_violins_data.layers + self.selected_layers = layer_violins_data.selected_layers + self.palette = dict(zip(self.layers, sns.palettes.husl_palette(len(self.layers), .7))) + self.n_epochs = len(self.values) + self._prepare_plot() + return self + + def _prepare_plot(self): + self.line = self.ax.plot([], []) + + @staticmethod + def _update(i, lv, epoch_start=0): + assert len(lv.names) == len(lv.values[i]), "Layer names and values have different lengths!" + epoch = i + epoch_start + + df = pd.concat([pd.DataFrame(layer_values.ravel(), + columns=[layer_name]).melt(var_name='layers', value_name='values') + for layer_name, layer_values in zip(lv.names, lv.values[i])]) + df = df[df.isin({'layers': lv.selected_layers}).values] + + lv.ax.clear() + sns.violinplot(data=df, x='layers', y='values', ax=lv.ax, cut=0, palette=lv.palette, scale='width') + lv.ax.set_xticklabels(df.layers.unique()) + lv.ax.set_xlabel('Layers') + lv.ax.set_ylabel(lv._title) + lv.ax.set_ylim([df['values'].min(), df['values'].max()]) + lv.ax.set_title('{} - Epoch: {}'.format(lv.title[0], epoch)) + + return lv.line + diff --git a/deepreplay/replay.py b/deepreplay/replay.py index b4335e3..e0823df 100644 --- a/deepreplay/replay.py +++ b/deepreplay/replay.py @@ -3,16 +3,25 @@ import h5py import keras.backend as K from keras.models import load_model -from .plot import build_2d_grid, FeatureSpace, ProbabilityHistogram, LossHistogram, LossAndMetric -from .plot import FeatureSpaceData, FeatureSpaceLines, ProbHistogramData, LossHistogramData, LossAndMetricData +from .plot import ( + build_2d_grid, FeatureSpace, ProbabilityHistogram, LossHistogram, LossAndMetric, LayerViolins +) +from .plot import ( + FeatureSpaceData, FeatureSpaceLines, ProbHistogramData, LossHistogramData, LossAndMetricData, LayerViolinsData +) +from .utils import make_batches, slice_arrays +from itertools import groupby +from operator import itemgetter TRAINING_MODE = 1 TEST_MODE = 0 +ACTIVATIONS = ['softmax', 'relu', 'elu', 'tanh', 'sigmoid', 'hard_sigmoid', 'linear', 'softplus', 'softsign', 'selu'] +Z_OPS = ['BiasAdd', 'MatMul', 'Add', 'Sub', 'Mul', 'Maximum', 'Minimum', 'RealDiv', 'ExpandDims'] + class Replay(object): """Creates an instance of Replay, to process information collected - by the callback and generate data to feed the supported visualiza- - tions. + by the callback and generate data to feed the supported visualizations. Parameters ---------- @@ -46,30 +55,68 @@ class Replay(object): animating; namedtuple containing information about classification probabilities and targets. + weights_violins: (LayerViolins, LayerViolinsData) + LayerViolins object to be used for plotting and animating; + namedtuple containing information about weights values + per layer. + + activations_violins: (LayerViolins, LayerViolinsData) + LayerViolins object to be used for plotting and animating; + namedtuple containing information about activation values + per layer. + + zvalues_violins: (LayerViolins, LayerViolinsData) + LayerViolins object to be used for plotting and animating; + namedtuple containing information about Z-values per layer. + + gradients_violins: (LayerViolins, LayerViolinsData) + LayerViolins object to be used for plotting and animating; + namedtuple containing information about gradient values + per layer. + + weights_std: ndarray + Standard deviation of the weights per layer. + + gradients_std: ndarray + Standard deivation of the gradients per layer. + training_loss: ndarray An array of shape (n_epochs, ) with training loss as reported by Keras at the end of each epoch. + + learning_rate: ndarray + An array of shape (n_epochs, ) with learning rate as reported + by Keras at the beginning of each epoch. """ def __init__(self, replay_filename, group_name, model_filename=''): # Set learning phase to TEST self.learning_phase = TEST_MODE + # Loads ReplayData file + self.replay_data = h5py.File('{}'.format(replay_filename), 'r') + try: + self.group = self.replay_data[group_name] + except KeyError: + self.group = self.replay_data[group_name + '_init'] + group_name += '_init' + + self.group_name = group_name + # If not informed, defaults to '_model' suffix if model_filename == '': model_filename = '{}_model.h5'.format(group_name) - # Loads Keras model self.model = load_model(model_filename) - # Loads ReplayData file - self.replay_data = h5py.File('{}'.format(replay_filename), 'r') - self.group_name = group_name - self.group = self.replay_data[self.group_name] # Retrieves some basic information from the replay data self.inputs = self.group['inputs'][:] self.targets = self.group['targets'][:] self.n_epochs = self.group.attrs['n_epochs'] self.n_layers = self.group.attrs['n_layers'] + + # Generates ranges for the number of different weight arrays in each layer + self.n_weights = [range(len(self.group['layer{}'.format(l)])) for l in range(self.n_layers)] + # Retrieves weights as a list, each element being one epoch self.weights = self._retrieve_weights() @@ -92,26 +139,104 @@ def __init__(self, replay_filename, group_name, model_filename=''): outputs=[K.binary_crossentropy(self.model.targets[0], self.model.outputs[0])]) + # Keras function to compute the gradients for trainable weights, given inputs, targets, weights and + # sample weights + self.__trainable_weights = [w for layer in self.model.layers + for w in layer.trainable_weights + if layer.trainable and ('bias' not in w.op.name)] + self.__trainable_gradients = self.model.optimizer.get_gradients(self.model.total_loss, self.__trainable_weights) + self._get_gradients = K.function(inputs=[K.learning_phase()] + self.model.inputs + self.model.targets + + self._model_weights + self.model.sample_weights, + outputs=self.__trainable_gradients) + + def get_z_op(layer): + op = layer.output.op + if op.type in Z_OPS: + return layer.output + else: + op_layer_name = op.name.split('/')[0] + for input in op.inputs: + input_layer_name = input.name.split('/')[0] + if (input.op.type in Z_OPS) and (op_layer_name == input_layer_name): + return input + return None + + __z_layers = np.array([i for i, layer in enumerate(self.model.layers) if get_z_op(layer) is not None]) + __act_layers = np.array([i for i, layer in enumerate(self.model.layers) + if layer.output.op.type.lower() in ACTIVATIONS]) + __z_layers = np.array([__z_layers[np.argmax(layer < __z_layers) - 1] for layer in __act_layers]) + self.z_act_layers = [self.model.layers[i].name for i in __z_layers] + + self._z_layers = ['inputs'] + [self.model.layers[i].name for i in __z_layers] + self._z_tensors = self.model.inputs + list(filter(lambda t: t is not None, + [get_z_op(self.model.layers[i]) for i in __z_layers])) + + self._activation_layers = ['inputs'] + [self.model.layers[i].name for i in __act_layers] + self._activation_tensors = self.model.inputs + [self.model.layers[i].output for i in __act_layers] + + # Keras function to compute the Z values given inputs and weights + self._get_zvalues = K.function(inputs=[K.learning_phase()] + self.model.inputs + self._model_weights, + outputs=self._z_tensors) + # Keras function to compute the activation values given inputs and weights + self._get_activations = K.function(inputs=[K.learning_phase()] + self.model.inputs + self._model_weights, + outputs=self._activation_tensors) + + # Gets names of all layers with arrays of weights of lengths 1 (no biases) or 2 (with biases) + # Layers without weights (e.g. Activation, BatchNorm) are not included + self.weights_layers = [layer.name for layer, weights in zip(self.model.layers, self.n_weights) + if len(weights) in (1, 2)] + # Attributes for the visualizations - Data self._feature_space_data = None self._loss_hist_data = None self._loss_and_metric_data = None self._prob_hist_data = None self._decision_boundary_data = None + self._weights_violins_data = None + self._activations_violins_data = None + self._zvalues_violins_data = None + self._gradients_data = None # Attributes for the visualizations - Plot objects self._feature_space_plot = None self._loss_hist_plot = None self._loss_and_metric_plot = None self._prob_hist_plot = None self._decision_boundary_plot = None + self._weights_violins_plot = None + self._activations_violins_plot = None + self._zvalues_violins_plot = None + self._gradients_plot = None + + def _make_batches(self, seed): + inputs = self.inputs[:] + targets = self.targets[:] + + np.random.seed(seed) + np.random.shuffle(inputs) + np.random.shuffle(targets) + num_training_samples = inputs.shape[0] + + batches = make_batches(num_training_samples, self.params['batch_size']) + index_array = np.arange(num_training_samples) + + inputs_batches = [] + targets_batches = [] + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + inputs_batch, targets_batch = slice_arrays([inputs, targets], batch_ids) + inputs_batches.append(inputs_batch) + targets_batches.append(targets_batch) + + return inputs_batches, targets_batches + + @staticmethod + def __assign_gradients_to_layers(layers, gradients): + return [list(list(zip(*g))[1]) for k, g in groupby(zip(layers, gradients), itemgetter(0))] def _retrieve_weights(self): - # Generates ranges for the number of different weight arrays in each layer - n_weights = [range(len(self.group['layer{}'.format(l)])) - for l in range(self.n_layers)] # Retrieves weights for each layer and sequence of weights weights = [np.array(self.group['layer{}'.format(l)]['weights{}'.format(w)]) - for l, ws in enumerate(n_weights) + for l, ws in enumerate(self.n_weights) for w in ws] # Since initial weights are also saved, there are n_epochs + 1 elements in total return [[w[epoch] for w in weights] for epoch in range(self.n_epochs + 1)] @@ -146,10 +271,50 @@ def loss_and_metric(self): def probability_histogram(self): return self._prob_hist_plot, self._prob_hist_data + @property + def weights_violins(self): + return self._weights_violins_plot, self._weights_violins_data + + @property + def activations_violins(self): + return self._activations_violins_plot, self._activations_violins_data + + @property + def zvalues_violins(self): + return self._zvalues_violins_plot, self._zvalues_violins_data + + @property + def gradients_violins(self): + return self._gradients_plot, self._gradients_data + + @staticmethod + def __calc_std(values): + return np.array([[layer.std() for layer in epoch] for epoch in values]) + + @property + def weights_std(self): + std = None + if self._weights_violins_data is not None: + weights = self._weights_violins_data.values + std = Replay.__calc_std(weights) + return std + + @property + def gradients_std(self): + std = None + if self._gradients_data is not None: + gradients = self._gradients_data.values + std = Replay.__calc_std(gradients) + return std + @property def training_loss(self): return self.group['loss'][:] + @property + def learning_rate(self): + return self.group['lr'][:] + def get_training_metric(self, metric_name): """Returns corresponding metric as reported by Keras at the end of each epoch. @@ -202,6 +367,182 @@ def predict_proba(self, epoch_start=0, epoch_end=-1): probas = np.array(probas) return probas + def build_gradients(self, ax, layer_names=None, exclude_outputs=True, epoch_start=0, epoch_end=-1): + """Builds a LayerViolins object to be used for plotting and + animating. + + Parameters + ---------- + ax: AxesSubplot + Subplot of a Matplotlib figure. + layer_names: list of Strings, optional + If informed, plots only the listed layers. + exclude_outputs: boolean, optional + If True, excludes distribution of output layer. Default is True. + If `layer_names` is informed, `exclude_outputs` is ignored. + epoch_start: int, optional + First epoch to consider. + epoch_end: int, optional + Last epoch to consider. + + Returns + ------- + gradients_plot: LayerViolins + An instance of a LayerViolins object to make plots and + animations. + """ + if epoch_end == -1: + epoch_end = self.n_epochs + epoch_end = min(epoch_end, self.n_epochs) + + gradient_names = [layer.name for layer in self.model.layers for w in layer.trainable_weights + if layer.trainable and ('bias' not in w.op.name)] + gradients = [] + # For each epoch, uses the corresponding weights + for epoch in range(epoch_start, epoch_end + 1): + weights = self.weights[epoch] + + # Sample weights fixed to one! + inputs = [self.learning_phase, self.inputs, self.targets] + weights + [np.ones(shape=self.inputs.shape[0])] + grad = [w for v in Replay.__assign_gradients_to_layers(gradient_names, self._get_gradients(inputs=inputs)) + for w in v] + gradients.append(grad) + + if layer_names is None: + layer_names = self.weights_layers + if exclude_outputs: + layer_names = layer_names[:-1] + + self._gradients_data = LayerViolinsData(names=gradient_names, values=gradients, layers=self.weights_layers, + selected_layers=layer_names) + if ax is None: + self._gradients_plot = None + else: + self._gradients_plot = LayerViolins(ax, 'Gradients').load_data(self._gradients_data) + return self._gradients_plot + + def build_outputs(self, ax, before_activation=False, layer_names=None, include_inputs=True, + exclude_outputs=True, epoch_start=0, epoch_end=-1): + """Builds a LayerViolins object to be used for plotting and + animating. + + Parameters + ---------- + ax: AxesSubplot + Subplot of a Matplotlib figure. + before_activation: Boolean, optional + If True, returns Z-values, that is, before applying + the activation function. + layer_names: list of Strings, optional + If informed, plots only the listed layers. + include_inputs: boolean, optional + If True, includes distribution of inputs. Default is True. + exclude_outputs: boolean, optional + If True, excludes distribution of output layer. Default is True. + If `layer_names` is informed, `exclude_outputs` is ignored. + epoch_start: int, optional + First epoch to consider. + epoch_end: int, optional + Last epoch to consider. + + Returns + ------- + activations_violins_plot/zvalues_violins_plot: LayerViolins + An instance of a LayerViolins object to make plots and + animations. + """ + if epoch_end == -1: + epoch_end = self.n_epochs + epoch_end = min(epoch_end, self.n_epochs) + + if before_activation: + title = 'Z-values' + names = self._z_layers + else: + title = 'Activations' + names = self._activation_layers + outputs = [] + # For each epoch, uses the corresponding weights + for epoch in range(epoch_start, epoch_end + 1): + weights = self.weights[epoch] + inputs = [self.learning_phase, self.inputs] + weights + if before_activation: + outputs.append(self._get_zvalues(inputs=inputs)) + else: + outputs.append(self._get_activations(inputs=inputs)) + + if layer_names is None: + layer_names = self.z_act_layers + if exclude_outputs: + layer_names = layer_names[:-1] + if include_inputs: + layer_names = ['inputs'] + layer_names + + data = LayerViolinsData(names=names, values=outputs, layers=self.z_act_layers, selected_layers=layer_names) + if ax is None: + plot = None + else: + plot = LayerViolins(ax, title).load_data(data) + if before_activation: + self._zvalues_violins_data = data + self._zvalues_violins_plot = plot + else: + self._activations_violins_data = data + self._activations_violins_plot = plot + return plot + + def build_weights(self, ax, layer_names=None, exclude_outputs=True, epoch_start=0, epoch_end=-1): + """Builds a LayerViolins object to be used for plotting and + animating. + + Parameters + ---------- + ax: AxesSubplot + Subplot of a Matplotlib figure. + layer_names: list of Strings, optional + If informed, plots only the listed layers. + exclude_outputs: boolean, optional + If True, excludes distribution of output layer. Default is True. + If `layer_names` is informed, `exclude_outputs` is ignored. + epoch_start: int, optional + First epoch to consider. + epoch_end: int, optional + Last epoch to consider. + + Returns + ------- + weights_violins_plot: LayerViolins + An instance of a LayerViolins object to make plots and + animations. + """ + if epoch_end == -1: + epoch_end = self.n_epochs + epoch_end = min(epoch_end, self.n_epochs) + + names = [layer.name for layer, weights in zip(self.model.layers, self.n_weights) if len(weights) in (1, 2)] + n_weights = [(i, len(weights)) for layer, weights in zip(self.model.layers, self.n_weights) for i in weights] + + weights = [] + # For each epoch, uses the corresponding weights + for epoch in range(epoch_start, epoch_end + 1): + # takes only the weights (i == 0), not the biases (i == 1) + weights.append([w for w, (i, n) in zip(self.weights[epoch], n_weights) if (n in (1, 2)) and (i == 0)]) + + if layer_names is None: + layer_names = self.weights_layers + if exclude_outputs: + layer_names = layer_names[:-1] + + self._weights_violins_data = LayerViolinsData(names=names, + values=weights, + layers=self.weights_layers, + selected_layers=layer_names) + if ax is None: + self._weights_violins_plot = None + else: + self._weights_violins_plot = LayerViolins(ax, 'Weights').load_data(self._weights_violins_data) + return self._weights_violins_plot + def build_loss_histogram(self, ax, epoch_start=0, epoch_end=-1): """Builds a LossHistogram object to be used for plotting and animating. diff --git a/deepreplay/utils.py b/deepreplay/utils.py new file mode 100644 index 0000000..361f3f7 --- /dev/null +++ b/deepreplay/utils.py @@ -0,0 +1,57 @@ +# These functions were extracted from Keras to assure there will +# be no break in compatibility with DeepReplay's code. + +def make_batches(size, batch_size): + """Function extracted from Keras - check keras.engine.training_utils + for the original version. + + Returns a list of batch indices (tuples of indices). + # Arguments + size: Integer, total size of the data to slice into batches. + batch_size: Integer, batch size. + # Returns + A list of tuples of array indices. + """ + num_batches = (size + batch_size - 1) // batch_size # round up + return [(i * batch_size, min(size, (i + 1) * batch_size)) + for i in range(num_batches)] + +def slice_arrays(arrays, start=None, stop=None): + """ Function extracted from Keras - check keras.utils.generic_utils + for the original version. + + Slices an array or list of arrays. + This takes an array-like, or a list of + array-likes, and outputs: + - arrays[start:stop] if `arrays` is an array-like + - [x[start:stop] for x in arrays] if `arrays` is a list + Can also work on list/array of indices: `_slice_arrays(x, indices)` + # Arguments + arrays: Single array or list of arrays. + start: can be an integer index (start index) + or a list/array of indices + stop: integer (stop index); should be None if + `start` was a list. + # Returns + A slice of the array(s). + """ + if arrays is None: + return [None] + elif isinstance(arrays, list): + if hasattr(start, '__len__'): + # hdf5 datasets only support list objects as indices + if hasattr(start, 'shape'): + start = start.tolist() + return [None if x is None else x[start] for x in arrays] + else: + return [None if x is None else x[start:stop] for x in arrays] + else: + if hasattr(start, '__len__'): + if hasattr(start, 'shape'): + start = start.tolist() + return arrays[start] + elif hasattr(start, '__getitem__'): + return arrays[start:stop] + else: + return [None] + diff --git a/docs/source/deepreplay.datasets.rst b/docs/source/deepreplay.datasets.rst index ecfffe7..9e3130b 100644 --- a/docs/source/deepreplay.datasets.rst +++ b/docs/source/deepreplay.datasets.rst @@ -4,6 +4,22 @@ deepreplay\.datasets package Submodules ---------- +deepreplay\.datasets\.ball module +--------------------------------- + +.. automodule:: deepreplay.datasets.ball + :members: + :undoc-members: + :show-inheritance: + +deepreplay\.datasets\.hypercube module +-------------------------------------- + +.. automodule:: deepreplay.datasets.hypercube + :members: + :undoc-members: + :show-inheritance: + deepreplay\.datasets\.parabola module ------------------------------------- diff --git a/docs/source/deepreplay.rst b/docs/source/deepreplay.rst index 5ae24e6..dc6665a 100644 --- a/docs/source/deepreplay.rst +++ b/docs/source/deepreplay.rst @@ -35,6 +35,14 @@ deepreplay\.replay module :undoc-members: :show-inheritance: +deepreplay\.utils module +------------------------ + +.. automodule:: deepreplay.utils + :members: + :undoc-members: + :show-inheritance: + Module contents --------------- diff --git a/examples/part2_initializers.py b/examples/part2_initializers.py deleted file mode 100644 index 6d65706..0000000 --- a/examples/part2_initializers.py +++ /dev/null @@ -1,55 +0,0 @@ -from keras.layers import Dense, Activation, BatchNormalization -from keras.models import Sequential -from keras.optimizers import SGD -from keras.initializers import glorot_normal, normal -from deepreplay.datasets.parabola import load_data -from deepreplay.callbacks import ReplayData - -X, y = load_data() - -sgd = SGD(lr=0.05) - -def basic_model(activation, initializers): - model = Sequential() - model.add(Dense(units=2, - input_dim=2, - kernel_initializer=initializers[0], - activation=activation, - name='hidden')) - - model.add(Dense(units=1, - kernel_initializer=initializers[1], - activation='sigmoid', - name='output')) - return model - -def bn_model(activation, initializers): - model = Sequential() - model.add(Dense(units=2, - input_dim=2, - kernel_initializer=initializers[0], - name='hidden_linear')) - model.add(BatchNormalization(name='hidden_bn')) - model.add(Activation(activation, name='hidden_activation')) - - model.add(Dense(units=1, - kernel_initializer=initializers[1], - activation='sigmoid', - name='output')) - return model - - -for seed in range(100): - print('Using seed {}') - replay = ReplayData(X, y, filename='part2_relu.h5', group_name='seed{:03}'.format(seed)) - - glorot_initializer = glorot_normal(seed=seed) - normal_initializer = normal(seed=42) - - model = basic_model('relu', [glorot_initializer, normal_initializer]) - - model.compile(loss='binary_crossentropy', - optimizer=sgd, - metrics=['acc']) - - model.fit(X, y, epochs=150, batch_size=16, callbacks=[replay]) diff --git a/examples/part2_weight_initializers.py b/examples/part2_weight_initializers.py new file mode 100644 index 0000000..678a737 --- /dev/null +++ b/examples/part2_weight_initializers.py @@ -0,0 +1,94 @@ +from keras.initializers import normal, glorot_normal, glorot_uniform, he_normal, he_uniform +from keras.layers import Dense +from keras.models import Sequential + +from deepreplay.callbacks import ReplayData +from deepreplay.datasets.ball import load_data +from deepreplay.plot import compose_plots +from deepreplay.replay import Replay + +from matplotlib import pyplot as plt + +# Model builder function +def build_model(n_layers, input_dim, units, activation, initializer): + if isinstance(units, list): + assert len(units) == n_layers + else: + units = [units] * n_layers + + model = Sequential() + # Adds first hidden layer with input_dim parameter + model.add(Dense(units=units[0], + input_dim=input_dim, + activation=activation, + kernel_initializer=initializer, + name='h1')) + + # Adds remaining hidden layers + for i in range(2, n_layers + 1): + model.add(Dense(units=units[i-1], + activation=activation, + kernel_initializer=initializer, + name='h{}'.format(i))) + + # Adds output layer + model.add(Dense(units=1, activation='sigmoid', kernel_initializer=initializer, name='o')) + # Compiles the model + model.compile(loss='binary_crossentropy', optimizer='sgd', metrics=['acc']) + return model + +X, y = load_data(n_dims=10) + +filename = 'part2_weight_initializers.h5' + +# Examples of different initializers + +# Uses normal initializer +# activation = 'sigmoid' +# initializer = normal(mean=0, stddev=0.01, seed=13) +# title = r'Activation: sigmoid - Initializer: $\sigma = 0.01$' +# group_name = 'sigmoid_stdev_0.01' + +# Users Glorot Uniform +# activation = 'tanh' +# initializer = glorot_uniform(seed=13) +# title = 'Activation: Tanh - Initializer: Glorot Uniform' +# group_name = 'tanh_glorot_uniform' + +# Uses He Uniform +activation = 'relu' +initializer = he_uniform(seed=13) +title = 'Activation: ReLU - Initializer: He Uniform' +group_name = 'relu_he_uniform' + +# Builds BLOCK model +model = build_model(n_layers=5, input_dim=10, units=100, + activation=activation, initializer=initializer) + +# Since we only need initial weights, we don't even need to train the model! +# We still use the ReplayData callback, but we can pass the model as argument instead +replaydata = ReplayData(X, y, filename=filename, group_name=group_name, model=model) + +# Now we feed the data to the actual Replay object so we can build the visualizations +replay = Replay(replay_filename=filename, group_name=group_name) + +# Using subplot2grid to assemble a complex figure... +fig = plt.figure(figsize=(12, 6)) +ax_zvalues = plt.subplot2grid((2, 2), (0, 0)) +ax_weights = plt.subplot2grid((2, 2), (0, 1)) +ax_activations = plt.subplot2grid((2, 2), (1, 0)) +ax_gradients = plt.subplot2grid((2, 2), (1, 1)) + +wv = replay.build_weights(ax_weights) +gv = replay.build_gradients(ax_gradients) +# Z-values +zv = replay.build_outputs(ax_zvalues, before_activation=True, + exclude_outputs=True, include_inputs=False) +# Activations +av = replay.build_outputs(ax_activations, exclude_outputs=True, include_inputs=False) + +# Finally, we use compose_plots to update all +# visualizations at once +fig = compose_plots([zv, wv, av, gv], + epoch=0, title=title) +fig.savefig('part2.png', format='png', dpi=120) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 516da19..544e093 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ numpy>=1.14 tensorflow==1.5.0 -keras>=2.0.7 +keras==2.2.0 scikit-learn>=0.18 h5py matplotlib diff --git a/setup.py b/setup.py index 2dacc30..e7cc00f 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ def readme(): return f.read() setup(name='deepreplay', - version='0.1.0a6', + version='0.1.1a1', install_requires=['matplotlib', 'numpy', 'h5py', 'seaborn', 'keras', 'scikit-learn'], description='"Hyper-parameters in Action!" visualizing tool for Keras models.', long_description=readme(), diff --git a/tests/rawdata/hyperparms_in_action.h5 b/tests/rawdata/hyperparms_in_action.h5 index c603858..b979796 100644 Binary files a/tests/rawdata/hyperparms_in_action.h5 and b/tests/rawdata/hyperparms_in_action.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch01.h5 b/tests/rawdata/part1_activation_functions_epoch01.h5 index 969f542..3e5a0fa 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch01.h5 and b/tests/rawdata/part1_activation_functions_epoch01.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch02.h5 b/tests/rawdata/part1_activation_functions_epoch02.h5 index e31552d..d21d0c3 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch02.h5 and b/tests/rawdata/part1_activation_functions_epoch02.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch03.h5 b/tests/rawdata/part1_activation_functions_epoch03.h5 index 3961fb8..333c958 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch03.h5 and b/tests/rawdata/part1_activation_functions_epoch03.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch04.h5 b/tests/rawdata/part1_activation_functions_epoch04.h5 index d377309..d2f81eb 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch04.h5 and b/tests/rawdata/part1_activation_functions_epoch04.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch05.h5 b/tests/rawdata/part1_activation_functions_epoch05.h5 index 78f7a0e..8586b81 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch05.h5 and b/tests/rawdata/part1_activation_functions_epoch05.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch06.h5 b/tests/rawdata/part1_activation_functions_epoch06.h5 index 7a47c1d..24b80dd 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch06.h5 and b/tests/rawdata/part1_activation_functions_epoch06.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch07.h5 b/tests/rawdata/part1_activation_functions_epoch07.h5 index 6b4858d..6bb5946 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch07.h5 and b/tests/rawdata/part1_activation_functions_epoch07.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch08.h5 b/tests/rawdata/part1_activation_functions_epoch08.h5 index d32bca1..5269fa7 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch08.h5 and b/tests/rawdata/part1_activation_functions_epoch08.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch09.h5 b/tests/rawdata/part1_activation_functions_epoch09.h5 index 503f79d..6eea359 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch09.h5 and b/tests/rawdata/part1_activation_functions_epoch09.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch10.h5 b/tests/rawdata/part1_activation_functions_epoch10.h5 index 5b1f7aa..380ebe2 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch10.h5 and b/tests/rawdata/part1_activation_functions_epoch10.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch11.h5 b/tests/rawdata/part1_activation_functions_epoch11.h5 index 2977f4d..eac92cf 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch11.h5 and b/tests/rawdata/part1_activation_functions_epoch11.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch12.h5 b/tests/rawdata/part1_activation_functions_epoch12.h5 index d6ff2c3..3d56bd4 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch12.h5 and b/tests/rawdata/part1_activation_functions_epoch12.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch13.h5 b/tests/rawdata/part1_activation_functions_epoch13.h5 index 2ad2359..d371c19 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch13.h5 and b/tests/rawdata/part1_activation_functions_epoch13.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch14.h5 b/tests/rawdata/part1_activation_functions_epoch14.h5 index 7e59983..89aa8b5 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch14.h5 and b/tests/rawdata/part1_activation_functions_epoch14.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch15.h5 b/tests/rawdata/part1_activation_functions_epoch15.h5 index dc5d1d2..78ce1dc 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch15.h5 and b/tests/rawdata/part1_activation_functions_epoch15.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch16.h5 b/tests/rawdata/part1_activation_functions_epoch16.h5 index 1940d45..53464dc 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch16.h5 and b/tests/rawdata/part1_activation_functions_epoch16.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch17.h5 b/tests/rawdata/part1_activation_functions_epoch17.h5 index dcbc5b6..1d85353 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch17.h5 and b/tests/rawdata/part1_activation_functions_epoch17.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch18.h5 b/tests/rawdata/part1_activation_functions_epoch18.h5 index b424afd..584b2c2 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch18.h5 and b/tests/rawdata/part1_activation_functions_epoch18.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch19.h5 b/tests/rawdata/part1_activation_functions_epoch19.h5 index daa8400..712050c 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch19.h5 and b/tests/rawdata/part1_activation_functions_epoch19.h5 differ diff --git a/tests/rawdata/part1_activation_functions_epoch20.h5 b/tests/rawdata/part1_activation_functions_epoch20.h5 index 4d44cb9..8db07c9 100644 Binary files a/tests/rawdata/part1_activation_functions_epoch20.h5 and b/tests/rawdata/part1_activation_functions_epoch20.h5 differ diff --git a/tests/rawdata/part1_activation_functions_model.h5 b/tests/rawdata/part1_activation_functions_model.h5 index 5f16fd7..bcb3f1a 100644 Binary files a/tests/rawdata/part1_activation_functions_model.h5 and b/tests/rawdata/part1_activation_functions_model.h5 differ