Skip to content

Commit

Permalink
Fixed bug with dropout when error function is fuzed woth last activat…
Browse files Browse the repository at this point in the history
…ion function
  • Loading branch information
milakov committed Nov 29, 2014
1 parent 70a2bf5 commit 4535b3d
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 8 deletions.
5 changes: 2 additions & 3 deletions nnforge/cuda/network_updater_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,6 @@ namespace nnforge
cuda_safe_call(cudaEventQuery(data_processed_event));
}

random_generator gen = rnd::get_random_generator();
nnforge_uniform_int_distribution<unsigned int> dist(0, static_cast<unsigned int>(random_uniform_list.size() - 1));
unsigned int mask = static_cast<unsigned int>(random_uniform_list.size() - 1);
unsigned int entries_processed_count = 0;
Expand Down Expand Up @@ -673,8 +672,8 @@ namespace nnforge
std::vector<std::vector<cuda_linear_buffer_device_smart_ptr> >::reverse_iterator net_data_custom_it = net_data_custom.rbegin();
std::vector<std::vector<cuda_linear_buffer_device_smart_ptr> >::reverse_iterator gradient_it = gradient.rbegin();
std::vector<std::vector<const_cuda_linear_buffer_device_smart_ptr> >::reverse_iterator schema_data_it = updater_schema_data.rbegin();
unsigned int reverse_layer_id = static_cast<unsigned int>(updater_list.size() + testing_layer_count) - 1 - (error_function_fused_with_activation ? 1 : 0);
layer_configuration_specific_list::const_reverse_iterator layer_config_it = layer_config_list.rbegin() + 1;
unsigned int reverse_layer_id = static_cast<unsigned int>(updater_list.size() + testing_layer_count) - 1;
layer_configuration_specific_list::const_reverse_iterator layer_config_it = layer_config_list.rbegin() + (1 + (error_function_fused_with_activation ? 1 : 0));
for(std::vector<layer_updater_cuda_smart_ptr>::reverse_iterator it = updater_list.rbegin(); it != updater_list.rend(); ++it, ++input_and_all_buffers_pack_it, ++schema_data_it, ++gradient_it, ++output_errors_it, ++net_data_it, ++net_data_custom_it, --reverse_layer_id, ++layer_config_it)
{
(*it)->enqueue_update_weights(
Expand Down
5 changes: 5 additions & 0 deletions nnforge/network_updater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ namespace nnforge
{
}

void network_updater::set_random_generator_seed(int seed)
{
gen = rnd::get_random_generator(seed);
}

void network_updater::set_input_configuration_specific(const layer_configuration_specific& input_configuration_specific)
{
if ((layer_config_list.size() > 0) && (layer_config_list[0] == input_configuration_specific))
Expand Down
5 changes: 4 additions & 1 deletion nnforge/network_updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ namespace nnforge
// set_input_configuration_specific should be called prior to this method call for this method to succeed
float get_flops_for_single_entry() const;

void set_random_generator_seed(int seed);

protected:
network_updater(
network_schema_smart_ptr schema,
Expand Down Expand Up @@ -78,12 +80,13 @@ namespace nnforge
std::vector<float> random_uniform_list;
float flops;

random_generator gen;

private:
network_updater();
network_updater(const network_updater&);
network_updater& operator =(const network_updater&);

random_generator gen;
static const unsigned int random_list_bits;
};

Expand Down
5 changes: 4 additions & 1 deletion nnforge/neural_network_toolset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,9 @@ namespace nnforge
network_updater_smart_ptr updater = updater_factory->create(
schema,
get_error_function());
// !!! Seeding to constant number will not guarantee that updater will run determnistically:
// Different updaters might use different internal batch sizes
updater->set_random_generator_seed(12349087);

supervised_data_reader_smart_ptr training_data_reader = get_data_reader_for_training(true);
training_data_reader = supervised_data_reader_smart_ptr(new supervised_limited_entry_count_data_reader(training_data_reader, profile_updater_entry_count));
Expand Down Expand Up @@ -1526,7 +1529,7 @@ namespace nnforge
batch_size,
weight_decay,
momentum,
std::map<unsigned int, float>());
get_dropout_rate_map());
boost::chrono::duration<float> sec = boost::chrono::high_resolution_clock::now() - start;

/*
Expand Down
5 changes: 2 additions & 3 deletions nnforge/plain/network_updater_plain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ namespace nnforge
}
}

random_generator gen = rnd::get_random_generator();
nnforge_uniform_int_distribution<unsigned int> dist(0, static_cast<unsigned int>(random_uniform_list.size() - 1));
unsigned int mask = static_cast<unsigned int>(random_uniform_list.size() - 1);
bool entries_remained_for_loading = true;
Expand Down Expand Up @@ -410,7 +409,7 @@ namespace nnforge
layer_data_custom_list::const_reverse_iterator data_custom_it = data->data_custom_list.rbegin() + (error_function_fused_with_activation ? 1 : 0);
layer_data_list::reverse_iterator gradient_it = gradient->rbegin() + (error_function_fused_with_activation ? 1 : 0);
additional_buffer_smart_ptr output_errors = initial_error_buf;
unsigned int reverse_layer_id = static_cast<unsigned int>(updater_list.size() + testing_layer_count) - 1 - (error_function_fused_with_activation ? 1 : 0);
unsigned int reverse_layer_id = static_cast<unsigned int>(updater_list.size() + testing_layer_count) - 1;
for(std::vector<const_layer_updater_plain_smart_ptr>::const_reverse_iterator it = updater_list.rbegin(); it != updater_list.rend(); ++it, ++layer_it, ++input_config_it, ++updater_buffers_it, ++data_it, ++data_custom_it, ++gradient_it, --reverse_layer_id)
{
if (it != updater_list.rend() - 1)
Expand Down Expand Up @@ -680,7 +679,7 @@ namespace nnforge
float val = *(in_it + i);
unsigned int random_elem_id = (i + offset_in_random_list) & mask;
bool under_threshold = (*(rnd_it + random_elem_id) < dropout_rate);
val *= under_threshold ? 0.0F : scale;
val *= (under_threshold ? 0.0F : scale);
*(in_it + i) = val;
}
}
Expand Down

0 comments on commit 4535b3d

Please sign in to comment.