-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Marc Oliu Simón
committed
Dec 1, 2017
0 parents
commit bd4d159
Showing
74 changed files
with
3,472 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
save | ||
.idea | ||
*.pyc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
Here we explain how the code is structured. | ||
|
||
|
||
DATASETS & MODELS: | ||
The models for each dataset are saved inside "./model/<dataset>/". Each model has a main file, one for the fRNN model | ||
named "model_frnn.py" and one for the RLadder baseline, named "main_rladder.py". These files specify the paths where | ||
to find the pre-processed data and save the trained models, as well as the training parameters: | ||
|
||
- Number of training terations (batches) | ||
- Batch size | ||
- Device to use | ||
- Topology parameters | ||
- Data loading and augmentation parameters | ||
|
||
The folder for each dataset also contains a dataset-specific loader to feed the network ("loader.py") and a data | ||
preprocessing script ("preprocess.py"). The former should only be modified if you plan on using the code on other | ||
datasets not considered here. The later should be manually run in order to prepare the dataset before trying | ||
to train any model. In the case of Moving MMNIST, the preprocessing script will download the necessary files before | ||
preprocessing. For KTH and UCF101 the script expects the uncompressed datasets to be already present. | ||
|
||
|
||
MAIN FILES: | ||
Each model has a main file associated. They are all identical, changing only the imported model file. By default, the | ||
model will train, test and analyse the results of the model. Each step can be executed independently by commenting the | ||
other actions, as intermediate results are saved to disk. There is also a "run" function commented by default. This | ||
function allows you to extract direct predictions from the model, as well as to plot the results as they are obtained. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import imageio | ||
import glob | ||
import scipy.ndimage as ndim | ||
import scipy.misc as sm | ||
import numpy as np | ||
|
||
# Prepare method strings | ||
PATH_STRINGS = '/home/moliu/Documents/Papers/Supplementary/titles/' | ||
s_titles = [ | ||
ndim.imread(PATH_STRINGS + 'frnn.png'), | ||
ndim.imread(PATH_STRINGS + 'rladder.png'), | ||
ndim.imread(PATH_STRINGS + 'prednet.png'), | ||
ndim.imread(PATH_STRINGS + 'srivastava.png'), | ||
ndim.imread(PATH_STRINGS + 'mathieu.png'), | ||
ndim.imread(PATH_STRINGS + 'villegas.png'), | ||
] | ||
|
||
|
||
def generate_captions(strings, width): | ||
titles = [255 * np.ones((20, width+15, 3), dtype=np.uint8) for _ in strings] | ||
|
||
# Prepare strings | ||
strings = [(np.stack([s, s, s], axis=2) if len(s.shape) == 2 else s) for s in strings] | ||
for i, s in enumerate(strings): | ||
s = sm.imresize(s, size=0.8) | ||
t_pad, l_pad = (20 - s.shape[0]) / 2, (width - s.shape[1]) / 2 | ||
titles[i][t_pad:t_pad+s.shape[0], l_pad:l_pad+s.shape[1]] = s | ||
|
||
return np.pad( | ||
np.concatenate(titles, axis=1)[:, :-15], | ||
((0, 0), (width+15, 0), (0, 0)), | ||
'constant', constant_values=255 | ||
) | ||
|
||
|
||
def preprocess_predictions(gt, predictions, width=80): | ||
# Append dimension if no channels | ||
gt = [(f if len(f.shape) == 3 else np.expand_dims(f, axis=2)) for f in gt] | ||
predictions = [[(f if len(f.shape) == 3 else np.expand_dims(f, axis=2)) for f in s] for s in predictions] | ||
|
||
# Replicate channel if grayscale | ||
gt = [(f if f.shape[2] == 3 else np.concatenate([f, f, f], axis=2)) for f in gt] | ||
predictions = [[(f if f.shape[2] == 3 else np.concatenate([f, f, f], axis=2)) for f in s] for s in predictions] | ||
|
||
# Reshape predictions to gt shape, prepare black frames, fill leading frames | ||
predictions = [[sm.imresize(f, gt[0].shape[:2]) for f in s] for s in predictions] | ||
predictions = [([np.zeros(gt[0].shape, dtype=np.uint8)] * 10 if len(s) == 0 else s) for s in predictions] | ||
predictions = [[gt[4]] * 5 + s for s in predictions] | ||
|
||
# Pad frames to fit expected width | ||
padding = ((0, 0), ((width - gt[0].shape[1]) / 2,)*2, (0, 0)) | ||
gt = [np.pad(f, padding, 'constant', constant_values=255) for f in gt] | ||
predictions = [[np.pad(f, padding, 'constant', constant_values=255) for f in s] for s in predictions] | ||
|
||
return gt, predictions | ||
|
||
|
||
def generate_instance_sequence(path): | ||
# List ground truth images | ||
f_gt = sorted(glob.glob(path + 'g*.png')) | ||
f_gt = f_gt[-5:] + f_gt[:10] | ||
|
||
# List prediction images | ||
f_methods = [ | ||
sorted(glob.glob(path + 'frnn_*.png')), sorted(glob.glob(path + 'rladder_*.png')), | ||
sorted(glob.glob(path + 'prednet_*.png')), sorted(glob.glob(path + 'srivastava_*.png')), | ||
sorted(glob.glob(path + 'mathieu_*.png')), sorted(glob.glob(path + 'villegas_*.png')) | ||
] | ||
|
||
# Read & preprocess frames | ||
f_gt, f_methods = [ndim.imread(f) for f in f_gt], [[ndim.imread(f) for f in m] for m in f_methods] | ||
f_gt, f_methods = preprocess_predictions(f_gt, f_methods) | ||
im_h, im_w = f_gt[0].shape[:2] | ||
|
||
# Fill frames with ground truth & predictions | ||
frame = 255 * np.ones((im_h, im_w*7 + 15*6, 3), dtype=np.uint8) | ||
frames = [np.copy(frame) for _ in range(15)] | ||
for i, (fg, fm) in enumerate(zip(f_gt, zip(*f_methods))): | ||
frames[i][:im_h, :im_w] = fg | ||
for j, (f, title) in enumerate(zip(fm, ['frnn', 'rladder', 'prednet', 'Srivastava', 'mathieu', 'villegas'])): | ||
r = (j + 1) * (im_w + 15) | ||
frames[i][:im_h, r:r+im_w] = f | ||
|
||
# Return sequence frames | ||
return frames | ||
|
||
|
||
def build_dataset(name, paths): | ||
titles = generate_captions(s_titles, 80) | ||
instances = [generate_instance_sequence(p) for p in paths] | ||
s_h, s_w = instances[0][0].shape[:2] | ||
|
||
# Merge sequences | ||
frame = 255 * np.ones((len(paths) * (s_h + 15) - 15, s_w, 3), dtype=np.float32) | ||
frames = [np.copy(frame) for _ in range(15)] | ||
for i, f in enumerate(zip(*instances)): | ||
for j, m in enumerate(f): | ||
t = j*(s_h+15) | ||
frames[i][t:t+s_h] = m | ||
|
||
imageio.mimsave(name, [np.concatenate((titles, f), axis=0) for f in frames], duration=0.5) | ||
|
||
|
||
if __name__ == '__main__': | ||
PATH_IN = '/home/moliu/Documents/Papers/Supplementary/images/qualitative/' | ||
PATH_OUT = '/home/moliu/Documents/Papers/Supplementary/gifs/' | ||
|
||
build_dataset(PATH_OUT + 'mmnist.gif', [ | ||
PATH_IN + 'mmnist_l1/s12/', PATH_IN + 'mmnist_l1/s11/', PATH_IN + 'mmnist_l1/s13/', | ||
PATH_IN + 'mmnist_l1/s17/', PATH_IN + 'mmnist_l1/s20/', PATH_IN + 'mmnist_l1/s21/', | ||
PATH_IN + 'mmnist_l1/s11_n/', PATH_IN + 'mmnist_l1/s5_n/', | ||
]) | ||
|
||
build_dataset(PATH_OUT + 'kth.gif', [ | ||
PATH_IN + 'kth_l1/s31/', PATH_IN + 'kth_l1/s37/', PATH_IN + 'kth_l1/s77/', | ||
PATH_IN + 'kth_l1/s23/', PATH_IN + 'kth_l1/s43/', PATH_IN + 'kth_l1/s75/', | ||
PATH_IN + 'kth_l1/s97/', PATH_IN + 'kth_l1/s37_2/', | ||
]) | ||
|
||
build_dataset(PATH_OUT + 'ucf101.gif', [ | ||
PATH_IN + 'ucf101_l1/s8/', PATH_IN + 'ucf101_l1/s9_last/', PATH_IN + 'ucf101_l1/s9_mean/', | ||
PATH_IN + 'ucf101_l1/s21/', PATH_IN + 'ucf101_l1/s37/', PATH_IN + 'ucf101_l1/s44/', | ||
PATH_IN + 'ucf101_l1/s28/', PATH_IN + 'ucf101_l1/s41/', | ||
]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import scipy.misc as sm | ||
import os | ||
|
||
|
||
def test_layer_subsets(model, preprocessor, indices, save_path): | ||
# Get samples to analyse | ||
preprocessor.set_loader( | ||
preprocessor.loader.instantiate(sample_indices=indices) | ||
) | ||
|
||
# Start layer removal process | ||
n_layers, preds = [14, 13, 11, 10, 8, 7, 5, 4], [] | ||
for n in [14, 13, 11, 10, 8, 7, 5, 4]: | ||
preprocessor.loader.reset() | ||
model.topology[0].topology = model.topology[0].topology[:n] | ||
t_preds = model.run(x=preprocessor, batch_size=10) / 2 + 0.5 | ||
preds.append(t_preds) | ||
|
||
# Remove color channel if images are greyscale | ||
preds = [p[..., 0] for p in preds] if preds[0].shape[-1] == 1 else preds | ||
|
||
# Save predictions | ||
for n, p in zip(n_layers, preds): | ||
for i, s in enumerate(p): | ||
# Create sequence path if it does not exist | ||
s_path = save_path + 's' + str(i) + '/' | ||
if not os.path.exists(s_path): | ||
os.makedirs(s_path) | ||
|
||
# Save sequence frames | ||
for j in range(10): | ||
sm.imsave(s_path + 'l' + str(n) + '_f' + str(j) + '.png', s[j]) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import os | ||
import scipy.misc as sm | ||
import numpy as np | ||
|
||
|
||
def save_sequences(model, preprocessor, path, indices=None): | ||
# Randomly select indices if not specified | ||
indices = np.random.randint(low=0, high=preprocessor.loader.num_samples, size=(50,)) if indices is None else indices | ||
|
||
# Get preprocessor subsample with the sequences of interest | ||
preprocessor.set_loader( | ||
preprocessor.loader.instantiate(sample_indices=indices) | ||
) | ||
|
||
# Get ground truth and predicted sequences | ||
gt = np.transpose(preprocessor.retrieve()[0], axes=[1, 0, 2, 3, 4]) / 2 + 0.5 | ||
predictions = model.run(x=preprocessor, batch_size=10) / 2 + 0.5 | ||
|
||
# Remove channels dimension for greyscale images | ||
if gt.shape[-1] == 1: | ||
predictions, gt = predictions[..., 0], gt[..., 0] | ||
|
||
# Save predictions | ||
for i, (g, p) in enumerate(zip(gt, predictions)): | ||
# Create sequence path if it does not exist | ||
s_path = path + 's' + str(i) + '/' | ||
if not os.path.exists(s_path): | ||
os.makedirs(s_path) | ||
print s_path + ' -> ' + str(indices[i]) | ||
|
||
for j in range(10): | ||
sm.imsave(s_path + 'g' + str(j) + '.png', g[j]) | ||
sm.imsave(s_path + 'g' + str(j+10) + '.png', g[10+j]) | ||
sm.imsave(s_path + 'p' + str(j+10) + '.png', p[j]) | ||
|
||
# Return indices and sequences | ||
return indices, predictions, gt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
|
||
def print_results(title, errors): | ||
# Remove prediction singularities | ||
errors[errors == np.inf] = np.nan | ||
|
||
# Print baseline errors | ||
errors = np.nanmean(errors, axis=0) | ||
errors = np.concatenate((errors, np.mean(errors, axis=1, keepdims=True)), axis=1) | ||
print '\n\n' + title + ':' | ||
print 'MSE: ' + str(errors[0]) | ||
print 'PSNR: ' + str(errors[1]) | ||
print 'DSSIM: ' + str(errors[2]) | ||
|
||
|
||
def make_plot(measures, limits=None): | ||
def lineplot(y_label, measures, limits): | ||
lines = [] | ||
for c, m, s in measures: | ||
lines += plt.plot(range(1, 11), m, s) | ||
plt.xlabel('time step') | ||
plt.ylabel(y_label) | ||
plt.xticks(range(1, 11)) | ||
plt.xlim(1, 10) | ||
if limits[0] is not None: | ||
plt.ylim(limits[0], limits[1]) | ||
plt.grid(True, linestyle='dashed') | ||
|
||
return lines | ||
|
||
# Set limits | ||
limits = [None] * 3 if limits is None else limits | ||
limits = [([[None, None]] * 3 if l is None else l) for l in limits] | ||
|
||
# Create plots | ||
fig = plt.figure(figsize=(16, 10), dpi=70) | ||
for i, (t_measures, t_limits) in enumerate(zip(measures, limits)): | ||
fig.add_subplot(3, 3, 3*i+1) | ||
lines = lineplot('Average MSE', t_measures[0], limits=t_limits[0]) | ||
fig.add_subplot(3, 3, 3*i+2) | ||
lineplot('Average PSNR', t_measures[1], limits=t_limits[1]) | ||
fig.add_subplot(3, 3, 3*i+3) | ||
lineplot('Average DSSIM', t_measures[2], limits=t_limits[2]) | ||
|
||
# Display legend | ||
plt.subplots_adjust(top=1, bottom=0.09, left=0.05, right=0.99) | ||
labels = tuple(l for l, v, s in measures[0][0]) | ||
plt.figlegend(lines, labels, loc='lower center', ncol=10, fontsize=12, frameon=False) | ||
|
||
# Display result | ||
# plt.tight_layout() | ||
plt.show() | ||
|
||
if __name__ == '__main__': | ||
make_plot([([ | ||
('Baseline', [0.04923, 0.06121, 0.06642, 0.07147, 0.07338, 0.07512, 0.07511, 0.07567, 0.07542, 0.07590], 'k-'), | ||
('RLadder', [0.04251, 0.04252, 0.04252, 0.04254, 0.04255, 0.04253, 0.04254, 0.04255, 0.04255, 0.04254], 'g-'), | ||
('Prednet', [0.02838, 0.04408, 0.04285, 0.04285, 0.04270, 0.04255, 0.04258, 0.04255, 0.04256, 0.04254], 'r-'), | ||
('Srivastava', [0.00885, 0.01097, 0.01308, 0.01496, 0.01680, 0.01856, 0.02035, 0.02185, 0.02335, 0.02473], 'y-'), | ||
('Mathieu', [0.022462, 0.032085 , 0.037718 , 0.043201 , 0.043589 , 0.043213, 0.044363 , 0.045566, 0.046795, 0.047694], 'c-'), | ||
('Villegas', [0.04251, 0.04252, 0.04252, 0.04254, 0.04255, 0.04253, 0.04254, 0.04255, 0.04255, 0.04255], 'm-'), | ||
('fRNN', [0.00475, 0.00578, 0.00686, 0.00784, 0.00887, 0.00994, 0.01105, 0.01207, 0.01319, 0.01435], 'b-'), | ||
('RLadder (pre-trained)', [0.00760, 0.00978, 0.01217, 0.01432, 0.01651, 0.01851, 0.02047, 0.02229, 0.02401, 0.02567], 'g--'), | ||
], [ | ||
('Baseline', [13.233, 12.266, 11.937, 11.601, 11.513, 11.396, 11.407, 11.362, 11.388, 11.350], 'k-'), | ||
('RLadder', [13.860, 13.859, 13.860, 13.858, 13.856, 13.858, 13.856, 13.855, 13.855, 13.856], 'g-'), | ||
('Prednet', [15.684, 13.711, 13.828, 13.831, 13.843, 13.857, 13.853, 13.855, 13.855, 13.855], 'r-'), | ||
('Srivastava', [20.809, 19.916, 19.177, 18.601, 18.103, 17.681, 17.276, 16.960, 16.671, 16.421], 'y-'), | ||
('Mathieu', [16.4688, 14.9215, 14.2196, 13.6307, 13.5919, 13.6295, 13.5155, 13.3994, 13.2839, 13.2013], 'c-'), | ||
('Villegas', [13.860, 13.859, 13.860, 13.858, 13.856, 13.858, 13.856, 13.855, 13.855, 13.856], 'm-'), | ||
('fRNN', [24.208, 23.287, 22.566, 21.983, 21.455, 20.949, 20.471, 20.060, 19.634, 19.242], 'b-'), | ||
('RLadder (pre-trained)', [21.703, 20.660, 19.674, 18.942, 18.291, 17.764, 17.291, 16.884, 16.531, 16.212], 'g--'), | ||
], [ | ||
('Baseline', [0.15520, 0.17771, 0.19192, 0.20677, 0.21422, 0.22155, 0.22383, 0.22647, 0.22637, 0.22770], 'k-'), | ||
('RLadder', [0.13797, 0.13776, 0.13783, 0.13785, 0.13780, 0.13777, 0.13789, 0.13799, 0.13802, 0.13791], 'g-'), | ||
('Prednet', [0.11971, 0.16172, 0.15431, 0.14562, 0.14292, 0.13912, 0.13945, 0.13909, 0.13920, 0.13899], 'r-'), | ||
('Srivastava', [0.05095, 0.05916, 0.06735, 0.07426, 0.08072, 0.08661, 0.09239, 0.09707, 0.10150, 0.10544], 'y-'), | ||
('Mathieu', [0.1601, 0.2268, 0.2835, 0.3486, 0.3765, 0.4050, 0.4171, 0.4240, 0.4273, 0.4334], 'c-'), | ||
('Villegas', [0.13905, 0.13885, 0.13891, 0.13894, 0.13889, 0.13886, 0.13898, 0.13908, 0.13910, 0.13899], 'm-'), | ||
('fRNN', [0.02375, 0.02854, 0.03336, 0.03762, 0.04180, 0.04612, 0.05047, 0.05444, 0.05871, 0.06275], 'b-'), | ||
('RLadder (pre-trained)', [0.03779, 0.04691, 0.05734, 0.06629, 0.07471, 0.08238, 0.08952, 0.09586, 0.10140, 0.10631], 'g--'), | ||
]), ([ | ||
('Baseline', [0.00103, 0.00204, 0.00280, 0.00338, 0.00383, 0.00420, 0.00450, 0.00475, 0.00497, 0.00515], 'k-'), | ||
('RLadder', [0.00080, 0.00064, 0.00087, 0.00110, 0.00132, 0.00151, 0.00169, 0.00184, 0.00199, 0.00213], 'g-'), | ||
('Prednet', [0.00144, 0.00447, 0.00361, 0.00673, 0.00580, 0.00907, 0.00856, 0.01218, 0.01237, 0.01645], 'r-'), | ||
('Srivastava', [0.00839, 0.00852, 0.00893, 0.00940, 0.00983, 0.01024, 0.01061, 0.01093, 0.01121, 0.01145], 'y-'), | ||
('Mathieu', [0.0006567, 0.0010211, 0.0012615, 0.0014319, 0.0016421, 0.0017737, 0.0019078, 0.0021111, 0.0021855, 0.0023709], 'c-'), | ||
('Villegas', [0.00030, 0.00063, 0.00098, 0.00132, 0.00161, 0.00189, 0.00214, 0.00234, 0.00254, 0.00274], 'm-'), | ||
('fRNN', [0.00074, 0.00097, 0.00122, 0.00147, 0.00170, 0.00190, 0.00210, 0.00228, 0.00246, 0.00262], 'b-'), | ||
('RLadder (pre-trained)', [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'g--'), | ||
], [ | ||
('Baseline', [34.780, 31.774, 30.223, 29.233, 28.515, 27.973, 27.535, 27.180, 26.878, 26.619], 'k-'), | ||
('RLadder', [31.737, 34.142, 33.018, 32.109, 31.393, 30.834, 30.375, 30.000, 29.677, 29.397], 'g-'), | ||
('Prednet', [31.360, 26.974, 27.532, 24.952, 25.300, 23.329, 23.340, 21.776, 21.508, 20.280], 'r-'), | ||
('Srivastava', [21.974, 21.922, 21.691, 21.439, 21.234, 21.048, 20.892, 20.762, 20.653, 20.559], 'y-'), | ||
('Mathieu', [33.1342, 31.8160, 31.2525, 30.7705, 30.4912, 29.9523, 29.6754, 29.3361, 29.1516, 28.8458], 'c-'), | ||
('Villegas', [37.575, 34.621, 32.709, 31.430, 30.401, 29.575, 28.940, 28.509, 28.061, 27.640], 'm-'), | ||
('fRNN', [32.044, 31.106, 30.320, 29.683, 29.165, 28.749, 28.400, 28.097, 27.829, 27.596], 'b-'), | ||
('RLadder (pre-trained)', [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'g--'), | ||
], [ | ||
('Baseline', [0.02873, 0.04799, 0.06156, 0.07211, 0.08091, 0.08836, 0.09483, 0.10040, 0.10536, 0.10978], 'k-'), | ||
('RLadder', [0.03249, 0.03745, 0.04533, 0.05268, 0.05916, 0.06473, 0.06965, 0.07395, 0.07780, 0.08125], 'g-'), | ||
('Prednet', [0.04704, 0.09218, 0.08831, 0.12188, 0.12168, 0.15040, 0.15538, 0.18037, 0.19019, 0.21140], 'r-'), | ||
('Srivastava', [0.18878, 0.18911, 0.19203, 0.19530, 0.19809, 0.20076, 0.20307, 0.20497, 0.20655, 0.20788], 'y-'), | ||
('Mathieu', [0.0656, 0.0774, 0.0851, 0.0947, 0.0994, 0.1082, 0.1093, 0.1200, 0.1206, 0.1298], 'c-'), | ||
('Villegas', [0.01778, 0.03261, 0.04741, 0.06162, 0.07656, 0.09009, 0.09973, 0.10550, 0.11346, 0.12094], 'm-'), | ||
('fRNN', [0.04057, 0.05004, 0.05858, 0.06605, 0.07262, 0.07830, 0.08335, 0.08787, 0.09200, 0.09571], 'b-'), | ||
('RLadder (pre-trained)', [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'g--'), | ||
]), ([ | ||
('Baseline', [0.00412, 0.00751, 0.00992, 0.01179, 0.01332, 0.01456, 0.01566, 0.01665, 0.01753, 0.01830], 'k-'), | ||
('RLadder', [0.00365, 0.00516, 0.00674, 0.00811, 0.00925, 0.01022, 0.01108, 0.01185, 0.01254, 0.01316], 'g-'), | ||
('Prednet', [0.00274, 0.00878, 0.00874, 0.01523, 0.01589, 0.02313, 0.02436, 0.03307, 0.03527, 0.04516], 'r-'), | ||
('Srivastava', [0.00908, 0.05399, 0.11943, 0.16735, 0.18014, 0.17885, 0.18194, 0.19184, 0.19989, 0.20404], 'y-'), | ||
('Mathieu', [0.00646, 0.00708, 0.00869, 0.00875, 0.00861, 0.01042, 0.01210, 0.01252, 0.01475, 0.01773], 'c-'), | ||
('Villegas', [0.00268, 0.00482, 0.00655, 0.00812, 0.00940, 0.01040, 0.01150, 0.01261, 0.01373, 0.01443], 'm-'), | ||
('fRNN', [0.00274, 0.00481, 0.00652, 0.00795, 0.00920, 0.01029, 0.01122, 0.01204, 0.01273, 0.01334], 'b-'), | ||
('RLadder (pre-trained)', [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'g--'), | ||
], [ | ||
('Baseline', [28.993, 25.335, 23.812, 22.874, 22.230, 21.765, 21.362, 21.007, 20.722, 20.486], 'k-'), | ||
('RLadder', [26.332, 25.979, 24.851, 23.992, 23.373, 22.893, 22.490, 22.154, 21.876, 21.641], 'g-'), | ||
('Prednet', [29.537, 23.693, 23.612, 20.923, 20.554, 18.746, 18.403, 16.859, 16.477, 15.180], 'r-'), | ||
('Srivastava', [21.077, 13.125, 9.876 , 8.515 , 8.212 , 8.212 , 8.094 , 7.849 , 7.670 , 7.577], 'y-'), | ||
('Mathieu', [22.634, 22.406, 21.488, 21.674, 22.192, 21.262, 20.591, 20.699, 19.653, 18.804], 'c-'), | ||
('Villegas', [29.389, 26.389, 24.759, 23.765, 22.959, 22.440, 21.854, 21.401, 20.940, 20.671], 'm-'), | ||
('fRNN', [28.942, 26.411, 25.011, 24.086, 23.402, 22.878, 22.453, 22.111, 21.830, 21.593], 'b-'), | ||
('RLadder (pre-trained)', [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'g--'), | ||
], [ | ||
('Baseline', [0.06650, 0.10611, 0.12946, 0.14549, 0.15735, 0.16643, 0.17415, 0.18086, 0.18654, 0.19142], 'k-'), | ||
('RLadder', [0.06874, 0.09301, 0.11197, 0.12665, 0.13804, 0.14714, 0.15475, 0.16118, 0.16664, 0.17136], 'g-'), | ||
('Prednet', [0.05345, 0.12480, 0.12507, 0.17436, 0.17825, 0.21586, 0.21961, 0.25540, 0.26117, 0.29331], 'r-'), | ||
('Srivastava', [0.13221, 0.34915, 0.41874, 0.46370, 0.47410, 0.47660, 0.47985, 0.48424, 0.48746, 0.48940], 'y-'), | ||
('Mathieu', [0.0905, 0.0943, 0.1031, 0.1036, 0.1010, 0.1059, 0.1126, 0.1157, 0.1252, 0.1429], 'c-'), | ||
('Villegas', [0.05502, 0.09051, 0.11414, 0.13304, 0.14620, 0.15725, 0.16712, 0.17599, 0.18492, 0.19083], 'm-'), | ||
('fRNN', [0.05446, 0.08526, 0.10694, 0.12300, 0.13569, 0.14580, 0.15421, 0.16115, 0.16700, 0.17195], 'b-'), | ||
('RLadder (pre-trained)', [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'g--'), | ||
])], limits=( | ||
None, | ||
[[0, 0.006], [20, 38], [0.01, 0.12]], | ||
[[0.0025, 0.02], [20, 30], [0.05, 0.20]], | ||
)) |
Oops, something went wrong.