Skip to content

Commit

Permalink
Merge branch 'rh/end_reason_trim' into 'master'
Browse files Browse the repository at this point in the history
End reason mux_change trimming

See merge request machine-learning/dorado!839
  • Loading branch information
tijyojwad committed Mar 11, 2024
2 parents 13ba5af + e2556ee commit 9d3af87
Show file tree
Hide file tree
Showing 9 changed files with 283 additions and 0 deletions.
14 changes: 14 additions & 0 deletions dorado/data_loader/DataLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,20 @@ SimplexReadPtr process_pod5_read(
new_read->read_common.experiment_id = run_info_data->experiment_name;
new_read->read_common.is_duplex = false;

pod5_end_reason_t end_reason_value{POD5_END_REASON_UNKNOWN};
char end_reason_string_value[200];
size_t end_reason_string_value_size = sizeof(end_reason_string_value);

pod5_error_t pod5_ret =
pod5_get_end_reason(batch, read_data.end_reason, &end_reason_value,
end_reason_string_value, &end_reason_string_value_size);
if (pod5_ret != POD5_OK) {
spdlog::error("Failed to get read end_reason {} {}", row, pod5_get_error_string());
} else if (end_reason_value == POD5_END_REASON_UNBLOCK_MUX_CHANGE ||
end_reason_value == POD5_END_REASON_MUX_CHANGE) {
new_read->read_common.attributes.is_end_reason_mux_change = true;
}

// Determine the time sorted predecessor of the read
// if that information is available (primarily used for offline
// duplex runs).
Expand Down
4 changes: 4 additions & 0 deletions dorado/read_pipeline/BasecallerNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "basecall/CRFModelConfig.h"
#include "basecall/ModelRunnerBase.h"
#include "read_utils.h"
#include "stitch.h"
#include "utils/stats.h"

Expand Down Expand Up @@ -203,6 +204,9 @@ void BasecallerNode::working_reads_manager() {
// Chunks have ownership of the working read, so destroy them to avoid a leak.
working_read->called_chunks.clear();

// Trim reads which are affected by mux change and unblocking
utils::mux_change_trim_read(read_common_data);

// Cleanup the working read.
{
std::unique_lock<std::mutex> working_reads_lock(m_working_reads_mutex);
Expand Down
3 changes: 3 additions & 0 deletions dorado/read_pipeline/messages.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
namespace dorado {

namespace details {

struct Attributes {
uint32_t mux{std::numeric_limits<uint32_t>::max()}; // Channel mux
int32_t read_number{-1}; // Per-channel number of each read as it was acquired by minknow
int32_t channel_number{-1}; //Channel ID
std::string start_time{}; //Read acquisition start time
std::string fast5_filename{};
uint64_t num_samples;
// Indicates if this read had end reason `mux_change` or `unblock_mux_change`
bool is_end_reason_mux_change{false};
};
} // namespace details

Expand Down
110 changes: 110 additions & 0 deletions dorado/read_pipeline/read_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
#include "read_utils.h"

#include "utils/math_utils.h"
#include "utils/sequence_utils.h"
#include "utils/trim.h"

#include <ATen/ATen.h>
#include <spdlog/spdlog.h>

#include <cmath>
#include <cstdint>
#include <optional>
#include <string_view>

using Slice = at::indexing::Slice;

namespace dorado::utils {
SimplexReadPtr shallow_copy_read(const SimplexRead& read) {
auto copy = std::make_unique<SimplexRead>();
Expand Down Expand Up @@ -46,4 +60,100 @@ SimplexReadPtr shallow_copy_read(const SimplexRead& read) {
return copy;
}

int64_t find_mux_change_trim_seq_index(const std::string& qstring) {
const int64_t size = static_cast<int64_t>(qstring.size());
// This algorithm categorises qscores into low, mid and high. For each base in reverse, the
// category score is accumulated and the index of the minimum value is taken as the
// trim index (e.g. argmin).

// Thresholds low:[0, 7], mid:(7, 12], high:(12, 50]
// Add + 33 to avoid subtracting 33 from qsting
const int kLowThreshold = 7 + 33;
const int kHighThreshold = 12 + 33;
// Scores [-1, 1, 10]
const int kLowScore = -1; // Do not change without updating early exit conditional
const int kMidScore = 1;
const int kHighScore = 10;

int64_t trim_index = size - 1; // index of minimum cumulative sum
int cum_sum = 0; // running total of cumulative sum
int cum_sum_min = -1; // minimum observed value

for (int64_t i = size - 1; i >= 0; --i) {
// Cast the qstring char to qscore. -33 is skipped by adding 33 to thresholds
const int qs = static_cast<int>(qstring[i]);

if (qs <= kLowThreshold) {
cum_sum += kLowScore;
} else if (qs <= kHighThreshold) {
cum_sum += kMidScore;
} else {
cum_sum += kHighScore;
}

if (cum_sum <= cum_sum_min) {
cum_sum_min = cum_sum;
trim_index = i - 1;
}

// Early exit if cum_sum can't change by enough to change the result
// This assumes kLowScore == -1
if (cum_sum > i) {
break;
}
}
return trim_index;
}

void mux_change_trim_read(ReadCommon& read_common) {
if (!read_common.attributes.is_end_reason_mux_change) {
return;
}

const auto sequence_size = static_cast<int64_t>(read_common.qstring.size());

// Do nothing for zero or very short sequences
if (sequence_size < 100) {
return;
}

const int64_t trim_seq_idx = find_mux_change_trim_seq_index(read_common.qstring);

// Excessive trimming - do nothing
if (trim_seq_idx < std::floor(sequence_size * 0.3f)) {
spdlog::trace("mux_change_trimming {} - size: {} trim: {} excessive trimming",
read_common.read_id, sequence_size, trim_seq_idx);
return;
}

const int kMinMuxChangeTrim = 5;
// Nothing to do
if (trim_seq_idx >= sequence_size - kMinMuxChangeTrim) {
spdlog::trace("mux_change_trimming {} - no trim", read_common.read_id, trim_seq_idx);
return;
}

// Trim the move table - We only trim from the back so no need to count leading trimmed samples
const int64_t trim_moves_idx =
utils::sequence_to_move_table_index(read_common.moves, trim_seq_idx, sequence_size);

if (trim_moves_idx < 0) {
spdlog::trace("mux_change_trimming {} - move table index failed", read_common.read_id);
return;
}
read_common.moves.resize(trim_moves_idx);

// Trim the sequence and qstring
const std::pair<int, int> trim_interval = {0, int(trim_seq_idx)};
read_common.seq = utils::trim_sequence(read_common.seq, trim_interval);
read_common.qstring = utils::trim_sequence(read_common.qstring, trim_interval);

// Trim the signal
const size_t trim_signal_idx = read_common.moves.size() * read_common.model_stride;
read_common.raw_data = read_common.raw_data.index({Slice(0, trim_signal_idx)});

spdlog::trace("mux_change_trimming {} - seq(before:{} after:{} net:-{})", read_common.read_id,
sequence_size, trim_seq_idx + 1, sequence_size - trim_seq_idx - 1);
}

} // namespace dorado::utils
8 changes: 8 additions & 0 deletions dorado/read_pipeline/read_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,12 @@

namespace dorado::utils {
SimplexReadPtr shallow_copy_read(const SimplexRead& read);

// Find the trimming index for degraded ends of a mux_change read.
int64_t find_mux_change_trim_seq_index(const std::string& qstring);

// Given a read, only trims reads which have end_reason `mux_change` or `unblock_mux_change` as
// the end of the sequence has been degraded.
void mux_change_trim_read(ReadCommon& read_common);

} // namespace dorado::utils
52 changes: 52 additions & 0 deletions dorado/utils/sequence_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
#include <edlib.h>
#include <minimap.h>
#include <nvtx3/nvtx3.hpp>
#include <spdlog/spdlog.h>

#include <algorithm>
#include <array>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <iterator>
#include <numeric>
#include <optional>
#include <vector>

namespace {
Expand Down Expand Up @@ -189,6 +193,54 @@ std::vector<int> sequence_to_ints(const std::string& sequence) {
return sequence_ints;
}

int64_t sequence_to_move_table_index(const std::vector<uint8_t>& move_vals,
int64_t sequence_index,
int64_t sequence_size) {
const int64_t moves_sz = static_cast<int64_t>(move_vals.size());
// Check out-of-bounds and input consistency
const bool oob_moves = sequence_index >= moves_sz;
const bool oob_seq = sequence_index >= sequence_size;
const bool size_invalid = sequence_size > moves_sz;

if (move_vals.empty() || oob_moves || oob_seq || size_invalid) {
spdlog::trace(
"sequence_to_move_table_index - bad input "
"seq_index:{} seq_size:{} move.size:{} - reason empty_moves: {} "
"oob_moves: {} oob_seq {} size_invalid: {}",
sequence_index, sequence_size, moves_sz, move_vals.empty(), oob_moves, oob_seq,
size_invalid);
return -1;
}

if (sequence_index <= sequence_size / 2) {
// Start with -1 because as soon as the first move_val==1 is encountered,
// we have moved to the first base.
int64_t seq_base_pos = -1;
for (int64_t i = 0; i < moves_sz; i++) {
if (move_vals[i] == 1) {
seq_base_pos++;
// seq_base_pos always > 0
if (seq_base_pos == sequence_index) {
return i;
}
}
}
} else {
// Start with size because as soon as the first move_val==1 is encountered,
// we have moved to the last index (size - 1).
int64_t seq_base_pos = sequence_size;
for (int64_t i = moves_sz - 1; i >= 0; --i) {
if (move_vals[i] == 1) {
seq_base_pos--;
if (seq_base_pos == sequence_index) {
return i;
}
}
}
}
return -1;
}

// Convert a move table to an array of the indices of the start/end of each base in the signal
std::vector<uint64_t> moves_to_map(const std::vector<uint8_t>& moves,
size_t block_stride,
Expand Down
6 changes: 6 additions & 0 deletions dorado/utils/sequence_utils.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <array>
#include <cstdint>
#include <optional>
#include <string>
#include <string_view>
Expand All @@ -24,6 +25,11 @@ inline int base_to_int(char c) { return 0b11 & ((c >> 2) ^ (c >> 1)); }
// No checking is performed on the input
std::vector<int> sequence_to_ints(const std::string& sequence);

// Find the move table index for a given sequence index
int64_t sequence_to_move_table_index(const std::vector<uint8_t>& move_vals,
int64_t sequence_index,
int64_t sequence_size);

// Convert move table to vector of indices
std::vector<uint64_t> moves_to_map(const std::vector<uint8_t>& moves,
size_t block_stride,
Expand Down
46 changes: 46 additions & 0 deletions tests/SequenceUtilsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

#include <catch2/catch.hpp>

#include <cstdint>
#include <cstdlib>
#include <optional>

#define TEST_GROUP "[seq_utils]"

Expand Down Expand Up @@ -176,3 +178,47 @@ TEST_CASE(TEST_GROUP "find rna polya - within search", TEST_GROUP) {
const size_t res = dorado::utils::find_rna_polya(seq);
CHECK(expected_index == res);
}

TEST_CASE("Test sequence to move table index", TEST_GROUP) {
SECTION("Happy path") {
// ---------------- seq index: 0, 1, , , , 2, 3, 4, , , 5, , 6, 7,
// ---------------- moves index: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16
const std::vector<uint8_t> move = {1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0};
const size_t seq_size = move_cum_sums(move).back();

auto [seq_index, expected] = GENERATE(table<size_t, int64_t>({
std::make_tuple(0, 0),
std::make_tuple(1, 1),
std::make_tuple(2, 5),
std::make_tuple(3, 6),
std::make_tuple(4, 7),
std::make_tuple(5, 10),
std::make_tuple(6, 12),
std::make_tuple(7, 13),
}));

CAPTURE(seq_index);
const auto res = sequence_to_move_table_index(move, seq_index, seq_size);
CHECK(expected == res);
}

SECTION("Empty moves") {
const std::vector<uint8_t> move = {};
const auto res = sequence_to_move_table_index(move, 0, 0);
CHECK(res < 0);
}

SECTION("Bad sequence index") {
const std::vector<uint8_t> move = {0, 1, 0, 1, 0};
const size_t seq_size = move_cum_sums(move).back();
const auto res = sequence_to_move_table_index(move, seq_size + 1, seq_size);
CHECK(res < 0);
}

SECTION("Bad sequence size") {
const std::vector<uint8_t> move = {0, 1, 0, 1, 0};
const size_t bad_seq_size = move.size() + 1;
const auto res = sequence_to_move_table_index(move, 0, bad_seq_size);
CHECK(res < 0);
}
}
40 changes: 40 additions & 0 deletions tests/TrimTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
#include "TestUtils.h"
#include "demux/Trimmer.h"
#include "read_pipeline/HtsReader.h"
#include "read_pipeline/read_utils.h"

#include <ATen/ATen.h>
#include <catch2/catch.hpp>
#include <htslib/sam.h>

#include <filesystem>
#include <random>
#include <string>
#include <vector>

using Catch::Matchers::Equals;
using Slice = at::indexing::Slice;
Expand Down Expand Up @@ -191,3 +194,40 @@ TEST_CASE("Test trim of reverse strand record in BAM", TEST_GROUP) {
CHECK_THAT(bam_aux2Z(bam_aux_get(trimmed_record.get(), "MM")),
Equals("C+h?,28,24;C+m?,28,24;"));
}

std::string to_qstr(std::vector<int8_t> qscore) {
std::string qstr;
for (size_t i = 0; i < qscore.size(); ++i) {
qstr += static_cast<char>(qscore[i] + 33);
}
return qstr;
}

TEST_CASE("Test find_mux_change_trim_seq_index", TEST_GROUP) {
SECTION("Trim simple") {
std::vector<int8_t> vec(50, 50);
for (size_t i = 40; i < vec.size(); ++i) {
vec[i] = 1;
}
CHECK(utils::find_mux_change_trim_seq_index(to_qstr(vec)) == 39);
}

SECTION("Trim all") {
std::vector<int8_t> vec(50, 1);
CHECK(utils::find_mux_change_trim_seq_index(to_qstr(vec)) == -1);
}

SECTION("Trim skip single high base") {
std::vector<int8_t> vec(50, 50);
for (size_t i = 30; i < vec.size(); ++i) {
vec[i] = 1;
}
vec[vec.size() - 1] = 50;
CHECK(utils::find_mux_change_trim_seq_index(to_qstr(vec)) == 29);
}

SECTION("Trim nothing") {
std::vector<int8_t> vec(120, 50);
CHECK(utils::find_mux_change_trim_seq_index(to_qstr(vec)) == 119);
}
}

0 comments on commit 9d3af87

Please sign in to comment.