diff --git a/dorado/data_loader/DataLoader.cpp b/dorado/data_loader/DataLoader.cpp index 07e7dae3..fa7c0d16 100644 --- a/dorado/data_loader/DataLoader.cpp +++ b/dorado/data_loader/DataLoader.cpp @@ -219,6 +219,20 @@ SimplexReadPtr process_pod5_read( new_read->read_common.experiment_id = run_info_data->experiment_name; new_read->read_common.is_duplex = false; + pod5_end_reason_t end_reason_value{POD5_END_REASON_UNKNOWN}; + char end_reason_string_value[200]; + size_t end_reason_string_value_size = sizeof(end_reason_string_value); + + pod5_error_t pod5_ret = + pod5_get_end_reason(batch, read_data.end_reason, &end_reason_value, + end_reason_string_value, &end_reason_string_value_size); + if (pod5_ret != POD5_OK) { + spdlog::error("Failed to get read end_reason {} {}", row, pod5_get_error_string()); + } else if (end_reason_value == POD5_END_REASON_UNBLOCK_MUX_CHANGE || + end_reason_value == POD5_END_REASON_MUX_CHANGE) { + new_read->read_common.attributes.is_end_reason_mux_change = true; + } + // Determine the time sorted predecessor of the read // if that information is available (primarily used for offline // duplex runs). diff --git a/dorado/read_pipeline/BasecallerNode.cpp b/dorado/read_pipeline/BasecallerNode.cpp index af65f70b..2014d3b9 100644 --- a/dorado/read_pipeline/BasecallerNode.cpp +++ b/dorado/read_pipeline/BasecallerNode.cpp @@ -2,6 +2,7 @@ #include "basecall/CRFModelConfig.h" #include "basecall/ModelRunnerBase.h" +#include "read_utils.h" #include "stitch.h" #include "utils/stats.h" @@ -203,6 +204,9 @@ void BasecallerNode::working_reads_manager() { // Chunks have ownership of the working read, so destroy them to avoid a leak. working_read->called_chunks.clear(); + // Trim reads which are affected by mux change and unblocking + utils::mux_change_trim_read(read_common_data); + // Cleanup the working read. { std::unique_lock working_reads_lock(m_working_reads_mutex); diff --git a/dorado/read_pipeline/messages.h b/dorado/read_pipeline/messages.h index 8933085d..2a5ae557 100644 --- a/dorado/read_pipeline/messages.h +++ b/dorado/read_pipeline/messages.h @@ -16,6 +16,7 @@ namespace dorado { namespace details { + struct Attributes { uint32_t mux{std::numeric_limits::max()}; // Channel mux int32_t read_number{-1}; // Per-channel number of each read as it was acquired by minknow @@ -23,6 +24,8 @@ struct Attributes { std::string start_time{}; //Read acquisition start time std::string fast5_filename{}; uint64_t num_samples; + // Indicates if this read had end reason `mux_change` or `unblock_mux_change` + bool is_end_reason_mux_change{false}; }; } // namespace details diff --git a/dorado/read_pipeline/read_utils.cpp b/dorado/read_pipeline/read_utils.cpp index 793cb5d9..870a2b2e 100644 --- a/dorado/read_pipeline/read_utils.cpp +++ b/dorado/read_pipeline/read_utils.cpp @@ -1,5 +1,19 @@ #include "read_utils.h" +#include "utils/math_utils.h" +#include "utils/sequence_utils.h" +#include "utils/trim.h" + +#include +#include + +#include +#include +#include +#include + +using Slice = at::indexing::Slice; + namespace dorado::utils { SimplexReadPtr shallow_copy_read(const SimplexRead& read) { auto copy = std::make_unique(); @@ -46,4 +60,100 @@ SimplexReadPtr shallow_copy_read(const SimplexRead& read) { return copy; } +int64_t find_mux_change_trim_seq_index(const std::string& qstring) { + const int64_t size = static_cast(qstring.size()); + // This algorithm categorises qscores into low, mid and high. For each base in reverse, the + // category score is accumulated and the index of the minimum value is taken as the + // trim index (e.g. argmin). + + // Thresholds low:[0, 7], mid:(7, 12], high:(12, 50] + // Add + 33 to avoid subtracting 33 from qsting + const int kLowThreshold = 7 + 33; + const int kHighThreshold = 12 + 33; + // Scores [-1, 1, 10] + const int kLowScore = -1; // Do not change without updating early exit conditional + const int kMidScore = 1; + const int kHighScore = 10; + + int64_t trim_index = size - 1; // index of minimum cumulative sum + int cum_sum = 0; // running total of cumulative sum + int cum_sum_min = -1; // minimum observed value + + for (int64_t i = size - 1; i >= 0; --i) { + // Cast the qstring char to qscore. -33 is skipped by adding 33 to thresholds + const int qs = static_cast(qstring[i]); + + if (qs <= kLowThreshold) { + cum_sum += kLowScore; + } else if (qs <= kHighThreshold) { + cum_sum += kMidScore; + } else { + cum_sum += kHighScore; + } + + if (cum_sum <= cum_sum_min) { + cum_sum_min = cum_sum; + trim_index = i - 1; + } + + // Early exit if cum_sum can't change by enough to change the result + // This assumes kLowScore == -1 + if (cum_sum > i) { + break; + } + } + return trim_index; +} + +void mux_change_trim_read(ReadCommon& read_common) { + if (!read_common.attributes.is_end_reason_mux_change) { + return; + } + + const auto sequence_size = static_cast(read_common.qstring.size()); + + // Do nothing for zero or very short sequences + if (sequence_size < 100) { + return; + } + + const int64_t trim_seq_idx = find_mux_change_trim_seq_index(read_common.qstring); + + // Excessive trimming - do nothing + if (trim_seq_idx < std::floor(sequence_size * 0.3f)) { + spdlog::trace("mux_change_trimming {} - size: {} trim: {} excessive trimming", + read_common.read_id, sequence_size, trim_seq_idx); + return; + } + + const int kMinMuxChangeTrim = 5; + // Nothing to do + if (trim_seq_idx >= sequence_size - kMinMuxChangeTrim) { + spdlog::trace("mux_change_trimming {} - no trim", read_common.read_id, trim_seq_idx); + return; + } + + // Trim the move table - We only trim from the back so no need to count leading trimmed samples + const int64_t trim_moves_idx = + utils::sequence_to_move_table_index(read_common.moves, trim_seq_idx, sequence_size); + + if (trim_moves_idx < 0) { + spdlog::trace("mux_change_trimming {} - move table index failed", read_common.read_id); + return; + } + read_common.moves.resize(trim_moves_idx); + + // Trim the sequence and qstring + const std::pair trim_interval = {0, int(trim_seq_idx)}; + read_common.seq = utils::trim_sequence(read_common.seq, trim_interval); + read_common.qstring = utils::trim_sequence(read_common.qstring, trim_interval); + + // Trim the signal + const size_t trim_signal_idx = read_common.moves.size() * read_common.model_stride; + read_common.raw_data = read_common.raw_data.index({Slice(0, trim_signal_idx)}); + + spdlog::trace("mux_change_trimming {} - seq(before:{} after:{} net:-{})", read_common.read_id, + sequence_size, trim_seq_idx + 1, sequence_size - trim_seq_idx - 1); +} + } // namespace dorado::utils diff --git a/dorado/read_pipeline/read_utils.h b/dorado/read_pipeline/read_utils.h index 3a1ebfdd..e8267fc4 100644 --- a/dorado/read_pipeline/read_utils.h +++ b/dorado/read_pipeline/read_utils.h @@ -4,4 +4,12 @@ namespace dorado::utils { SimplexReadPtr shallow_copy_read(const SimplexRead& read); + +// Find the trimming index for degraded ends of a mux_change read. +int64_t find_mux_change_trim_seq_index(const std::string& qstring); + +// Given a read, only trims reads which have end_reason `mux_change` or `unblock_mux_change` as +// the end of the sequence has been degraded. +void mux_change_trim_read(ReadCommon& read_common); + } // namespace dorado::utils diff --git a/dorado/utils/sequence_utils.cpp b/dorado/utils/sequence_utils.cpp index 481139b2..896a12e4 100644 --- a/dorado/utils/sequence_utils.cpp +++ b/dorado/utils/sequence_utils.cpp @@ -6,13 +6,17 @@ #include #include #include +#include #include #include +#include #include +#include #include #include #include +#include #include namespace { @@ -189,6 +193,54 @@ std::vector sequence_to_ints(const std::string& sequence) { return sequence_ints; } +int64_t sequence_to_move_table_index(const std::vector& move_vals, + int64_t sequence_index, + int64_t sequence_size) { + const int64_t moves_sz = static_cast(move_vals.size()); + // Check out-of-bounds and input consistency + const bool oob_moves = sequence_index >= moves_sz; + const bool oob_seq = sequence_index >= sequence_size; + const bool size_invalid = sequence_size > moves_sz; + + if (move_vals.empty() || oob_moves || oob_seq || size_invalid) { + spdlog::trace( + "sequence_to_move_table_index - bad input " + "seq_index:{} seq_size:{} move.size:{} - reason empty_moves: {} " + "oob_moves: {} oob_seq {} size_invalid: {}", + sequence_index, sequence_size, moves_sz, move_vals.empty(), oob_moves, oob_seq, + size_invalid); + return -1; + } + + if (sequence_index <= sequence_size / 2) { + // Start with -1 because as soon as the first move_val==1 is encountered, + // we have moved to the first base. + int64_t seq_base_pos = -1; + for (int64_t i = 0; i < moves_sz; i++) { + if (move_vals[i] == 1) { + seq_base_pos++; + // seq_base_pos always > 0 + if (seq_base_pos == sequence_index) { + return i; + } + } + } + } else { + // Start with size because as soon as the first move_val==1 is encountered, + // we have moved to the last index (size - 1). + int64_t seq_base_pos = sequence_size; + for (int64_t i = moves_sz - 1; i >= 0; --i) { + if (move_vals[i] == 1) { + seq_base_pos--; + if (seq_base_pos == sequence_index) { + return i; + } + } + } + } + return -1; +} + // Convert a move table to an array of the indices of the start/end of each base in the signal std::vector moves_to_map(const std::vector& moves, size_t block_stride, diff --git a/dorado/utils/sequence_utils.h b/dorado/utils/sequence_utils.h index e01868b1..75a4ece7 100644 --- a/dorado/utils/sequence_utils.h +++ b/dorado/utils/sequence_utils.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -24,6 +25,11 @@ inline int base_to_int(char c) { return 0b11 & ((c >> 2) ^ (c >> 1)); } // No checking is performed on the input std::vector sequence_to_ints(const std::string& sequence); +// Find the move table index for a given sequence index +int64_t sequence_to_move_table_index(const std::vector& move_vals, + int64_t sequence_index, + int64_t sequence_size); + // Convert move table to vector of indices std::vector moves_to_map(const std::vector& moves, size_t block_stride, diff --git a/tests/SequenceUtilsTest.cpp b/tests/SequenceUtilsTest.cpp index 51f0e650..32f2ab93 100644 --- a/tests/SequenceUtilsTest.cpp +++ b/tests/SequenceUtilsTest.cpp @@ -2,7 +2,9 @@ #include +#include #include +#include #define TEST_GROUP "[seq_utils]" @@ -176,3 +178,47 @@ TEST_CASE(TEST_GROUP "find rna polya - within search", TEST_GROUP) { const size_t res = dorado::utils::find_rna_polya(seq); CHECK(expected_index == res); } + +TEST_CASE("Test sequence to move table index", TEST_GROUP) { + SECTION("Happy path") { + // ---------------- seq index: 0, 1, , , , 2, 3, 4, , , 5, , 6, 7, + // ---------------- moves index: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16 + const std::vector move = {1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0}; + const size_t seq_size = move_cum_sums(move).back(); + + auto [seq_index, expected] = GENERATE(table({ + std::make_tuple(0, 0), + std::make_tuple(1, 1), + std::make_tuple(2, 5), + std::make_tuple(3, 6), + std::make_tuple(4, 7), + std::make_tuple(5, 10), + std::make_tuple(6, 12), + std::make_tuple(7, 13), + })); + + CAPTURE(seq_index); + const auto res = sequence_to_move_table_index(move, seq_index, seq_size); + CHECK(expected == res); + } + + SECTION("Empty moves") { + const std::vector move = {}; + const auto res = sequence_to_move_table_index(move, 0, 0); + CHECK(res < 0); + } + + SECTION("Bad sequence index") { + const std::vector move = {0, 1, 0, 1, 0}; + const size_t seq_size = move_cum_sums(move).back(); + const auto res = sequence_to_move_table_index(move, seq_size + 1, seq_size); + CHECK(res < 0); + } + + SECTION("Bad sequence size") { + const std::vector move = {0, 1, 0, 1, 0}; + const size_t bad_seq_size = move.size() + 1; + const auto res = sequence_to_move_table_index(move, 0, bad_seq_size); + CHECK(res < 0); + } +} diff --git a/tests/TrimTest.cpp b/tests/TrimTest.cpp index 46fe9fa7..43a6ce75 100644 --- a/tests/TrimTest.cpp +++ b/tests/TrimTest.cpp @@ -3,6 +3,7 @@ #include "TestUtils.h" #include "demux/Trimmer.h" #include "read_pipeline/HtsReader.h" +#include "read_pipeline/read_utils.h" #include #include @@ -10,6 +11,8 @@ #include #include +#include +#include using Catch::Matchers::Equals; using Slice = at::indexing::Slice; @@ -191,3 +194,40 @@ TEST_CASE("Test trim of reverse strand record in BAM", TEST_GROUP) { CHECK_THAT(bam_aux2Z(bam_aux_get(trimmed_record.get(), "MM")), Equals("C+h?,28,24;C+m?,28,24;")); } + +std::string to_qstr(std::vector qscore) { + std::string qstr; + for (size_t i = 0; i < qscore.size(); ++i) { + qstr += static_cast(qscore[i] + 33); + } + return qstr; +} + +TEST_CASE("Test find_mux_change_trim_seq_index", TEST_GROUP) { + SECTION("Trim simple") { + std::vector vec(50, 50); + for (size_t i = 40; i < vec.size(); ++i) { + vec[i] = 1; + } + CHECK(utils::find_mux_change_trim_seq_index(to_qstr(vec)) == 39); + } + + SECTION("Trim all") { + std::vector vec(50, 1); + CHECK(utils::find_mux_change_trim_seq_index(to_qstr(vec)) == -1); + } + + SECTION("Trim skip single high base") { + std::vector vec(50, 50); + for (size_t i = 30; i < vec.size(); ++i) { + vec[i] = 1; + } + vec[vec.size() - 1] = 50; + CHECK(utils::find_mux_change_trim_seq_index(to_qstr(vec)) == 29); + } + + SECTION("Trim nothing") { + std::vector vec(120, 50); + CHECK(utils::find_mux_change_trim_seq_index(to_qstr(vec)) == 119); + } +}