From ec1fde5500d39f78f586251fab69d8c3b27104c8 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Wed, 11 Oct 2023 14:54:47 +0100 Subject: [PATCH 01/39] Add modbasecaller to duplex pipeline --- dorado/cli/basecaller.cpp | 9 ++--- dorado/cli/duplex.cpp | 59 ++++++++++++++++++++++++++++-- dorado/read_pipeline/Pipelines.cpp | 22 ++++++----- dorado/read_pipeline/Pipelines.h | 1 + dorado/utils/parameters.h | 2 +- tests/NodeSmokeTest.cpp | 2 +- 6 files changed, 75 insertions(+), 20 deletions(-) diff --git a/dorado/cli/basecaller.cpp b/dorado/cli/basecaller.cpp index 4fd9792d..8b26ad51 100644 --- a/dorado/cli/basecaller.cpp +++ b/dorado/cli/basecaller.cpp @@ -98,8 +98,9 @@ void setup(std::vector args, } // create modbase runners first so basecall runners can pick batch sizes based on available memory - auto remora_runners = create_modbase_runners( - remora_models, device, default_parameters.remora_runners_per_caller, remora_batch_size); + auto remora_runners = create_modbase_runners(remora_models, device, + default_parameters.mod_base_runners_per_caller, + remora_batch_size); if (!remora_runners.empty() && output_mode == HtsWriter::OutputMode::FASTQ) { throw std::runtime_error("Modified base models cannot be used with FASTQ output"); @@ -115,8 +116,6 @@ void setup(std::vector args, data_path, read_list, {} /*reads_already_processed*/, recursive_file_loading); num_reads = max_reads == 0 ? num_reads : std::min(num_reads, max_reads); - bool duplex = false; - const auto thread_allocations = utils::default_thread_allocations( num_devices, !remora_runners.empty() ? num_remora_threads : 0, enable_aligner, !barcode_kits.empty()); @@ -211,7 +210,7 @@ void setup(std::vector args, } std::vector stats_callables; - ProgressTracker tracker(num_reads, duplex); + ProgressTracker tracker(num_reads, false); stats_callables.push_back( [&tracker](const stats::NamedStats& stats) { tracker.update_progress_bar(stats); }); constexpr auto kStatsPeriod = 100ms; diff --git a/dorado/cli/duplex.cpp b/dorado/cli/duplex.cpp index 35940d1c..24206a05 100644 --- a/dorado/cli/duplex.cpp +++ b/dorado/cli/duplex.cpp @@ -3,6 +3,7 @@ #include "data_loader/DataLoader.h" #include "models/models.h" #include "nn/CRFModelConfig.h" +#include "nn/ModBaseRunner.h" #include "nn/Runners.h" #include "read_pipeline/AlignerNode.h" #include "read_pipeline/BaseSpaceDuplexCallerNode.h" @@ -111,6 +112,31 @@ int duplex(int argc, char* argv[]) { parser.visible.add_argument("-v", "--verbose").default_value(false).implicit_value(true); + parser.visible.add_argument("--modified-bases") + .nargs(argparse::nargs_pattern::at_least_one) + .action([](const std::string& value) { + const auto& mods = models::modified_mods(); + if (std::find(mods.begin(), mods.end(), value) == mods.end()) { + spdlog::error( + "'{}' is not a supported modification please select from {}", value, + std::accumulate( + std::next(mods.begin()), mods.end(), mods[0], + [](std::string a, std::string b) { return a + ", " + b; })); + std::exit(EXIT_FAILURE); + } + return value; + }); + + parser.visible.add_argument("--modified-bases-models") + .default_value(std::string()) + .help("a comma separated list of modified base models"); + + parser.visible.add_argument("--modified-bases-threshold") + .default_value(default_parameters.methylation_threshold) + .scan<'f', float>() + .help("the minimum predicted methylation probability for a modified base to be emitted " + "in an all-context model, [0, 1]"); + cli::add_minimap2_arguments(parser, Aligner::dflt_options); cli::add_internal_arguments(parser); @@ -134,6 +160,24 @@ int duplex(int argc, char* argv[]) { if (parser.visible.get("--verbose")) { utils::SetDebugLogging(); } + + auto mod_bases = parser.visible.get>("--modified-bases"); + auto mod_bases_models = parser.visible.get("--modified-bases-models"); + + if (mod_bases.size() && !mod_bases_models.empty()) { + spdlog::error( + "only one of --modified-bases or --modified-bases-models should be specified."); + std::exit(EXIT_FAILURE); + } else if (mod_bases.size()) { + std::vector m; + std::transform( + mod_bases.begin(), mod_bases.end(), std::back_inserter(m), + [&model](std::string m) { return models::get_modification_model(model, m); }); + + mod_bases_models = + std::accumulate(std::next(m.begin()), m.end(), m[0], + [](std::string a, std::string b) { return a + "," + b; }); + } std::map template_complement_map; auto read_list = utils::load_read_list(parser.visible.get("--read-ids")); @@ -281,6 +325,15 @@ int duplex(int argc, char* argv[]) { } auto stereo_model_config = load_crf_model_config(stereo_model_path); + // create modbase runners first so basecall runners can pick batch sizes based on available memory + auto mod_base_runners = create_modbase_runners( + mod_bases_models, device, default_parameters.mod_base_runners_per_caller, + default_parameters.remora_batchsize); + + if (!mod_base_runners.empty() && output_mode == HtsWriter::OutputMode::FASTQ) { + throw std::runtime_error("Modified base models cannot be used with FASTQ output"); + } + // Write read group info to header. auto duplex_rg_name = std::string(model + "_" + stereo_model_name); auto read_groups = DataLoader::load_read_groups(reads, model, recursive_file_loading); @@ -341,9 +394,9 @@ int duplex(int argc, char* argv[]) { } } pipelines::create_stereo_duplex_pipeline( - pipeline_desc, std::move(runners), std::move(stereo_runners), overlap, - mean_qscore_start_pos, num_devices * 2, num_devices, - std::move(pairing_parameters), read_filter_node); + pipeline_desc, std::move(runners), std::move(stereo_runners), + std::move(mod_base_runners), overlap, mean_qscore_start_pos, num_devices * 2, + num_devices, std::move(pairing_parameters), read_filter_node); std::vector stats_reporters; pipeline = Pipeline::create(std::move(pipeline_desc), &stats_reporters); diff --git a/dorado/read_pipeline/Pipelines.cpp b/dorado/read_pipeline/Pipelines.cpp index 532832b7..7047f987 100644 --- a/dorado/read_pipeline/Pipelines.cpp +++ b/dorado/read_pipeline/Pipelines.cpp @@ -75,16 +75,18 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc, } } -void create_stereo_duplex_pipeline(PipelineDescriptor& pipeline_desc, - std::vector&& runners, - std::vector&& stereo_runners, - size_t overlap, - uint32_t mean_qscore_start_pos, - int scaler_node_threads, - int splitter_node_threads, - PairingParameters pairing_parameters, - NodeHandle sink_node_handle, - NodeHandle source_node_handle) { +void create_stereo_duplex_pipeline( + PipelineDescriptor& pipeline_desc, + std::vector&& runners, + std::vector&& stereo_runners, + std::vector>&& modbase_runners, + size_t overlap, + uint32_t mean_qscore_start_pos, + int scaler_node_threads, + int splitter_node_threads, + PairingParameters pairing_parameters, + NodeHandle sink_node_handle, + NodeHandle source_node_handle) { const auto& model_config = runners.front()->config(); const auto& stereo_model_config = stereo_runners.front()->config(); std::string model_name = diff --git a/dorado/read_pipeline/Pipelines.h b/dorado/read_pipeline/Pipelines.h index 591351d7..e4d10b0c 100644 --- a/dorado/read_pipeline/Pipelines.h +++ b/dorado/read_pipeline/Pipelines.h @@ -41,6 +41,7 @@ void create_stereo_duplex_pipeline( PipelineDescriptor& pipeline_desc, std::vector&& runners, std::vector&& stereo_runners, + std::vector>&& modbase_runners, size_t overlap, uint32_t mean_qscore_start_pos, int scaler_node_threads, diff --git a/dorado/utils/parameters.h b/dorado/utils/parameters.h index 3f0ad5ed..ec4f9a65 100644 --- a/dorado/utils/parameters.h +++ b/dorado/utils/parameters.h @@ -22,7 +22,7 @@ struct DefaultParameters { int remora_batchsize{1024}; #endif int remora_threads{4}; - int remora_runners_per_caller{2}; + int mod_base_runners_per_caller{2}; float methylation_threshold{0.05f}; // Minimum length for a sequence to be outputted. diff --git a/tests/NodeSmokeTest.cpp b/tests/NodeSmokeTest.cpp index 899323be..3628c0b1 100644 --- a/tests/NodeSmokeTest.cpp +++ b/tests/NodeSmokeTest.cpp @@ -296,7 +296,7 @@ DEFINE_TEST(NodeSmokeTestRead, "ModBaseCallerNode") { for (const auto& device_string : modbase_devices) { auto caller = dorado::create_modbase_caller({remora_model, remora_model_6mA}, batch_size, device_string); - for (size_t i = 0; i < default_params.remora_runners_per_caller; i++) { + for (size_t i = 0; i < default_params.mod_base_runners_per_caller; i++) { remora_runners.push_back(std::make_unique(caller)); } } From c3fe5fce8090fa16830352f7a149f1bdf19814de Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Fri, 13 Oct 2023 13:47:33 +0100 Subject: [PATCH 02/39] modbase calling enabled for simplex during duplex calling --- dorado/read_pipeline/Pipelines.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/dorado/read_pipeline/Pipelines.cpp b/dorado/read_pipeline/Pipelines.cpp index 7047f987..53ecaa3a 100644 --- a/dorado/read_pipeline/Pipelines.cpp +++ b/dorado/read_pipeline/Pipelines.cpp @@ -102,6 +102,16 @@ void create_stereo_duplex_pipeline( {}, std::move(stereo_runners), adjusted_stereo_overlap, kStereoBatchTimeoutMS, duplex_rg_name, 1000, "StereoBasecallerNode", true, mean_qscore_start_pos); + NodeHandle last_node_handle = stereo_basecaller_node; + if (!modbase_runners.empty()) { + auto mod_base_caller_node = pipeline_desc.add_node( + {}, std::move(modbase_runners), 4, + runners.front() + ->model_stride()); // TODO - 4 for both of these parameters is incorrect + pipeline_desc.add_node_sink(stereo_basecaller_node, mod_base_caller_node); + last_node_handle = mod_base_caller_node; + } + auto simplex_model_stride = runners.front()->model_stride(); auto stereo_node = pipeline_desc.add_node({stereo_basecaller_node}, simplex_model_stride); @@ -140,7 +150,7 @@ void create_stereo_duplex_pipeline( // if we've been provided a sink node, connect it to the end of our pipeline if (sink_node_handle != PipelineDescriptor::InvalidNodeHandle) { - pipeline_desc.add_node_sink(stereo_basecaller_node, sink_node_handle); + pipeline_desc.add_node_sink(last_node_handle, sink_node_handle); } } From b23a17b74444677455ddda5bfc2d4e13b2fcf149 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Mon, 16 Oct 2023 14:27:08 +0100 Subject: [PATCH 03/39] placeholder for handling duplex template strand in worker --- dorado/read_pipeline/ModBaseCallerNode.cpp | 48 ++++++++++++++++++---- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index b66aa229..f9d13168 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -38,7 +38,7 @@ struct ModBaseCallerNode::RemoraChunk { }; struct ModBaseCallerNode::WorkingRead { - SimplexReadPtr read; // The read itself. + Message read; // The read itself. size_t num_modbase_chunks; std::atomic_size_t num_modbase_chunks_called; // Number of modbase chunks which have been scored @@ -147,16 +147,17 @@ void ModBaseCallerNode::input_worker_thread() { Message message; while (get_input_message(message)) { // If this message isn't a read, just forward it to the sink. - if (!std::holds_alternative(message)) { + if (!is_read_message(message)) { send_message_to_sink(std::move(message)); continue; } nvtx3::scoped_range range{"modbase_input_worker_thread"}; // If this message isn't a read, we'll get a bad_variant_access exception. - auto read = std::get(std::move(message)); - while (true) { + auto& read_common = get_read_common_data(message); + if (!read_common.is_duplex) { + auto read = std::get(std::move(message)); stats::Timer timer; { nvtx3::scoped_range range{"base_mod_probs_init"}; @@ -264,7 +265,34 @@ void ModBaseCallerNode::input_worker_thread() { send_message_to_sink(std::move(read)); ++m_num_non_mod_base_reads_pushed; } - break; + } else { + // ----> DUPLEX HANDLED HERE (FOR NOW) <----. + + // Step 1: prepare the base mod probs table - we are likely going to need two of these, one for each strand, + // We probably don't want them on read_common either + auto read = std::get(std::move(message)); + stats::Timer timer; + { + nvtx3::scoped_range range{"base_mod_probs_init"}; + // initialize base_mod_probs _before_ we start handing out chunks + read->read_common.base_mod_probs.resize(read->read_common.seq.size() * m_num_states, + 0); + for (size_t i = 0; i < read->read_common.seq.size(); ++i) { + // Initialize for what corresponds to 100% canonical base for each position. + int base_id = utils::BaseInfo::BASE_IDS[read->read_common.seq[i]]; + if (base_id < 0) { + throw std::runtime_error("Invalid character in sequence."); + } + read->read_common + .base_mod_probs[i * m_num_states + m_base_prob_offsets[base_id]] = 1.0f; + } + } + + // mod_base_info is the modified base settings of the models that ran on this read + read->read_common.mod_base_info = m_mod_base_info; + + send_message_to_sink(std::move(read)); + ++m_num_non_mod_base_reads_pushed; } } @@ -385,11 +413,13 @@ void ModBaseCallerNode::output_worker_thread() { for (const auto& chunk : processed_chunks) { auto working_read = chunk->working_read; auto& source_read = working_read->read; + auto& source_read_common = get_read_common_data(source_read); + int64_t result_pos = chunk->context_hit; int64_t offset = m_base_prob_offsets - [utils::BaseInfo::BASE_IDS[source_read->read_common.seq[result_pos]]]; + [utils::BaseInfo::BASE_IDS[source_read_common.seq[result_pos]]]; for (size_t i = 0; i < chunk->scores.size(); ++i) { - source_read->read_common.base_mod_probs[m_num_states * result_pos + offset + i] = + source_read_common.base_mod_probs[m_num_states * result_pos + offset + i] = static_cast(std::min(std::floor(chunk->scores[i] * 256), 255.0f)); } // If all chunks for the read associated with this chunk have now been called, @@ -409,8 +439,8 @@ void ModBaseCallerNode::output_worker_thread() { if (read_iter != m_working_reads.end()) { m_working_reads.erase(read_iter); } else { - throw std::runtime_error("Expected to find read id " + - completed_read->read->read_common.read_id + + auto read_id = get_read_common_data(completed_read->read).read_id; + throw std::runtime_error("Expected to find read id " + read_id + " in working reads set but it doesn't exist."); } } From c8ccb3f23a9403e6193cf477b47e99cb10b3b189 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Fri, 17 Nov 2023 10:53:05 +0000 Subject: [PATCH 04/39] Separated duplex and simple mod calls into separate functions --- dorado/read_pipeline/ModBaseCallerNode.cpp | 253 +++++++++------------ dorado/read_pipeline/ModBaseCallerNode.h | 2 + 2 files changed, 114 insertions(+), 141 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index 393a9f14..e1406d67 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -142,158 +142,129 @@ void ModBaseCallerNode::init_modbase_info() { m_base_prob_offsets[3] = m_base_prob_offsets[2] + result.base_counts[2]; } -void ModBaseCallerNode::input_worker_thread() { - at::InferenceMode inference_mode_guard; +void ModBaseCallerNode::duplex_mod_call(Message message) { + send_message_to_sink(std::move(message)); +} - Message message; - while (get_input_message(message)) { - // If this message isn't a read, just forward it to the sink. - if (!is_read_message(message)) { - send_message_to_sink(std::move(message)); - continue; +void ModBaseCallerNode::simplex_mod_call(Message message) { + auto read = std::get(std::move(message)); + stats::Timer timer; + { + nvtx3::scoped_range range{"base_mod_probs_init"}; + // initialize base_mod_probs _before_ we start handing out chunks + read->read_common.base_mod_probs.resize(read->read_common.seq.size() * m_num_states, 0); + for (size_t i = 0; i < read->read_common.seq.size(); ++i) { + // Initialize for what corresponds to 100% canonical base for each position. + int base_id = utils::BaseInfo::BASE_IDS[read->read_common.seq[i]]; + if (base_id < 0) { + throw std::runtime_error("Invalid character in sequence."); + } + read->read_common.base_mod_probs[i * m_num_states + m_base_prob_offsets[base_id]] = 1; } + } + read->read_common.mod_base_info = m_mod_base_info; - nvtx3::scoped_range range{"modbase_input_worker_thread"}; - // If this message isn't a read, we'll get a bad_variant_access exception. - - auto& read_common = get_read_common_data(message); - if (!read_common.is_duplex) { - auto read = std::get(std::move(message)); - stats::Timer timer; - { - nvtx3::scoped_range range{"base_mod_probs_init"}; - // initialize base_mod_probs _before_ we start handing out chunks - read->read_common.base_mod_probs.resize(read->read_common.seq.size() * m_num_states, - 0); - for (size_t i = 0; i < read->read_common.seq.size(); ++i) { - // Initialize for what corresponds to 100% canonical base for each position. - int base_id = utils::BaseInfo::BASE_IDS[read->read_common.seq[i]]; - if (base_id < 0) { - throw std::runtime_error("Invalid character in sequence."); - } - read->read_common - .base_mod_probs[i * m_num_states + m_base_prob_offsets[base_id]] = 1; - } - } - read->read_common.mod_base_info = m_mod_base_info; - - auto working_read = std::make_shared(); - working_read->num_modbase_chunks = 0; - working_read->num_modbase_chunks_called = 0; - - std::vector sequence_ints = utils::sequence_to_ints(read->read_common.seq); - - // all runners have the same set of callers, so we only need to use the first one - auto& runner = m_runners[0]; - std::vector>> chunks_to_enqueue_by_caller( - runner->num_callers()); - for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { - nvtx3::scoped_range range{"generate_chunks"}; - - auto signal_len = read->read_common.get_raw_data_samples(); - std::vector seq_to_sig_map = - utils::moves_to_map(read->read_common.moves, m_block_stride, signal_len, - read->read_common.seq.size() + 1); - - auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); - auto& params = runner->caller_params(caller_id); - auto signal = read->read_common.raw_data; - if (params.reverse_signal) { - signal = at::flip(signal, 0); - std::reverse(std::begin(seq_to_sig_map), std::end(seq_to_sig_map)); - std::transform(std::begin(seq_to_sig_map), std::end(seq_to_sig_map), - std::begin(seq_to_sig_map), [signal_len](auto signal_pos) { - return signal_len - signal_pos; - }); - } + auto working_read = std::make_shared(); + working_read->num_modbase_chunks = 0; + working_read->num_modbase_chunks_called = 0; - // scale signal based on model parameters - auto scaled_signal = - runner->scale_signal(caller_id, signal, sequence_ints, seq_to_sig_map); - - auto context_samples = (params.context_before + params.context_after); - - // One-hot encodes the kmer at each signal step for input into the network - ModBaseEncoder encoder(m_block_stride, context_samples, params.bases_before, - params.bases_after); - encoder.init(sequence_ints, seq_to_sig_map); - - auto context_hits = runner->get_motif_hits(caller_id, read->read_common.seq); - m_num_context_hits += static_cast(context_hits.size()); - chunks_to_enqueue.reserve(context_hits.size()); - for (auto context_hit : context_hits) { - nvtx3::scoped_range range{"create_chunk"}; - auto slice = encoder.get_context(context_hit); - // signal - auto input_signal = scaled_signal.index({at::indexing::Slice( - slice.first_sample, slice.first_sample + slice.num_samples)}); - if (slice.lead_samples_needed != 0 || slice.tail_samples_needed != 0) { - input_signal = at::constant_pad_nd(input_signal, - {(int64_t)slice.lead_samples_needed, - (int64_t)slice.tail_samples_needed}); - } - chunks_to_enqueue.push_back(std::make_unique( - working_read, input_signal, std::move(slice.data), context_hit)); - - ++working_read->num_modbase_chunks; - } + std::vector sequence_ints = utils::sequence_to_ints(read->read_common.seq); + + // all runners have the same set of callers, so we only need to use the first one + auto& runner = m_runners[0]; + std::vector>> chunks_to_enqueue_by_caller( + runner->num_callers()); + for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { + nvtx3::scoped_range range{"generate_chunks"}; + + auto signal_len = read->read_common.get_raw_data_samples(); + std::vector seq_to_sig_map = + utils::moves_to_map(read->read_common.moves, m_block_stride, signal_len, + read->read_common.seq.size() + 1); + + auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); + auto& params = runner->caller_params(caller_id); + auto signal = read->read_common.raw_data; + if (params.reverse_signal) { + signal = at::flip(signal, 0); + std::reverse(std::begin(seq_to_sig_map), std::end(seq_to_sig_map)); + std::transform(std::begin(seq_to_sig_map), std::end(seq_to_sig_map), + std::begin(seq_to_sig_map), + [signal_len](auto signal_pos) { return signal_len - signal_pos; }); + } + + // scale signal based on model parameters + auto scaled_signal = runner->scale_signal(caller_id, signal, sequence_ints, seq_to_sig_map); + + auto context_samples = (params.context_before + params.context_after); + + // One-hot encodes the kmer at each signal step for input into the network + ModBaseEncoder encoder(m_block_stride, context_samples, params.bases_before, + params.bases_after); + encoder.init(sequence_ints, seq_to_sig_map); + + auto context_hits = runner->get_motif_hits(caller_id, read->read_common.seq); + m_num_context_hits += static_cast(context_hits.size()); + chunks_to_enqueue.reserve(context_hits.size()); + for (auto context_hit : context_hits) { + nvtx3::scoped_range range{"create_chunk"}; + auto slice = encoder.get_context(context_hit); + // signal + auto input_signal = scaled_signal.index({at::indexing::Slice( + slice.first_sample, slice.first_sample + slice.num_samples)}); + if (slice.lead_samples_needed != 0 || slice.tail_samples_needed != 0) { + input_signal = at::constant_pad_nd( + input_signal, + {(int64_t)slice.lead_samples_needed, (int64_t)slice.tail_samples_needed}); } - m_chunk_generation_ms += timer.GetElapsedMS(); + chunks_to_enqueue.push_back(std::make_unique( + working_read, input_signal, std::move(slice.data), context_hit)); - if (working_read->num_modbase_chunks != 0) { - // Hand over our ownership to the working read - working_read->read = std::move(read); + ++working_read->num_modbase_chunks; + } + } + m_chunk_generation_ms += timer.GetElapsedMS(); - // Put the read in the working list - { - std::lock_guard working_reads_lock(m_working_reads_mutex); - m_working_reads.insert(std::move(working_read)); - ++m_working_reads_size; - } + if (working_read->num_modbase_chunks != 0) { + // Hand over our ownership to the working read + working_read->read = std::move(read); - // push the chunks to the chunk queues - // needs to be done after working_read->read is set as chunks could be processed - // before we set that value otherwise - for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { - auto& chunk_queue = m_chunk_queues.at(caller_id); - auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); - for (auto& chunk : chunks_to_enqueue) { - chunk_queue->try_push(std::move(chunk)); - } - } - } else { - // No modbases to call, pass directly to next node - send_message_to_sink(std::move(read)); - ++m_num_non_mod_base_reads_pushed; - } - } else { - // ----> DUPLEX HANDLED HERE (FOR NOW) <----. - - // Step 1: prepare the base mod probs table - we are likely going to need two of these, one for each strand, - // We probably don't want them on read_common either - auto read = std::get(std::move(message)); - stats::Timer timer; - { - nvtx3::scoped_range range{"base_mod_probs_init"}; - // initialize base_mod_probs _before_ we start handing out chunks - read->read_common.base_mod_probs.resize(read->read_common.seq.size() * m_num_states, - 0); - for (size_t i = 0; i < read->read_common.seq.size(); ++i) { - // Initialize for what corresponds to 100% canonical base for each position. - int base_id = utils::BaseInfo::BASE_IDS[read->read_common.seq[i]]; - if (base_id < 0) { - throw std::runtime_error("Invalid character in sequence."); - } - read->read_common - .base_mod_probs[i * m_num_states + m_base_prob_offsets[base_id]] = 1.0f; - } + // Put the read in the working list + { + std::lock_guard working_reads_lock(m_working_reads_mutex); + m_working_reads.insert(std::move(working_read)); + ++m_working_reads_size; + } + + // push the chunks to the chunk queues + // needs to be done after working_read->read is set as chunks could be processed + // before we set that value otherwise + for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { + auto& chunk_queue = m_chunk_queues.at(caller_id); + auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); + for (auto& chunk : chunks_to_enqueue) { + chunk_queue->try_push(std::move(chunk)); } + } + } else { + // No modbases to call, pass directly to next node + send_message_to_sink(std::move(read)); + ++m_num_non_mod_base_reads_pushed; + } +} - // mod_base_info is the modified base settings of the models that ran on this read - read->read_common.mod_base_info = m_mod_base_info; +void ModBaseCallerNode::input_worker_thread() { + at::InferenceMode inference_mode_guard; - send_message_to_sink(std::move(read)); - ++m_num_non_mod_base_reads_pushed; + Message message; + while (get_input_message(message)) { + // If this message isn't a read, just forward it to the sink. + if (!is_read_message(message)) { + send_message_to_sink(std::move(message)); + } else if (std::holds_alternative(message)) { + simplex_mod_call(std::move(message)); + } else if (std::holds_alternative(message)) { + duplex_mod_call(std::move(message)); } } diff --git a/dorado/read_pipeline/ModBaseCallerNode.h b/dorado/read_pipeline/ModBaseCallerNode.h index 60bae7f2..298a9266 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.h +++ b/dorado/read_pipeline/ModBaseCallerNode.h @@ -32,6 +32,8 @@ class ModBaseCallerNode : public MessageSink { stats::NamedStats sample_stats() const override; void terminate(const FlushOptions& flush_options) override { terminate_impl(); } void restart() override; + void simplex_mod_call(Message message); + void duplex_mod_call(Message message); private: void start_threads(); From 8ded2fae4e81e81a44e2d10ddc989c5904516086 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Mon, 20 Nov 2023 17:05:56 +0000 Subject: [PATCH 05/39] Computing simplex to duplex edlib alignment --- dorado/read_pipeline/ModBaseCallerNode.cpp | 46 ++++++++++- dorado/utils/sequence_utils.cpp | 92 ++++++++++++++++++++++ dorado/utils/sequence_utils.h | 8 ++ 3 files changed, 145 insertions(+), 1 deletion(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index e1406d67..e81bda56 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -143,7 +143,51 @@ void ModBaseCallerNode::init_modbase_info() { } void ModBaseCallerNode::duplex_mod_call(Message message) { - send_message_to_sink(std::move(message)); + // Let's do this only for the template strand for now. + + auto read = std::get(std::move(message)); + stats::Timer timer; + + { + nvtx3::scoped_range range{"base_mod_probs_init"}; + // initialize base_mod_probs _before_ we start handing out chunks + read->read_common.base_mod_probs.resize(read->read_common.seq.size() * m_num_states, 0); + for (size_t i = 0; i < read->read_common.seq.size(); ++i) { + // Initialize for what corresponds to 100% canonical base for each position. + int base_id = utils::BaseInfo::BASE_IDS[read->read_common.seq[i]]; + if (base_id < 0) { + throw std::runtime_error("Invalid character in sequence."); + } + read->read_common.base_mod_probs[i * m_num_states + m_base_prob_offsets[base_id]] = 1; + } + } + + read->read_common.mod_base_info = m_mod_base_info; + + auto working_read = std::make_shared(); + working_read->num_modbase_chunks = 0; + working_read->num_modbase_chunks_called = 0; + + std::vector sequence_ints = utils::sequence_to_ints(read->read_common.seq); + + // all runners have the same set of callers, so we only need to use the first one + auto& runner = m_runners[0]; + std::vector>> chunks_to_enqueue_by_caller( + runner->num_callers()); + + for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { + nvtx3::scoped_range range{"generate_chunks"}; + + //auto signal_len = read->stereo_feature_inputs.template_signal.size(0); + + std::vector template_moves = utils::realign_moves( + read->stereo_feature_inputs.template_seq, read->read_common.seq, + read->stereo_feature_inputs.template_moves); + // Next - build the sig to seq map. + // What we need first is a new moves table for the template to the duplex read. + } + + send_message_to_sink(std::move(read)); } void ModBaseCallerNode::simplex_mod_call(Message message) { diff --git a/dorado/utils/sequence_utils.cpp b/dorado/utils/sequence_utils.cpp index b2fdf8f5..bc82cfcf 100644 --- a/dorado/utils/sequence_utils.cpp +++ b/dorado/utils/sequence_utils.cpp @@ -1,12 +1,16 @@ #include "sequence_utils.h" #include "simd.h" +#include "types.h" +#include +#include #include #include #include #include +#include #include #include #include @@ -182,6 +186,94 @@ std::vector moves_to_map(const std::vector& moves, return seq_to_sig_map; } +OverlapResult compute_overlap(std::string query_seq, std::string target_seq) { + OverlapResult overlap_result = {false, 0, 0, 0, 0}; + + // Add mm2 based overlap check. + mm_idxopt_t m_idx_opt; + mm_mapopt_t m_map_opt; + mm_set_opt(0, &m_idx_opt, &m_map_opt); + mm_set_opt("map-hifi", &m_idx_opt, &m_map_opt); + + std::vector seqs = {query_seq.c_str()}; + std::vector names = {"query"}; + mm_idx_t* m_index = mm_idx_str(m_idx_opt.w, m_idx_opt.k, 0, m_idx_opt.bucket_bits, 1, + seqs.data(), names.data()); + mm_mapopt_update(&m_map_opt, m_index); + + MmTbufPtr mbuf = MmTbufPtr(mm_tbuf_init()); + + int hits = 0; + mm_reg1_t* reg = mm_map(m_index, int(target_seq.length()), target_seq.c_str(), &hits, + mbuf.get(), &m_map_opt, "target"); + + mm_idx_destroy(m_index); + + if (hits > 0) { + int32_t target_start = 0; + int32_t target_end = 0; + int32_t query_start = 0; + int32_t query_end = 0; + + auto best_map = std::max_element( + reg, reg + hits, + [](const mm_reg1_t& l, const mm_reg1_t& r) { return l.mapq < r.mapq; }); + target_start = best_map->rs; + target_end = best_map->re; + query_start = best_map->qs; + query_end = best_map->qe; + + overlap_result = {true, target_start, target_end, query_start, query_end}; + } + + for (int i = 0; i < hits; ++i) { + free(reg[i].p); + } + free(reg); + + return overlap_result; +} + +std::vector realign_moves(std::string query_sequence, + std::string target_sequence, + std::vector moves) { + //Initially let's just spread the moves evenly, we can come back to this later + std::vector new_moves; + int num_moves = std::accumulate(moves.begin(), moves.end(), 0); + int input_seq_size = query_sequence.size(); + int target_seq_size = target_sequence.size(); + std::cerr << num_moves; + std::cerr << input_seq_size; + std::cerr << target_seq_size; + + auto [is_overlap, query_start, query_end, target_start, target_end] = compute_overlap( + query_sequence, + target_sequence); // We are going to compute the overlap between the two reads + + // Now let's perform an alignmnet: + + EdlibAlignConfig align_config = edlibDefaultAlignConfig(); + align_config.task = EDLIB_TASK_PATH; + + auto target_sequence_component = + target_sequence.substr(target_start, target_end - target_start); + auto query_sequence_component = query_sequence.substr(query_start, query_end - query_start); + + EdlibAlignResult edlib_result = edlibAlign( + target_sequence_component.data(), static_cast(target_sequence_component.length()), + query_sequence_component.data(), static_cast(query_sequence_component.length()), + align_config); + + std::cerr << "TSC:" << target_sequence_component << std::endl; + std::cerr << "QSC:" << query_sequence_component << std::endl; + + // Now that we have the alignment, we need to compute the new move table, by walking along the alignment + + edlibFreeAlignResult(edlib_result); + + return new_moves; +} + std::vector move_cum_sums(const std::vector& moves) { std::vector ans(moves.size(), 0); if (!moves.empty()) { diff --git a/dorado/utils/sequence_utils.h b/dorado/utils/sequence_utils.h index 750dbf50..f39a0b79 100644 --- a/dorado/utils/sequence_utils.h +++ b/dorado/utils/sequence_utils.h @@ -26,6 +26,11 @@ std::vector moves_to_map(const std::vector& moves, // Compute cumulative sums of the move table std::vector move_cum_sums(const std::vector& moves); +// Result of overlapping two reads +using OverlapResult = std::tuple; + +OverlapResult compute_overlap(std::string query_seq, std::string target_seq); + // Compute reverse complement of a nucleotide sequence. // Bases are specified as capital letters. // Undefined output if characters other than A, C, G, T appear. @@ -39,4 +44,7 @@ class BaseInfo { int count_trailing_chars(const std::string_view adapter, char c); +std::vector realign_moves(std::string query_sequence, + std::string target_sequence, + std::vector moves); } // namespace dorado::utils From 260cd9370415b14ff765495b3c6a16cdad814e02 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Tue, 21 Nov 2023 15:03:26 +0000 Subject: [PATCH 06/39] Computing realigned move table! --- CMakeLists.txt | 6 +- dorado/utils/CMakeLists.txt | 2 +- dorado/utils/sequence_utils.cpp | 98 +++++++++++++++++++++++++++++++-- 3 files changed, 96 insertions(+), 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 15de9d1a..a6dfa029 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -277,7 +277,7 @@ endif() add_library(dorado_lib ${LIB_SOURCE_FILES}) -enable_warnings_as_errors(dorado_lib) +#enable_warnings_as_errors(dorado_lib) set_target_properties(dorado_lib PROPERTIES @@ -374,7 +374,7 @@ if(NOT DORADO_LIB_ONLY) set_target_properties(dorado PROPERTIES LINK_OPTIONS "/ignore:4099") endif() - enable_warnings_as_errors(dorado) + #enable_warnings_as_errors(dorado) if (DORADO_ENABLE_PCH) target_precompile_headers(dorado REUSE_FROM dorado_lib) @@ -515,7 +515,7 @@ if(NOT CMAKE_SYSTEM_NAME STREQUAL "iOS") ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" ) - enable_warnings_as_errors(dorado_io_lib) + #enable_warnings_as_errors(dorado_io_lib) if (NOT DORADO_LIB_ONLY) add_subdirectory(tests) diff --git a/dorado/utils/CMakeLists.txt b/dorado/utils/CMakeLists.txt index 49a0a979..155f4895 100644 --- a/dorado/utils/CMakeLists.txt +++ b/dorado/utils/CMakeLists.txt @@ -120,4 +120,4 @@ if (ECM_ENABLE_SANITIZERS AND (CMAKE_CXX_COMPILER_ID MATCHES "GNU") AND (CMAKE_C set_source_files_properties(duplex_utils.cpp PROPERTIES COMPILE_OPTIONS "-O0") endif() -enable_warnings_as_errors(dorado_utils) \ No newline at end of file +#enable_warnings_as_errors(dorado_utils) \ No newline at end of file diff --git a/dorado/utils/sequence_utils.cpp b/dorado/utils/sequence_utils.cpp index bc82cfcf..41232e05 100644 --- a/dorado/utils/sequence_utils.cpp +++ b/dorado/utils/sequence_utils.cpp @@ -238,13 +238,12 @@ std::vector realign_moves(std::string query_sequence, std::string target_sequence, std::vector moves) { //Initially let's just spread the moves evenly, we can come back to this later - std::vector new_moves; int num_moves = std::accumulate(moves.begin(), moves.end(), 0); int input_seq_size = query_sequence.size(); int target_seq_size = target_sequence.size(); - std::cerr << num_moves; - std::cerr << input_seq_size; - std::cerr << target_seq_size; + //std::cerr << num_moves; + //std::cerr << input_seq_size; + //std::cerr << target_seq_size; auto [is_overlap, query_start, query_end, target_start, target_end] = compute_overlap( query_sequence, @@ -264,13 +263,100 @@ std::vector realign_moves(std::string query_sequence, query_sequence_component.data(), static_cast(query_sequence_component.length()), align_config); - std::cerr << "TSC:" << target_sequence_component << std::endl; - std::cerr << "QSC:" << query_sequence_component << std::endl; + //std::cerr << "TSC:" << target_sequence_component << std::endl; + //std::cerr << "QSC:" << query_sequence_component << std::endl; // Now that we have the alignment, we need to compute the new move table, by walking along the alignment + const auto alignment_size = + static_cast(edlib_result.endLocations[0] - edlib_result.startLocations[0]); + std::vector alignment; + alignment.resize(alignment_size); + std::memcpy(alignment.data(), &edlib_result.alignment[edlib_result.startLocations[0]], + alignment_size); + + std::vector new_moves; + + // Let's keep two cursor positions - one for the new move table and one for the old: + int new_move_cursor = 0; + int old_move_cursor = + 0; // Need to update to be the query start. // QUESTION do we need to worry about the start and end locations. + // Let's keep two cursor positions - one for the query sequence and one for the target: + int query_seq_cursor = query_start; + int target_seq_cursor = target_start; + + int moves_found = 0; + + while (moves_found < moves.size() && + moves_found < + query_start) { // TODO - is "query start" zero indexed? need to think about that + moves_found += moves[old_move_cursor]; + ++old_move_cursor; + } + --old_move_cursor; // We have gone one too far. + int old_moves_offset = old_move_cursor; + + // static constexpr unsigned char kAlignMatch = 0; + // static constexpr unsigned char kAlignInsertionToTarget = 1; + // static constexpr unsigned char kAlignInsertionToQuery = 2; + + // First thing to do - let's just print out the alignment line by line so we know it's working. + for (auto alignment_entry : alignment) { + if (alignment_entry == + 0) { //Match, need to update the new move table and move the cursor of the old move table. + std::cerr << query_sequence[query_seq_cursor] << "/" + << target_sequence[target_seq_cursor] << std::endl; + int a = moves[old_move_cursor]; + std::cerr << a; + new_moves.push_back(1); // We have a match so we need a 1 + new_move_cursor++; + old_move_cursor++; + + while (moves[old_move_cursor] == 0) { + if (old_move_cursor < (new_move_cursor + old_moves_offset)) { + new_moves.push_back(1); + old_move_cursor++; + } else { + new_moves.push_back(0); + } + // Unless there's a new/old mismatch - in which case we need to catch up by adding 1s. TODO this later. + new_move_cursor++; + old_move_cursor++; + } + // Update the Query and target seq cursors + query_seq_cursor++; + target_seq_cursor++; + } + if (alignment_entry == 1) { //Insertion to target + //std::cerr << "-" << "/" << target_sequence[target_seq_cursor] << std::endl; + // If we have an insertion in the target, we need to add a 1 to the new move table, and increment the new move table cursor. the old move table cursor and new are now out of sync and need fixing. + new_moves.push_back(1); + new_move_cursor++; + target_seq_cursor++; + } + if (alignment_entry == 2) { //Insertion to Query + // We have a query insertion, all we need to do is add zeros to the new move table to make it up, the signal can be assigned to the leftmost nucleotide in the sequence. + new_moves.push_back(0); + new_move_cursor++; + old_move_cursor++; + while (moves[old_move_cursor] == 0) { + new_moves.push_back(0); + old_move_cursor++; + new_move_cursor++; + } + // Update the Query and target seq cursors + query_seq_cursor++; + } + } + edlibFreeAlignResult(edlib_result); + // Need to return: + // 1. Moves start + // 2. Target sequence Start + // 3. Moves end + // 3. Target sequence end + return new_moves; } From c82ddb7387093018c23f8ddf354a6942baaa05d5 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Wed, 22 Nov 2023 14:42:18 +0000 Subject: [PATCH 07/39] Sending duplex read template chunks for mod base calling --- dorado/read_pipeline/ModBaseCallerNode.cpp | 80 ++++++++++++++++++++-- dorado/utils/sequence_utils.cpp | 32 ++------- dorado/utils/sequence_utils.h | 6 +- 3 files changed, 85 insertions(+), 33 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index e81bda56..1f6ca7b7 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -15,6 +15,7 @@ #include #include +#include using namespace std::chrono_literals; namespace dorado { @@ -168,8 +169,6 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { working_read->num_modbase_chunks = 0; working_read->num_modbase_chunks_called = 0; - std::vector sequence_ints = utils::sequence_to_ints(read->read_common.seq); - // all runners have the same set of callers, so we only need to use the first one auto& runner = m_runners[0]; std::vector>> chunks_to_enqueue_by_caller( @@ -178,13 +177,84 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { nvtx3::scoped_range range{"generate_chunks"}; - //auto signal_len = read->stereo_feature_inputs.template_signal.size(0); - - std::vector template_moves = utils::realign_moves( + auto [moves_offset, target_start, new_move_table] = utils::realign_moves( read->stereo_feature_inputs.template_seq, read->read_common.seq, read->stereo_feature_inputs.template_moves); // Next - build the sig to seq map. // What we need first is a new moves table for the template to the duplex read. + + auto signal_len = new_move_table.size() * m_block_stride; + auto num_moves = std::accumulate(new_move_table.begin(), new_move_table.end(), 0); + auto new_seq = read->read_common.seq.substr(target_start, num_moves); + std::vector sequence_ints = utils::sequence_to_ints(new_seq); + + std::vector seq_to_sig_map = + utils::moves_to_map(new_move_table, m_block_stride, signal_len, num_moves + 1); + auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); + auto& params = runner->caller_params(caller_id); + auto signal = read->stereo_feature_inputs.template_signal.slice(0, moves_offset, + moves_offset + signal_len); + + // scale signal based on model parameters + auto scaled_signal = runner->scale_signal(caller_id, signal, sequence_ints, seq_to_sig_map); + + auto context_samples = (params.context_before + params.context_after); + + // One-hot encodes the kmer at each signal step for input into the network + ModBaseEncoder encoder(m_block_stride, context_samples, params.bases_before, + params.bases_after); + encoder.init(sequence_ints, seq_to_sig_map); + + auto context_hits = runner->get_motif_hits(caller_id, new_seq); + m_num_context_hits += static_cast(context_hits.size()); + chunks_to_enqueue.reserve(context_hits.size()); + + for (auto context_hit : context_hits) { + nvtx3::scoped_range range{"create_chunk"}; + auto slice = encoder.get_context(context_hit); + // signal + auto input_signal = scaled_signal.index({at::indexing::Slice( + slice.first_sample, slice.first_sample + slice.num_samples)}); + if (slice.lead_samples_needed != 0 || slice.tail_samples_needed != 0) { + input_signal = at::constant_pad_nd( + input_signal, + {(int64_t)slice.lead_samples_needed, (int64_t)slice.tail_samples_needed}); + } + chunks_to_enqueue.push_back(std::make_unique( + working_read, input_signal, std::move(slice.data), context_hit)); + + ++working_read->num_modbase_chunks; + } + std::cerr << "Context hits done" << std::endl; + } + + m_chunk_generation_ms += timer.GetElapsedMS(); + + if (working_read->num_modbase_chunks != 0) { + // Hand over our ownership to the working read + working_read->read = std::move(read); + + // Put the read in the working list + { + std::lock_guard working_reads_lock(m_working_reads_mutex); + m_working_reads.insert(std::move(working_read)); + ++m_working_reads_size; + } + + // push the chunks to the chunk queues + // needs to be done after working_read->read is set as chunks could be processed + // before we set that value otherwise + for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { + auto& chunk_queue = m_chunk_queues.at(caller_id); + auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); + for (auto& chunk : chunks_to_enqueue) { + chunk_queue->try_push(std::move(chunk)); + } + } + } else { + // No modbases to call, pass directly to next node + send_message_to_sink(std::move(read)); + ++m_num_non_mod_base_reads_pushed; } send_message_to_sink(std::move(read)); diff --git a/dorado/utils/sequence_utils.cpp b/dorado/utils/sequence_utils.cpp index 41232e05..dce5a7fb 100644 --- a/dorado/utils/sequence_utils.cpp +++ b/dorado/utils/sequence_utils.cpp @@ -234,17 +234,11 @@ OverlapResult compute_overlap(std::string query_seq, std::string target_seq) { return overlap_result; } -std::vector realign_moves(std::string query_sequence, - std::string target_sequence, - std::vector moves) { - //Initially let's just spread the moves evenly, we can come back to this later - int num_moves = std::accumulate(moves.begin(), moves.end(), 0); - int input_seq_size = query_sequence.size(); - int target_seq_size = target_sequence.size(); - //std::cerr << num_moves; - //std::cerr << input_seq_size; - //std::cerr << target_seq_size; - +// Query is the read that the moves table is associated with. A new moves table will be generated +// Which is aligned to the target sequence. +std::tuple> realign_moves(std::string query_sequence, + std::string target_sequence, + std::vector moves) { auto [is_overlap, query_start, query_end, target_start, target_end] = compute_overlap( query_sequence, target_sequence); // We are going to compute the overlap between the two reads @@ -263,9 +257,6 @@ std::vector realign_moves(std::string query_sequence, query_sequence_component.data(), static_cast(query_sequence_component.length()), align_config); - //std::cerr << "TSC:" << target_sequence_component << std::endl; - //std::cerr << "QSC:" << query_sequence_component << std::endl; - // Now that we have the alignment, we need to compute the new move table, by walking along the alignment const auto alignment_size = @@ -280,7 +271,7 @@ std::vector realign_moves(std::string query_sequence, // Let's keep two cursor positions - one for the new move table and one for the old: int new_move_cursor = 0; int old_move_cursor = - 0; // Need to update to be the query start. // QUESTION do we need to worry about the start and end locations. + 0; // Need to update to be the query start. // TODO do we need to worry about the start and end locations. // Let's keep two cursor positions - one for the query sequence and one for the target: int query_seq_cursor = query_start; int target_seq_cursor = target_start; @@ -296,18 +287,10 @@ std::vector realign_moves(std::string query_sequence, --old_move_cursor; // We have gone one too far. int old_moves_offset = old_move_cursor; - // static constexpr unsigned char kAlignMatch = 0; - // static constexpr unsigned char kAlignInsertionToTarget = 1; - // static constexpr unsigned char kAlignInsertionToQuery = 2; - // First thing to do - let's just print out the alignment line by line so we know it's working. for (auto alignment_entry : alignment) { if (alignment_entry == 0) { //Match, need to update the new move table and move the cursor of the old move table. - std::cerr << query_sequence[query_seq_cursor] << "/" - << target_sequence[target_seq_cursor] << std::endl; - int a = moves[old_move_cursor]; - std::cerr << a; new_moves.push_back(1); // We have a match so we need a 1 new_move_cursor++; old_move_cursor++; @@ -328,7 +311,6 @@ std::vector realign_moves(std::string query_sequence, target_seq_cursor++; } if (alignment_entry == 1) { //Insertion to target - //std::cerr << "-" << "/" << target_sequence[target_seq_cursor] << std::endl; // If we have an insertion in the target, we need to add a 1 to the new move table, and increment the new move table cursor. the old move table cursor and new are now out of sync and need fixing. new_moves.push_back(1); new_move_cursor++; @@ -357,7 +339,7 @@ std::vector realign_moves(std::string query_sequence, // 3. Moves end // 3. Target sequence end - return new_moves; + return {old_moves_offset, target_start, new_moves}; } std::vector move_cum_sums(const std::vector& moves) { diff --git a/dorado/utils/sequence_utils.h b/dorado/utils/sequence_utils.h index f39a0b79..0ed1b69f 100644 --- a/dorado/utils/sequence_utils.h +++ b/dorado/utils/sequence_utils.h @@ -44,7 +44,7 @@ class BaseInfo { int count_trailing_chars(const std::string_view adapter, char c); -std::vector realign_moves(std::string query_sequence, - std::string target_sequence, - std::vector moves); +std::tuple> realign_moves(std::string query_sequence, + std::string target_sequence, + std::vector moves); } // namespace dorado::utils From 0b9187bc9aa716fbb34edd135f9761640018dcc6 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Wed, 22 Nov 2023 15:04:16 +0000 Subject: [PATCH 08/39] Duplex reads for template being generated, requires re-adjustment of offsets --- dorado/read_pipeline/ModBaseCallerNode.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index 1f6ca7b7..60f15e71 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -257,7 +257,7 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { ++m_num_non_mod_base_reads_pushed; } - send_message_to_sink(std::move(read)); + //send_message_to_sink(std::move(read)); } void ModBaseCallerNode::simplex_mod_call(Message message) { From 0fd89bf296f50574964853aa3527017e7d647ec5 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Thu, 23 Nov 2023 14:36:10 +0000 Subject: [PATCH 09/39] Bugfix --- dorado/read_pipeline/ModBaseCallerNode.cpp | 63 +++++++++++++++++++--- 1 file changed, 55 insertions(+), 8 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index 60f15e71..e66264e9 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -12,10 +12,12 @@ #include #include #include +#include #include #include #include + using namespace std::chrono_literals; namespace dorado { @@ -143,12 +145,28 @@ void ModBaseCallerNode::init_modbase_info() { m_base_prob_offsets[3] = m_base_prob_offsets[2] + result.base_counts[2]; } +void serializeVector(const std::vector& vec, const std::string& filename) { + // Open a file in binary mode + std::ofstream file(filename, std::ios::binary); + + // Write the size of the vector (number of elements) + long size = vec.size(); + file.write(reinterpret_cast(&size), sizeof(long)); + + // Write the vector data + file.write(reinterpret_cast(vec.data()), size * sizeof(unsigned char)); + + // Close the file + file.close(); +} + void ModBaseCallerNode::duplex_mod_call(Message message) { // Let's do this only for the template strand for now. auto read = std::get(std::move(message)); stats::Timer timer; + // TODO: Does `base_mod_probs` need to be the same size as `new_seq`? { nvtx3::scoped_range range{"base_mod_probs_init"}; // initialize base_mod_probs _before_ we start handing out chunks @@ -177,24 +195,55 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { nvtx3::scoped_range range{"generate_chunks"}; + // Next - build the sig to seq map. + // What we need first is a new moves table for the template to the duplex read. auto [moves_offset, target_start, new_move_table] = utils::realign_moves( read->stereo_feature_inputs.template_seq, read->read_common.seq, read->stereo_feature_inputs.template_moves); - // Next - build the sig to seq map. - // What we need first is a new moves table for the template to the duplex read. auto signal_len = new_move_table.size() * m_block_stride; auto num_moves = std::accumulate(new_move_table.begin(), new_move_table.end(), 0); - auto new_seq = read->read_common.seq.substr(target_start, num_moves); + auto new_seq = read->read_common.seq.substr( + target_start, num_moves); // temporary- change 0 back to target_start? std::vector sequence_ints = utils::sequence_to_ints(new_seq); std::vector seq_to_sig_map = utils::moves_to_map(new_move_table, m_block_stride, signal_len, num_moves + 1); auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); auto& params = runner->caller_params(caller_id); - auto signal = read->stereo_feature_inputs.template_signal.slice(0, moves_offset, - moves_offset + signal_len); - + auto signal = read->stereo_feature_inputs.template_signal.slice( + 0, moves_offset * m_block_stride, moves_offset * m_block_stride + signal_len); + + if (read->read_common.read_id == + "fa3d4195-5ee1-4ab7-b048-9ce004292b62;dae07e1e-d2a9-44eb-8e8c-282491a977a9") { + serializeVector(new_move_table, "duplex_move_table.bin"); + torch::save(signal, "duplex_signal.pt"); + // Open a file in write mode + std::ofstream file("duplex_seq.txt"); + + // Write the string to the file + file << new_seq; + // Close the file + file.close(); + + serializeVector(read->stereo_feature_inputs.template_moves, "simplex_move_table.bin"); + torch::save(read->stereo_feature_inputs.template_signal, "simplex_signal.pt"); + // Open a file in write mode + std::ofstream sfile("simplex_seq.txt"); + + // Write the string to the file + sfile << read->stereo_feature_inputs + .template_seq; // TODO understnad why this is necessary + // Close the file + sfile.close(); + + std::cerr << "Found and serialised read of interest" << std::endl; + + std::cerr << std::endl; + std::cerr << new_seq.substr(0, 100) << std::endl; + std::cerr << read->read_common.seq.substr(0, 100) << std::endl; + std::cerr << "Found and serialised read of interest" << std::endl; + } // scale signal based on model parameters auto scaled_signal = runner->scale_signal(caller_id, signal, sequence_ints, seq_to_sig_map); @@ -256,8 +305,6 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { send_message_to_sink(std::move(read)); ++m_num_non_mod_base_reads_pushed; } - - //send_message_to_sink(std::move(read)); } void ModBaseCallerNode::simplex_mod_call(Message message) { From 9fb48d8298d683646b466fa3a509027d50ee6304 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Thu, 23 Nov 2023 17:28:16 +0000 Subject: [PATCH 10/39] Fixed several off-by-N errors in duplex/simplex alignment --- dorado/read_pipeline/ModBaseCallerNode.cpp | 3 ++- dorado/utils/sequence_utils.cpp | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index e66264e9..a4f02541 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -270,7 +270,8 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { {(int64_t)slice.lead_samples_needed, (int64_t)slice.tail_samples_needed}); } chunks_to_enqueue.push_back(std::make_unique( - working_read, input_signal, std::move(slice.data), context_hit)); + working_read, input_signal, std::move(slice.data), + context_hit + target_start)); // TODO do we need to update the context hit here ++working_read->num_modbase_chunks; } diff --git a/dorado/utils/sequence_utils.cpp b/dorado/utils/sequence_utils.cpp index dce5a7fb..334197a3 100644 --- a/dorado/utils/sequence_utils.cpp +++ b/dorado/utils/sequence_utils.cpp @@ -243,6 +243,7 @@ std::tuple> realign_moves(std::string query_seque query_sequence, target_sequence); // We are going to compute the overlap between the two reads + // TODO sanity check if and why this is needed // Now let's perform an alignmnet: EdlibAlignConfig align_config = edlibDefaultAlignConfig(); @@ -297,14 +298,13 @@ std::tuple> realign_moves(std::string query_seque while (moves[old_move_cursor] == 0) { if (old_move_cursor < (new_move_cursor + old_moves_offset)) { - new_moves.push_back(1); old_move_cursor++; } else { new_moves.push_back(0); + new_move_cursor++; + old_move_cursor++; } // Unless there's a new/old mismatch - in which case we need to catch up by adding 1s. TODO this later. - new_move_cursor++; - old_move_cursor++; } // Update the Query and target seq cursors query_seq_cursor++; @@ -339,7 +339,7 @@ std::tuple> realign_moves(std::string query_seque // 3. Moves end // 3. Target sequence end - return {old_moves_offset, target_start, new_moves}; + return {old_moves_offset, target_start - 1, new_moves}; } std::vector move_cum_sums(const std::vector& moves) { From 3f9dbd1f2b9f0ef69931d98cb196e932e958de30 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Fri, 24 Nov 2023 14:32:41 +0000 Subject: [PATCH 11/39] Fixed bug with mismatches --- dorado/utils/sequence_utils.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/dorado/utils/sequence_utils.cpp b/dorado/utils/sequence_utils.cpp index 334197a3..15654dd1 100644 --- a/dorado/utils/sequence_utils.cpp +++ b/dorado/utils/sequence_utils.cpp @@ -290,8 +290,9 @@ std::tuple> realign_moves(std::string query_seque // First thing to do - let's just print out the alignment line by line so we know it's working. for (auto alignment_entry : alignment) { - if (alignment_entry == - 0) { //Match, need to update the new move table and move the cursor of the old move table. + if ((alignment_entry == 0) || + (alignment_entry == + 3)) { //Match or mismatch, need to update the new move table and move the cursor of the old move table. new_moves.push_back(1); // We have a match so we need a 1 new_move_cursor++; old_move_cursor++; @@ -309,14 +310,12 @@ std::tuple> realign_moves(std::string query_seque // Update the Query and target seq cursors query_seq_cursor++; target_seq_cursor++; - } - if (alignment_entry == 1) { //Insertion to target + } else if (alignment_entry == 1) { //Insertion to target // If we have an insertion in the target, we need to add a 1 to the new move table, and increment the new move table cursor. the old move table cursor and new are now out of sync and need fixing. new_moves.push_back(1); new_move_cursor++; target_seq_cursor++; - } - if (alignment_entry == 2) { //Insertion to Query + } else if (alignment_entry == 2) { //Insertion to Query // We have a query insertion, all we need to do is add zeros to the new move table to make it up, the signal can be assigned to the leftmost nucleotide in the sequence. new_moves.push_back(0); new_move_cursor++; From b9590828e9075a8fdb87e6472d1a0169a3a5e0f2 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Mon, 27 Nov 2023 14:57:06 +0000 Subject: [PATCH 12/39] Complement chunks created and sent --- dorado/read_pipeline/ModBaseCallerNode.cpp | 231 +++++++++++---------- 1 file changed, 120 insertions(+), 111 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index a4f02541..a33d0a97 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -183,128 +183,137 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { read->read_common.mod_base_info = m_mod_base_info; - auto working_read = std::make_shared(); - working_read->num_modbase_chunks = 0; - working_read->num_modbase_chunks_called = 0; - - // all runners have the same set of callers, so we only need to use the first one - auto& runner = m_runners[0]; - std::vector>> chunks_to_enqueue_by_caller( - runner->num_callers()); - - for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { - nvtx3::scoped_range range{"generate_chunks"}; - - // Next - build the sig to seq map. - // What we need first is a new moves table for the template to the duplex read. - auto [moves_offset, target_start, new_move_table] = utils::realign_moves( - read->stereo_feature_inputs.template_seq, read->read_common.seq, - read->stereo_feature_inputs.template_moves); - - auto signal_len = new_move_table.size() * m_block_stride; - auto num_moves = std::accumulate(new_move_table.begin(), new_move_table.end(), 0); - auto new_seq = read->read_common.seq.substr( - target_start, num_moves); // temporary- change 0 back to target_start? - std::vector sequence_ints = utils::sequence_to_ints(new_seq); - - std::vector seq_to_sig_map = - utils::moves_to_map(new_move_table, m_block_stride, signal_len, num_moves + 1); - auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); - auto& params = runner->caller_params(caller_id); - auto signal = read->stereo_feature_inputs.template_signal.slice( - 0, moves_offset * m_block_stride, moves_offset * m_block_stride + signal_len); - - if (read->read_common.read_id == - "fa3d4195-5ee1-4ab7-b048-9ce004292b62;dae07e1e-d2a9-44eb-8e8c-282491a977a9") { - serializeVector(new_move_table, "duplex_move_table.bin"); - torch::save(signal, "duplex_signal.pt"); - // Open a file in write mode - std::ofstream file("duplex_seq.txt"); - - // Write the string to the file - file << new_seq; - // Close the file - file.close(); - - serializeVector(read->stereo_feature_inputs.template_moves, "simplex_move_table.bin"); - torch::save(read->stereo_feature_inputs.template_signal, "simplex_signal.pt"); - // Open a file in write mode - std::ofstream sfile("simplex_seq.txt"); - - // Write the string to the file - sfile << read->stereo_feature_inputs - .template_seq; // TODO understnad why this is necessary - // Close the file - sfile.close(); - - std::cerr << "Found and serialised read of interest" << std::endl; - - std::cerr << std::endl; - std::cerr << new_seq.substr(0, 100) << std::endl; - std::cerr << read->read_common.seq.substr(0, 100) << std::endl; - std::cerr << "Found and serialised read of interest" << std::endl; - } - // scale signal based on model parameters - auto scaled_signal = runner->scale_signal(caller_id, signal, sequence_ints, seq_to_sig_map); + { + auto working_read = std::make_shared(); + working_read->num_modbase_chunks = 0; + working_read->num_modbase_chunks_called = 0; - auto context_samples = (params.context_before + params.context_after); + // all runners have the same set of callers, so we only need to use the first one + auto& runner = m_runners[0]; + std::vector>> chunks_to_enqueue_by_caller( + runner->num_callers()); - // One-hot encodes the kmer at each signal step for input into the network - ModBaseEncoder encoder(m_block_stride, context_samples, params.bases_before, - params.bases_after); - encoder.init(sequence_ints, seq_to_sig_map); + std::vector all_context_hits; - auto context_hits = runner->get_motif_hits(caller_id, new_seq); - m_num_context_hits += static_cast(context_hits.size()); - chunks_to_enqueue.reserve(context_hits.size()); - - for (auto context_hit : context_hits) { - nvtx3::scoped_range range{"create_chunk"}; - auto slice = encoder.get_context(context_hit); - // signal - auto input_signal = scaled_signal.index({at::indexing::Slice( - slice.first_sample, slice.first_sample + slice.num_samples)}); - if (slice.lead_samples_needed != 0 || slice.tail_samples_needed != 0) { - input_signal = at::constant_pad_nd( - input_signal, - {(int64_t)slice.lead_samples_needed, (int64_t)slice.tail_samples_needed}); + for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { + for (const auto& strand_type : {"template", "complement"}) { + nvtx3::scoped_range range{"generate_chunks"}; + + auto& simplex_moves = (strcmp(strand_type, "template") == 0) + ? read->stereo_feature_inputs.template_moves + : read->stereo_feature_inputs.complement_moves; + + auto& simplex_signal = (strcmp(strand_type, "template") == 0) + ? read->stereo_feature_inputs.template_signal + : read->stereo_feature_inputs.complement_signal; + + auto& simplex_seq = (strcmp(strand_type, "template") == 0) + ? read->stereo_feature_inputs.template_seq + : read->stereo_feature_inputs.complement_seq; + + auto duplex_seq = (strcmp(strand_type, "template") == 0) + ? read->read_common.seq + : utils::reverse_complement(read->read_common.seq); + + auto [moves_offset, target_start, new_move_table] = + utils::realign_moves(simplex_seq, duplex_seq, simplex_moves); + + auto signal_len = new_move_table.size() * m_block_stride; + auto num_moves = std::accumulate(new_move_table.begin(), new_move_table.end(), 0); + auto new_seq = duplex_seq.substr(target_start, num_moves); + std::vector sequence_ints = utils::sequence_to_ints(new_seq); + + std::vector seq_to_sig_map = utils::moves_to_map( + new_move_table, m_block_stride, signal_len, num_moves + 1); + auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); + auto& params = runner->caller_params(caller_id); + auto signal = simplex_signal.slice(0, moves_offset * m_block_stride, + moves_offset * m_block_stride + signal_len); + + // scale signal based on model parameters + auto scaled_signal = + runner->scale_signal(caller_id, signal, sequence_ints, seq_to_sig_map); + + auto context_samples = (params.context_before + params.context_after); + + // One-hot encodes the kmer at each signal step for input into the network + ModBaseEncoder encoder(m_block_stride, context_samples, params.bases_before, + params.bases_after); + encoder.init(sequence_ints, seq_to_sig_map); + + auto context_hits = runner->get_motif_hits(caller_id, new_seq); + m_num_context_hits += static_cast(context_hits.size()); + chunks_to_enqueue.reserve(context_hits.size()); + + for (auto context_hit : context_hits) { + nvtx3::scoped_range range{"create_chunk"}; + auto slice = encoder.get_context(context_hit); + // signal + auto input_signal = scaled_signal.index({at::indexing::Slice( + slice.first_sample, slice.first_sample + slice.num_samples)}); + if (slice.lead_samples_needed != 0 || slice.tail_samples_needed != 0) { + input_signal = at::constant_pad_nd(input_signal, + {(int64_t)slice.lead_samples_needed, + (int64_t)slice.tail_samples_needed}); + } + + // Update the context hit into the duplex reference context + unsigned long context_hit_in_duplex_space; + if (std::strcmp(strand_type, "template") == 0) { + context_hit_in_duplex_space = context_hit + target_start; + } else { + //std::cerr<< strand_type << std::endl; + context_hit_in_duplex_space = + read->read_common.seq.size() - + (context_hit + target_start + + 1); // Sanity check: Need to check the plus 1, does it need to go somewhere else? + /* std:: cerr<< read->read_common.seq.size() << std::endl; + std::cerr << context_hit << std::endl; + std::cerr << target_start << std::endl; + std::cerr << "Checkpoint" << std::endl;*/ + } + + chunks_to_enqueue.push_back(std::make_unique( + working_read, input_signal, std::move(slice.data), + context_hit_in_duplex_space)); // TODO do we need to update the context hit here + + all_context_hits.push_back(context_hit_in_duplex_space); + ++working_read->num_modbase_chunks; + } + //std::cerr << "CH: " << all_context_hits << std::endl; + //std::cerr << "Read ID " << read->read_common.read_id << std::endl; + //std::cerr << "Context hits done! "<< std::endl; } - chunks_to_enqueue.push_back(std::make_unique( - working_read, input_signal, std::move(slice.data), - context_hit + target_start)); // TODO do we need to update the context hit here - - ++working_read->num_modbase_chunks; } - std::cerr << "Context hits done" << std::endl; - } - m_chunk_generation_ms += timer.GetElapsedMS(); + m_chunk_generation_ms += timer.GetElapsedMS(); - if (working_read->num_modbase_chunks != 0) { - // Hand over our ownership to the working read - working_read->read = std::move(read); + if (working_read->num_modbase_chunks != 0) { + // Hand over our ownership to the working read + working_read->read = std::move(read); - // Put the read in the working list - { - std::lock_guard working_reads_lock(m_working_reads_mutex); - m_working_reads.insert(std::move(working_read)); - ++m_working_reads_size; - } + // Put the read in the working list + { + std::lock_guard working_reads_lock(m_working_reads_mutex); + m_working_reads.insert(std::move(working_read)); + ++m_working_reads_size; + } - // push the chunks to the chunk queues - // needs to be done after working_read->read is set as chunks could be processed - // before we set that value otherwise - for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { - auto& chunk_queue = m_chunk_queues.at(caller_id); - auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); - for (auto& chunk : chunks_to_enqueue) { - chunk_queue->try_push(std::move(chunk)); + // push the chunks to the chunk queues + // needs to be done after working_read->read is set as chunks could be processed + // before we set that value otherwise + for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { + auto& chunk_queue = m_chunk_queues.at(caller_id); + auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); + for (auto& chunk : chunks_to_enqueue) { + chunk_queue->try_push(std::move(chunk)); + } } + } else { + // No modbases to call, pass directly to next node + send_message_to_sink(std::move(read)); + ++m_num_non_mod_base_reads_pushed; } - } else { - // No modbases to call, pass directly to next node - send_message_to_sink(std::move(read)); - ++m_num_non_mod_base_reads_pushed; } } From ba95a39c2cdfc7b2c5ceac81831bb0ed4b14a8a7 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Tue, 28 Nov 2023 11:21:32 +0000 Subject: [PATCH 13/39] Producing (incorrect) complement probabilities --- dorado/read_pipeline/ModBaseCallerNode.cpp | 13 ++---- dorado/read_pipeline/ReadPipeline.cpp | 49 +++++++++++++++++++++- 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index a33d0a97..5fe91fbb 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -263,19 +263,14 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { context_hit_in_duplex_space = context_hit + target_start; } else { //std::cerr<< strand_type << std::endl; - context_hit_in_duplex_space = - read->read_common.seq.size() - - (context_hit + target_start + - 1); // Sanity check: Need to check the plus 1, does it need to go somewhere else? - /* std:: cerr<< read->read_common.seq.size() << std::endl; - std::cerr << context_hit << std::endl; - std::cerr << target_start << std::endl; - std::cerr << "Checkpoint" << std::endl;*/ + context_hit_in_duplex_space = read->read_common.seq.size() - + (context_hit + target_start + + 1); // TODO: Do I need a plus one here? Why? } chunks_to_enqueue.push_back(std::make_unique( working_read, input_signal, std::move(slice.data), - context_hit_in_duplex_space)); // TODO do we need to update the context hit here + context_hit_in_duplex_space)); all_context_hits.push_back(context_hit_in_duplex_space); ++working_read->num_modbase_chunks; diff --git a/dorado/read_pipeline/ReadPipeline.cpp b/dorado/read_pipeline/ReadPipeline.cpp index 7f5860b8..b2d94822 100644 --- a/dorado/read_pipeline/ReadPipeline.cpp +++ b/dorado/read_pipeline/ReadPipeline.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -139,6 +140,9 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { return; } + //if(read_id == "fed86c0e-8ace-4e78-8f35-b525233de130;b2958b74-31c8-4aa4-91e1-db81d64e40e3") { // Temporary measuure - only mod call this read + std::cerr << "Found it" << std::endl; + const size_t num_channels = mod_base_info->alphabet.size(); const std::string cardinal_bases = "ACGT"; char current_cardinal = 0; @@ -171,13 +175,15 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { threshold); // Iterate over the provided alphabet and find all the channels we need to write out - for (size_t channel_idx = 0; channel_idx < num_channels; channel_idx++) { + for (size_t channel_idx = 0; channel_idx < num_channels; + channel_idx++) { // Loop over each channel. Writing out the if (cardinal_bases.find(mod_base_info->alphabet[channel_idx]) != std::string::npos) { // A cardinal base current_cardinal = mod_base_info->alphabet[channel_idx][0]; } else { // A modification on the previous cardinal base std::string bam_name = mod_base_info->alphabet[channel_idx]; + std::cerr << bam_name << std::endl; if (!utils::validate_bam_tag_code(bam_name)) { return; } @@ -203,6 +209,47 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { } } + if (read_id == + "fed86c0e-8ace-4e78-8f35-b525233de130;b2958b74-31c8-4aa4-91e1-" + "db81d64e40e3") { //temporary debug thing + // Now let's do the complement + for (size_t channel_idx = 0; channel_idx < num_channels; + channel_idx++) { // Loop over each channel. Writing out the + if (cardinal_bases.find(mod_base_info->alphabet[channel_idx]) != std::string::npos) { + // A cardinal base + current_cardinal = mod_base_info->alphabet[channel_idx][0]; + } else { + if (current_cardinal == 'C') { + // A modification on the previous cardinal base + std::string bam_name = mod_base_info->alphabet[channel_idx]; + std::cerr << bam_name << std::endl; + if (!utils::validate_bam_tag_code(bam_name)) { + return; + } + + // Write out the results we found + modbase_string += std::string(1, 'G') + "-" + bam_name; //TODO need to RC + modbase_string += base_has_context[current_cardinal] ? "?" : "."; + int skipped_bases = 0; + for (size_t base_idx = 0; base_idx < seq.size(); base_idx++) { + if (seq[base_idx] == 'G') { // complement + if (true) { // Not sure this one is right + modbase_string += "," + std::to_string(skipped_bases); + skipped_bases = 0; + modbase_prob.push_back(base_mod_probs[base_idx * num_channels + + 4]); // Channel index 4 for G + } else { + // Skip this base + skipped_bases++; + } + } + } + modbase_string += ";"; + } + } + } + } + bam_aux_append(aln, "MM", 'Z', int(modbase_string.length() + 1), (uint8_t *)modbase_string.c_str()); bam_aux_update_array(aln, "ML", 'C', int(modbase_prob.size()), (uint8_t *)modbase_prob.data()); From 997e482aabfaabf5a03fc8a4e16cbab80dfffbc8 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Tue, 28 Nov 2023 23:39:53 +0000 Subject: [PATCH 14/39] Fixed complement modds, mod basecalling now working --- dorado/read_pipeline/ModBaseCallerNode.cpp | 71 +++++++++++++++++++--- dorado/read_pipeline/ReadPipeline.cpp | 6 +- dorado/utils/sequence_utils.cpp | 8 +-- dorado/utils/sequence_utils.h | 6 +- 4 files changed, 73 insertions(+), 18 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index 5fe91fbb..9c9237e1 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -203,20 +203,31 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { ? read->stereo_feature_inputs.template_moves : read->stereo_feature_inputs.complement_moves; - auto& simplex_signal = (strcmp(strand_type, "template") == 0) - ? read->stereo_feature_inputs.template_signal - : read->stereo_feature_inputs.complement_signal; + auto simplex_signal = + (strcmp(strand_type, "template") == 0) + ? read->stereo_feature_inputs.template_signal + : at::flip(read->stereo_feature_inputs.complement_signal, 0); + ; - auto& simplex_seq = (strcmp(strand_type, "template") == 0) - ? read->stereo_feature_inputs.template_seq - : read->stereo_feature_inputs.complement_seq; + auto simplex_seq = (strcmp(strand_type, "template") == 0) + ? read->stereo_feature_inputs.template_seq + : utils::reverse_complement( + read->stereo_feature_inputs.complement_seq); auto duplex_seq = (strcmp(strand_type, "template") == 0) ? read->read_common.seq : utils::reverse_complement(read->read_common.seq); - auto [moves_offset, target_start, new_move_table] = - utils::realign_moves(simplex_seq, duplex_seq, simplex_moves); + auto [moves_offset, target_start, new_move_table, query_start] = + utils::realign_moves(simplex_seq, duplex_seq, + simplex_moves); // TODO: Check that this is OK + + std::cerr << std::endl; + std::cerr << "Read ID: " << read->read_common.read_id << std::endl; + std::cerr << "Target Start: " << target_start << std::endl; + std::cerr << "Query Start: " << query_start << std::endl; + std::cerr << "Duplex Seq: " << duplex_seq << std::endl; + std::cerr << "Simplx Seq: " << simplex_seq << std::endl; auto signal_len = new_move_table.size() * m_block_stride; auto num_moves = std::accumulate(new_move_table.begin(), new_move_table.end(), 0); @@ -230,6 +241,37 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { auto signal = simplex_signal.slice(0, moves_offset * m_block_stride, moves_offset * m_block_stride + signal_len); + // Serialise some stuff for debugging + if (read->read_common.read_id == + "fa3d4195-5ee1-4ab7-b048-9ce004292b62;dae07e1e-d2a9-44eb-8e8c-282491a977a9") { + serializeVector(new_move_table, "duplex_move_table.bin"); + torch::save(signal, "duplex_signal.pt"); + // Open a file in write mode + std::ofstream file("duplex_seq.txt"); + + // Write the string to the file + file << new_seq; + // Close the file + file.close(); + + serializeVector(simplex_moves, "simplex_move_table.bin"); + torch::save(simplex_signal, "simplex_signal.pt"); + // Open a file in write mode + std::ofstream sfile("simplex_seq.txt"); + + // Write the string to the file + sfile << simplex_seq; // TODO understnad why this is necessary + // Close the file + sfile.close(); + + std::cerr << "Found and serialised read of interest" << std::endl; + + std::cerr << std::endl; + std::cerr << new_seq.substr(0, 100) << std::endl; + std::cerr << read->read_common.seq.substr(0, 100) << std::endl; + std::cerr << "Found and serialised read of interest" << std::endl; + } + // scale signal based on model parameters auto scaled_signal = runner->scale_signal(caller_id, signal, sequence_ints, seq_to_sig_map); @@ -330,6 +372,11 @@ void ModBaseCallerNode::simplex_mod_call(Message message) { } read->read_common.mod_base_info = m_mod_base_info; + if (read->read_common.read_id == "dab67b24-8c0e-4f62-af16-9a946fd3f682") { + std::cerr << "Original seq:" << read->read_common.seq << std::endl; + std::cerr << std::endl; + } + auto working_read = std::make_shared(); working_read->num_modbase_chunks = 0; working_read->num_modbase_chunks_called = 0; @@ -555,9 +602,15 @@ void ModBaseCallerNode::output_worker_thread() { auto& source_read_common = get_read_common_data(source_read); int64_t result_pos = chunk->context_hit; + //TODO - this is just an experiment, roll it back + /* int64_t offset = m_base_prob_offsets [utils::BaseInfo::BASE_IDS[source_read_common.seq[result_pos]]]; - for (size_t i = 0; i < chunk->scores.size(); ++i) { +*/ + int64_t offset = m_base_prob_offsets[utils::BaseInfo::BASE_IDS['C']]; + + auto num_chunk_scores = chunk->scores.size(); + for (size_t i = 0; i < num_chunk_scores; ++i) { source_read_common.base_mod_probs[m_num_states * result_pos + offset + i] = static_cast(std::min(std::floor(chunk->scores[i] * 256), 255.0f)); } diff --git a/dorado/read_pipeline/ReadPipeline.cpp b/dorado/read_pipeline/ReadPipeline.cpp index b2d94822..caffacd3 100644 --- a/dorado/read_pipeline/ReadPipeline.cpp +++ b/dorado/read_pipeline/ReadPipeline.cpp @@ -236,8 +236,10 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { if (true) { // Not sure this one is right modbase_string += "," + std::to_string(skipped_bases); skipped_bases = 0; - modbase_prob.push_back(base_mod_probs[base_idx * num_channels + - 4]); // Channel index 4 for G + modbase_prob.push_back( + base_mod_probs[base_idx * num_channels + channel_idx]); + //modbase_prob.push_back(base_mod_probs[base_idx * num_channels + + // 4]); // Channel index 4 for G } else { // Skip this base skipped_bases++; diff --git a/dorado/utils/sequence_utils.cpp b/dorado/utils/sequence_utils.cpp index 15654dd1..b8c496af 100644 --- a/dorado/utils/sequence_utils.cpp +++ b/dorado/utils/sequence_utils.cpp @@ -236,9 +236,9 @@ OverlapResult compute_overlap(std::string query_seq, std::string target_seq) { // Query is the read that the moves table is associated with. A new moves table will be generated // Which is aligned to the target sequence. -std::tuple> realign_moves(std::string query_sequence, - std::string target_sequence, - std::vector moves) { +std::tuple, int> realign_moves(std::string query_sequence, + std::string target_sequence, + std::vector moves) { auto [is_overlap, query_start, query_end, target_start, target_end] = compute_overlap( query_sequence, target_sequence); // We are going to compute the overlap between the two reads @@ -338,7 +338,7 @@ std::tuple> realign_moves(std::string query_seque // 3. Moves end // 3. Target sequence end - return {old_moves_offset, target_start - 1, new_moves}; + return {old_moves_offset, target_start - 1, new_moves, query_start}; } std::vector move_cum_sums(const std::vector& moves) { diff --git a/dorado/utils/sequence_utils.h b/dorado/utils/sequence_utils.h index 0ed1b69f..e22cc1ba 100644 --- a/dorado/utils/sequence_utils.h +++ b/dorado/utils/sequence_utils.h @@ -44,7 +44,7 @@ class BaseInfo { int count_trailing_chars(const std::string_view adapter, char c); -std::tuple> realign_moves(std::string query_sequence, - std::string target_sequence, - std::vector moves); +std::tuple, int> realign_moves(std::string query_sequence, + std::string target_sequence, + std::vector moves); } // namespace dorado::utils From 15e9dbb22076a232a686afff868a26ed7a711274 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Wed, 29 Nov 2023 10:54:23 +0000 Subject: [PATCH 15/39] Fixed complement modds, mod basecalling now working --- dorado/read_pipeline/ReadPipeline.cpp | 67 +++++++++++++-------------- 1 file changed, 31 insertions(+), 36 deletions(-) diff --git a/dorado/read_pipeline/ReadPipeline.cpp b/dorado/read_pipeline/ReadPipeline.cpp index caffacd3..67e1bbf0 100644 --- a/dorado/read_pipeline/ReadPipeline.cpp +++ b/dorado/read_pipeline/ReadPipeline.cpp @@ -209,49 +209,44 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { } } - if (read_id == - "fed86c0e-8ace-4e78-8f35-b525233de130;b2958b74-31c8-4aa4-91e1-" - "db81d64e40e3") { //temporary debug thing - // Now let's do the complement - for (size_t channel_idx = 0; channel_idx < num_channels; - channel_idx++) { // Loop over each channel. Writing out the - if (cardinal_bases.find(mod_base_info->alphabet[channel_idx]) != std::string::npos) { - // A cardinal base - current_cardinal = mod_base_info->alphabet[channel_idx][0]; - } else { - if (current_cardinal == 'C') { - // A modification on the previous cardinal base - std::string bam_name = mod_base_info->alphabet[channel_idx]; - std::cerr << bam_name << std::endl; - if (!utils::validate_bam_tag_code(bam_name)) { - return; - } + // Now let's do the complement + for (size_t channel_idx = 0; channel_idx < num_channels; + channel_idx++) { // Loop over each channel. Writing out the + if (cardinal_bases.find(mod_base_info->alphabet[channel_idx]) != std::string::npos) { + // A cardinal base + current_cardinal = mod_base_info->alphabet[channel_idx][0]; + } else { + if (current_cardinal == 'C') { + // A modification on the previous cardinal base + std::string bam_name = mod_base_info->alphabet[channel_idx]; + std::cerr << bam_name << std::endl; + if (!utils::validate_bam_tag_code(bam_name)) { + return; + } - // Write out the results we found - modbase_string += std::string(1, 'G') + "-" + bam_name; //TODO need to RC - modbase_string += base_has_context[current_cardinal] ? "?" : "."; - int skipped_bases = 0; - for (size_t base_idx = 0; base_idx < seq.size(); base_idx++) { - if (seq[base_idx] == 'G') { // complement - if (true) { // Not sure this one is right - modbase_string += "," + std::to_string(skipped_bases); - skipped_bases = 0; - modbase_prob.push_back( - base_mod_probs[base_idx * num_channels + channel_idx]); - //modbase_prob.push_back(base_mod_probs[base_idx * num_channels + - // 4]); // Channel index 4 for G - } else { - // Skip this base - skipped_bases++; - } + // Write out the results we found + modbase_string += std::string(1, 'G') + "-" + bam_name; //TODO need to RC + modbase_string += base_has_context[current_cardinal] ? "?" : "."; + int skipped_bases = 0; + for (size_t base_idx = 0; base_idx < seq.size(); base_idx++) { + if (seq[base_idx] == 'G') { // complement + if (true) { // Not sure this one is right + modbase_string += "," + std::to_string(skipped_bases); + skipped_bases = 0; + modbase_prob.push_back( + base_mod_probs[base_idx * num_channels + channel_idx]); + //modbase_prob.push_back(base_mod_probs[base_idx * num_channels + + // 4]); // Channel index 4 for G + } else { + // Skip this base + skipped_bases++; } } - modbase_string += ";"; } + modbase_string += ";"; } } } - bam_aux_append(aln, "MM", 'Z', int(modbase_string.length() + 1), (uint8_t *)modbase_string.c_str()); bam_aux_update_array(aln, "ML", 'C', int(modbase_prob.size()), (uint8_t *)modbase_prob.data()); From 31826b57b244aecf716eb3c535736d41c26fef84 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Wed, 29 Nov 2023 13:35:52 +0000 Subject: [PATCH 16/39] Improved false +ves with motif matching in reverse complement --- dorado/read_pipeline/ReadPipeline.cpp | 53 ++++++++++++++++++++------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/dorado/read_pipeline/ReadPipeline.cpp b/dorado/read_pipeline/ReadPipeline.cpp index 67e1bbf0..d7a233cc 100644 --- a/dorado/read_pipeline/ReadPipeline.cpp +++ b/dorado/read_pipeline/ReadPipeline.cpp @@ -159,6 +159,7 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { std::unordered_map base_has_context = { {'A', false}, {'C', false}, {'G', false}, {'T', false}}; utils::ModBaseContext context_handler; + std::cerr << mod_base_info->context << std::endl; if (!mod_base_info->context.empty()) { if (!context_handler.decode(mod_base_info->context)) { throw std::runtime_error("Invalid base modification context string."); @@ -174,6 +175,32 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { context_handler.update_mask(modbase_mask, seq, mod_base_info->alphabet, base_mod_probs, threshold); + if (is_duplex) { + auto reverse_complemented_seq = utils::reverse_complement(seq); + + // Compute the reverse complement mask + auto modbase_mask_rc = context_handler.get_sequence_mask(reverse_complemented_seq); + + // Update the context mask using the reversed sequence + context_handler.update_mask(modbase_mask_rc, reverse_complemented_seq, + mod_base_info->alphabet, base_mod_probs, + 0); // TODO: Setting threshold to zero as a temporary measure + + // Reverse the mask in-place + std::reverse(modbase_mask_rc.begin(), modbase_mask_rc.end()); + + // Combine the original and reverse complement masks + // Using std::transform for better readability and potential efficiency + std::transform(modbase_mask.begin(), modbase_mask.end(), modbase_mask_rc.begin(), + modbase_mask.begin(), std::plus<>()); + } + + std::map nucleotide_complements; + nucleotide_complements['A'] = 'T'; + nucleotide_complements['T'] = 'A'; + nucleotide_complements['C'] = 'G'; + nucleotide_complements['G'] = 'C'; + // Iterate over the provided alphabet and find all the channels we need to write out for (size_t channel_idx = 0; channel_idx < num_channels; channel_idx++) { // Loop over each channel. Writing out the @@ -209,14 +236,15 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { } } - // Now let's do the complement - for (size_t channel_idx = 0; channel_idx < num_channels; - channel_idx++) { // Loop over each channel. Writing out the - if (cardinal_bases.find(mod_base_info->alphabet[channel_idx]) != std::string::npos) { - // A cardinal base - current_cardinal = mod_base_info->alphabet[channel_idx][0]; - } else { - if (current_cardinal == 'C') { + if (is_duplex) { + // Now let's do the complement + for (size_t channel_idx = 0; channel_idx < num_channels; + channel_idx++) { // Loop over each channel. Writing out the + if (cardinal_bases.find(mod_base_info->alphabet[channel_idx]) != std::string::npos) { + // A cardinal base + current_cardinal = mod_base_info->alphabet[channel_idx][0]; + } else { + auto cardinal_complement = nucleotide_complements[current_cardinal]; // A modification on the previous cardinal base std::string bam_name = mod_base_info->alphabet[channel_idx]; std::cerr << bam_name << std::endl; @@ -225,18 +253,17 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { } // Write out the results we found - modbase_string += std::string(1, 'G') + "-" + bam_name; //TODO need to RC + modbase_string += + std::string(1, cardinal_complement) + "-" + bam_name; //TODO need to RC modbase_string += base_has_context[current_cardinal] ? "?" : "."; int skipped_bases = 0; for (size_t base_idx = 0; base_idx < seq.size(); base_idx++) { - if (seq[base_idx] == 'G') { // complement - if (true) { // Not sure this one is right + if (seq[base_idx] == cardinal_complement) { // complement + if (modbase_mask[base_idx]) { // Not sure this one is right modbase_string += "," + std::to_string(skipped_bases); skipped_bases = 0; modbase_prob.push_back( base_mod_probs[base_idx * num_channels + channel_idx]); - //modbase_prob.push_back(base_mod_probs[base_idx * num_channels + - // 4]); // Channel index 4 for G } else { // Skip this base skipped_bases++; From 621bc5c3112194bcfb8acf638398a0087f8b8275 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Wed, 29 Nov 2023 14:37:47 +0000 Subject: [PATCH 17/39] Remove some no-longer-needed debug code --- dorado/read_pipeline/ModBaseCallerNode.cpp | 40 ---------------------- dorado/read_pipeline/ReadPipeline.cpp | 6 ---- 2 files changed, 46 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index 9c9237e1..f4a905b0 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -241,37 +241,6 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { auto signal = simplex_signal.slice(0, moves_offset * m_block_stride, moves_offset * m_block_stride + signal_len); - // Serialise some stuff for debugging - if (read->read_common.read_id == - "fa3d4195-5ee1-4ab7-b048-9ce004292b62;dae07e1e-d2a9-44eb-8e8c-282491a977a9") { - serializeVector(new_move_table, "duplex_move_table.bin"); - torch::save(signal, "duplex_signal.pt"); - // Open a file in write mode - std::ofstream file("duplex_seq.txt"); - - // Write the string to the file - file << new_seq; - // Close the file - file.close(); - - serializeVector(simplex_moves, "simplex_move_table.bin"); - torch::save(simplex_signal, "simplex_signal.pt"); - // Open a file in write mode - std::ofstream sfile("simplex_seq.txt"); - - // Write the string to the file - sfile << simplex_seq; // TODO understnad why this is necessary - // Close the file - sfile.close(); - - std::cerr << "Found and serialised read of interest" << std::endl; - - std::cerr << std::endl; - std::cerr << new_seq.substr(0, 100) << std::endl; - std::cerr << read->read_common.seq.substr(0, 100) << std::endl; - std::cerr << "Found and serialised read of interest" << std::endl; - } - // scale signal based on model parameters auto scaled_signal = runner->scale_signal(caller_id, signal, sequence_ints, seq_to_sig_map); @@ -304,7 +273,6 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { if (std::strcmp(strand_type, "template") == 0) { context_hit_in_duplex_space = context_hit + target_start; } else { - //std::cerr<< strand_type << std::endl; context_hit_in_duplex_space = read->read_common.seq.size() - (context_hit + target_start + 1); // TODO: Do I need a plus one here? Why? @@ -317,9 +285,6 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { all_context_hits.push_back(context_hit_in_duplex_space); ++working_read->num_modbase_chunks; } - //std::cerr << "CH: " << all_context_hits << std::endl; - //std::cerr << "Read ID " << read->read_common.read_id << std::endl; - //std::cerr << "Context hits done! "<< std::endl; } } @@ -372,11 +337,6 @@ void ModBaseCallerNode::simplex_mod_call(Message message) { } read->read_common.mod_base_info = m_mod_base_info; - if (read->read_common.read_id == "dab67b24-8c0e-4f62-af16-9a946fd3f682") { - std::cerr << "Original seq:" << read->read_common.seq << std::endl; - std::cerr << std::endl; - } - auto working_read = std::make_shared(); working_read->num_modbase_chunks = 0; working_read->num_modbase_chunks_called = 0; diff --git a/dorado/read_pipeline/ReadPipeline.cpp b/dorado/read_pipeline/ReadPipeline.cpp index d7a233cc..bf9cb489 100644 --- a/dorado/read_pipeline/ReadPipeline.cpp +++ b/dorado/read_pipeline/ReadPipeline.cpp @@ -140,9 +140,6 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { return; } - //if(read_id == "fed86c0e-8ace-4e78-8f35-b525233de130;b2958b74-31c8-4aa4-91e1-db81d64e40e3") { // Temporary measuure - only mod call this read - std::cerr << "Found it" << std::endl; - const size_t num_channels = mod_base_info->alphabet.size(); const std::string cardinal_bases = "ACGT"; char current_cardinal = 0; @@ -159,7 +156,6 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { std::unordered_map base_has_context = { {'A', false}, {'C', false}, {'G', false}, {'T', false}}; utils::ModBaseContext context_handler; - std::cerr << mod_base_info->context << std::endl; if (!mod_base_info->context.empty()) { if (!context_handler.decode(mod_base_info->context)) { throw std::runtime_error("Invalid base modification context string."); @@ -210,7 +206,6 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { } else { // A modification on the previous cardinal base std::string bam_name = mod_base_info->alphabet[channel_idx]; - std::cerr << bam_name << std::endl; if (!utils::validate_bam_tag_code(bam_name)) { return; } @@ -247,7 +242,6 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { auto cardinal_complement = nucleotide_complements[current_cardinal]; // A modification on the previous cardinal base std::string bam_name = mod_base_info->alphabet[channel_idx]; - std::cerr << bam_name << std::endl; if (!utils::validate_bam_tag_code(bam_name)) { return; } From a6d8245e7dceb41ee028e35ec4002ae466589d7a Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Wed, 29 Nov 2023 14:52:31 +0000 Subject: [PATCH 18/39] Remove some no-longer-needed debug code --- dorado/read_pipeline/ModBaseCallerNode.cpp | 8 -------- dorado/read_pipeline/ReadPipeline.cpp | 8 ++++---- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index f4a905b0..958123a3 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -166,7 +166,6 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { auto read = std::get(std::move(message)); stats::Timer timer; - // TODO: Does `base_mod_probs` need to be the same size as `new_seq`? { nvtx3::scoped_range range{"base_mod_probs_init"}; // initialize base_mod_probs _before_ we start handing out chunks @@ -222,13 +221,6 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { utils::realign_moves(simplex_seq, duplex_seq, simplex_moves); // TODO: Check that this is OK - std::cerr << std::endl; - std::cerr << "Read ID: " << read->read_common.read_id << std::endl; - std::cerr << "Target Start: " << target_start << std::endl; - std::cerr << "Query Start: " << query_start << std::endl; - std::cerr << "Duplex Seq: " << duplex_seq << std::endl; - std::cerr << "Simplx Seq: " << simplex_seq << std::endl; - auto signal_len = new_move_table.size() * m_block_stride; auto num_moves = std::accumulate(new_move_table.begin(), new_move_table.end(), 0); auto new_seq = duplex_seq.substr(target_start, num_moves); diff --git a/dorado/read_pipeline/ReadPipeline.cpp b/dorado/read_pipeline/ReadPipeline.cpp index bf9cb489..d3c34bdb 100644 --- a/dorado/read_pipeline/ReadPipeline.cpp +++ b/dorado/read_pipeline/ReadPipeline.cpp @@ -232,7 +232,8 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { } if (is_duplex) { - // Now let's do the complement + // Having done the strand in the forward direction, if the read is duplex we need to also process its complement + // There is some code repetition here, but it makes it more readable. for (size_t channel_idx = 0; channel_idx < num_channels; channel_idx++) { // Loop over each channel. Writing out the if (cardinal_bases.find(mod_base_info->alphabet[channel_idx]) != std::string::npos) { @@ -246,9 +247,7 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { return; } - // Write out the results we found - modbase_string += - std::string(1, cardinal_complement) + "-" + bam_name; //TODO need to RC + modbase_string += std::string(1, cardinal_complement) + "-" + bam_name; modbase_string += base_has_context[current_cardinal] ? "?" : "."; int skipped_bases = 0; for (size_t base_idx = 0; base_idx < seq.size(); base_idx++) { @@ -268,6 +267,7 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { } } } + bam_aux_append(aln, "MM", 'Z', int(modbase_string.length() + 1), (uint8_t *)modbase_string.c_str()); bam_aux_update_array(aln, "ML", 'C', int(modbase_prob.size()), (uint8_t *)modbase_prob.data()); From d5e306807227a05ac28cd8457920846e173ee4b1 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Wed, 29 Nov 2023 17:03:52 +0000 Subject: [PATCH 19/39] Generic handling of non-modC probabilities --- dorado/read_pipeline/ModBaseCallerNode.cpp | 37 ++++++++++++++++------ dorado/utils/sequence_utils.cpp | 13 +------- dorado/utils/sequence_utils.h | 16 ++++++++++ 3 files changed, 44 insertions(+), 22 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index 958123a3..65f2d034 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -28,17 +28,20 @@ struct ModBaseCallerNode::RemoraChunk { RemoraChunk(std::shared_ptr read, at::Tensor input_signal, std::vector kmer_data, - size_t position) + size_t position, + bool is_template_direction) : working_read(std::move(read)), signal(std::move(input_signal)), encoded_kmers(std::move(kmer_data)), - context_hit(position) {} + context_hit(position), + is_template_direction(is_template_direction) {} std::shared_ptr working_read; at::Tensor signal; std::vector encoded_kmers; size_t context_hit; std::vector scores; + bool is_template_direction; }; struct ModBaseCallerNode::WorkingRead { @@ -198,22 +201,23 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { for (const auto& strand_type : {"template", "complement"}) { nvtx3::scoped_range range{"generate_chunks"}; - auto& simplex_moves = (strcmp(strand_type, "template") == 0) + bool is_template_direction = (strcmp(strand_type, "template") == 0); + auto& simplex_moves = is_template_direction ? read->stereo_feature_inputs.template_moves : read->stereo_feature_inputs.complement_moves; auto simplex_signal = - (strcmp(strand_type, "template") == 0) + is_template_direction ? read->stereo_feature_inputs.template_signal : at::flip(read->stereo_feature_inputs.complement_signal, 0); ; - auto simplex_seq = (strcmp(strand_type, "template") == 0) + auto simplex_seq = is_template_direction ? read->stereo_feature_inputs.template_seq : utils::reverse_complement( read->stereo_feature_inputs.complement_seq); - auto duplex_seq = (strcmp(strand_type, "template") == 0) + auto duplex_seq = is_template_direction ? read->read_common.seq : utils::reverse_complement(read->read_common.seq); @@ -262,7 +266,7 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { // Update the context hit into the duplex reference context unsigned long context_hit_in_duplex_space; - if (std::strcmp(strand_type, "template") == 0) { + if (is_template_direction) { context_hit_in_duplex_space = context_hit + target_start; } else { context_hit_in_duplex_space = read->read_common.seq.size() - @@ -272,7 +276,7 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { chunks_to_enqueue.push_back(std::make_unique( working_read, input_signal, std::move(slice.data), - context_hit_in_duplex_space)); + context_hit_in_duplex_space, is_template_direction)); all_context_hits.push_back(context_hit_in_duplex_space); ++working_read->num_modbase_chunks; @@ -383,7 +387,7 @@ void ModBaseCallerNode::simplex_mod_call(Message message) { {(int64_t)slice.lead_samples_needed, (int64_t)slice.tail_samples_needed}); } chunks_to_enqueue.push_back(std::make_unique( - working_read, input_signal, std::move(slice.data), context_hit)); + working_read, input_signal, std::move(slice.data), context_hit, true)); ++working_read->num_modbase_chunks; } @@ -554,12 +558,25 @@ void ModBaseCallerNode::output_worker_thread() { auto& source_read_common = get_read_common_data(source_read); int64_t result_pos = chunk->context_hit; + + int64_t offset; + + if (chunk->is_template_direction) { + offset = m_base_prob_offsets + [utils::BaseInfo::BASE_IDS[source_read_common.seq[result_pos]]]; + } else { + //Offset into mod base probabilties is based on the complement of the base + offset = m_base_prob_offsets + [utils::BaseInfo::BASE_IDS[dorado::utils::complement_table + [source_read_common.seq[result_pos]]]]; + } + //TODO - this is just an experiment, roll it back /* int64_t offset = m_base_prob_offsets [utils::BaseInfo::BASE_IDS[source_read_common.seq[result_pos]]]; */ - int64_t offset = m_base_prob_offsets[utils::BaseInfo::BASE_IDS['C']]; + //int64_t offset = m_base_prob_offsets[utils::BaseInfo::BASE_IDS['C']]; auto num_chunk_scores = chunk->scores.size(); for (size_t i = 0; i < num_chunk_scores; ++i) { diff --git a/dorado/utils/sequence_utils.cpp b/dorado/utils/sequence_utils.cpp index b8c496af..4a4dbe3e 100644 --- a/dorado/utils/sequence_utils.cpp +++ b/dorado/utils/sequence_utils.cpp @@ -30,23 +30,12 @@ reverse_complement_impl(const std::string& sequence) { std::string rev_comp_sequence; rev_comp_sequence.resize(num_bases); - // Compile-time constant lookup table. - static constexpr auto kComplementTable = [] { - std::array a{}; - // Valid input will only touch the entries set here. - a['A'] = 'T'; - a['T'] = 'A'; - a['C'] = 'G'; - a['G'] = 'C'; - return a; - }(); - // Run every template base through the table, reading in reverse order. const char* template_ptr = &sequence[num_bases - 1]; char* complement_ptr = &rev_comp_sequence[0]; for (size_t i = 0; i < num_bases; ++i) { const auto template_base = *template_ptr--; - *complement_ptr++ = kComplementTable[template_base]; + *complement_ptr++ = dorado::utils::complement_table[template_base]; } return rev_comp_sequence; } diff --git a/dorado/utils/sequence_utils.h b/dorado/utils/sequence_utils.h index e22cc1ba..faef822a 100644 --- a/dorado/utils/sequence_utils.h +++ b/dorado/utils/sequence_utils.h @@ -47,4 +47,20 @@ int count_trailing_chars(const std::string_view adapter, char c); std::tuple, int> realign_moves(std::string query_sequence, std::string target_sequence, std::vector moves); + +// Compile-time constant lookup table. +static constexpr auto complement_table = [] { + std::array a{}; + // Valid input will only touch the entries set here. + a['A'] = 'T'; + a['T'] = 'A'; + a['C'] = 'G'; + a['G'] = 'C'; + a['a'] = 't'; + a['t'] = 'a'; + a['c'] = 'g'; + a['g'] = 'c'; + return a; +}(); + } // namespace dorado::utils From 1f05ccdd102f10dc4e9bff326b607cb78ce1fef9 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Wed, 29 Nov 2023 17:14:01 +0000 Subject: [PATCH 20/39] Generic handling of non-modC probabiliti --- dorado/read_pipeline/ModBaseCallerNode.cpp | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index 65f2d034..eae81831 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -560,23 +560,12 @@ void ModBaseCallerNode::output_worker_thread() { int64_t result_pos = chunk->context_hit; int64_t offset; + const auto& baseIds = utils::BaseInfo::BASE_IDS; + const auto& seq = source_read_common.seq[result_pos]; - if (chunk->is_template_direction) { - offset = m_base_prob_offsets - [utils::BaseInfo::BASE_IDS[source_read_common.seq[result_pos]]]; - } else { - //Offset into mod base probabilties is based on the complement of the base - offset = m_base_prob_offsets - [utils::BaseInfo::BASE_IDS[dorado::utils::complement_table - [source_read_common.seq[result_pos]]]]; - } - - //TODO - this is just an experiment, roll it back - /* - int64_t offset = m_base_prob_offsets - [utils::BaseInfo::BASE_IDS[source_read_common.seq[result_pos]]]; -*/ - //int64_t offset = m_base_prob_offsets[utils::BaseInfo::BASE_IDS['C']]; + offset = chunk->is_template_direction + ? m_base_prob_offsets[baseIds[seq]] + : m_base_prob_offsets[baseIds[dorado::utils::complement_table[seq]]]; auto num_chunk_scores = chunk->scores.size(); for (size_t i = 0; i < num_chunk_scores; ++i) { From f455fad2e76e89db87007de5f6a19ce7ef3ea789 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Wed, 29 Nov 2023 17:32:33 +0000 Subject: [PATCH 21/39] Re-enabled warnings as errors --- CMakeLists.txt | 6 +++--- dorado/utils/CMakeLists.txt | 2 +- dorado/utils/sequence_utils.cpp | 24 +++--------------------- 3 files changed, 7 insertions(+), 25 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a6dfa029..15de9d1a 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -277,7 +277,7 @@ endif() add_library(dorado_lib ${LIB_SOURCE_FILES}) -#enable_warnings_as_errors(dorado_lib) +enable_warnings_as_errors(dorado_lib) set_target_properties(dorado_lib PROPERTIES @@ -374,7 +374,7 @@ if(NOT DORADO_LIB_ONLY) set_target_properties(dorado PROPERTIES LINK_OPTIONS "/ignore:4099") endif() - #enable_warnings_as_errors(dorado) + enable_warnings_as_errors(dorado) if (DORADO_ENABLE_PCH) target_precompile_headers(dorado REUSE_FROM dorado_lib) @@ -515,7 +515,7 @@ if(NOT CMAKE_SYSTEM_NAME STREQUAL "iOS") ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" ) - #enable_warnings_as_errors(dorado_io_lib) + enable_warnings_as_errors(dorado_io_lib) if (NOT DORADO_LIB_ONLY) add_subdirectory(tests) diff --git a/dorado/utils/CMakeLists.txt b/dorado/utils/CMakeLists.txt index 155f4895..49a0a979 100644 --- a/dorado/utils/CMakeLists.txt +++ b/dorado/utils/CMakeLists.txt @@ -120,4 +120,4 @@ if (ECM_ENABLE_SANITIZERS AND (CMAKE_CXX_COMPILER_ID MATCHES "GNU") AND (CMAKE_C set_source_files_properties(duplex_utils.cpp PROPERTIES COMPILE_OPTIONS "-O0") endif() -#enable_warnings_as_errors(dorado_utils) \ No newline at end of file +enable_warnings_as_errors(dorado_utils) \ No newline at end of file diff --git a/dorado/utils/sequence_utils.cpp b/dorado/utils/sequence_utils.cpp index 4a4dbe3e..33a37df7 100644 --- a/dorado/utils/sequence_utils.cpp +++ b/dorado/utils/sequence_utils.cpp @@ -260,17 +260,11 @@ std::tuple, int> realign_moves(std::string query_ // Let's keep two cursor positions - one for the new move table and one for the old: int new_move_cursor = 0; - int old_move_cursor = - 0; // Need to update to be the query start. // TODO do we need to worry about the start and end locations. - // Let's keep two cursor positions - one for the query sequence and one for the target: - int query_seq_cursor = query_start; - int target_seq_cursor = target_start; + int old_move_cursor = 0; int moves_found = 0; - while (moves_found < moves.size() && - moves_found < - query_start) { // TODO - is "query start" zero indexed? need to think about that + while (moves_found < moves.size() && moves_found < query_start) { moves_found += moves[old_move_cursor]; ++old_move_cursor; } @@ -282,7 +276,7 @@ std::tuple, int> realign_moves(std::string query_ if ((alignment_entry == 0) || (alignment_entry == 3)) { //Match or mismatch, need to update the new move table and move the cursor of the old move table. - new_moves.push_back(1); // We have a match so we need a 1 + new_moves.push_back(1); // We have a match so we need a 1 (move) new_move_cursor++; old_move_cursor++; @@ -294,16 +288,12 @@ std::tuple, int> realign_moves(std::string query_ new_move_cursor++; old_move_cursor++; } - // Unless there's a new/old mismatch - in which case we need to catch up by adding 1s. TODO this later. } // Update the Query and target seq cursors - query_seq_cursor++; - target_seq_cursor++; } else if (alignment_entry == 1) { //Insertion to target // If we have an insertion in the target, we need to add a 1 to the new move table, and increment the new move table cursor. the old move table cursor and new are now out of sync and need fixing. new_moves.push_back(1); new_move_cursor++; - target_seq_cursor++; } else if (alignment_entry == 2) { //Insertion to Query // We have a query insertion, all we need to do is add zeros to the new move table to make it up, the signal can be assigned to the leftmost nucleotide in the sequence. new_moves.push_back(0); @@ -314,19 +304,11 @@ std::tuple, int> realign_moves(std::string query_ old_move_cursor++; new_move_cursor++; } - // Update the Query and target seq cursors - query_seq_cursor++; } } edlibFreeAlignResult(edlib_result); - // Need to return: - // 1. Moves start - // 2. Target sequence Start - // 3. Moves end - // 3. Target sequence end - return {old_moves_offset, target_start - 1, new_moves, query_start}; } From 014d98ab24dae30a9e1462a258f70cdfb0675164 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Thu, 30 Nov 2023 07:46:50 +0000 Subject: [PATCH 22/39] missing include --- dorado/read_pipeline/ModBaseCallerNode.cpp | 3 +-- dorado/utils/sequence_utils.cpp | 1 - dorado/utils/sequence_utils.h | 1 + 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index 63f68cc6..152b9ad6 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -222,8 +222,7 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { : utils::reverse_complement(read->read_common.seq); auto [moves_offset, target_start, new_move_table, query_start] = - utils::realign_moves(simplex_seq, duplex_seq, - simplex_moves); // TODO: Check that this is OK + utils::realign_moves(simplex_seq, duplex_seq, simplex_moves); auto signal_len = new_move_table.size() * m_block_stride; auto num_moves = std::accumulate(new_move_table.begin(), new_move_table.end(), 0); diff --git a/dorado/utils/sequence_utils.cpp b/dorado/utils/sequence_utils.cpp index 14f67d5b..c481aacb 100644 --- a/dorado/utils/sequence_utils.cpp +++ b/dorado/utils/sequence_utils.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include #include diff --git a/dorado/utils/sequence_utils.h b/dorado/utils/sequence_utils.h index 9954f684..1002a642 100644 --- a/dorado/utils/sequence_utils.h +++ b/dorado/utils/sequence_utils.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include From ba7a3ba5f06fd9febb2fad88453c5014351ed293 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Thu, 30 Nov 2023 07:56:56 +0000 Subject: [PATCH 23/39] missing include --- dorado/utils/sequence_utils.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/dorado/utils/sequence_utils.cpp b/dorado/utils/sequence_utils.cpp index c481aacb..06cb143c 100644 --- a/dorado/utils/sequence_utils.cpp +++ b/dorado/utils/sequence_utils.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include From 2c50cbd35f24151db8f9c25aaf05ac6d48c2c8b0 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Thu, 30 Nov 2023 08:22:04 +0000 Subject: [PATCH 24/39] Fix windows-only compilation warnings --- dorado/read_pipeline/ModBaseCallerNode.cpp | 25 +++++----------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index 152b9ad6..1af2344f 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -148,21 +148,6 @@ void ModBaseCallerNode::init_modbase_info() { m_base_prob_offsets[3] = m_base_prob_offsets[2] + result.base_counts[2]; } -void serializeVector(const std::vector& vec, const std::string& filename) { - // Open a file in binary mode - std::ofstream file(filename, std::ios::binary); - - // Write the size of the vector (number of elements) - long size = vec.size(); - file.write(reinterpret_cast(&size), sizeof(long)); - - // Write the vector data - file.write(reinterpret_cast(vec.data()), size * sizeof(unsigned char)); - - // Close the file - file.close(); -} - void ModBaseCallerNode::duplex_mod_call(Message message) { // Let's do this only for the template strand for now. @@ -266,11 +251,11 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { // Update the context hit into the duplex reference context unsigned long context_hit_in_duplex_space; if (is_template_direction) { - context_hit_in_duplex_space = context_hit + target_start; + context_hit_in_duplex_space = + static_cast(context_hit + target_start); } else { - context_hit_in_duplex_space = read->read_common.seq.size() - - (context_hit + target_start + - 1); // TODO: Do I need a plus one here? Why? + context_hit_in_duplex_space = static_cast( + read->read_common.seq.size() - (context_hit + target_start + 1)); } chunks_to_enqueue.push_back(std::make_unique( @@ -375,7 +360,7 @@ void ModBaseCallerNode::simplex_mod_call(Message message) { m_num_context_hits += static_cast(context_hits.size()); chunks_to_enqueue.reserve(context_hits.size()); for (auto context_hit : context_hits) { - nvtx3::scoped_range range{"create_chunk"}; + nvtx3::scoped_range nvtxrange{"create_chunk"}; auto slice = encoder.get_context(context_hit); // signal auto input_signal = scaled_signal.index({at::indexing::Slice( From 85627f4bd95a03d7ff0a5b95ca1747edb379c2cd Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Thu, 30 Nov 2023 09:34:23 +0000 Subject: [PATCH 25/39] Reverse modbase probs when working in complement space --- dorado/read_pipeline/ReadPipeline.cpp | 31 ++++++++++++++++++--------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/dorado/read_pipeline/ReadPipeline.cpp b/dorado/read_pipeline/ReadPipeline.cpp index 66c4d7fe..ca494e1e 100644 --- a/dorado/read_pipeline/ReadPipeline.cpp +++ b/dorado/read_pipeline/ReadPipeline.cpp @@ -175,15 +175,32 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { threshold); if (is_duplex) { + // If this is a duplex read, we need to compute the reverse complement mask and combine it auto reverse_complemented_seq = utils::reverse_complement(seq); // Compute the reverse complement mask auto modbase_mask_rc = context_handler.get_sequence_mask(reverse_complemented_seq); + auto reverseMatrix = [](const std::vector &matrix, int m_num_states) { + int numRows = matrix.size() / m_num_states; + std::vector reversedMatrix(matrix.size()); + + for (int i = 0; i < numRows; ++i) { + for (int j = 0; j < m_num_states; ++j) { + reversedMatrix[i * m_num_states + j] = + matrix[(numRows - 1 - i) * m_num_states + j]; + } + } + + return reversedMatrix; + }; + + int num_states = base_mod_probs.size() / seq.size(); // Update the context mask using the reversed sequence - context_handler.update_mask(modbase_mask_rc, reverse_complemented_seq, - mod_base_info->alphabet, base_mod_probs, - 0); // TODO: Setting threshold to zero as a temporary measure + context_handler.update_mask( + modbase_mask_rc, reverse_complemented_seq, mod_base_info->alphabet, + reverseMatrix(base_mod_probs, num_states), + threshold); // TODO: Setting threshold to zero as a temporary measure // Reverse the mask in-place std::reverse(modbase_mask_rc.begin(), modbase_mask_rc.end()); @@ -194,12 +211,6 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { modbase_mask.begin(), std::plus<>()); } - std::map nucleotide_complements; - nucleotide_complements['A'] = 'T'; - nucleotide_complements['T'] = 'A'; - nucleotide_complements['C'] = 'G'; - nucleotide_complements['G'] = 'C'; - // Iterate over the provided alphabet and find all the channels we need to write out for (size_t channel_idx = 0; channel_idx < num_channels; channel_idx++) { // Loop over each channel. Writing out the @@ -243,7 +254,7 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { // A cardinal base current_cardinal = mod_base_info->alphabet[channel_idx][0]; } else { - auto cardinal_complement = nucleotide_complements[current_cardinal]; + auto cardinal_complement = utils::complement_table[current_cardinal]; // A modification on the previous cardinal base std::string bam_name = mod_base_info->alphabet[channel_idx]; if (!utils::validate_bam_tag_code(bam_name)) { From 6158a71014c689e2dc097bbee8676c273683169b Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Thu, 30 Nov 2023 12:21:09 +0000 Subject: [PATCH 26/39] Fix warning as error --- dorado/read_pipeline/ReadPipeline.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dorado/read_pipeline/ReadPipeline.cpp b/dorado/read_pipeline/ReadPipeline.cpp index ca494e1e..7f03ef30 100644 --- a/dorado/read_pipeline/ReadPipeline.cpp +++ b/dorado/read_pipeline/ReadPipeline.cpp @@ -182,7 +182,7 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { auto modbase_mask_rc = context_handler.get_sequence_mask(reverse_complemented_seq); auto reverseMatrix = [](const std::vector &matrix, int m_num_states) { - int numRows = matrix.size() / m_num_states; + int numRows = static_cast(matrix.size()) / static_cast(m_num_states); std::vector reversedMatrix(matrix.size()); for (int i = 0; i < numRows; ++i) { @@ -195,7 +195,7 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { return reversedMatrix; }; - int num_states = base_mod_probs.size() / seq.size(); + int num_states = static_cast(base_mod_probs.size()) / static_cast(seq.size()); // Update the context mask using the reversed sequence context_handler.update_mask( modbase_mask_rc, reverse_complemented_seq, mod_base_info->alphabet, From 8920f72d16fa7d36840aaffd114bcb7cdd212b6b Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Thu, 30 Nov 2023 12:55:29 +0000 Subject: [PATCH 27/39] [skip-ci] Remove uncessary comment --- dorado/read_pipeline/ReadPipeline.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dorado/read_pipeline/ReadPipeline.cpp b/dorado/read_pipeline/ReadPipeline.cpp index 7f03ef30..d0c6b2e8 100644 --- a/dorado/read_pipeline/ReadPipeline.cpp +++ b/dorado/read_pipeline/ReadPipeline.cpp @@ -248,8 +248,7 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { if (is_duplex) { // Having done the strand in the forward direction, if the read is duplex we need to also process its complement // There is some code repetition here, but it makes it more readable. - for (size_t channel_idx = 0; channel_idx < num_channels; - channel_idx++) { // Loop over each channel. Writing out the + for (size_t channel_idx = 0; channel_idx < num_channels; channel_idx++) { if (cardinal_bases.find(mod_base_info->alphabet[channel_idx]) != std::string::npos) { // A cardinal base current_cardinal = mod_base_info->alphabet[channel_idx][0]; From d41572ba5b6c676f74aa35ec4bd27c3bc4f212e6 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Thu, 30 Nov 2023 13:03:26 +0000 Subject: [PATCH 28/39] [skip-ci] Addressing PR comment --- dorado/utils/sequence_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dorado/utils/sequence_utils.cpp b/dorado/utils/sequence_utils.cpp index 06cb143c..0aba2732 100644 --- a/dorado/utils/sequence_utils.cpp +++ b/dorado/utils/sequence_utils.cpp @@ -169,7 +169,7 @@ std::vector moves_to_map(const std::vector& moves, return seq_to_sig_map; } -OverlapResult compute_overlap(std::string query_seq, std::string target_seq) { +OverlapResult compute_overlap(const std::string& query_seq, const std::string& target_seq) { OverlapResult overlap_result = {false, 0, 0, 0, 0}; // Add mm2 based overlap check. From 4cc591dbb8747636df86ae8f2641e8ad3304a11c Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Thu, 30 Nov 2023 13:08:01 +0000 Subject: [PATCH 29/39] [skip-ci] Addressing PR comment --- dorado/utils/sequence_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dorado/utils/sequence_utils.h b/dorado/utils/sequence_utils.h index 1002a642..cab1f7bf 100644 --- a/dorado/utils/sequence_utils.h +++ b/dorado/utils/sequence_utils.h @@ -31,7 +31,7 @@ std::vector move_cum_sums(const std::vector& moves); // Result of overlapping two reads using OverlapResult = std::tuple; -OverlapResult compute_overlap(std::string query_seq, std::string target_seq); +OverlapResult compute_overlap(const std::string& query_seq, const std::string& target_seq); // Compute reverse complement of a nucleotide sequence. // Bases are specified as capital letters. From d2e7f66850d28368a1dd147a4ef49f2861a3f348 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Thu, 30 Nov 2023 13:09:15 +0000 Subject: [PATCH 30/39] [skip-ci] Addressing PR comment --- dorado/read_pipeline/ReadPipeline.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dorado/read_pipeline/ReadPipeline.cpp b/dorado/read_pipeline/ReadPipeline.cpp index d0c6b2e8..530018d5 100644 --- a/dorado/read_pipeline/ReadPipeline.cpp +++ b/dorado/read_pipeline/ReadPipeline.cpp @@ -212,8 +212,7 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { } // Iterate over the provided alphabet and find all the channels we need to write out - for (size_t channel_idx = 0; channel_idx < num_channels; - channel_idx++) { // Loop over each channel. Writing out the + for (size_t channel_idx = 0; channel_idx < num_channels; channel_idx++) { if (cardinal_bases.find(mod_base_info->alphabet[channel_idx]) != std::string::npos) { // A cardinal base current_cardinal = mod_base_info->alphabet[channel_idx][0]; From 4c2015577ea4c49a821258dc4bfc6dcb8f1a8346 Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Thu, 30 Nov 2023 13:17:10 +0000 Subject: [PATCH 31/39] [skip-ci] Addressing PR comment --- dorado/utils/sequence_utils.cpp | 6 +++--- dorado/utils/sequence_utils.h | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dorado/utils/sequence_utils.cpp b/dorado/utils/sequence_utils.cpp index 0aba2732..0d008dce 100644 --- a/dorado/utils/sequence_utils.cpp +++ b/dorado/utils/sequence_utils.cpp @@ -219,9 +219,9 @@ OverlapResult compute_overlap(const std::string& query_seq, const std::string& t // Query is the read that the moves table is associated with. A new moves table will be generated // Which is aligned to the target sequence. -std::tuple, int> realign_moves(std::string query_sequence, - std::string target_sequence, - std::vector moves) { +std::tuple, int> realign_moves(const std::string& query_sequence, + const std::string& target_sequence, + const std::vector& moves) { auto [is_overlap, query_start, query_end, target_start, target_end] = compute_overlap( query_sequence, target_sequence); // We are going to compute the overlap between the two reads diff --git a/dorado/utils/sequence_utils.h b/dorado/utils/sequence_utils.h index cab1f7bf..7e8500d3 100644 --- a/dorado/utils/sequence_utils.h +++ b/dorado/utils/sequence_utils.h @@ -46,9 +46,9 @@ class BaseInfo { int count_trailing_chars(const std::string_view adapter, char c); -std::tuple, int> realign_moves(std::string query_sequence, - std::string target_sequence, - std::vector moves); +std::tuple, int> realign_moves(const std::string& query_sequence, + const std::string& target_sequence, + const std::vector& moves); // Compile-time constant lookup table. static constexpr auto complement_table = [] { From 96d78883681d084ac989c77655dd269dbcbb387b Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Thu, 30 Nov 2023 13:18:25 +0000 Subject: [PATCH 32/39] [skip-ci] Addressing PR comment --- dorado/utils/sequence_utils.h | 1 + 1 file changed, 1 insertion(+) diff --git a/dorado/utils/sequence_utils.h b/dorado/utils/sequence_utils.h index 7e8500d3..7c4d27a0 100644 --- a/dorado/utils/sequence_utils.h +++ b/dorado/utils/sequence_utils.h @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace dorado::utils { From 320f2b01dba10a0ec9a9c02ae210990f7677e7ed Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Thu, 30 Nov 2023 13:50:51 +0000 Subject: [PATCH 33/39] [skip-ci] Addressing PR comment --- dorado/utils/sequence_utils.cpp | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/dorado/utils/sequence_utils.cpp b/dorado/utils/sequence_utils.cpp index 0d008dce..aa274267 100644 --- a/dorado/utils/sequence_utils.cpp +++ b/dorado/utils/sequence_utils.cpp @@ -241,17 +241,6 @@ std::tuple, int> realign_moves(const std::string& query_sequence_component.data(), static_cast(query_sequence_component.length()), align_config); - // Now that we have the alignment, we need to compute the new move table, by walking along the alignment - - const auto alignment_size = - static_cast(edlib_result.endLocations[0] - edlib_result.startLocations[0]); - std::vector alignment; - alignment.resize(alignment_size); - std::memcpy(alignment.data(), &edlib_result.alignment[edlib_result.startLocations[0]], - alignment_size); - - std::vector new_moves; - // Let's keep two cursor positions - one for the new move table and one for the old: int new_move_cursor = 0; int old_move_cursor = 0; @@ -265,8 +254,12 @@ std::tuple, int> realign_moves(const std::string& --old_move_cursor; // We have gone one too far. int old_moves_offset = old_move_cursor; - // First thing to do - let's just print out the alignment line by line so we know it's working. - for (auto alignment_entry : alignment) { + const auto alignment_size = + static_cast(edlib_result.endLocations[0] - edlib_result.startLocations[0]); + // Now that we have the alignment, we need to compute the new move table, by walking along the alignment + std::vector new_moves; + for (size_t i = 0; i < alignment_size; i++) { + auto alignment_entry = edlib_result.alignment[i]; if ((alignment_entry == 0) || (alignment_entry == 3)) { //Match or mismatch, need to update the new move table and move the cursor of the old move table. From a4f789887a8aa449dd552fa3def14d879fdbce0f Mon Sep 17 00:00:00 2001 From: Mike Vella Date: Thu, 30 Nov 2023 14:00:12 +0000 Subject: [PATCH 34/39] Fixing modbase threads in Duplex --- dorado/cli/duplex.cpp | 7 +++++-- dorado/read_pipeline/Pipelines.cpp | 4 ++-- dorado/read_pipeline/Pipelines.h | 1 + 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/dorado/cli/duplex.cpp b/dorado/cli/duplex.cpp index f710e35f..b9508679 100644 --- a/dorado/cli/duplex.cpp +++ b/dorado/cli/duplex.cpp @@ -495,11 +495,14 @@ int duplex(int argc, char* argv[]) { throw std::runtime_error("Mean q-score start position cannot be < 0"); } } + pipelines::create_stereo_duplex_pipeline( pipeline_desc, std::move(runners), std::move(stereo_runners), std::move(mod_base_runners), overlap, mean_qscore_start_pos, - int(num_devices * 2), int(num_devices), std::move(pairing_parameters), - read_filter_node, PipelineDescriptor::InvalidNodeHandle); + int(num_devices * 2), int(num_devices), + int(default_parameters.remora_threads * num_devices), + std::move(pairing_parameters), read_filter_node, + PipelineDescriptor::InvalidNodeHandle); pipeline = Pipeline::create(std::move(pipeline_desc), &stats_reporters); if (pipeline == nullptr) { diff --git a/dorado/read_pipeline/Pipelines.cpp b/dorado/read_pipeline/Pipelines.cpp index baac70df..049c98a2 100644 --- a/dorado/read_pipeline/Pipelines.cpp +++ b/dorado/read_pipeline/Pipelines.cpp @@ -114,6 +114,7 @@ void create_stereo_duplex_pipeline( uint32_t mean_qscore_start_pos, int scaler_node_threads, int splitter_node_threads, + int modbase_node_threads, PairingParameters pairing_parameters, NodeHandle sink_node_handle, NodeHandle source_node_handle) { @@ -135,8 +136,7 @@ void create_stereo_duplex_pipeline( NodeHandle last_node_handle = stereo_basecaller_node; if (!modbase_runners.empty()) { auto mod_base_caller_node = pipeline_desc.add_node( - {}, std::move(modbase_runners), - size_t(4), // TODO - what shold this be? + {}, std::move(modbase_runners), modbase_node_threads, size_t(runners.front()->model_stride()), 1000); pipeline_desc.add_node_sink(stereo_basecaller_node, mod_base_caller_node); last_node_handle = mod_base_caller_node; diff --git a/dorado/read_pipeline/Pipelines.h b/dorado/read_pipeline/Pipelines.h index a982e512..f4a27211 100644 --- a/dorado/read_pipeline/Pipelines.h +++ b/dorado/read_pipeline/Pipelines.h @@ -46,6 +46,7 @@ void create_stereo_duplex_pipeline( uint32_t mean_qscore_start_pos, int scaler_node_threads, int splitter_node_threads, + int modbase_node_threads, PairingParameters pairing_parameters, NodeHandle sink_node_handle, NodeHandle source_node_handle); From 429546f8c704a01b95942d4b63ccaf9266897d8d Mon Sep 17 00:00:00 2001 From: Steve Malton Date: Thu, 30 Nov 2023 15:22:20 +0000 Subject: [PATCH 35/39] Remove done TODO --- dorado/read_pipeline/ReadPipeline.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dorado/read_pipeline/ReadPipeline.cpp b/dorado/read_pipeline/ReadPipeline.cpp index 530018d5..42b7f668 100644 --- a/dorado/read_pipeline/ReadPipeline.cpp +++ b/dorado/read_pipeline/ReadPipeline.cpp @@ -197,10 +197,9 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { int num_states = static_cast(base_mod_probs.size()) / static_cast(seq.size()); // Update the context mask using the reversed sequence - context_handler.update_mask( - modbase_mask_rc, reverse_complemented_seq, mod_base_info->alphabet, - reverseMatrix(base_mod_probs, num_states), - threshold); // TODO: Setting threshold to zero as a temporary measure + context_handler.update_mask(modbase_mask_rc, reverse_complemented_seq, + mod_base_info->alphabet, + reverseMatrix(base_mod_probs, num_states), threshold); // Reverse the mask in-place std::reverse(modbase_mask_rc.begin(), modbase_mask_rc.end()); From 15f5dd8d82c1f3bd4a3bdc10f6b0697e9fb6f14c Mon Sep 17 00:00:00 2001 From: Steve Malton Date: Thu, 30 Nov 2023 15:22:59 +0000 Subject: [PATCH 36/39] Remove unused include --- dorado/read_pipeline/ModBaseCallerNode.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index 1af2344f..4a096810 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -12,7 +12,6 @@ #include #include #include -#include #include #include From 756c69406c67d9e66a667ca2aed9cffe92c237f4 Mon Sep 17 00:00:00 2001 From: Steve Malton Date: Thu, 30 Nov 2023 15:23:57 +0000 Subject: [PATCH 37/39] Change signature to clarify variable lifetime --- dorado/read_pipeline/ModBaseCallerNode.cpp | 4 ++-- dorado/read_pipeline/ModBaseCallerNode.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index 4a096810..7d4bec6e 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -147,7 +147,7 @@ void ModBaseCallerNode::init_modbase_info() { m_base_prob_offsets[3] = m_base_prob_offsets[2] + result.base_counts[2]; } -void ModBaseCallerNode::duplex_mod_call(Message message) { +void ModBaseCallerNode::duplex_mod_call(Message&& message) { // Let's do this only for the template strand for now. auto read = std::get(std::move(message)); @@ -298,7 +298,7 @@ void ModBaseCallerNode::duplex_mod_call(Message message) { } } -void ModBaseCallerNode::simplex_mod_call(Message message) { +void ModBaseCallerNode::simplex_mod_call(Message&& message) { auto read = std::get(std::move(message)); stats::Timer timer; { diff --git a/dorado/read_pipeline/ModBaseCallerNode.h b/dorado/read_pipeline/ModBaseCallerNode.h index 1f98326a..a97d032f 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.h +++ b/dorado/read_pipeline/ModBaseCallerNode.h @@ -32,8 +32,8 @@ class ModBaseCallerNode : public MessageSink { stats::NamedStats sample_stats() const override; void terminate(const FlushOptions&) override { terminate_impl(); } void restart() override; - void simplex_mod_call(Message message); - void duplex_mod_call(Message message); + void simplex_mod_call(Message&& message); + void duplex_mod_call(Message&& message); private: void start_threads(); From 0d58cc7c0facc29fdbe9c6293ddbc4ddcb01dc28 Mon Sep 17 00:00:00 2001 From: Steve Malton Date: Thu, 30 Nov 2023 15:25:16 +0000 Subject: [PATCH 38/39] Reduce copies in duplex mods --- dorado/read_pipeline/ModBaseCallerNode.cpp | 68 +++++++++++----------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index 7d4bec6e..d50f3dfa 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -181,40 +181,42 @@ void ModBaseCallerNode::duplex_mod_call(Message&& message) { std::vector all_context_hits; - for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { - for (const auto& strand_type : {"template", "complement"}) { - nvtx3::scoped_range range{"generate_chunks"}; - - bool is_template_direction = (strcmp(strand_type, "template") == 0); - auto& simplex_moves = is_template_direction - ? read->stereo_feature_inputs.template_moves - : read->stereo_feature_inputs.complement_moves; - - auto simplex_signal = - is_template_direction - ? read->stereo_feature_inputs.template_signal - : at::flip(read->stereo_feature_inputs.complement_signal, 0); - ; - - auto simplex_seq = is_template_direction - ? read->stereo_feature_inputs.template_seq - : utils::reverse_complement( - read->stereo_feature_inputs.complement_seq); + for (const bool is_template_direction : {true, false}) { + auto simplex_signal = + is_template_direction + ? read->stereo_feature_inputs.template_signal + : at::flip(read->stereo_feature_inputs.complement_signal, 0); + + // const-ref extends lifetime of temporary + const auto& simplex_moves = is_template_direction + ? read->stereo_feature_inputs.template_moves + : read->stereo_feature_inputs.complement_moves; + + // const-ref extends lifetime of temporary + const auto& simplex_seq = + is_template_direction + ? read->stereo_feature_inputs.template_seq + : utils::reverse_complement(read->stereo_feature_inputs.complement_seq); + + // const-ref extends lifetime of temporary + const auto& duplex_seq = is_template_direction + ? read->read_common.seq + : utils::reverse_complement(read->read_common.seq); + + auto [moves_offset, target_start, new_move_table, query_start] = + utils::realign_moves(simplex_seq, duplex_seq, simplex_moves); + + auto signal_len = new_move_table.size() * m_block_stride; + auto num_moves = std::accumulate(new_move_table.begin(), new_move_table.end(), 0); + auto new_seq = duplex_seq.substr(target_start, num_moves); + std::vector sequence_ints = utils::sequence_to_ints(new_seq); + + // no reverse_signal in duplex, so we can do this once for all callers + std::vector seq_to_sig_map = + utils::moves_to_map(new_move_table, m_block_stride, signal_len, num_moves + 1); - auto duplex_seq = is_template_direction - ? read->read_common.seq - : utils::reverse_complement(read->read_common.seq); - - auto [moves_offset, target_start, new_move_table, query_start] = - utils::realign_moves(simplex_seq, duplex_seq, simplex_moves); - - auto signal_len = new_move_table.size() * m_block_stride; - auto num_moves = std::accumulate(new_move_table.begin(), new_move_table.end(), 0); - auto new_seq = duplex_seq.substr(target_start, num_moves); - std::vector sequence_ints = utils::sequence_to_ints(new_seq); - - std::vector seq_to_sig_map = utils::moves_to_map( - new_move_table, m_block_stride, signal_len, num_moves + 1); + for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { + nvtx3::scoped_range range{"generate_chunks"}; auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); auto& params = runner->caller_params(caller_id); auto signal = simplex_signal.slice(0, moves_offset * m_block_stride, From e95a23fe6c51c192c103e888d4e950bdf0d1d6bf Mon Sep 17 00:00:00 2001 From: Richard Harris Date: Thu, 30 Nov 2023 17:26:36 +0000 Subject: [PATCH 39/39] add some top level tests --- tests/test_simple_basecaller_execution.bat | 2 ++ tests/test_simple_basecaller_execution.sh | 22 ++++++++++++++++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/tests/test_simple_basecaller_execution.bat b/tests/test_simple_basecaller_execution.bat index 4661ad0d..ddfd0333 100644 --- a/tests/test_simple_basecaller_execution.bat +++ b/tests/test_simple_basecaller_execution.bat @@ -16,6 +16,8 @@ echo dorado aligner test stage echo dorado duplex basespace test stage %dorado_bin% duplex basespace tests/data/basespace/pairs.bam --threads 1 --pairs tests/data/basespace/pairs.txt > calls.bam +%dorado_bin% duplex hac tests/data/duplex/pod5 --threads 1 > $output_dir/duplex_calls.bam +%dorado_bin% duplex hac,5mCG_5hmCG tests/data/duplex/pod5 --threads 1 > $output_dir/duplex_calls_mods.bam echo dorado demux test stage %dorado_bin% demux tests/data/barcode_demux/double_end_variant/EXP-PBC096_BC83.fastq --threads 1 --kit-name EXP-PBC096 --output-dir ./demux diff --git a/tests/test_simple_basecaller_execution.sh b/tests/test_simple_basecaller_execution.sh index 137b2feb..4be0c369 100755 --- a/tests/test_simple_basecaller_execution.sh +++ b/tests/test_simple_basecaller_execution.sh @@ -153,7 +153,7 @@ if ! uname -r | grep -q tegra; then samtools quickcheck -u $output_dir/duplex_calls.bam num_duplex_reads=$(samtools view $output_dir/duplex_calls.bam | grep dx:i:1 | wc -l | awk '{print $1}') if [[ $num_duplex_reads -ne "2" ]]; then - echo "Duplex basecalling missing reads." + echo "Duplex basecalling missing reads - in-line" exit 1 fi @@ -162,7 +162,25 @@ if ! uname -r | grep -q tegra; then samtools quickcheck -u $output_dir/duplex_calls.bam num_duplex_reads=$(samtools view $output_dir/duplex_calls.bam | grep dx:i:1 | wc -l | awk '{print $1}') if [[ $num_duplex_reads -ne "2" ]]; then - echo "Duplex basecalling missing reads." + echo "Duplex basecalling missing reads - pairs file" + exit 1 + fi + + echo dorado in-line duplex from model complex + $dorado_bin duplex hac@v4.2.0 $data_dir/duplex/pod5 > $output_dir/duplex_calls_complex.bam + samtools quickcheck -u $output_dir/duplex_calls_complex.bam + num_duplex_reads=$(samtools view $output_dir/duplex_calls_complex.bam | grep dx:i:1 | wc -l | awk '{print $1}') + if [[ $num_duplex_reads -ne "2" ]]; then + echo "Duplex basecalling missing reads - model complex" + exit 1 + fi + + echo dorado in-line modbase duplex from model complex + $dorado_bin duplex hac,5mCG_5hmCG $data_dir/duplex/pod5 > $output_dir/duplex_calls_mods.bam + samtools quickcheck -u $output_dir/duplex_calls_mods.bam + num_duplex_reads=$(samtools view $output_dir/duplex_calls_mods.bam | grep dx:i:1 | wc -l | awk '{print $1}') + if [[ $num_duplex_reads -ne "2" ]]; then + echo "Duplex basecalling missing reads - mods" exit 1 fi fi