Skip to content

Commit

Permalink
Addressed code review results (2nd round)
Browse files Browse the repository at this point in the history
  • Loading branch information
ihavnoid committed Jul 25, 2018
1 parent c2ab6a9 commit 23d2feb
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
6 changes: 5 additions & 1 deletion src/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,9 @@ T relative_difference(const T a, const T b) {
return fabs(fa - fb) / std::min(fa, fb);
}

#endif

#ifdef USE_OPENCL_SELFCHECK
void Network::compare_net_outputs(Netresult& data,
Netresult& ref) {
// We accept an error up to 20%, but output values
Expand Down Expand Up @@ -540,7 +543,7 @@ void Network::compare_net_outputs(Netresult& data,
LOCK(m_selfcheck_mutex, selfcheck_lock);
if (selfcheck_fail) {
m_selfcheck_fails.push_back(true);
if (std::count(m_selfcheck_fails.begin(), m_selfcheck_fails.end(), true) >= max_failures) {
if (std::count(begin(m_selfcheck_fails), end(m_selfcheck_fails), true) >= max_failures) {
printf("Error in OpenCL calculation: Update your GPU drivers or reduce the amount of games "
"played simultaneously.\n");
throw std::runtime_error("OpenCL self-check mismatch.");
Expand Down Expand Up @@ -681,6 +684,7 @@ Network::Netresult Network::get_output_internal(
m_forward->forward(input_data, policy_data, value_data);
}
#else
m_forward->forward(input_data, policy_data, value_data);
(void) selfcheck;
#endif

Expand Down
2 changes: 1 addition & 1 deletion src/Network.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ class Network {
std::vector<float>::iterator black,
std::vector<float>::iterator white,
const int symmetry);
void compare_net_outputs(Netresult& data, Netresult& ref);
bool probe_cache(const GameState* const state, Network::Netresult& result);
std::unique_ptr<ForwardPipe> m_forward;
#ifdef USE_OPENCL_SELFCHECK
void compare_net_outputs(Netresult& data, Netresult& ref);
std::unique_ptr<ForwardPipe> m_forward_cpu;

// records the result of most recent selfchecks
Expand Down
32 changes: 23 additions & 9 deletions src/Tuner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,26 @@
#endif

const auto TUNER_FILE_LOCAL = std::string("leelaz_opencl_tuning");

template <typename net_t> static std::string getTunerKernel();
template <typename net_t> static float getTunerMaxError();

template <> std::string getTunerKernel<float>() {
return std::string("XgemmBatched");
}

template <> float getTunerMaxError<float>() {
return 1e-4f;
}

#ifdef USE_HALF
const auto TUNER_KERNEL = std::string("XgemmBatchedHalf");
constexpr auto MAX_ERROR = 1e-2f;
#else
const auto TUNER_KERNEL = std::string("XgemmBatched");
constexpr auto MAX_ERROR = 1e-4f;
template <> std::string getTunerKernel<half_float::half>() {
return std::string("XgemmBatchedHalf");
}

template <> float getTunerMaxError<half_float::half>() {
return 5e-2f;
}
#endif

using namespace Utils;
Expand Down Expand Up @@ -406,11 +420,11 @@ std::string Tuner<net_t>::tune_sgemm(const int m, const int n, const int k,
sum += elapsed;
} catch (const cl::Error&) {
// Failed to enqueue kernel. Set error to max.
max_error = MAX_ERROR;
max_error = getTunerMaxError<net_t>();
break;
}
}
if (max_error < MAX_ERROR && (best_time == 0 || sum < best_time)) {
if (max_error < getTunerMaxError<net_t>() && (best_time == 0 || sum < best_time)) {
auto param_str = parameters_to_string(p);
auto kernel_ms = 1e-6f * (sum / runs);
// Timing is in nanoseconds (10^-9), Giga = 10^9, so this works out
Expand Down Expand Up @@ -450,7 +464,7 @@ void Tuner<net_t>::store_sgemm_tuners(const int m, const int n, const int k,
tuning_params << m << ";" << n << ";" << k << ";" << batch_size;

auto tuning_line_prefix = std::to_string(TUNER_VERSION) + ";"
+ TUNER_KERNEL + ";" + tuning_params.str() + ";";
+ getTunerKernel<net_t>() + ";" + tuning_params.str() + ";";
auto tuning_line = tuning_line_prefix + tuners + ";" + device_name;

// Write back previous data as long as it's not the device and
Expand Down Expand Up @@ -492,7 +506,7 @@ std::string Tuner<net_t>::sgemm_tuners_from_line(std::string line,
return "";
}

if (s[1] != TUNER_KERNEL) {
if (s[1] != getTunerKernel<net_t>()) {
return "";
}

Expand Down

0 comments on commit 23d2feb

Please sign in to comment.