Skip to content

Commit

Permalink
Replace if-else chain with switch statement.
Browse files Browse the repository at this point in the history
Pull request #1638.
  • Loading branch information
TFiFiE authored and gcp committed Jul 25, 2018
1 parent 057dbd1 commit c80015c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 33 deletions.
53 changes: 22 additions & 31 deletions src/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ void Network::benchmark(const GameState* const state, const int iterations) {
runcount.load(), elapsed, int(runcount.load() / elapsed));
}

void Network::process_bn_var(std::vector<float>& weights, const float epsilon) {
template<class container>
void process_bn_var(container& weights) {
constexpr float epsilon = 1e-5f;
for (auto&& w : weights) {
w = 1.0f / std::sqrt(w + epsilon);
}
Expand Down Expand Up @@ -206,39 +208,28 @@ std::pair<int, int> Network::load_v1_network(std::istream& wtfile) {
process_bn_var(weights);
m_batchnorm_stddevs.emplace_back(weights);
}
} else if (linecount == plain_conv_wts) {
m_conv_pol_w = std::move(weights);
} else if (linecount == plain_conv_wts + 1) {
m_conv_pol_b = std::move(weights);
} else if (linecount == plain_conv_wts + 2) {
std::copy(cbegin(weights), cend(weights), begin(m_bn_pol_w1));
} else if (linecount == plain_conv_wts + 3) {
process_bn_var(weights);
std::copy(cbegin(weights), cend(weights), begin(m_bn_pol_w2));
} else if (linecount == plain_conv_wts + 4) {
std::copy(cbegin(weights), cend(weights), begin(m_ip_pol_w));
} else if (linecount == plain_conv_wts + 5) {
std::copy(cbegin(weights), cend(weights), begin(m_ip_pol_b));
} else if (linecount == plain_conv_wts + 6) {
m_conv_val_w = std::move(weights);
} else if (linecount == plain_conv_wts + 7) {
m_conv_val_b = std::move(weights);
} else if (linecount == plain_conv_wts + 8) {
std::copy(cbegin(weights), cend(weights), begin(m_bn_val_w1));
} else if (linecount == plain_conv_wts + 9) {
process_bn_var(weights);
std::copy(cbegin(weights), cend(weights), begin(m_bn_val_w2));
} else if (linecount == plain_conv_wts + 10) {
std::copy(cbegin(weights), cend(weights), begin(m_ip1_val_w));
} else if (linecount == plain_conv_wts + 11) {
std::copy(cbegin(weights), cend(weights), begin(m_ip1_val_b));
} else if (linecount == plain_conv_wts + 12) {
std::copy(cbegin(weights), cend(weights), begin(m_ip2_val_w));
} else if (linecount == plain_conv_wts + 13) {
std::copy(cbegin(weights), cend(weights), begin(m_ip2_val_b));
} else {
switch (linecount - plain_conv_wts) {
case 0: m_conv_pol_w = std::move(weights); break;
case 1: m_conv_pol_b = std::move(weights); break;
case 2: std::copy(cbegin(weights), cend(weights), begin(m_bn_pol_w1)); break;
case 3: std::copy(cbegin(weights), cend(weights), begin(m_bn_pol_w2)); break;
case 4: std::copy(cbegin(weights), cend(weights), begin(m_ip_pol_w)); break;
case 5: std::copy(cbegin(weights), cend(weights), begin(m_ip_pol_b)); break;
case 6: m_conv_val_w = std::move(weights); break;
case 7: m_conv_val_b = std::move(weights); break;
case 8: std::copy(cbegin(weights), cend(weights), begin(m_bn_val_w1)); break;
case 9: std::copy(cbegin(weights), cend(weights), begin(m_bn_val_w2)); break;
case 10: std::copy(cbegin(weights), cend(weights), begin(m_ip1_val_w)); break;
case 11: std::copy(cbegin(weights), cend(weights), begin(m_ip1_val_b)); break;
case 12: std::copy(cbegin(weights), cend(weights), begin(m_ip2_val_w)); break;
case 13: std::copy(cbegin(weights), cend(weights), begin(m_ip2_val_b)); break;
}
}
linecount++;
}
process_bn_var(m_bn_pol_w2);
process_bn_var(m_bn_val_w2);

return {channels, static_cast<int>(residual_blocks)};
}
Expand Down
2 changes: 0 additions & 2 deletions src/Network.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ class Network {
private:
std::pair<int, int> load_v1_network(std::istream& wtfile);
std::pair<int, int> load_network_file(const std::string& filename);
static void process_bn_var(std::vector<float>& weights,
const float epsilon = 1e-5f);

static std::vector<float> winograd_transform_f(const std::vector<float>& f,
const int outputs, const int channels);
Expand Down

0 comments on commit c80015c

Please sign in to comment.