Skip to content

Commit

Permalink
Merge branch 'jdaw/fix-modbase-trim-reverse' into 'master'
Browse files Browse the repository at this point in the history
Correctly trim modbase tags for reverse strand alignments

Closes DOR-523

See merge request machine-learning/dorado!797

(cherry picked from commit 59a445b)

c562c9a Correctly trim modbase tags for reverse strand alignments
  • Loading branch information
tijyojwad committed Jan 17, 2024
1 parent 9959654 commit 8c2d004
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 2 deletions.
17 changes: 15 additions & 2 deletions dorado/demux/Trimmer.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "Trimmer.h"

#include "utils/bam_utils.h"
#include "utils/sequence_utils.h"
#include "utils/trim.h"

#include <ATen/ATen.h>
Expand Down Expand Up @@ -30,6 +31,13 @@ void trim_torch_tensor(at::Tensor& raw_data, std::pair<uint64_t,uint64_t> sample
raw_data = raw_data.index({Slice(sample_trim_interval.first, sample_trim_interval.second)});
}

// For alignments that are reverse complemented, the trim interval derived from adapters/barcodes
// will need to be reverse complemented when applied to the trimming of modbase tags because
// modbase tags are all relative to the original sequence that was basecalled.
std::pair<int, int> reverse_complement_interval(const std::pair<int, int>& interval, int seqlen) {
return {seqlen - interval.second, seqlen - interval.first};
}

} // namespace

namespace dorado {
Expand Down Expand Up @@ -122,6 +130,8 @@ std::pair<int, int> Trimmer::determine_trim_interval(const AdapterScoreResult& r
BamPtr Trimmer::trim_sequence(BamPtr input, std::pair<int, int> trim_interval) {
bam1_t* input_record = input.get();

bool is_seq_reversed = input_record->core.flag & BAM_FREVERSE;

// Fetch components that need to be trimmed.
std::string seq = utils::extract_sequence(input_record);
std::vector<uint8_t> qual = utils::extract_quality(input_record);
Expand All @@ -141,8 +151,10 @@ BamPtr Trimmer::trim_sequence(BamPtr input, std::pair<int, int> trim_interval) {
// |---------------------- ns ------------------|
// |----ts----|--------moves signal-------------|
ns = int(trimmed_moves.size() * stride) + ts;
auto [trimmed_modbase_str, trimmed_modbase_probs] =
utils::trim_modbase_info(seq, modbase_str, modbase_probs, trim_interval);
auto [trimmed_modbase_str, trimmed_modbase_probs] = utils::trim_modbase_info(
is_seq_reversed ? utils::reverse_complement(seq) : seq, modbase_str, modbase_probs,
is_seq_reversed ? reverse_complement_interval(trim_interval, int(seq.length()))
: trim_interval);
auto n_cigar = input_record->core.n_cigar;
std::vector<uint32_t> ops;
uint32_t ref_pos_consumed = 0;
Expand Down Expand Up @@ -180,6 +192,7 @@ BamPtr Trimmer::trim_sequence(BamPtr input, std::pair<int, int> trim_interval) {
bam_aux_del(out_record, bam_aux_get(out_record, "ML"));
bam_aux_update_array(out_record, "ML", 'C', int(trimmed_modbase_probs.size()),
(uint8_t*)trimmed_modbase_probs.data());
bam_aux_update_int(out_record, "MN", trimmed_seq.length());
}

bam_aux_update_int(out_record, "ts", ts);
Expand Down
7 changes: 7 additions & 0 deletions dorado/read_pipeline/HtsWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ int HtsWriter::write(bam1_t* const record) {
}
m_primary = m_total - m_secondary - m_supplementary - m_unmapped;

// Verify that the MN tag, if it exists, and the sequence length are in sync.
if (auto tag = bam_aux_get(record, "MN"); tag != nullptr) {
if (bam_aux2i(tag) != record->core.l_qseq) {
throw std::runtime_error("MN tag and sequence length are not in sync.");
};
}

// FIXME -- HtsWriter is constructed in a state where attempting to write
// will segfault, since set_and_write_header has to have been called
// in order to set m_header.
Expand Down
30 changes: 30 additions & 0 deletions tests/TrimTest.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
#include "utils/trim.h"

#include "TestUtils.h"
#include "demux/Trimmer.h"
#include "read_pipeline/HtsReader.h"

#include <ATen/ATen.h>
#include <catch2/catch.hpp>
#include <htslib/sam.h>

#include <filesystem>
#include <random>

using Catch::Matchers::Equals;
using Slice = at::indexing::Slice;
using namespace dorado;

#define TEST_GROUP "[utils][trim]"

namespace fs = std::filesystem;

TEST_CASE("Test trim signal", TEST_GROUP) {
constexpr int signal_len = 2000;

Expand Down Expand Up @@ -161,3 +170,24 @@ TEST_CASE("Test trim mod base info", TEST_GROUP) {
CHECK(probs.size() == 0);
}
}

// This test case is useful because trimming of reverse strand requires
// the modbase tags to be treated differently since they are written
// relative to the original sequence that was basecalled.
TEST_CASE("Test trim of reverse strand record in BAM", TEST_GROUP) {
const auto data_dir = fs::path(get_data_dir("trimmer"));
const auto bam_file = data_dir / "reverse_strand_record.bam";
HtsReader reader(bam_file.string(), std::nullopt);
reader.read();
auto &record = reader.record;

Trimmer trimmer;
const std::pair<int, int> trim_interval = {72, 647};
auto trimmed_record = trimmer.trim_sequence(std::move(record), trim_interval);
auto seqlen = trimmed_record->core.l_qseq;

CHECK(seqlen == (trim_interval.second - trim_interval.first));
CHECK(bam_aux2i(bam_aux_get(trimmed_record.get(), "MN")) == seqlen);
CHECK_THAT(bam_aux2Z(bam_aux_get(trimmed_record.get(), "MM")),
Equals("C+h?,28,24;C+m?,28,24;"));
}
Binary file added tests/data/trimmer/reverse_strand_record.bam
Binary file not shown.

0 comments on commit 8c2d004

Please sign in to comment.