Skip to content

Commit

Permalink
Merge branch 'INSTX-3094-split-lstm' into 'master'
Browse files Browse the repository at this point in the history
[INSTX-3094] Split the LSTM kernels up into individual command buffers

Closes INSTX-3094

See merge request machine-learning/dorado!763
  • Loading branch information
blawrence-ont committed Dec 12, 2023
2 parents 364d15d + 083e26d commit 6d31793
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
18 changes: 12 additions & 6 deletions dorado/nn/MetalCRFModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ constexpr int kTileSize = 8;
constexpr int kSIMDGroupWidth = 32;

// Returns true on success.
bool finishCommandBuffer(const char *label, MTL::CommandBuffer *cb, int try_count) {
bool finishCommandBuffer(std::string_view label, MTL::CommandBuffer *cb, int try_count) {
cb->commit();
cb->waitUntilCompleted();

Expand Down Expand Up @@ -489,9 +489,12 @@ struct MetalBlockImpl : Module {
if (!finishCommandBuffer("convolutions", command_buffer, try_count)) {
return nullptr;
}
command_buffer = m_command_queue->commandBuffer();

std::string lstm_label = "lstm_rnn0";
for (auto &rnn : {rnn1, rnn2, rnn3, rnn4, rnn5}) {
lstm_label.back()++;
command_buffer = m_command_queue->commandBuffer();

const int kResBufSize = dtype_bytes * kernel_simd_groups * 2 * kTileSize * kTileSize;
const int kOutBufSize = dtype_bytes * kernel_simd_groups * kTileSize * kTileSize;
const std::vector<int> tg_buffer_lens{kResBufSize, kOutBufSize};
Expand All @@ -503,10 +506,12 @@ struct MetalBlockImpl : Module {
tg_buffer_lens, kernel_thread_groups,
kernel_simd_groups * kSIMDGroupWidth);
}

if (!finishCommandBuffer(lstm_label, command_buffer, try_count)) {
return nullptr;
}
}
if (!finishCommandBuffer("lstm", command_buffer, try_count)) {
return nullptr;
}

command_buffer = m_command_queue->commandBuffer();

// The output buffers of conv/LSTM layers are not used by the decoding, so
Expand Down Expand Up @@ -977,7 +982,8 @@ class MetalCaller {
m_decoder_options.beam_cut, m_decoder_options.blank_score,
m_decoder_options.q_shift, m_decoder_options.q_scale, score_scale);

(*task->out_chunks)[chunk_idx] = DecodedChunk{sequence, qstring, moves};
(*task->out_chunks)[chunk_idx] =
DecodedChunk{std::move(sequence), std::move(qstring), std::move(moves)};

// Wake the waiting thread which called `call_chunks()` if we're done decoding
std::unique_lock<std::mutex> task_lock(task->mut);
Expand Down
5 changes: 5 additions & 0 deletions dorado/utils/metal_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ void launch_kernel(ComputePipelineState *const pipeline,

command_buffer->commit();
command_buffer->waitUntilCompleted();

auto status = command_buffer->status();
if (status != MTL::CommandBufferStatusCompleted) {
spdlog::warn("Synchronous metal command buffer failed: {}", fmt::underlying(status));
}
}

void launch_kernel_no_wait(ComputePipelineState *const pipeline,
Expand Down

0 comments on commit 6d31793

Please sign in to comment.