Skip to content

Commit

Permalink
Merge branch 'smalton/DOR-550-modbase-stream-fix' into 'master'
Browse files Browse the repository at this point in the history
DOR-550: Modbase streams fix

Closes DOR-550

See merge request machine-learning/dorado!826
  • Loading branch information
malton-ont committed Feb 2, 2024
2 parents c40ba61 + 6170326 commit ec106d6
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
1 change: 1 addition & 0 deletions dorado/modbase/ModBaseCaller.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class ModBaseCaller {

std::vector<at::Tensor> create_input_sig_tensors() const;
std::vector<at::Tensor> create_input_seq_tensors() const;
c10::Device device() const { return m_options.device(); }

at::Tensor call_chunks(size_t model_id,
at::Tensor& input_sigs,
Expand Down
12 changes: 6 additions & 6 deletions dorado/modbase/ModBaseRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

namespace {
#if DORADO_CUDA_BUILD
std::vector<c10::optional<c10::Stream>> get_streams_from_tensors(
const std::vector<at::Tensor>& tensors) {
std::vector<c10::optional<c10::Stream>> get_streams_from_caller(
const std::shared_ptr<dorado::modbase::ModBaseCaller>& caller) {
std::vector<c10::optional<c10::Stream>> streams;
for (const auto& tensor : tensors) {
if (tensor.device().is_cuda()) {
streams.push_back(c10::cuda::getStreamFromPool(false, tensor.device().index()));
for (size_t i = 0; i < caller->num_model_callers(); ++i) {
if (caller->device().is_cuda()) {
streams.push_back(c10::cuda::getStreamFromPool(false, caller->device().index()));
} else {
streams.emplace_back();
}
Expand All @@ -37,7 +37,7 @@ ModBaseRunner::ModBaseRunner(std::shared_ptr<ModBaseCaller> caller)
m_input_seqs(m_caller->create_input_seq_tensors())
#if DORADO_CUDA_BUILD
,
m_streams(get_streams_from_tensors(m_input_sigs))
m_streams(get_streams_from_caller(m_caller))
#endif
{
}
Expand Down

0 comments on commit ec106d6

Please sign in to comment.