Skip to content

Commit

Permalink
Merge branch 'beam_search' into 'master'
Browse files Browse the repository at this point in the history
Minor improvements to CPU beam search

See merge request machine-learning/dorado!666
  • Loading branch information
StuartAbercrombie committed Oct 31, 2023
2 parents 2dc1f03 + 267aa02 commit 062e3fd
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 190 deletions.
2 changes: 0 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ set(LIB_SOURCE_FILES
dorado/demux/BarcodeClassifierSelector.cpp
dorado/demux/BarcodeClassifierSelector.h
dorado/decode/beam_search.cpp
dorado/decode/fast_hash.cpp
dorado/decode/fast_hash.h
dorado/decode/beam_search.h
dorado/decode/CPUDecoder.cpp
dorado/decode/CPUDecoder.h
Expand Down
3 changes: 1 addition & 2 deletions dorado/decode/CPUDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ std::vector<DecodedChunk> CPUDecoder::beam_search(const torch::Tensor& scores,
for (int i = 0; i < t_num_chunks; i++) {
auto decode_result = beam_search_decode(
t_scores[i], bwd[i], posts[i], options.beam_width, options.beam_cut,
options.blank_score, options.q_shift, options.q_scale,
options.temperature, 1.0f);
options.blank_score, options.q_shift, options.q_scale, 1.0f);
chunk_results[t_first_chunk + i] = DecodedChunk{
std::get<0>(decode_result),
std::get<1>(decode_result),
Expand Down
56 changes: 34 additions & 22 deletions dorado/decode/beam_search.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "beam_search.h"

#include "fast_hash.h"
#include "utils/simd.h"

#include <math.h>
Expand Down Expand Up @@ -33,15 +32,15 @@ struct BeamElement {
// This is the data we need to retain for only the previous timestep (block) in the beam
// (and what we construct for the new timestep)
struct BeamFrontElement {
uint64_t hash;
uint32_t hash;
state_t state;
uint8_t prev_element_index;
bool stay;
};

float log_sum_exp(float x, float y, float t) {
float abs_diff = std::abs(x - y) / t;
return std::max(x, y) + ((abs_diff < 17.0f) ? (std::log1p(std::exp(-abs_diff)) * t) : 0.0f);
float log_sum_exp(float x, float y) {
float abs_diff = std::abs(x - y);
return std::max(x, y) + ((abs_diff < 17.0f) ? (std::log1p(std::exp(-abs_diff))) : 0.0f);
}

int get_num_states(size_t num_trans_states) {
Expand Down Expand Up @@ -102,6 +101,22 @@ std::tuple<std::string, std::string> generate_sequence(const std::vector<uint8_t
return make_tuple(sequence, qstring);
}

// Incorporates NUM_NEW_BITS into a Castagnoli CRC32, aka CRC32C
// (not the same polynomial as CRC32 as used in zip/ethernet).
template <int NUM_NEW_BITS>
uint32_t crc32c(uint32_t crc, uint32_t new_bits) {
// Note that this is the reversed polynomial.
constexpr uint32_t POLYNOMIAL = 0x82f63b78u;
for (int i = 0; i < NUM_NEW_BITS; ++i) {
auto b = (new_bits ^ crc) & 1;
crc >>= 1;
if (b)
crc ^= POLYNOMIAL;
new_bits >>= 1;
}
return crc;
}

} // anonymous namespace

template <typename T>
Expand All @@ -117,7 +132,6 @@ float beam_search(const T* const scores,
std::vector<int32_t>& states,
std::vector<uint8_t>& moves,
std::vector<float>& qual_data,
float temperature,
float score_scale) {
const size_t num_states = 1 << num_state_bits;
const auto states_mask = static_cast<state_t>(num_states - 1);
Expand All @@ -127,9 +141,9 @@ float beam_search(const T* const scores,
}

// Some values we need
constexpr uint64_t HASH_SEED = 0x880355f21e6d1965ULL;
constexpr uint32_t CRC_SEED = 0x12345678u;
const float log_beam_cut =
(beam_cut > 0.0f) ? (temperature * logf(beam_cut)) : std::numeric_limits<float>::max();
(beam_cut > 0.0f) ? logf(beam_cut) : std::numeric_limits<float>::max();

// Create the beam. We need to keep beam_width elements for each block, plus the initial state
std::vector<BeamElement> beam_vector(max_beam_width * (num_blocks + 1));
Expand Down Expand Up @@ -163,7 +177,7 @@ float beam_search(const T* const scores,
state++) {
if (back_guide[state] >= beam_init_threshold) {
// Note that this first element has a prev_element_index of 0
prev_beam_front[beam_element] = {fasthash::chainfasthash64(HASH_SEED, state),
prev_beam_front[beam_element] = {crc32c<32>(CRC_SEED, state),
static_cast<state_t>(state), 0, false};
prev_scores[beam_element] = 0.0f;
++beam_element;
Expand Down Expand Up @@ -208,8 +222,8 @@ float beam_search(const T* const scores,
// Essentially a k=1 Bloom filter, indicating the presence of steps with particular
// sequence hashes. Avoids comparing stay hashes against all possible progenitor
// states where none of them has the requisite sequence hash.
const uint64_t HASH_PRESENT_BITS = 4096;
const uint64_t HASH_PRESENT_MASK = HASH_PRESENT_BITS - 1;
const uint32_t HASH_PRESENT_BITS = 4096;
const uint32_t HASH_PRESENT_MASK = HASH_PRESENT_BITS - 1;
std::bitset<HASH_PRESENT_BITS> step_hash_present; // Default constructor zeros content.

// Generate list of candidate elements for this timestep (block).
Expand All @@ -219,7 +233,7 @@ float beam_search(const T* const scores,
const auto& previous_element = prev_beam_front[prev_elem_idx];

// Expand all the possible steps
for (size_t new_base = 0; new_base < NUM_BASES; new_base++) {
for (int new_base = 0; new_base < NUM_BASES; new_base++) {
state_t new_state =
(state_t((previous_element.state << NUM_BASE_BITS) & states_mask) |
new_base);
Expand All @@ -228,7 +242,8 @@ float beam_search(const T* const scores,
(((previous_element.state << NUM_BASE_BITS) >> num_state_bits)));
float new_score = prev_scores[prev_elem_idx] + fetch_block_score(move_idx) +
static_cast<float>(block_back_scores[new_state]);
uint64_t new_hash = fasthash::chainfasthash64(previous_element.hash, new_state);
uint32_t new_hash = crc32c<NUM_BASE_BITS>(previous_element.hash, new_base);

step_hash_present[new_hash & HASH_PRESENT_MASK] = true;

// Add new element to the candidate list
Expand Down Expand Up @@ -266,18 +281,16 @@ float beam_search(const T* const scores,
current_beam_front[step_elem_idx].hash) {
if (current_scores[stay_elem_idx] > current_scores[step_elem_idx]) {
// Fold the step into the stay
const float folded_score =
log_sum_exp(current_scores[stay_elem_idx],
current_scores[step_elem_idx], temperature);
const float folded_score = log_sum_exp(current_scores[stay_elem_idx],
current_scores[step_elem_idx]);
current_scores[stay_elem_idx] = folded_score;
max_score = std::max(max_score, folded_score);
// The step element will end up last, sorted by score
current_scores[step_elem_idx] = std::numeric_limits<float>::lowest();
} else {
// Fold the stay into the step
const float folded_score =
log_sum_exp(current_scores[stay_elem_idx],
current_scores[step_elem_idx], temperature);
const float folded_score = log_sum_exp(current_scores[stay_elem_idx],
current_scores[step_elem_idx]);
current_scores[step_elem_idx] = folded_score;
max_score = std::max(max_score, folded_score);
// The stay element will end up last, sorted by score
Expand Down Expand Up @@ -508,7 +521,6 @@ std::tuple<std::string, std::string, std::vector<uint8_t>> beam_search_decode(
float fixed_stay_score,
float q_shift,
float q_scale,
float temperature,
float byte_score_scale) {
const int num_blocks = int(scores_t.size(0));
const int num_states = get_num_states(scores_t.size(1));
Expand Down Expand Up @@ -542,14 +554,14 @@ std::tuple<std::string, std::string, std::vector<uint8_t>> beam_search_decode(

beam_search<float>(scores, scores_block_stride, back_guides, posts, num_state_bits,
num_blocks, max_beam_width, beam_cut, fixed_stay_score, states, moves,
qual_data, temperature, 1.0f);
qual_data, 1.0f);
} else if (scores_t.dtype() == torch::kInt8) {
const auto scores = scores_block_contig.data_ptr<int8_t>();
const auto back_guides = back_guides_contig->data_ptr<float>();
const auto posts = posts_contig->data_ptr<float>();
beam_search<int8_t>(scores, scores_block_stride, back_guides, posts, num_state_bits,
num_blocks, max_beam_width, beam_cut, fixed_stay_score, states, moves,
qual_data, temperature, byte_score_scale);
qual_data, byte_score_scale);
} else {
throw std::runtime_error(std::string("beam_search_decode: unsupported tensor type ") +
std::string(scores_t.dtype().name()));
Expand Down
1 change: 0 additions & 1 deletion dorado/decode/beam_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@ std::tuple<std::string, std::string, std::vector<uint8_t>> beam_search_decode(
float fixed_stay_score,
float q_shift,
float q_scale,
float temperature,
float byte_score_scale);
79 changes: 0 additions & 79 deletions dorado/decode/fast_hash.cpp

This file was deleted.

83 changes: 0 additions & 83 deletions dorado/decode/fast_hash.h

This file was deleted.

2 changes: 1 addition & 1 deletion dorado/nn/MetalCRFModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ class MetalCaller {
m_bwd.at(out_buf_idx)[buf_chunk_idx], m_posts.at(out_buf_idx)[buf_chunk_idx],
m_decoder_options.beam_width, m_decoder_options.beam_cut,
m_decoder_options.blank_score, m_decoder_options.q_shift,
m_decoder_options.q_scale, m_decoder_options.temperature, score_scale);
m_decoder_options.q_scale, score_scale);

(*task->out_chunks)[chunk_idx] = DecodedChunk{sequence, qstring, moves};

Expand Down

0 comments on commit 062e3fd

Please sign in to comment.