Skip to content

Commit

Permalink
Merge branch 'sabercrombie/merged_kernels_int16' into 'master'
Browse files Browse the repository at this point in the history
Merge Metal forward scan/add_softmax kernels and switch to int16 output,...

See merge request machine-learning/dorado!744
  • Loading branch information
StuartAbercrombie committed Dec 1, 2023
2 parents e9f060c + 8ceed02 commit a7fa371
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 102 deletions.
63 changes: 39 additions & 24 deletions dorado/decode/beam_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,10 @@ std::tuple<std::string, std::string> generate_sequence(const std::vector<uint8_t
for (size_t i = 0; i < seqLen; ++i) {
sequence[i] = alphabet[int(sequence[i])];
baseProbs[i] = 1.0f - (baseProbs[i] / totalProbs[i]);
baseProbs[i] = -10.0f * log10f(baseProbs[i]);
baseProbs[i] = -10.0f * std::log10(baseProbs[i]);
float qscore = baseProbs[i] * scale + shift;
qscore = std::min(50.0f, qscore);
qscore = std::max(1.0f, qscore);
qstring[i] = char(33.5f + qscore);
qscore = std::clamp(qscore, 1.0f, 50.0f);
qstring[i] = static_cast<char>(33.5f + qscore);
}

return make_tuple(sequence, qstring);
Expand All @@ -118,11 +117,11 @@ uint32_t crc32c(uint32_t crc, uint32_t new_bits) {

} // anonymous namespace

template <typename T>
template <typename T, typename U>
float beam_search(const T* const scores,
size_t scores_block_stride,
const float* const back_guide,
const float* const posts,
const U* const posts,
int num_state_bits,
size_t num_blocks,
size_t max_beam_width,
Expand All @@ -131,7 +130,8 @@ float beam_search(const T* const scores,
std::vector<int32_t>& states,
std::vector<uint8_t>& moves,
std::vector<float>& qual_data,
float score_scale) {
float score_scale,
float posts_scale) {
const size_t num_states = 1ull << num_state_bits;
const auto states_mask = static_cast<state_t>(num_states - 1);

Expand Down Expand Up @@ -459,9 +459,12 @@ float beam_search(const T* const scores,

// Compute a probability for this block, based on the path kmer. See the following explanation:
// https://git.oxfordnanolabs.local/machine-learning/notebooks/-/blob/master/bonito-basecaller-qscores.ipynb
const float* timestep_posts = posts + ((block_idx + 1) << num_state_bits);
const U* const timestep_posts = posts + ((block_idx + 1) << num_state_bits);
const auto fetch_post = [timestep_posts, posts_scale](size_t idx) {
return static_cast<float>(timestep_posts[idx]) * posts_scale;
};

float block_prob = float(timestep_posts[state]);
float block_prob = fetch_post(state);

// Get indices of left- and right-shifted kmers
int l_shift_idx = state >> NUM_BASE_BITS;
Expand Down Expand Up @@ -492,15 +495,15 @@ float beam_search(const T* const scores,
}
}
if (count_state) {
block_prob += float(timestep_posts[candidate_state]);
block_prob += fetch_post(candidate_state);
}
}

block_prob = std::clamp(block_prob, 0.0f, 1.0f);
block_prob = std::pow(block_prob, 0.4f); // Power fudge factor

// Calculate a placeholder qscore for the "wrong" bases
float wrong_base_prob = (1.0f - block_prob) / 3.0f;
const float wrong_base_prob = (1.0f - block_prob) / 3.0f;

for (size_t base = 0; base < NUM_BASES; base++) {
qual_data[block_idx * NUM_BASES + base] =
Expand Down Expand Up @@ -528,12 +531,9 @@ std::tuple<std::string, std::string, std::vector<uint8_t>> beam_search_decode(
throw std::runtime_error("num_states must be an integral power of 2");
}

// Posterior probabilities and back guides must be floats regardless of scores type.
if (posts_t.dtype() != at::ScalarType::Float ||
back_guides_t.dtype() != at::ScalarType::Float) {
throw std::runtime_error(
"beam_search_decode: mismatched tensor types provided for posts and "
"guides");
// Back guides must be floats regardless of scores type.
if (back_guides_t.dtype() != at::ScalarType::Float) {
throw std::runtime_error("beam_search_decode: back guides type must be float");
}

// back guides and posts should be contiguous
Expand All @@ -548,20 +548,35 @@ std::tuple<std::string, std::string, std::vector<uint8_t>> beam_search_decode(

const size_t scores_block_stride = scores_block_contig.stride(0);
if (scores_t.dtype() == at::ScalarType::Float) {
// If the scores are floats, so must the other tensors.
if (posts_t.dtype() != at::ScalarType::Float) {
throw std::runtime_error(
"beam_search_decode: only float posts are supported for float scores");
}

const auto scores = scores_block_contig.data_ptr<float>();
const auto back_guides = back_guides_contig->data_ptr<float>();
const auto posts = posts_contig->data_ptr<float>();

beam_search<float>(scores, scores_block_stride, back_guides, posts, num_state_bits,
num_blocks, max_beam_width, beam_cut, fixed_stay_score, states, moves,
qual_data, 1.0f);
beam_search<float, float>(scores, scores_block_stride, back_guides, posts, num_state_bits,
num_blocks, max_beam_width, beam_cut, fixed_stay_score, states,
moves, qual_data, 1.0f, 1.0f);
} else if (scores_t.dtype() == at::kChar) {
// If the scores are 8 bit, the posterior probabilities must be 16 bit (Apple path).
if (posts_t.dtype() != at::ScalarType::Short) {
throw std::runtime_error(
"beam_search_decode: only int16 posts are supported for int8 scores");
}

const auto scores = scores_block_contig.data_ptr<int8_t>();
const auto back_guides = back_guides_contig->data_ptr<float>();
const auto posts = posts_contig->data_ptr<float>();
beam_search<int8_t>(scores, scores_block_stride, back_guides, posts, num_state_bits,
num_blocks, max_beam_width, beam_cut, fixed_stay_score, states, moves,
qual_data, byte_score_scale);
const auto posts = posts_contig->data_ptr<int16_t>();
const float posts_scale = static_cast<float>(1.0 / 32767.0);
beam_search<int8_t, int16_t>(scores, scores_block_stride, back_guides, posts,
num_state_bits, num_blocks, max_beam_width, beam_cut,
fixed_stay_score, states, moves, qual_data, byte_score_scale,
posts_scale);

} else {
throw std::runtime_error(std::string("beam_search_decode: unsupported tensor type ") +
std::string(scores_t.dtype().name()));
Expand Down
39 changes: 17 additions & 22 deletions dorado/nn/MetalCRFModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,16 +673,18 @@ class MetalCaller {

m_decode_complete_event = NS::TransferPtr(m_device->newSharedEvent());
m_bwd_scan_cps = make_cps(m_device.get(), "backward_scan", {}, std::nullopt);
m_fwd_scan_cps = make_cps(m_device.get(), "forward_scan", {}, std::nullopt);
m_add_softmax_cps = make_cps(m_device.get(), "add_softmax", {}, std::nullopt);
m_fwd_scan_add_softmax_cps =
make_cps(m_device.get(), "forward_scan_add_softmax", {}, std::nullopt);

int T = m_out_chunk_size;
int C = model_config.outsize;
int Cs = m_states;

for (int i = 0; i < m_out_split; ++i) {
m_scores_int8.push_back(torch::empty({T, m_out_batch_size, C}, torch::kInt8));
m_posts.push_back(torch::empty({m_out_batch_size, T + 1, Cs}));
// Unfortunately torch doesn't have Uint16, or we would use it. We could offset,
// or rely on undefined overflow behaviour, but for now we waste the sign bit.
m_posts_int16.push_back(torch::empty({m_out_batch_size, T + 1, Cs}, torch::kInt16));
m_bwd.push_back(torch::empty({m_out_batch_size, T + 1, Cs}));
}

Expand Down Expand Up @@ -806,30 +808,23 @@ class MetalCaller {
continue;
}

// The same buffer is used for the forward scan results and the output of
// m_add_softmax_cps.
auto &fwd = m_posts;
// This stage is operating on the split outputs of the linear layer, so
// the effective batch size is m_out_batch_size.
std::vector<int32_t> scan_args_{m_out_chunk_size, m_out_batch_size, m_states};
auto scan_args = create_vec_buffer(m_device.get(), scan_args_);

for (int i = 0; i < m_out_split; ++i) {
// TODO: optimise grid size
launch_kernel_no_wait(m_fwd_scan_cps.get(), cb,
{scan_args.get(), mtl_for_tensor(m_scores_int8.at(i)),
mtl_for_tensor(fwd.at(i))},
{}, m_out_batch_size, m_states);

launch_kernel_no_wait(m_bwd_scan_cps.get(), cb,
{scan_args.get(), mtl_for_tensor(m_scores_int8.at(i)),
mtl_for_tensor(m_bwd.at(i))},
{}, m_out_batch_size, m_states);

launch_kernel_no_wait(m_add_softmax_cps.get(), cb,
{scan_args.get(), mtl_for_tensor(fwd.at(i)),
mtl_for_tensor(m_bwd.at(i))},
{}, m_out_batch_size, m_states);
launch_kernel_no_wait(
m_fwd_scan_add_softmax_cps.get(), cb,
{scan_args.get(), mtl_for_tensor(m_scores_int8.at(i)),
mtl_for_tensor(m_bwd.at(i)), mtl_for_tensor(m_posts_int16.at(i))},
{}, m_out_batch_size, m_states);
}
if (finishCommandBuffer("linear/scan/softmax", cb, try_count)) {
cb_success = true;
Expand Down Expand Up @@ -876,16 +871,16 @@ class MetalCaller {
// Model outputs are split across m_out_split buffers.
assert(m_scores_int8.size() == static_cast<size_t>(m_out_split));
assert(m_bwd.size() == static_cast<size_t>(m_out_split));
assert(m_posts.size() == static_cast<size_t>(m_out_split));
assert(m_posts_int16.size() == static_cast<size_t>(m_out_split));
const int out_buf_idx = chunk_idx / m_out_batch_size;
const int buf_chunk_idx = chunk_idx % m_out_batch_size;

auto [sequence, qstring, moves] = beam_search_decode(
m_scores_int8.at(out_buf_idx).index({Slice(), buf_chunk_idx}),
m_bwd.at(out_buf_idx)[buf_chunk_idx], m_posts.at(out_buf_idx)[buf_chunk_idx],
m_decoder_options.beam_width, m_decoder_options.beam_cut,
m_decoder_options.blank_score, m_decoder_options.q_shift,
m_decoder_options.q_scale, score_scale);
m_bwd.at(out_buf_idx)[buf_chunk_idx],
m_posts_int16.at(out_buf_idx)[buf_chunk_idx], m_decoder_options.beam_width,
m_decoder_options.beam_cut, m_decoder_options.blank_score,
m_decoder_options.q_shift, m_decoder_options.q_scale, score_scale);

(*task->out_chunks)[chunk_idx] = DecodedChunk{sequence, qstring, moves};

Expand Down Expand Up @@ -942,10 +937,10 @@ class MetalCaller {
DecoderOptions m_decoder_options;
nn::MetalModel m_model{nullptr};
NS::SharedPtr<MTL::Device> m_device;
NS::SharedPtr<MTL::ComputePipelineState> m_bwd_scan_cps, m_fwd_scan_cps, m_add_softmax_cps;
NS::SharedPtr<MTL::ComputePipelineState> m_bwd_scan_cps, m_fwd_scan_add_softmax_cps;
// Used to signal completion of an NNTask's decoding.
NS::SharedPtr<MTL::SharedEvent> m_decode_complete_event;
std::vector<at::Tensor> m_scores_int8, m_posts, m_bwd;
std::vector<at::Tensor> m_scores_int8, m_posts_int16, m_bwd;
int m_in_chunk_size, m_out_chunk_size, m_batch_size, m_states;
// Number of pieces the linear output is split into, for reasons of
// buffer size constraints.
Expand Down
135 changes: 79 additions & 56 deletions dorado/nn/metal/nn.metal
Original file line number Diff line number Diff line change
Expand Up @@ -126,86 +126,109 @@ kernel void backward_scan(
}
}

kernel void forward_scan(
// Performs the forward scan, writing out posterior probabilities as it goes.
// Forward scan results exist only transiently in threadgroup memory.
kernel void forward_scan_add_softmax(
device const ScanArgs* const args,
device const int8_t* const scores_in,
device ftype_out* const out,
device const ftype_out* const bwd,
device int16_t* const post_int16,
KERNEL_INDEX_INPUTS)
{
constexpr int kNumBases = 4;
constexpr int kNumTransitions = kNumBases + 1;
constexpr float kFixedStayScore = 2.0f;

const int T = args->T;
const int N = args->N;
const int num_states = args->C;
const int ts_states = num_states * kNumBases;
const int T = args->T + 1; // Time steps over which we iterate.
const int N = args->N; // Batch size.
const int num_states = args->C; // kmer state space size
const int chunk = gid; // Batch element index.
const int kMsb = num_states / kNumBases;
const int chunk = gid;
const int ts_states = num_states * kNumBases;

// This batch element's scores.
device const int8_t* const chunk_scores = scores_in + chunk * ts_states;

// TG buffers used to reduce max/sum across SIMD groups.
constexpr int kMaxSIMDGroups = 32;
threadgroup float sg_max_vals[kMaxSIMDGroups], sg_sums[kMaxSIMDGroups];

// Alternating forward guide buffers used for successive time steps.
constexpr int kMaxStates = 1024;
threadgroup float ts_fwd[2][kMaxStates];

// The forward guide input for the first step is 0.
ts_fwd[0][tid] = 0.0f;
threadgroup_barrier(mem_flags::mem_threadgroup);

device const int8_t* const chunk_in = scores_in + chunk * ts_states;
device ftype_out* const chunk_out = out + chunk * (T+1) * num_states;
device ftype_out* const alpha_init = chunk_out;
for (int c = tid; c < num_states; c += threads) {
alpha_init[c] = 0.0f;
}
for (int ts = 0; ts < T; ++ts) {
threadgroup_barrier(mem_flags::mem_device);
device const auto* const ts_in = chunk_in + N * ts_states * ts;
device ftype_out* const ts_alpha_in = alpha_init + num_states * ts;
device ftype_out* const ts_alpha_out = ts_alpha_in + num_states;
// We read forward guide values written to TG memory in the previous step as
// inputs to this step. However, there has already been a TG barrier since
// they were written.

// This time step's scores.
device const auto* const ts_scores = chunk_scores + N * ts_states * ts;

// Alternating TG buffer twiddling.
threadgroup const auto* const ts_alpha_in = ts_fwd[ts & 1];
threadgroup auto* const ts_alpha_out = ts_fwd[(ts & 1) ^ 1];

// Calculate the next time step's forward guide from this time step's scores
// and forward guide. It's written to threadgroup memory for use in the
// next iteration.
const int state = tid;
const int stay_state_idx = state;
const int step_state_idx_a = state / kNumBases;
const int step_trans_idx_a = state * kNumBases;

float vals[kNumTransitions];
float max_val = vals[0] = ts_alpha_in[stay_state_idx] + kFixedStayScore;
float fwd_max_val = vals[0] = ts_alpha_in[stay_state_idx] + kFixedStayScore;
for (int base = 0; base < kNumBases; ++base) {
vals[base + 1] = ts_alpha_in[step_state_idx_a + base * kMsb] +
ScaleByteScore(ts_in[step_trans_idx_a + base]);
max_val = max(max_val, vals[base + 1]);
ScaleByteScore(ts_scores[step_trans_idx_a + base]);
fwd_max_val = max(fwd_max_val, vals[base + 1]);
}
float sum = 0.0f;
float fwd_sum = 0.0f;
for (int i = 0; i < kNumTransitions; ++i) {
sum += exp(vals[i] - max_val);
fwd_sum += exp(vals[i] - fwd_max_val);
}
ts_alpha_out[tid] = max_val + log(sum);
}
}

kernel void add_softmax(
device const ScanArgs* const args,
device ftype_out* const fwd_post,
device const ftype_out* const bwd,
KERNEL_INDEX_INPUTS)
{
int T = args->T + 1;
int C = args->C;
int chunk = gid;
int simd_lane = tid & 31;

for (int ts = sid; ts < T; ts += simdgroups) {
int ts_idx = (chunk * T + ts) * C;
float max_val = -1e38;
for (int i = simd_lane; i < C; i += 32) {
float val = fwd_post[ts_idx + i] + bwd[ts_idx + i];
max_val = max(max_val, val);
fwd_post[ts_idx + i] = val;
}
max_val = simd_max(max_val);
float sum = 0;
for (int i = simd_lane; i < C; i += 32) {
float val = exp(fwd_post[ts_idx + i] - max_val);
sum += val;
fwd_post[ts_idx + i] = val;
ts_alpha_out[tid] = fwd_max_val + log(fwd_sum);

// Load the forward guide value calculated in the last time step for use
// in this time step's posterior probability calculation.
const float fwd_val = ts_alpha_in[tid];

// Calculate fwd/bwd guide product in log space.
const int ts_idx = (chunk * T + ts) * num_states;
const float val = fwd_val + bwd[ts_idx + tid];

// Determine max across this SIMD group, and write the result to
// a threadgroup array with an entry for each SIMD group.
sg_max_vals[sid] = simd_max(val);
threadgroup_barrier(mem_flags::mem_threadgroup);

// Find the max across all SIMD groups, and hence all entries
// for this time step.
float max_val = sg_max_vals[0];
for (uint i = 1; i < simdgroups; ++i) {
max_val = max(max_val, sg_max_vals[i]);
}
sum = simd_sum(sum);
float rcp_sum = 1.f / sum;
for (int i = simd_lane; i < C; i += 32) {
fwd_post[ts_idx + i] *= rcp_sum;

// Determine the sum of the exponentiated shifted log probabilities
// across this SIMD group, and write the result to a threadgroup array
// with an entry for each SIMD group.
const float exp_val = exp(val - max_val);
sg_sums[sid] = simd_sum(exp_val);
threadgroup_barrier(mem_flags::mem_threadgroup);

// Find the sum across all SIMD groups, and hence all entries
// for this time step.
float sum = sg_sums[0];
for (uint i = 1; i < simdgroups; ++i) {
sum += sg_sums[i];
}

// Write out the posterior probability, scaled to int16 range.
post_int16[ts_idx + tid] = static_cast<int16_t>(round(clamp(exp_val / sum, 0.0f, 1.0f) * 32767.0f));
}
}

Expand Down

0 comments on commit a7fa371

Please sign in to comment.