Skip to content

Commit

Permalink
Merge branch 'DOR-616_bam_sort_threading' into 'master'
Browse files Browse the repository at this point in the history
DOR-616 Added threading for compression to bam sorting

Closes DOR-616

See merge request machine-learning/dorado!901
  • Loading branch information
kdolan1973 committed Mar 20, 2024
2 parents c06dfa1 + 6ff78a4 commit c0075a0
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 21 deletions.
8 changes: 5 additions & 3 deletions dorado/cli/aligner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,11 @@ int aligner(int argc, char* argv[]) {

// Report progress during output file finalisation.
tracker.set_description("Sorting output files");
hts_file.finalise([&](size_t progress) {
tracker.update_post_processing_progress(static_cast<float>(progress));
});
hts_file.finalise(
[&](size_t progress) {
tracker.update_post_processing_progress(static_cast<float>(progress));
},
writer_threads);

tracker.summarize();

Expand Down
8 changes: 5 additions & 3 deletions dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,11 @@ void setup(std::vector<std::string> args,

// Report progress during output file finalisation.
tracker.set_description("Sorting output files");
hts_file.finalise([&](size_t progress) {
tracker.update_post_processing_progress(static_cast<float>(progress));
});
hts_file.finalise(
[&](size_t progress) {
tracker.update_post_processing_progress(static_cast<float>(progress));
},
thread_allocations.writer_threads);

// Give the user a nice summary.
tracker.summarize();
Expand Down
11 changes: 7 additions & 4 deletions dorado/cli/duplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ int duplex(int argc, char* argv[]) {
SamHdrPtr hdr(sam_hdr_init());
cli::add_pg_hdr(hdr.get(), args);

utils::HtsFile hts_file("-", output_mode, 4);
constexpr int WRITER_THREADS = 4;
utils::HtsFile hts_file("-", output_mode, WRITER_THREADS);

PipelineDescriptor pipeline_desc;
auto hts_writer = PipelineDescriptor::InvalidNodeHandle;
Expand Down Expand Up @@ -555,9 +556,11 @@ int duplex(int argc, char* argv[]) {

// Report progress during output file finalisation.
tracker.set_description("Sorting output files");
hts_file.finalise([&](size_t progress) {
tracker.update_post_processing_progress(static_cast<float>(progress));
});
hts_file.finalise(
[&](size_t progress) {
tracker.update_post_processing_progress(static_cast<float>(progress));
},
WRITER_THREADS);

tracker.summarize();
if (!dump_stats_file.empty()) {
Expand Down
8 changes: 5 additions & 3 deletions dorado/cli/trim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,11 @@ int trim(int argc, char* argv[]) {

// Report progress during output file finalisation.
tracker.set_description("Sorting output files");
hts_file.finalise([&](size_t progress) {
tracker.update_post_processing_progress(static_cast<float>(progress));
});
hts_file.finalise(
[&](size_t progress) {
tracker.update_post_processing_progress(static_cast<float>(progress));
},
trim_writer_threads);
tracker.summarize();

spdlog::info("> finished adapter/primer trimming");
Expand Down
12 changes: 7 additions & 5 deletions dorado/read_pipeline/BarcodeDemuxerNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,13 @@ void BarcodeDemuxerNode::finalise_hts_files(
const size_t num_files = m_files.size();
size_t current_file_idx = 0;
for (auto& [bc, hts_file] : m_files) {
hts_file->finalise([&](size_t progress) {
// Give each file/barcode the same contribution to the total progress.
const size_t total_progress = (current_file_idx * 100 + progress) / num_files;
progress_callback(total_progress);
});
hts_file->finalise(
[&](size_t progress) {
// Give each file/barcode the same contribution to the total progress.
const size_t total_progress = (current_file_idx * 100 + progress) / num_files;
progress_callback(total_progress);
},
m_htslib_threads);
++current_file_idx;
}

Expand Down
11 changes: 10 additions & 1 deletion dorado/utils/hts_file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ HtsFile::~HtsFile() {
// in order to generate a map of sort coordinates to virtual file offsets. we can then jump around in the
// file to write out the records in the sorted order. finally we can delete the unsorted file.
// in case an error occurs, the unsorted file is left on disk, so users can recover their data.
void HtsFile::finalise(const ProgressCallback& progress_callback) {
void HtsFile::finalise(const ProgressCallback& progress_callback, int writer_threads) {
assert(progress_callback);

// Rough divisions of how far through we are at the start of each section.
Expand Down Expand Up @@ -112,6 +112,15 @@ void HtsFile::finalise(const ProgressCallback& progress_callback) {
HtsFilePtr in_file(hts_open(temp_filename.c_str(), "rb"));
HtsFilePtr out_file(hts_open(filepath.string().c_str(), "wb"));

if (bgzf_mt(in_file->fp.bgzf, writer_threads, 128) < 0) {
spdlog::error("Could not enable multi threading for BAM reading.");
return;
}
if (bgzf_mt(out_file->fp.bgzf, writer_threads, 128) < 0) {
spdlog::error("Could not enable multi threading for BAM generation.");
return;
}

SamHdrPtr in_header(sam_hdr_read(in_file.get()));
SamHdrPtr out_header(sam_hdr_dup(in_header.get()));
sam_hdr_change_HD(out_header.get(), "SO", "coordinate");
Expand Down
2 changes: 1 addition & 1 deletion dorado/utils/hts_file.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class HtsFile {
int write(const bam1_t* record);

bool finalise_is_noop() const { return m_finalise_is_noop; }
void finalise(const ProgressCallback& progress_callback);
void finalise(const ProgressCallback& progress_callback, int writer_threads);
};

} // namespace dorado::utils
2 changes: 1 addition & 1 deletion tests/BamWriterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class HtsWriterTestsFixture {
auto& writer_ref = dynamic_cast<HtsWriter&>(pipeline->get_node_ref(writer));
stats = writer_ref.sample_stats();

hts_file.finalise([](size_t) { /* noop */ });
hts_file.finalise([](size_t) { /* noop */ }, num_threads);
}

stats::NamedStats stats;
Expand Down

0 comments on commit c0075a0

Please sign in to comment.