diff --git a/dorado/cli/basecaller.cpp b/dorado/cli/basecaller.cpp index ccf5782a..8d0c83c0 100644 --- a/dorado/cli/basecaller.cpp +++ b/dorado/cli/basecaller.cpp @@ -124,8 +124,9 @@ void setup(std::vector args, const bool enable_aligner = !ref.empty(); // create modbase runners first so basecall runners can pick batch sizes based on available memory - auto remora_runners = create_modbase_runners( - remora_models, device, default_parameters.remora_runners_per_caller, remora_batch_size); + auto remora_runners = create_modbase_runners(remora_models, device, + default_parameters.mod_base_runners_per_caller, + remora_batch_size); auto [runners, num_devices] = create_basecall_runners(model_config, device, num_runners, 0, batch_size, chunk_size, 1.f, false); @@ -133,8 +134,6 @@ void setup(std::vector args, auto read_groups = DataLoader::load_read_groups(data_path, model_name, modbase_model_names, recursive_file_loading); - bool duplex = false; - const auto thread_allocations = utils::default_thread_allocations( int(num_devices), !remora_runners.empty() ? int(num_remora_threads) : 0, enable_aligner, !barcode_kits.empty()); @@ -251,7 +250,7 @@ void setup(std::vector args, } std::vector stats_callables; - ProgressTracker tracker(int(num_reads), duplex); + ProgressTracker tracker(int(num_reads), false); stats_callables.push_back( [&tracker](const stats::NamedStats& stats) { tracker.update_progress_bar(stats); }); constexpr auto kStatsPeriod = 100ms; diff --git a/dorado/cli/duplex.cpp b/dorado/cli/duplex.cpp index 229acb29..b9508679 100644 --- a/dorado/cli/duplex.cpp +++ b/dorado/cli/duplex.cpp @@ -4,6 +4,7 @@ #include "data_loader/ModelFinder.h" #include "models/models.h" #include "nn/CRFModelConfig.h" +#include "nn/ModBaseRunner.h" #include "nn/Runners.h" #include "read_pipeline/AlignerNode.h" #include "read_pipeline/BaseSpaceDuplexCallerNode.h" @@ -21,6 +22,7 @@ #include "utils/log_utils.h" #include "utils/parameters.h" #include "utils/stats.h" +#include "utils/string_utils.h" #include "utils/sys_stats.h" #include "utils/torch_utils.h" #include "utils/types.h" @@ -33,6 +35,8 @@ #include #include +namespace fs = std::filesystem; + namespace dorado { namespace { @@ -45,17 +49,30 @@ struct DuplexModels { CRFModelConfig stereo_model_config; std::string stereo_model_name; + std::vector mods_model_paths; + std::set temp_paths{}; }; DuplexModels load_models(const std::string& model_arg, + const std::vector& mod_bases, + const std::string& mod_bases_models, const std::string& reads, const bool recursive_file_loading, const bool skip_model_compatibility_check) { const ModelSelection model_selection = cli::parse_model_argument(model_arg); + auto ways = {model_selection.has_mods_variant(), !mod_bases.empty(), !mod_bases_models.empty()}; + if (std::count(ways.begin(), ways.end(), true) > 1) { + spdlog::error( + "only one of --modified-bases, --modified-bases-models, or modified models set " + "via models argument can be used at once", + model_arg); + std::exit(EXIT_FAILURE); + }; + if (model_selection.is_path()) { - const auto model_path = std::filesystem::canonical(std::filesystem::path(model_arg)); + const auto model_path = fs::canonical(fs::path(model_arg)); const auto model_name = model_path.filename().string(); const auto model_config = load_crf_model_config(model_path); @@ -75,23 +92,30 @@ DuplexModels load_models(const std::string& model_arg, throw std::runtime_error(err.str()); } const auto stereo_model_name = utils::get_stereo_model_name(model_arg, data_sample_rate); - const auto stereo_model_path = - model_path.parent_path() / std::filesystem::path(stereo_model_name); + const auto stereo_model_path = model_path.parent_path() / fs::path(stereo_model_name); - if (!std::filesystem::exists(stereo_model_path)) { + if (!fs::exists(stereo_model_path)) { if (!models::download_models(model_path.parent_path().u8string(), stereo_model_name)) { throw std::runtime_error("Failed to download model: " + stereo_model_name); } } const auto stereo_model_config = load_crf_model_config(stereo_model_path); - return DuplexModels{model_path, model_name, model_config, - stereo_model_path, stereo_model_config, stereo_model_name}; - } + std::vector mods_model_paths; + if (!mod_bases.empty()) { + std::transform(mod_bases.begin(), mod_bases.end(), std::back_inserter(mods_model_paths), + [&model_arg](std::string m) { + return fs::path(models::get_modification_model(model_arg, m)); + }); + } else if (!mod_bases_models.empty()) { + const auto split = utils::split(mod_bases_models, ','); + std::transform(split.begin(), split.end(), std::back_inserter(mods_model_paths), + [&](std::string m) { return fs::path(m); }); + } - if (model_selection.has_mods_variant()) { - spdlog::error("Modified bases models are not supported for duplex"); - std::exit(EXIT_FAILURE); + return DuplexModels{model_path, model_name, model_config, + stereo_model_path, stereo_model_config, stereo_model_name, + mods_model_paths}; } auto model_finder = cli::model_finder(model_selection, reads, recursive_file_loading, true); @@ -104,13 +128,14 @@ DuplexModels load_models(const std::string& model_arg, const auto stereo_model_name = stereo_model_path.filename().string(); const auto stereo_model_config = load_crf_model_config(stereo_model_path); - return DuplexModels{model_path, - model_name, - model_config, - stereo_model_path, - stereo_model_config, - stereo_model_name, - model_finder.downloaded_models()}; + const std::vector mods_model_paths = model_selection.has_mods_variant() + ? model_finder.fetch_mods_models() + : std::vector{}; + + return DuplexModels{model_path, model_name, + model_config, stereo_model_path, + stereo_model_config, stereo_model_name, + mods_model_paths, model_finder.downloaded_models()}; } } // namespace @@ -203,6 +228,28 @@ int duplex(int argc, char* argv[]) { .action([&](const auto&) { ++verbosity; }) .append(); + parser.visible.add_argument("--modified-bases") + .nargs(argparse::nargs_pattern::at_least_one) + .action([](const std::string& value) { + const auto& mods = models::modified_model_variants(); + if (std::find(mods.begin(), mods.end(), value) == mods.end()) { + spdlog::error("'{}' is not a supported modification please select from {}", + value, utils::join(mods, ", ")); + std::exit(EXIT_FAILURE); + } + return value; + }); + + parser.visible.add_argument("--modified-bases-models") + .default_value(std::string()) + .help("a comma separated list of modified base models"); + + parser.visible.add_argument("--modified-bases-threshold") + .default_value(default_parameters.methylation_threshold) + .scan<'f', float>() + .help("the minimum predicted methylation probability for a modified base to be emitted " + "in an all-context model, [0, 1]"); + cli::add_minimap2_arguments(parser, alignment::dflt_options); cli::add_internal_arguments(parser); @@ -226,6 +273,10 @@ int duplex(int argc, char* argv[]) { if (parser.visible.get("--verbose")) { utils::SetVerboseLogging(static_cast(verbosity)); } + + auto mod_bases = parser.visible.get>("--modified-bases"); + auto mod_bases_models = parser.visible.get("--modified-bases-models"); + std::map template_complement_map; auto read_list = utils::load_read_list(parser.visible.get("--read-ids")); @@ -331,6 +382,11 @@ int duplex(int argc, char* argv[]) { return 1; // Exit with an error code } + if (!mod_bases.empty() || !mod_bases_models.empty()) { + spdlog::error("Basespace duplex does not support modbase models"); + return EXIT_FAILURE; + } + spdlog::info("> Loading reads"); auto read_map = read_bam(reads, read_list_from_pairs); @@ -355,15 +411,25 @@ int duplex(int argc, char* argv[]) { kStatsPeriod, stats_reporters, stats_callables, max_stats_records); } else { // Execute a Stereo Duplex pipeline. + if (!DataLoader::is_read_data_present(reads, recursive_file_loading)) { + std::string err = "No POD5 or FAST5 data found in path: " + reads; + throw std::runtime_error(err); + } + const bool skip_model_compatibility_check = parser.hidden.get("--skip-model-compatibility-check"); - const DuplexModels models = load_models(model, reads, recursive_file_loading, - skip_model_compatibility_check); + const DuplexModels models = + load_models(model, mod_bases, mod_bases_models, reads, recursive_file_loading, + skip_model_compatibility_check); - if (!DataLoader::is_read_data_present(reads, recursive_file_loading)) { - std::string err = "No POD5 or FAST5 data found in path: " + reads; - throw std::runtime_error(err); + // create modbase runners first so basecall runners can pick batch sizes based on available memory + auto mod_base_runners = create_modbase_runners( + models.mods_model_paths, device, default_parameters.mod_base_runners_per_caller, + default_parameters.remora_batchsize); + + if (!mod_base_runners.empty() && output_mode == HtsWriter::OutputMode::FASTQ) { + throw std::runtime_error("Modified base models cannot be used with FASTQ output"); } // Write read group info to header. @@ -429,9 +495,12 @@ int duplex(int argc, char* argv[]) { throw std::runtime_error("Mean q-score start position cannot be < 0"); } } + pipelines::create_stereo_duplex_pipeline( - pipeline_desc, std::move(runners), std::move(stereo_runners), overlap, - mean_qscore_start_pos, int(num_devices * 2), int(num_devices), + pipeline_desc, std::move(runners), std::move(stereo_runners), + std::move(mod_base_runners), overlap, mean_qscore_start_pos, + int(num_devices * 2), int(num_devices), + int(default_parameters.remora_threads * num_devices), std::move(pairing_parameters), read_filter_node, PipelineDescriptor::InvalidNodeHandle); diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index 9a2af406..d50f3dfa 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -15,6 +15,8 @@ #include #include +#include + using namespace std::chrono_literals; namespace dorado { @@ -25,21 +27,24 @@ struct ModBaseCallerNode::RemoraChunk { RemoraChunk(std::shared_ptr read, at::Tensor input_signal, std::vector kmer_data, - size_t position) + size_t position, + bool is_template_direction) : working_read(std::move(read)), signal(std::move(input_signal)), encoded_kmers(std::move(kmer_data)), - context_hit(position) {} + context_hit(position), + is_template_direction(is_template_direction) {} std::shared_ptr working_read; at::Tensor signal; std::vector encoded_kmers; size_t context_hit; std::vector scores; + bool is_template_direction; }; struct ModBaseCallerNode::WorkingRead { - SimplexReadPtr read; // The read itself. + Message read; // The read itself. size_t num_modbase_chunks; std::atomic_size_t num_modbase_chunks_called; // Number of modbase chunks which have been scored @@ -142,69 +147,80 @@ void ModBaseCallerNode::init_modbase_info() { m_base_prob_offsets[3] = m_base_prob_offsets[2] + result.base_counts[2]; } -void ModBaseCallerNode::input_worker_thread() { - at::InferenceMode inference_mode_guard; - - Message message; - while (get_input_message(message)) { - // If this message isn't a read, just forward it to the sink. - if (!std::holds_alternative(message)) { - send_message_to_sink(std::move(message)); - continue; +void ModBaseCallerNode::duplex_mod_call(Message&& message) { + // Let's do this only for the template strand for now. + + auto read = std::get(std::move(message)); + stats::Timer timer; + + { + nvtx3::scoped_range range{"base_mod_probs_init"}; + // initialize base_mod_probs _before_ we start handing out chunks + read->read_common.base_mod_probs.resize(read->read_common.seq.size() * m_num_states, 0); + for (size_t i = 0; i < read->read_common.seq.size(); ++i) { + // Initialize for what corresponds to 100% canonical base for each position. + int base_id = utils::BaseInfo::BASE_IDS[read->read_common.seq[i]]; + if (base_id < 0) { + throw std::runtime_error("Invalid character in sequence."); + } + read->read_common.base_mod_probs[i * m_num_states + m_base_prob_offsets[base_id]] = 1; } + } - nvtx3::scoped_range range{"modbase_input_worker_thread"}; - // If this message isn't a read, we'll get a bad_variant_access exception. - auto read = std::get(std::move(message)); + read->read_common.mod_base_info = m_mod_base_info; - while (true) { - stats::Timer timer; - { - nvtx3::scoped_range range_init{"base_mod_probs_init"}; - // initialize base_mod_probs _before_ we start handing out chunks - read->read_common.base_mod_probs.resize(read->read_common.seq.size() * m_num_states, - 0); - for (size_t i = 0; i < read->read_common.seq.size(); ++i) { - // Initialize for what corresponds to 100% canonical base for each position. - int base_id = utils::BaseInfo::BASE_IDS[read->read_common.seq[i]]; - if (base_id < 0) { - throw std::runtime_error("Invalid character in sequence."); - } - read->read_common - .base_mod_probs[i * m_num_states + m_base_prob_offsets[base_id]] = 1; - } - } - read->read_common.mod_base_info = m_mod_base_info; + { + auto working_read = std::make_shared(); + working_read->num_modbase_chunks = 0; + working_read->num_modbase_chunks_called = 0; - auto working_read = std::make_shared(); - working_read->num_modbase_chunks = 0; - working_read->num_modbase_chunks_called = 0; + // all runners have the same set of callers, so we only need to use the first one + auto& runner = m_runners[0]; + std::vector>> chunks_to_enqueue_by_caller( + runner->num_callers()); - std::vector sequence_ints = utils::sequence_to_ints(read->read_common.seq); + std::vector all_context_hits; - // all runners have the same set of callers, so we only need to use the first one - auto& runner = m_runners[0]; - std::vector>> chunks_to_enqueue_by_caller( - runner->num_callers()); - for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { - nvtx3::scoped_range range_gen_chunks{"generate_chunks"}; + for (const bool is_template_direction : {true, false}) { + auto simplex_signal = + is_template_direction + ? read->stereo_feature_inputs.template_signal + : at::flip(read->stereo_feature_inputs.complement_signal, 0); + + // const-ref extends lifetime of temporary + const auto& simplex_moves = is_template_direction + ? read->stereo_feature_inputs.template_moves + : read->stereo_feature_inputs.complement_moves; + + // const-ref extends lifetime of temporary + const auto& simplex_seq = + is_template_direction + ? read->stereo_feature_inputs.template_seq + : utils::reverse_complement(read->stereo_feature_inputs.complement_seq); + + // const-ref extends lifetime of temporary + const auto& duplex_seq = is_template_direction + ? read->read_common.seq + : utils::reverse_complement(read->read_common.seq); + + auto [moves_offset, target_start, new_move_table, query_start] = + utils::realign_moves(simplex_seq, duplex_seq, simplex_moves); - auto signal_len = read->read_common.get_raw_data_samples(); - std::vector seq_to_sig_map = - utils::moves_to_map(read->read_common.moves, m_block_stride, signal_len, - read->read_common.seq.size() + 1); + auto signal_len = new_move_table.size() * m_block_stride; + auto num_moves = std::accumulate(new_move_table.begin(), new_move_table.end(), 0); + auto new_seq = duplex_seq.substr(target_start, num_moves); + std::vector sequence_ints = utils::sequence_to_ints(new_seq); + // no reverse_signal in duplex, so we can do this once for all callers + std::vector seq_to_sig_map = + utils::moves_to_map(new_move_table, m_block_stride, signal_len, num_moves + 1); + + for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { + nvtx3::scoped_range range{"generate_chunks"}; auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); auto& params = runner->caller_params(caller_id); - auto signal = read->read_common.raw_data; - if (params.reverse_signal) { - signal = at::flip(signal, 0); - std::reverse(std::begin(seq_to_sig_map), std::end(seq_to_sig_map)); - std::transform(std::begin(seq_to_sig_map), std::end(seq_to_sig_map), - std::begin(seq_to_sig_map), [signal_len](auto signal_pos) { - return signal_len - signal_pos; - }); - } + auto signal = simplex_signal.slice(0, moves_offset * m_block_stride, + moves_offset * m_block_stride + signal_len); // scale signal based on model parameters auto scaled_signal = @@ -217,9 +233,10 @@ void ModBaseCallerNode::input_worker_thread() { params.bases_after); encoder.init(sequence_ints, seq_to_sig_map); - auto context_hits = runner->get_motif_hits(caller_id, read->read_common.seq); + auto context_hits = runner->get_motif_hits(caller_id, new_seq); m_num_context_hits += static_cast(context_hits.size()); chunks_to_enqueue.reserve(context_hits.size()); + for (auto context_hit : context_hits) { nvtx3::scoped_range range_create_chunk{"create_chunk"}; auto slice = encoder.get_context(context_hit); @@ -231,41 +248,177 @@ void ModBaseCallerNode::input_worker_thread() { {(int64_t)slice.lead_samples_needed, (int64_t)slice.tail_samples_needed}); } + + // Update the context hit into the duplex reference context + unsigned long context_hit_in_duplex_space; + if (is_template_direction) { + context_hit_in_duplex_space = + static_cast(context_hit + target_start); + } else { + context_hit_in_duplex_space = static_cast( + read->read_common.seq.size() - (context_hit + target_start + 1)); + } + chunks_to_enqueue.push_back(std::make_unique( - working_read, input_signal, std::move(slice.data), context_hit)); + working_read, input_signal, std::move(slice.data), + context_hit_in_duplex_space, is_template_direction)); + all_context_hits.push_back(context_hit_in_duplex_space); ++working_read->num_modbase_chunks; } } - m_chunk_generation_ms += timer.GetElapsedMS(); + } - if (working_read->num_modbase_chunks != 0) { - // Hand over our ownership to the working read - working_read->read = std::move(read); + m_chunk_generation_ms += timer.GetElapsedMS(); - // Put the read in the working list - { - std::lock_guard working_reads_lock(m_working_reads_mutex); - m_working_reads.insert(std::move(working_read)); - ++m_working_reads_size; - } + if (working_read->num_modbase_chunks != 0) { + // Hand over our ownership to the working read + working_read->read = std::move(read); - // push the chunks to the chunk queues - // needs to be done after working_read->read is set as chunks could be processed - // before we set that value otherwise - for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { - auto& chunk_queue = m_chunk_queues.at(caller_id); - auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); - for (auto& chunk : chunks_to_enqueue) { - chunk_queue->try_push(std::move(chunk)); - } + // Put the read in the working list + { + std::lock_guard working_reads_lock(m_working_reads_mutex); + m_working_reads.insert(std::move(working_read)); + ++m_working_reads_size; + } + + // push the chunks to the chunk queues + // needs to be done after working_read->read is set as chunks could be processed + // before we set that value otherwise + for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { + auto& chunk_queue = m_chunk_queues.at(caller_id); + auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); + for (auto& chunk : chunks_to_enqueue) { + chunk_queue->try_push(std::move(chunk)); } - } else { - // No modbases to call, pass directly to next node - send_message_to_sink(std::move(read)); - ++m_num_non_mod_base_reads_pushed; } - break; + } else { + // No modbases to call, pass directly to next node + send_message_to_sink(std::move(read)); + ++m_num_non_mod_base_reads_pushed; + } + } +} + +void ModBaseCallerNode::simplex_mod_call(Message&& message) { + auto read = std::get(std::move(message)); + stats::Timer timer; + { + nvtx3::scoped_range range{"base_mod_probs_init"}; + // initialize base_mod_probs _before_ we start handing out chunks + read->read_common.base_mod_probs.resize(read->read_common.seq.size() * m_num_states, 0); + for (size_t i = 0; i < read->read_common.seq.size(); ++i) { + // Initialize for what corresponds to 100% canonical base for each position. + int base_id = utils::BaseInfo::BASE_IDS[read->read_common.seq[i]]; + if (base_id < 0) { + throw std::runtime_error("Invalid character in sequence."); + } + read->read_common.base_mod_probs[i * m_num_states + m_base_prob_offsets[base_id]] = 1; + } + } + read->read_common.mod_base_info = m_mod_base_info; + + auto working_read = std::make_shared(); + working_read->num_modbase_chunks = 0; + working_read->num_modbase_chunks_called = 0; + + std::vector sequence_ints = utils::sequence_to_ints(read->read_common.seq); + + // all runners have the same set of callers, so we only need to use the first one + auto& runner = m_runners[0]; + std::vector>> chunks_to_enqueue_by_caller( + runner->num_callers()); + for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { + nvtx3::scoped_range range{"generate_chunks"}; + + auto signal_len = read->read_common.get_raw_data_samples(); + std::vector seq_to_sig_map = + utils::moves_to_map(read->read_common.moves, m_block_stride, signal_len, + read->read_common.seq.size() + 1); + + auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); + auto& params = runner->caller_params(caller_id); + auto signal = read->read_common.raw_data; + if (params.reverse_signal) { + signal = at::flip(signal, 0); + std::reverse(std::begin(seq_to_sig_map), std::end(seq_to_sig_map)); + std::transform(std::begin(seq_to_sig_map), std::end(seq_to_sig_map), + std::begin(seq_to_sig_map), + [signal_len](auto signal_pos) { return signal_len - signal_pos; }); + } + + // scale signal based on model parameters + auto scaled_signal = runner->scale_signal(caller_id, signal, sequence_ints, seq_to_sig_map); + + auto context_samples = (params.context_before + params.context_after); + + // One-hot encodes the kmer at each signal step for input into the network + ModBaseEncoder encoder(m_block_stride, context_samples, params.bases_before, + params.bases_after); + encoder.init(sequence_ints, seq_to_sig_map); + + auto context_hits = runner->get_motif_hits(caller_id, read->read_common.seq); + m_num_context_hits += static_cast(context_hits.size()); + chunks_to_enqueue.reserve(context_hits.size()); + for (auto context_hit : context_hits) { + nvtx3::scoped_range nvtxrange{"create_chunk"}; + auto slice = encoder.get_context(context_hit); + // signal + auto input_signal = scaled_signal.index({at::indexing::Slice( + slice.first_sample, slice.first_sample + slice.num_samples)}); + if (slice.lead_samples_needed != 0 || slice.tail_samples_needed != 0) { + input_signal = at::constant_pad_nd( + input_signal, + {(int64_t)slice.lead_samples_needed, (int64_t)slice.tail_samples_needed}); + } + chunks_to_enqueue.push_back(std::make_unique( + working_read, input_signal, std::move(slice.data), context_hit, true)); + + ++working_read->num_modbase_chunks; + } + } + m_chunk_generation_ms += timer.GetElapsedMS(); + + if (working_read->num_modbase_chunks != 0) { + // Hand over our ownership to the working read + working_read->read = std::move(read); + + // Put the read in the working list + { + std::lock_guard working_reads_lock(m_working_reads_mutex); + m_working_reads.insert(std::move(working_read)); + ++m_working_reads_size; + } + + // push the chunks to the chunk queues + // needs to be done after working_read->read is set as chunks could be processed + // before we set that value otherwise + for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { + auto& chunk_queue = m_chunk_queues.at(caller_id); + auto& chunks_to_enqueue = chunks_to_enqueue_by_caller.at(caller_id); + for (auto& chunk : chunks_to_enqueue) { + chunk_queue->try_push(std::move(chunk)); + } + } + } else { + // No modbases to call, pass directly to next node + send_message_to_sink(std::move(read)); + ++m_num_non_mod_base_reads_pushed; + } +} + +void ModBaseCallerNode::input_worker_thread() { + at::InferenceMode inference_mode_guard; + + Message message; + while (get_input_message(message)) { + // If this message isn't a read, just forward it to the sink. + if (!is_read_message(message)) { + send_message_to_sink(std::move(message)); + } else if (std::holds_alternative(message)) { + simplex_mod_call(std::move(message)); + } else if (std::holds_alternative(message)) { + duplex_mod_call(std::move(message)); } } @@ -387,11 +540,21 @@ void ModBaseCallerNode::output_worker_thread() { for (const auto& chunk : processed_chunks) { auto working_read = chunk->working_read; auto& source_read = working_read->read; + auto& source_read_common = get_read_common_data(source_read); + int64_t result_pos = chunk->context_hit; - int64_t offset = m_base_prob_offsets - [utils::BaseInfo::BASE_IDS[source_read->read_common.seq[result_pos]]]; - for (size_t i = 0; i < chunk->scores.size(); ++i) { - source_read->read_common.base_mod_probs[m_num_states * result_pos + offset + i] = + + int64_t offset; + const auto& baseIds = utils::BaseInfo::BASE_IDS; + const auto& seq = source_read_common.seq[result_pos]; + + offset = chunk->is_template_direction + ? m_base_prob_offsets[baseIds[seq]] + : m_base_prob_offsets[baseIds[dorado::utils::complement_table[seq]]]; + + auto num_chunk_scores = chunk->scores.size(); + for (size_t i = 0; i < num_chunk_scores; ++i) { + source_read_common.base_mod_probs[m_num_states * result_pos + offset + i] = static_cast(std::min(std::floor(chunk->scores[i] * 256), 255.0f)); } // If all chunks for the read associated with this chunk have now been called, @@ -411,8 +574,8 @@ void ModBaseCallerNode::output_worker_thread() { if (read_iter != m_working_reads.end()) { m_working_reads.erase(read_iter); } else { - throw std::runtime_error("Expected to find read id " + - completed_read->read->read_common.read_id + + auto read_id = get_read_common_data(completed_read->read).read_id; + throw std::runtime_error("Expected to find read id " + read_id + " in working reads set but it doesn't exist."); } } diff --git a/dorado/read_pipeline/ModBaseCallerNode.h b/dorado/read_pipeline/ModBaseCallerNode.h index fa19510c..a97d032f 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.h +++ b/dorado/read_pipeline/ModBaseCallerNode.h @@ -32,6 +32,8 @@ class ModBaseCallerNode : public MessageSink { stats::NamedStats sample_stats() const override; void terminate(const FlushOptions&) override { terminate_impl(); } void restart() override; + void simplex_mod_call(Message&& message); + void duplex_mod_call(Message&& message); private: void start_threads(); diff --git a/dorado/read_pipeline/Pipelines.cpp b/dorado/read_pipeline/Pipelines.cpp index b633d13d..049c98a2 100644 --- a/dorado/read_pipeline/Pipelines.cpp +++ b/dorado/read_pipeline/Pipelines.cpp @@ -105,16 +105,19 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc, } } -void create_stereo_duplex_pipeline(PipelineDescriptor& pipeline_desc, - std::vector&& runners, - std::vector&& stereo_runners, - size_t overlap, - uint32_t mean_qscore_start_pos, - int scaler_node_threads, - int splitter_node_threads, - PairingParameters pairing_parameters, - NodeHandle sink_node_handle, - NodeHandle source_node_handle) { +void create_stereo_duplex_pipeline( + PipelineDescriptor& pipeline_desc, + std::vector&& runners, + std::vector&& stereo_runners, + std::vector>&& modbase_runners, + size_t overlap, + uint32_t mean_qscore_start_pos, + int scaler_node_threads, + int splitter_node_threads, + int modbase_node_threads, + PairingParameters pairing_parameters, + NodeHandle sink_node_handle, + NodeHandle source_node_handle) { const auto& model_config = runners.front()->config(); const auto& stereo_model_config = stereo_runners.front()->config(); std::string model_name = @@ -130,6 +133,15 @@ void create_stereo_duplex_pipeline(PipelineDescriptor& pipeline_desc, {}, std::move(stereo_runners), adjusted_stereo_overlap, kStereoBatchTimeoutMS, duplex_rg_name, 1000, "StereoBasecallerNode", mean_qscore_start_pos); + NodeHandle last_node_handle = stereo_basecaller_node; + if (!modbase_runners.empty()) { + auto mod_base_caller_node = pipeline_desc.add_node( + {}, std::move(modbase_runners), modbase_node_threads, + size_t(runners.front()->model_stride()), 1000); + pipeline_desc.add_node_sink(stereo_basecaller_node, mod_base_caller_node); + last_node_handle = mod_base_caller_node; + } + auto simplex_model_stride = runners.front()->model_stride(); auto stereo_node = pipeline_desc.add_node({stereo_basecaller_node}, int(simplex_model_stride)); @@ -172,7 +184,7 @@ void create_stereo_duplex_pipeline(PipelineDescriptor& pipeline_desc, // if we've been provided a sink node, connect it to the end of our pipeline if (sink_node_handle != PipelineDescriptor::InvalidNodeHandle) { - pipeline_desc.add_node_sink(stereo_basecaller_node, sink_node_handle); + pipeline_desc.add_node_sink(last_node_handle, sink_node_handle); } } diff --git a/dorado/read_pipeline/Pipelines.h b/dorado/read_pipeline/Pipelines.h index c06767ef..f4a27211 100644 --- a/dorado/read_pipeline/Pipelines.h +++ b/dorado/read_pipeline/Pipelines.h @@ -37,16 +37,19 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc, /// Create a duplex basecall pipeline description /// If source_node_handle is valid, set this to be the source of the simplex pipeline /// If sink_node_handle is valid, set this to be the sink of the simplex pipeline -void create_stereo_duplex_pipeline(PipelineDescriptor& pipeline_desc, - std::vector&& runners, - std::vector&& stereo_runners, - size_t overlap, - uint32_t mean_qscore_start_pos, - int scaler_node_threads, - int splitter_node_threads, - PairingParameters pairing_parameters, - NodeHandle sink_node_handle, - NodeHandle source_node_handle); +void create_stereo_duplex_pipeline( + PipelineDescriptor& pipeline_desc, + std::vector&& runners, + std::vector&& stereo_runners, + std::vector>&& modbase_runners, + size_t overlap, + uint32_t mean_qscore_start_pos, + int scaler_node_threads, + int splitter_node_threads, + int modbase_node_threads, + PairingParameters pairing_parameters, + NodeHandle sink_node_handle, + NodeHandle source_node_handle); } // namespace pipelines diff --git a/dorado/read_pipeline/ReadPipeline.cpp b/dorado/read_pipeline/ReadPipeline.cpp index d8baed28..42b7f668 100644 --- a/dorado/read_pipeline/ReadPipeline.cpp +++ b/dorado/read_pipeline/ReadPipeline.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -173,6 +174,42 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { context_handler.update_mask(modbase_mask, seq, mod_base_info->alphabet, base_mod_probs, threshold); + if (is_duplex) { + // If this is a duplex read, we need to compute the reverse complement mask and combine it + auto reverse_complemented_seq = utils::reverse_complement(seq); + + // Compute the reverse complement mask + auto modbase_mask_rc = context_handler.get_sequence_mask(reverse_complemented_seq); + + auto reverseMatrix = [](const std::vector &matrix, int m_num_states) { + int numRows = static_cast(matrix.size()) / static_cast(m_num_states); + std::vector reversedMatrix(matrix.size()); + + for (int i = 0; i < numRows; ++i) { + for (int j = 0; j < m_num_states; ++j) { + reversedMatrix[i * m_num_states + j] = + matrix[(numRows - 1 - i) * m_num_states + j]; + } + } + + return reversedMatrix; + }; + + int num_states = static_cast(base_mod_probs.size()) / static_cast(seq.size()); + // Update the context mask using the reversed sequence + context_handler.update_mask(modbase_mask_rc, reverse_complemented_seq, + mod_base_info->alphabet, + reverseMatrix(base_mod_probs, num_states), threshold); + + // Reverse the mask in-place + std::reverse(modbase_mask_rc.begin(), modbase_mask_rc.end()); + + // Combine the original and reverse complement masks + // Using std::transform for better readability and potential efficiency + std::transform(modbase_mask.begin(), modbase_mask.end(), modbase_mask_rc.begin(), + modbase_mask.begin(), std::plus<>()); + } + // Iterate over the provided alphabet and find all the channels we need to write out for (size_t channel_idx = 0; channel_idx < num_channels; channel_idx++) { if (cardinal_bases.find(mod_base_info->alphabet[channel_idx]) != std::string::npos) { @@ -206,9 +243,44 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { } } + if (is_duplex) { + // Having done the strand in the forward direction, if the read is duplex we need to also process its complement + // There is some code repetition here, but it makes it more readable. + for (size_t channel_idx = 0; channel_idx < num_channels; channel_idx++) { + if (cardinal_bases.find(mod_base_info->alphabet[channel_idx]) != std::string::npos) { + // A cardinal base + current_cardinal = mod_base_info->alphabet[channel_idx][0]; + } else { + auto cardinal_complement = utils::complement_table[current_cardinal]; + // A modification on the previous cardinal base + std::string bam_name = mod_base_info->alphabet[channel_idx]; + if (!utils::validate_bam_tag_code(bam_name)) { + return; + } + + modbase_string += std::string(1, cardinal_complement) + "-" + bam_name; + modbase_string += base_has_context[current_cardinal] ? "?" : "."; + int skipped_bases = 0; + for (size_t base_idx = 0; base_idx < seq.size(); base_idx++) { + if (seq[base_idx] == cardinal_complement) { // complement + if (modbase_mask[base_idx]) { // Not sure this one is right + modbase_string += "," + std::to_string(skipped_bases); + skipped_bases = 0; + modbase_prob.push_back( + base_mod_probs[base_idx * num_channels + channel_idx]); + } else { + // Skip this base + skipped_bases++; + } + } + } + modbase_string += ";"; + } + } + } + int seq_len = int(seq.length()); bam_aux_append(aln, "MN", 'i', sizeof(seq_len), (uint8_t *)&seq_len); - bam_aux_append(aln, "MM", 'Z', int(modbase_string.length() + 1), (uint8_t *)modbase_string.c_str()); bam_aux_update_array(aln, "ML", 'C', int(modbase_prob.size()), (uint8_t *)modbase_prob.data()); diff --git a/dorado/utils/parameters.h b/dorado/utils/parameters.h index 1dde2fec..fe0be06f 100644 --- a/dorado/utils/parameters.h +++ b/dorado/utils/parameters.h @@ -22,7 +22,7 @@ struct DefaultParameters { int remora_batchsize{1024}; #endif int remora_threads{4}; - int remora_runners_per_caller{2}; + int mod_base_runners_per_caller{2}; float methylation_threshold{0.05f}; // Minimum length for a sequence to be outputted. diff --git a/dorado/utils/sequence_utils.cpp b/dorado/utils/sequence_utils.cpp index f81b1f82..aa274267 100644 --- a/dorado/utils/sequence_utils.cpp +++ b/dorado/utils/sequence_utils.cpp @@ -1,12 +1,16 @@ #include "sequence_utils.h" #include "simd.h" +#include "types.h" +#include +#include #include #include #include #include +#include #include #include #include @@ -26,23 +30,12 @@ reverse_complement_impl(const std::string& sequence) { std::string rev_comp_sequence; rev_comp_sequence.resize(num_bases); - // Compile-time constant lookup table. - static constexpr auto kComplementTable = [] { - std::array a{}; - // Valid input will only touch the entries set here. - a['A'] = 'T'; - a['T'] = 'A'; - a['C'] = 'G'; - a['G'] = 'C'; - return a; - }(); - // Run every template base through the table, reading in reverse order. const char* template_ptr = &sequence[num_bases - 1]; char* complement_ptr = &rev_comp_sequence[0]; for (size_t i = 0; i < num_bases; ++i) { const auto template_base = *template_ptr--; - *complement_ptr++ = kComplementTable[template_base]; + *complement_ptr++ = dorado::utils::complement_table[template_base]; } return rev_comp_sequence; } @@ -176,6 +169,136 @@ std::vector moves_to_map(const std::vector& moves, return seq_to_sig_map; } +OverlapResult compute_overlap(const std::string& query_seq, const std::string& target_seq) { + OverlapResult overlap_result = {false, 0, 0, 0, 0}; + + // Add mm2 based overlap check. + mm_idxopt_t m_idx_opt; + mm_mapopt_t m_map_opt; + mm_set_opt(0, &m_idx_opt, &m_map_opt); + mm_set_opt("map-hifi", &m_idx_opt, &m_map_opt); + + std::vector seqs = {query_seq.c_str()}; + std::vector names = {"query"}; + mm_idx_t* m_index = mm_idx_str(m_idx_opt.w, m_idx_opt.k, 0, m_idx_opt.bucket_bits, 1, + seqs.data(), names.data()); + mm_mapopt_update(&m_map_opt, m_index); + + MmTbufPtr mbuf = MmTbufPtr(mm_tbuf_init()); + + int hits = 0; + mm_reg1_t* reg = mm_map(m_index, int(target_seq.length()), target_seq.c_str(), &hits, + mbuf.get(), &m_map_opt, "target"); + + mm_idx_destroy(m_index); + + if (hits > 0) { + int32_t target_start = 0; + int32_t target_end = 0; + int32_t query_start = 0; + int32_t query_end = 0; + + auto best_map = std::max_element( + reg, reg + hits, + [](const mm_reg1_t& l, const mm_reg1_t& r) { return l.mapq < r.mapq; }); + target_start = best_map->rs; + target_end = best_map->re; + query_start = best_map->qs; + query_end = best_map->qe; + + overlap_result = {true, target_start, target_end, query_start, query_end}; + } + + for (int i = 0; i < hits; ++i) { + free(reg[i].p); + } + free(reg); + + return overlap_result; +} + +// Query is the read that the moves table is associated with. A new moves table will be generated +// Which is aligned to the target sequence. +std::tuple, int> realign_moves(const std::string& query_sequence, + const std::string& target_sequence, + const std::vector& moves) { + auto [is_overlap, query_start, query_end, target_start, target_end] = compute_overlap( + query_sequence, + target_sequence); // We are going to compute the overlap between the two reads + + // TODO sanity check if and why this is needed + // Now let's perform an alignmnet: + + EdlibAlignConfig align_config = edlibDefaultAlignConfig(); + align_config.task = EDLIB_TASK_PATH; + + auto target_sequence_component = + target_sequence.substr(target_start, target_end - target_start); + auto query_sequence_component = query_sequence.substr(query_start, query_end - query_start); + + EdlibAlignResult edlib_result = edlibAlign( + target_sequence_component.data(), static_cast(target_sequence_component.length()), + query_sequence_component.data(), static_cast(query_sequence_component.length()), + align_config); + + // Let's keep two cursor positions - one for the new move table and one for the old: + int new_move_cursor = 0; + int old_move_cursor = 0; + + int moves_found = 0; + + while (moves_found < int(moves.size()) && moves_found < int(query_start)) { + moves_found += moves[old_move_cursor]; + ++old_move_cursor; + } + --old_move_cursor; // We have gone one too far. + int old_moves_offset = old_move_cursor; + + const auto alignment_size = + static_cast(edlib_result.endLocations[0] - edlib_result.startLocations[0]); + // Now that we have the alignment, we need to compute the new move table, by walking along the alignment + std::vector new_moves; + for (size_t i = 0; i < alignment_size; i++) { + auto alignment_entry = edlib_result.alignment[i]; + if ((alignment_entry == 0) || + (alignment_entry == + 3)) { //Match or mismatch, need to update the new move table and move the cursor of the old move table. + new_moves.push_back(1); // We have a match so we need a 1 (move) + new_move_cursor++; + old_move_cursor++; + + while (moves[old_move_cursor] == 0) { + if (old_move_cursor < (new_move_cursor + old_moves_offset)) { + old_move_cursor++; + } else { + new_moves.push_back(0); + new_move_cursor++; + old_move_cursor++; + } + } + // Update the Query and target seq cursors + } else if (alignment_entry == 1) { //Insertion to target + // If we have an insertion in the target, we need to add a 1 to the new move table, and increment the new move table cursor. the old move table cursor and new are now out of sync and need fixing. + new_moves.push_back(1); + new_move_cursor++; + } else if (alignment_entry == 2) { //Insertion to Query + // We have a query insertion, all we need to do is add zeros to the new move table to make it up, the signal can be assigned to the leftmost nucleotide in the sequence. + new_moves.push_back(0); + new_move_cursor++; + old_move_cursor++; + while (moves[old_move_cursor] == 0) { + new_moves.push_back(0); + old_move_cursor++; + new_move_cursor++; + } + } + } + + edlibFreeAlignResult(edlib_result); + + return {old_moves_offset, target_start - 1, new_moves, query_start}; +} + std::vector move_cum_sums(const std::vector& moves) { std::vector ans(moves.size(), 0); if (!moves.empty()) { diff --git a/dorado/utils/sequence_utils.h b/dorado/utils/sequence_utils.h index 5df82283..7c4d27a0 100644 --- a/dorado/utils/sequence_utils.h +++ b/dorado/utils/sequence_utils.h @@ -1,8 +1,10 @@ #pragma once +#include #include #include #include +#include #include namespace dorado::utils { @@ -27,6 +29,11 @@ std::vector moves_to_map(const std::vector& moves, // Compute cumulative sums of the move table std::vector move_cum_sums(const std::vector& moves); +// Result of overlapping two reads +using OverlapResult = std::tuple; + +OverlapResult compute_overlap(const std::string& query_seq, const std::string& target_seq); + // Compute reverse complement of a nucleotide sequence. // Bases are specified as capital letters. // Undefined output if characters other than A, C, G, T appear. @@ -40,4 +47,23 @@ class BaseInfo { int count_trailing_chars(const std::string_view adapter, char c); +std::tuple, int> realign_moves(const std::string& query_sequence, + const std::string& target_sequence, + const std::vector& moves); + +// Compile-time constant lookup table. +static constexpr auto complement_table = [] { + std::array a{}; + // Valid input will only touch the entries set here. + a['A'] = 'T'; + a['T'] = 'A'; + a['C'] = 'G'; + a['G'] = 'C'; + a['a'] = 't'; + a['t'] = 'a'; + a['c'] = 'g'; + a['g'] = 'c'; + return a; +}(); + } // namespace dorado::utils diff --git a/tests/NodeSmokeTest.cpp b/tests/NodeSmokeTest.cpp index a23f554f..8f2cca06 100644 --- a/tests/NodeSmokeTest.cpp +++ b/tests/NodeSmokeTest.cpp @@ -301,7 +301,7 @@ DEFINE_TEST(NodeSmokeTestRead, "ModBaseCallerNode") { for (const auto& device_string : modbase_devices) { auto caller = dorado::create_modbase_caller({remora_model, remora_model_6mA}, batch_size, device_string); - for (int i = 0; i < default_params.remora_runners_per_caller; i++) { + for (int i = 0; i < default_params.mod_base_runners_per_caller; i++) { remora_runners.push_back(std::make_unique(caller)); } } diff --git a/tests/test_simple_basecaller_execution.bat b/tests/test_simple_basecaller_execution.bat index 11cc7b9a..8d7ecb3b 100644 --- a/tests/test_simple_basecaller_execution.bat +++ b/tests/test_simple_basecaller_execution.bat @@ -23,6 +23,14 @@ echo dorado duplex basespace test stage %dorado_bin% duplex basespace tests/data/basespace/pairs.bam --threads 1 --pairs tests/data/basespace/pairs.txt > calls.bam if %errorlevel% neq 0 exit /b %errorlevel% +echo dorado duplex hac complex +%dorado_bin% duplex hac tests/data/duplex/pod5 --threads 1 > $output_dir/duplex_calls.bam +if %errorlevel% neq 0 exit /b %errorlevel% + +echo dorado duplex hac complex with mods +%dorado_bin% duplex hac,5mCG_5hmCG tests/data/duplex/pod5 --threads 1 > $output_dir/duplex_calls_mods.bam +if %errorlevel% neq 0 exit /b %errorlevel% + echo dorado demux test stage %dorado_bin% demux tests/data/barcode_demux/double_end_variant/EXP-PBC096_BC83.fastq --threads 1 --kit-name EXP-PBC096 --output-dir ./demux if %errorlevel% neq 0 exit /b %errorlevel% diff --git a/tests/test_simple_basecaller_execution.sh b/tests/test_simple_basecaller_execution.sh index 137b2feb..4be0c369 100755 --- a/tests/test_simple_basecaller_execution.sh +++ b/tests/test_simple_basecaller_execution.sh @@ -153,7 +153,7 @@ if ! uname -r | grep -q tegra; then samtools quickcheck -u $output_dir/duplex_calls.bam num_duplex_reads=$(samtools view $output_dir/duplex_calls.bam | grep dx:i:1 | wc -l | awk '{print $1}') if [[ $num_duplex_reads -ne "2" ]]; then - echo "Duplex basecalling missing reads." + echo "Duplex basecalling missing reads - in-line" exit 1 fi @@ -162,7 +162,25 @@ if ! uname -r | grep -q tegra; then samtools quickcheck -u $output_dir/duplex_calls.bam num_duplex_reads=$(samtools view $output_dir/duplex_calls.bam | grep dx:i:1 | wc -l | awk '{print $1}') if [[ $num_duplex_reads -ne "2" ]]; then - echo "Duplex basecalling missing reads." + echo "Duplex basecalling missing reads - pairs file" + exit 1 + fi + + echo dorado in-line duplex from model complex + $dorado_bin duplex hac@v4.2.0 $data_dir/duplex/pod5 > $output_dir/duplex_calls_complex.bam + samtools quickcheck -u $output_dir/duplex_calls_complex.bam + num_duplex_reads=$(samtools view $output_dir/duplex_calls_complex.bam | grep dx:i:1 | wc -l | awk '{print $1}') + if [[ $num_duplex_reads -ne "2" ]]; then + echo "Duplex basecalling missing reads - model complex" + exit 1 + fi + + echo dorado in-line modbase duplex from model complex + $dorado_bin duplex hac,5mCG_5hmCG $data_dir/duplex/pod5 > $output_dir/duplex_calls_mods.bam + samtools quickcheck -u $output_dir/duplex_calls_mods.bam + num_duplex_reads=$(samtools view $output_dir/duplex_calls_mods.bam | grep dx:i:1 | wc -l | awk '{print $1}') + if [[ $num_duplex_reads -ne "2" ]]; then + echo "Duplex basecalling missing reads - mods" exit 1 fi fi