Skip to content

Commit

Permalink
Merge branch 'sabercrombie/message_sink_input_threads2' into 'master'
Browse files Browse the repository at this point in the history
Consolidate pipeline node input thread handling

See merge request machine-learning/dorado!799
  • Loading branch information
StuartAbercrombie committed Jan 12, 2024
2 parents 4018823 + 31c9a59 commit 8dfd180
Show file tree
Hide file tree
Showing 52 changed files with 689 additions and 972 deletions.
75 changes: 40 additions & 35 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -187,59 +187,64 @@ set(LIB_SOURCE_FILES
dorado/api/runner_creation.h
dorado/api/pipeline_creation.cpp
dorado/api/pipeline_creation.h
dorado/read_pipeline/FakeDataLoader.cpp
dorado/read_pipeline/FakeDataLoader.h
dorado/read_pipeline/ReadPipeline.cpp
dorado/read_pipeline/ReadPipeline.h
dorado/read_pipeline/ClientInfo.h
dorado/read_pipeline/DefaultClientInfo.h
dorado/read_pipeline/ScalerNode.cpp
dorado/read_pipeline/ScalerNode.h
dorado/read_pipeline/StereoDuplexEncoderNode.cpp
dorado/read_pipeline/StereoDuplexEncoderNode.h
dorado/read_pipeline/AdapterDetectorNode.cpp
dorado/read_pipeline/AdapterDetectorNode.h
dorado/read_pipeline/AlignerNode.cpp
dorado/read_pipeline/AlignerNode.h
dorado/read_pipeline/BarcodeClassifierNode.cpp
dorado/read_pipeline/BarcodeClassifierNode.h
dorado/read_pipeline/BarcodeDemuxerNode.cpp
dorado/read_pipeline/BarcodeDemuxerNode.h
dorado/read_pipeline/BasecallerNode.cpp
dorado/read_pipeline/BasecallerNode.h
dorado/read_pipeline/ModBaseCallerNode.cpp
dorado/read_pipeline/ModBaseCallerNode.h
dorado/read_pipeline/ReadFilterNode.cpp
dorado/read_pipeline/ReadFilterNode.h
dorado/read_pipeline/ReadToBamTypeNode.cpp
dorado/read_pipeline/ReadToBamTypeNode.h
dorado/read_pipeline/SubreadTaggerNode.cpp
dorado/read_pipeline/SubreadTaggerNode.h
dorado/read_pipeline/BaseSpaceDuplexCallerNode.cpp
dorado/read_pipeline/BaseSpaceDuplexCallerNode.h
dorado/read_pipeline/AlignerNode.cpp
dorado/read_pipeline/AlignerNode.h
dorado/read_pipeline/ClientInfo.h
dorado/read_pipeline/DefaultClientInfo.h
dorado/read_pipeline/DuplexReadTaggingNode.cpp
dorado/read_pipeline/DuplexReadTaggingNode.h
dorado/read_pipeline/FakeDataLoader.cpp
dorado/read_pipeline/FakeDataLoader.h
dorado/read_pipeline/HtsReader.cpp
dorado/read_pipeline/HtsReader.h
dorado/read_pipeline/HtsWriter.cpp
dorado/read_pipeline/HtsWriter.h
dorado/read_pipeline/ProgressTracker.h
dorado/read_pipeline/ResumeLoaderNode.cpp
dorado/read_pipeline/ResumeLoaderNode.h
dorado/read_pipeline/DuplexReadTaggingNode.cpp
dorado/read_pipeline/DuplexReadTaggingNode.h
dorado/read_pipeline/BarcodeClassifierNode.cpp
dorado/read_pipeline/BarcodeClassifierNode.h
dorado/read_pipeline/BarcodeDemuxerNode.cpp
dorado/read_pipeline/BarcodeDemuxerNode.h
dorado/read_pipeline/AdapterDetectorNode.cpp
dorado/read_pipeline/AdapterDetectorNode.h
dorado/read_pipeline/MessageSink.cpp
dorado/read_pipeline/MessageSink.h
dorado/read_pipeline/ModBaseCallerNode.cpp
dorado/read_pipeline/ModBaseCallerNode.h
dorado/read_pipeline/NullNode.h
dorado/read_pipeline/NullNode.cpp
dorado/read_pipeline/PairingNode.cpp
dorado/read_pipeline/PairingNode.h
dorado/read_pipeline/PolyACalculator.cpp
dorado/read_pipeline/PolyACalculator.h
dorado/read_pipeline/PolyACalculatorNode.cpp
dorado/read_pipeline/PolyACalculatorNode.h
dorado/read_pipeline/ProgressTracker.h
dorado/read_pipeline/ReadFilterNode.cpp
dorado/read_pipeline/ReadFilterNode.h
dorado/read_pipeline/ReadPipeline.cpp
dorado/read_pipeline/ReadPipeline.h
dorado/read_pipeline/ReadSplitNode.cpp
dorado/read_pipeline/ReadSplitNode.h
dorado/read_pipeline/ReadToBamTypeNode.cpp
dorado/read_pipeline/ReadToBamTypeNode.h
dorado/read_pipeline/ResumeLoaderNode.cpp
dorado/read_pipeline/ResumeLoaderNode.h
dorado/read_pipeline/ScalerNode.cpp
dorado/read_pipeline/ScalerNode.h
dorado/read_pipeline/StereoDuplexEncoderNode.cpp
dorado/read_pipeline/StereoDuplexEncoderNode.h
dorado/read_pipeline/SubreadTaggerNode.cpp
dorado/read_pipeline/SubreadTaggerNode.h
dorado/read_pipeline/messages.cpp
dorado/read_pipeline/messages.h
dorado/read_pipeline/flush_options.h
dorado/read_pipeline/read_utils.cpp
dorado/read_pipeline/read_utils.h
dorado/read_pipeline/stereo_features.cpp
dorado/read_pipeline/stereo_features.h
dorado/read_pipeline/stitch.cpp
dorado/read_pipeline/stitch.h
dorado/read_pipeline/ReadSplitNode.cpp
dorado/read_pipeline/ReadSplitNode.h
dorado/splitter/DuplexReadSplitter.cpp
dorado/splitter/DuplexReadSplitter.h
dorado/splitter/RNAReadSplitter.cpp
Expand Down
6 changes: 3 additions & 3 deletions dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "read_pipeline/BarcodeClassifierNode.h"
#include "read_pipeline/HtsReader.h"
#include "read_pipeline/HtsWriter.h"
#include "read_pipeline/PolyACalculator.h"
#include "read_pipeline/PolyACalculatorNode.h"
#include "read_pipeline/ProgressTracker.h"
#include "read_pipeline/ReadFilterNode.h"
#include "read_pipeline/ReadToBamTypeNode.h"
Expand Down Expand Up @@ -157,11 +157,11 @@ void setup(std::vector<std::string> args,
thread_allocations.aligner_threads);
current_sink_node = aligner;
}
current_sink_node = pipeline_desc.add_node<ReadToBamType>(
current_sink_node = pipeline_desc.add_node<ReadToBamTypeNode>(
{current_sink_node}, emit_moves, thread_allocations.read_converter_threads,
methylation_threshold_pct, std::move(sample_sheet), 1000);
if (estimate_poly_a) {
current_sink_node = pipeline_desc.add_node<PolyACalculator>(
current_sink_node = pipeline_desc.add_node<PolyACalculatorNode>(
{current_sink_node}, std::thread::hardware_concurrency(),
is_rna_model(model_config), 1000);
}
Expand Down
2 changes: 1 addition & 1 deletion dorado/cli/duplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ int duplex(int argc, char* argv[]) {
pipeline_desc.add_node_sink(aligner, hts_writer);
converted_reads_sink = aligner;
}
auto read_converter = pipeline_desc.add_node<ReadToBamType>(
auto read_converter = pipeline_desc.add_node<ReadToBamTypeNode>(
{converted_reads_sink}, emit_moves, 2, 0.0f, nullptr, 1000);
auto duplex_read_tagger = pipeline_desc.add_node<DuplexReadTaggingNode>({read_converter});
// The minimum sequence length is set to 5 to avoid issues with duplex node printing very short sequences for mismatched pairs.
Expand Down
38 changes: 6 additions & 32 deletions dorado/read_pipeline/AdapterDetectorNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,51 +12,25 @@
#include <algorithm>
#include <memory>
#include <string>
#include <string_view>
#include <tuple>
#include <utility>
#include <vector>

namespace dorado {

// A Node which encapsulates running adapter and primer detection on each read.
AdapterDetectorNode::AdapterDetectorNode(int threads, bool trim_adapters, bool trim_primers)
: MessageSink(10000),
m_threads(threads),
: MessageSink(10000, threads),
m_trim_adapters(trim_adapters),
m_trim_primers(trim_primers) {
start_threads();
start_input_processing(&AdapterDetectorNode::input_thread_fn, this);
}

AdapterDetectorNode::AdapterDetectorNode(int threads)
: MessageSink(10000), m_threads(threads), m_trim_adapters(true), m_trim_primers(true) {
start_threads();
: MessageSink(10000, threads), m_trim_adapters(true), m_trim_primers(true) {
start_input_processing(&AdapterDetectorNode::input_thread_fn, this);
}

void AdapterDetectorNode::start_threads() {
for (size_t i = 0; i < m_threads; i++) {
m_workers.push_back(std::make_unique<std::thread>(
std::thread(&AdapterDetectorNode::worker_thread, this)));
}
}

void AdapterDetectorNode::terminate_impl() {
terminate_input_queue();
for (auto& m : m_workers) {
if (m->joinable()) {
m->join();
}
}
m_workers.clear();
}

void AdapterDetectorNode::restart() {
restart_input_queue();
start_threads();
}

AdapterDetectorNode::~AdapterDetectorNode() { terminate_impl(); }

void AdapterDetectorNode::worker_thread() {
void AdapterDetectorNode::input_thread_fn() {
Message message;
while (get_input_message(message)) {
if (std::holds_alternative<BamPtr>(message)) {
Expand Down
17 changes: 6 additions & 11 deletions dorado/read_pipeline/AdapterDetectorNode.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#pragma once

#include "demux/AdapterDetector.h"
#include "read_pipeline/ReadPipeline.h"
#include "read_pipeline/MessageSink.h"
#include "utils/stats.h"
#include "utils/types.h"

#include <atomic>
#include <memory>
#include <string>
#include <string_view>
#include <vector>

namespace dorado {
Expand All @@ -16,27 +16,22 @@ class AdapterDetectorNode : public MessageSink {
public:
AdapterDetectorNode(int threads, bool trim_adapters, bool trim_primers);
AdapterDetectorNode(int threads);
~AdapterDetectorNode() override;
~AdapterDetectorNode() override { stop_input_processing(); }
std::string get_name() const override { return "AdapterDetectorNode"; }
stats::NamedStats sample_stats() const override;
void terminate(const FlushOptions&) override { terminate_impl(); }
void restart() override;
void terminate(const FlushOptions&) override { stop_input_processing(); }
void restart() override { start_input_processing(&AdapterDetectorNode::input_thread_fn, this); }

private:
void start_threads();

size_t m_threads{1};
bool m_trim_adapters;
bool m_trim_primers;
std::vector<std::unique_ptr<std::thread>> m_workers;
std::atomic<int> m_num_records{0};
demux::AdapterDetector m_detector;

void worker_thread();
void input_thread_fn();
void process_read(BamPtr& read);
void process_read(SimplexRead& read);

void terminate_impl();
};

} // namespace dorado
36 changes: 5 additions & 31 deletions dorado/read_pipeline/AlignerNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,16 @@ AlignerNode::AlignerNode(std::shared_ptr<alignment::IndexFileAccess> index_file_
const std::string& filename,
const alignment::Minimap2Options& options,
int threads)
: MessageSink(10000),
m_threads(threads),
: MessageSink(10000, threads),
m_index_for_bam_messages(
load_and_get_index(*index_file_access, filename, options, threads)),
m_index_file_access(std::move(index_file_access)) {
start_threads();
start_input_processing(&AlignerNode::input_thread_fn, this);
}

AlignerNode::AlignerNode(std::shared_ptr<alignment::IndexFileAccess> index_file_access, int threads)
: MessageSink(10000),
m_threads(threads),
m_index_file_access(std::move(index_file_access)) {
start_threads();
: MessageSink(10000, threads), m_index_file_access(std::move(index_file_access)) {
start_input_processing(&AlignerNode::input_thread_fn, this);
}

std::shared_ptr<const alignment::Minimap2Index> AlignerNode::get_index(
Expand All @@ -80,29 +77,6 @@ std::shared_ptr<const alignment::Minimap2Index> AlignerNode::get_index(
return index;
}

void AlignerNode::start_threads() {
for (size_t i = 0; i < m_threads; i++) {
m_workers.push_back(std::thread(&AlignerNode::worker_thread, this));
}
}

void AlignerNode::terminate_impl() {
terminate_input_queue();
for (auto& m : m_workers) {
if (m.joinable()) {
m.join();
}
}
m_workers.clear();
}

void AlignerNode::restart() {
restart_input_queue();
start_threads();
}

AlignerNode::~AlignerNode() { terminate_impl(); }

alignment::HeaderSequenceRecords AlignerNode::get_sequence_records_for_header() const {
assert(m_index_for_bam_messages != nullptr &&
"get_sequence_records_for_header only valid if AlignerNode constructed with index file");
Expand All @@ -122,7 +96,7 @@ void AlignerNode::align_read_common(ReadCommon& read_common, mm_tbuf_t* tbuf) {
alignment::Minimap2Aligner(index).align(read_common, tbuf);
}

void AlignerNode::worker_thread() {
void AlignerNode::input_thread_fn() {
Message message;
mm_tbuf_t* tbuf = mm_tbuf_init();
auto align_read = [this, tbuf](auto&& read) {
Expand Down
17 changes: 6 additions & 11 deletions dorado/read_pipeline/AlignerNode.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
#pragma once
#include "ReadPipeline.h"

#include "alignment/IndexFileAccess.h"
#include "alignment/Minimap2Options.h"
#include "read_pipeline/MessageSink.h"
#include "utils/stats.h"
#include "utils/types.h"

#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>

struct bam1_t;
Expand All @@ -27,23 +26,19 @@ class AlignerNode : public MessageSink {
const alignment::Minimap2Options& options,
int threads);
AlignerNode(std::shared_ptr<alignment::IndexFileAccess> index_file_access, int threads);
~AlignerNode();
~AlignerNode() { stop_input_processing(); }
std::string get_name() const override { return "AlignerNode"; }
stats::NamedStats sample_stats() const override;
void terminate(const FlushOptions&) override { terminate_impl(); }
void restart() override;
void terminate(const FlushOptions&) override { stop_input_processing(); }
void restart() override { start_input_processing(&AlignerNode::input_thread_fn, this); }

alignment::HeaderSequenceRecords get_sequence_records_for_header() const;

private:
void start_threads();
void terminate_impl();
void worker_thread();
void input_thread_fn();
std::shared_ptr<const alignment::Minimap2Index> get_index(const ReadCommon& read_common);
void align_read_common(ReadCommon& read_common, mm_tbuf_t* tbuf);

size_t m_threads;
std::vector<std::thread> m_workers;
std::shared_ptr<const alignment::Minimap2Index> m_index_for_bam_messages{};
std::shared_ptr<alignment::IndexFileAccess> m_index_file_access{};
};
Expand Down
Loading

0 comments on commit 8dfd180

Please sign in to comment.