Skip to content

Commit

Permalink
Merge branch 'smalton/modbase-streams' into 'master'
Browse files Browse the repository at this point in the history
DOR-483: Modbase streams

Closes DOR-483

See merge request machine-learning/dorado!768
  • Loading branch information
malton-ont committed Dec 13, 2023
2 parents decb9e7 + a3fae3e commit 6ed81c5
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 22 deletions.
2 changes: 1 addition & 1 deletion dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ void setup(std::vector<std::string> args,
pipeline_desc, std::move(runners), std::move(remora_runners), overlap,
mean_qscore_start_pos, thread_allocations.scaler_node_threads,
true /* Enable read splitting */, thread_allocations.splitter_node_threads,
int(thread_allocations.remora_threads * num_devices), current_sink_node,
thread_allocations.remora_threads, current_sink_node,
PipelineDescriptor::InvalidNodeHandle);

// Create the Pipeline from our description.
Expand Down
47 changes: 28 additions & 19 deletions dorado/nn/ModBaseRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class ModBaseCaller {
at::Tensor out;
bool done{false};
int num_chunks;
#if DORADO_GPU_BUILD && !defined(__APPLE__)
c10::optional<c10::Stream> stream;
#endif
};

struct ModBaseData {
Expand All @@ -46,9 +49,6 @@ class ModBaseCaller {
std::deque<std::shared_ptr<ModBaseTask>> input_queue;
std::mutex input_lock;
std::condition_variable input_cv;
#if DORADO_GPU_BUILD && !defined(__APPLE__)
c10::optional<c10::Stream> stream;
#endif
const int batch_size;

std::vector<size_t> get_motif_hits(const std::string& seq) const {
Expand All @@ -70,8 +70,6 @@ class ModBaseCaller {

#if DORADO_GPU_BUILD && !defined(__APPLE__)
if (opts.device().is_cuda()) {
stream = c10::cuda::getStreamFromPool(false, opts.device().index());

auto sig_len = static_cast<int64_t>(params.context_before + params.context_after);
auto kmer_len = params.bases_after + params.bases_before + 1;

Expand Down Expand Up @@ -149,12 +147,13 @@ class ModBaseCaller {
int num_chunks) {
NVTX3_FUNC_RANGE();
auto& caller_data = m_caller_data[model_id];

#if DORADO_GPU_BUILD && !defined(__APPLE__)
c10::cuda::OptionalCUDAStreamGuard stream_guard(caller_data->stream);
#endif
auto task = std::make_shared<ModBaseTask>(input_sigs.to(m_options.device()),
input_seqs.to(m_options.device()), num_chunks);
#if DORADO_GPU_BUILD && !defined(__APPLE__)
if (m_options.device().is_cuda()) {
task->stream = c10::cuda::getCurrentCUDAStream(m_options.device().index());
}
#endif
{
std::lock_guard<std::mutex> lock(caller_data->input_lock);
caller_data->input_queue.push_front(task);
Expand All @@ -171,17 +170,10 @@ class ModBaseCaller {

void modbase_task_thread_fn(size_t model_id) {
auto& caller_data = m_caller_data[model_id];
#if DORADO_GPU_BUILD && !defined(__APPLE__)
const bool has_stream = caller_data->stream.has_value();
#endif
while (true) {
nvtx3::scoped_range loop{"modbase_task_thread_fn"};
at::InferenceMode guard;
#if DORADO_GPU_BUILD && !defined(__APPLE__)
// If caller_data->stream is set, sets the current stream to caller_data->stream, and the current device to
// the device associated with the stream. Resets both to their prior state on destruction
c10::cuda::OptionalCUDAStreamGuard stream_guard(caller_data->stream);
#endif

std::unique_lock<std::mutex> input_lock(caller_data->input_lock);
while (caller_data->input_queue.empty() && !m_terminate.load()) {
caller_data->input_cv.wait_for(input_lock, 100ms);
Expand All @@ -195,12 +187,18 @@ class ModBaseCaller {
caller_data->input_queue.pop_back();
input_lock.unlock();

#if DORADO_GPU_BUILD && !defined(__APPLE__)
// If task->stream is set, sets the current stream to task->stream, and the current device to
// the device associated with the stream. Resets both to their prior state on destruction
c10::cuda::OptionalCUDAStreamGuard stream_guard(task->stream);
#endif

std::unique_lock<std::mutex> task_lock(task->mut);
stats::Timer timer;
task->out = caller_data->module_holder->forward(task->input_sigs, task->input_seqs);
#if DORADO_GPU_BUILD && !defined(__APPLE__)
if (has_stream) {
caller_data->stream->synchronize();
if (task->stream.has_value()) {
task->stream->synchronize();
}
// Only meaningful if we're syncing the stream.
m_model_ms += timer.GetElapsedMS();
Expand Down Expand Up @@ -281,6 +279,14 @@ ModBaseRunner::ModBaseRunner(std::shared_ptr<ModBaseCaller> caller) : m_caller(s
m_input_seqs.push_back(torch::empty(
{caller_data->batch_size, sig_len, utils::BaseInfo::NUM_BASES * kmer_len},
seq_input_options));
#if DORADO_GPU_BUILD && !defined(__APPLE__)
if (m_caller->m_options.device().is_cuda()) {
m_streams.push_back(
c10::cuda::getStreamFromPool(false, m_caller->m_options.device().index()));
} else {
m_streams.emplace_back();
}
#endif
}
}

Expand Down Expand Up @@ -311,6 +317,9 @@ void ModBaseRunner::accept_chunk(int model_id,
}

at::Tensor ModBaseRunner::call_chunks(int model_id, int num_chunks) {
#if DORADO_GPU_BUILD && !defined(__APPLE__)
c10::cuda::OptionalCUDAStreamGuard guard(m_streams[model_id]);
#endif
return m_caller->call_chunks(model_id, m_input_sigs[model_id], m_input_seqs[model_id],
num_chunks);
}
Expand Down
5 changes: 4 additions & 1 deletion dorado/nn/ModBaseRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
#include "utils/stats.h"

#include <ATen/core/TensorBody.h>

#if DORADO_GPU_BUILD && !defined(__APPLE__)
#include <c10/cuda/CUDAStream.h>
#endif
#include <atomic>
#include <filesystem>
#include <string>
Expand Down Expand Up @@ -44,6 +46,7 @@ class ModBaseRunner {
std::shared_ptr<ModBaseCaller> m_caller;
std::vector<at::Tensor> m_input_sigs;
std::vector<at::Tensor> m_input_seqs;
std::vector<c10::optional<c10::Stream>> m_streams;

// Performance monitoring stats.
std::atomic<int64_t> m_num_batches_called = 0;
Expand Down
2 changes: 1 addition & 1 deletion dorado/utils/parameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ ThreadAllocations default_thread_allocations(int num_devices,
allocs.writer_threads = num_devices * 2;
allocs.read_converter_threads = num_devices * 2;
allocs.read_filter_threads = num_devices * 2;
allocs.remora_threads = num_remora_threads;
allocs.remora_threads = num_devices * num_remora_threads;
allocs.scaler_node_threads = num_devices * 4;
allocs.splitter_node_threads = num_devices;
allocs.loader_threads = num_devices;
Expand Down

0 comments on commit 6ed81c5

Please sign in to comment.