Skip to content

Commit

Permalink
Merge branch 'snurk/split-improve' into 'master'
Browse files Browse the repository at this point in the history
Fix oversplitting and improve splitting with pA-scaled models

See merge request machine-learning/dorado!728
  • Loading branch information
tijyojwad committed Dec 1, 2023
2 parents 4b743a1 + 5d484cf commit e9f060c
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 51 deletions.
7 changes: 4 additions & 3 deletions dorado/read_pipeline/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc,
first_node_handle = scaler_node;
}
current_node_handle = scaler_node;

auto basecaller_node = pipeline_desc.add_node<BasecallerNode>(
{}, std::move(runners), overlap, kBatchTimeoutMS, model_name, 1000, "BasecallerNode",
mean_qscore_start_pos);
Expand All @@ -76,7 +75,8 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc,

// For DNA, read splitting happens after basecall.
if (enable_read_splitter && !is_rna) {
splitter::DuplexSplitSettings splitter_settings;
splitter::DuplexSplitSettings splitter_settings(model_config.signal_norm_params.strategy ==
ScalingStrategy::PA);
splitter_settings.simplex_mode = true;
auto dna_splitter = std::make_unique<const splitter::DuplexReadSplitter>(splitter_settings);
auto dna_splitter_node = pipeline_desc.add_node<ReadSplitNode>({}, std::move(dna_splitter),
Expand Down Expand Up @@ -161,7 +161,8 @@ void create_stereo_duplex_pipeline(
// If splitter_settings.enabled is set to false, the splitter node will act
// as a passthrough, meaning it won't perform any splitting operations and
// will just pass data through.
splitter::DuplexSplitSettings splitter_settings;
splitter::DuplexSplitSettings splitter_settings(model_config.signal_norm_params.strategy ==
ScalingStrategy::PA);
auto duplex_splitter = std::make_unique<const splitter::DuplexReadSplitter>(splitter_settings);
auto splitter_node = pipeline_desc.add_node<ReadSplitNode>(
{pairing_node}, std::move(duplex_splitter), splitter_node_threads, 1000);
Expand Down
71 changes: 58 additions & 13 deletions dorado/splitter/DuplexReadSplitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <ATen/ATen.h>
#include <spdlog/spdlog.h>

#include <algorithm>
#include <cassert>
#include <chrono>
#include <cmath>
#include <iomanip>
Expand Down Expand Up @@ -100,11 +102,24 @@ std::optional<PosRange> check_rc_match(const std::string& seq,
return res;
}

// NB: Computes literal mean of the qscore values, not generally applicable
float qscore_mean(const std::string& qstring, splitter::PosRange r) {
uint64_t len = qstring.size();
uint64_t start = r.first;
uint64_t end = std::min(r.second, len);
assert(start < end);
uint64_t sum = 0;
for (size_t i = start; i < end; ++i) {
assert(qstring[i] > 33);
sum += qstring[i] - 33;
}
return float(1. * sum / (end - start));
}

} // namespace

namespace dorado::splitter {

//TODO consider precomputing and reusing ranges with high signal
struct DuplexReadSplitter::ExtRead {
SimplexReadPtr read;
at::Tensor data_as_float32;
Expand All @@ -130,26 +145,56 @@ PosRanges DuplexReadSplitter::possible_pore_regions(const DuplexReadSplitter::Ex
detect_pore_signal<float>(read.data_as_float32, m_settings.pore_thr,
m_settings.pore_cl_dist, m_settings.expect_pore_prefix);

PosRanges pore_regions;
std::vector<std::pair<float, PosRange>> candidate_regions;
for (auto pore_sample_range : pore_sample_ranges) {
auto move_start = pore_sample_range.first / read.read->read_common.model_stride;
auto move_end = pore_sample_range.second / read.read->read_common.model_stride;
assert(move_end >= move_start);
//NB move_start can get to move_sums.size(), because of the stride rounding?
if (move_start >= read.move_sums.size() || move_end >= read.move_sums.size() ||
read.move_sums[move_start] == 0) {
auto move_start = pore_sample_range.start_sample / read.read->read_common.model_stride;
auto move_end = pore_sample_range.end_sample / read.read->read_common.model_stride;
auto move_argmax = pore_sample_range.argmax_sample / read.read->read_common.model_stride;
assert(move_end >= move_argmax && move_argmax >= move_start);
if (move_end >= read.move_sums.size() || read.move_sums[move_start] == 0) {
//either at very end of the signal or basecalls have not started yet
continue;
}
auto start_pos = read.move_sums[move_start] - 1;
//NB. adding adapter length
//TODO check (- 1)
auto argmax_pos = read.move_sums[move_argmax] - 1;
auto end_pos = read.move_sums[move_end];
assert(end_pos > start_pos);
if (end_pos <= start_pos + m_settings.max_pore_region) {
pore_regions.push_back({start_pos, end_pos});
//check that detected cluster corresponds to not too many bases
if (end_pos > start_pos + m_settings.max_pore_region) {
continue;
}

assert(end_pos > argmax_pos && argmax_pos >= start_pos);
if (m_settings.use_argmax) {
//switch to position of max sample
start_pos = argmax_pos;
end_pos = argmax_pos + 1;
}

//check that mean qscore near pore is low
if (m_settings.qscore_check_span > 0 &&
qscore_mean(read.read->read_common.qstring,
{start_pos, start_pos + m_settings.qscore_check_span}) >
m_settings.mean_qscore_thr - std::numeric_limits<float>::epsilon()) {
continue;
}
candidate_regions.push_back({pore_sample_range.max_val, {start_pos, end_pos}});
}

//sorting by max signal value within the cluster
std::sort(candidate_regions.begin(), candidate_regions.end());
//take top candidates
PosRanges pore_regions;
for (size_t i = std::max(int64_t(candidate_regions.size()) - m_settings.top_candidates,
int64_t(0));
i < candidate_regions.size(); ++i) {
pore_regions.push_back(candidate_regions[i].second);
}
//sorting by first coordinate again
std::sort(pore_regions.begin(), pore_regions.end());

spdlog::trace("Detected {} potential pore regions in read {}", pore_regions.size(),
read.read->read_common.read_id);
return pore_regions;
}

Expand All @@ -158,7 +203,7 @@ bool DuplexReadSplitter::check_nearby_adapter(const SimplexRead& read,
int adapter_edist) const {
return find_best_adapter_match(m_settings.adapter, read.read_common.seq, adapter_edist,
//including spacer region in search
{r.first, std::min(r.second + m_settings.pore_adapter_range,
{r.first, std::min(r.second + m_settings.pore_adapter_span,
(uint64_t)read.read_common.seq.size())})
.has_value();
}
Expand Down
22 changes: 10 additions & 12 deletions dorado/splitter/RNAReadSplitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ RNAReadSplitter::ExtRead RNAReadSplitter::create_ext_read(SimplexReadPtr r) cons
detect_pore_signal<int16_t>(ext_read.read->read_common.raw_data, m_settings.pore_thr,
m_settings.pore_cl_dist, m_settings.expect_pore_prefix);
for (const auto& range : ext_read.possible_pore_regions) {
spdlog::trace("Pore range {}-{} {}", range.first, range.second,
spdlog::trace("Pore range {}-{} {}", range.start_sample, range.end_sample,
ext_read.read->read_common.read_id);
}
return ext_read;
}

std::vector<SimplexReadPtr> RNAReadSplitter::subreads(SimplexReadPtr read,
const PosRanges& spacers) const {
const SampleRanges<int16_t>& spacers) const {
std::vector<SimplexReadPtr> subreads;
subreads.reserve(spacers.size() + 1);

Expand All @@ -37,16 +37,16 @@ std::vector<SimplexReadPtr> RNAReadSplitter::subreads(SimplexReadPtr read,
return subreads;
}

uint64_t start_pos = 0;
uint64_t start_sample = 0;
for (const auto& r : spacers) {
if (start_pos < r.first) {
subreads.push_back(subread(*read, std::nullopt, {start_pos, r.first}));
if (start_sample < r.start_sample) {
subreads.push_back(subread(*read, std::nullopt, {start_sample, r.start_sample}));
}
start_pos = r.second;
start_sample = r.end_sample;
}
if (start_pos < read->read_common.get_raw_data_samples()) {
if (start_sample < read->read_common.get_raw_data_samples()) {
subreads.push_back(subread(*read, std::nullopt,
{start_pos, read->read_common.get_raw_data_samples()}));
{start_sample, read->read_common.get_raw_data_samples()}));
}

return subreads;
Expand All @@ -55,10 +55,8 @@ std::vector<SimplexReadPtr> RNAReadSplitter::subreads(SimplexReadPtr read,
std::vector<std::pair<std::string, RNAReadSplitter::SplitFinderF>>
RNAReadSplitter::build_split_finders() const {
std::vector<std::pair<std::string, SplitFinderF>> split_finders;
split_finders.push_back({"PORE_ADAPTER", [&](const ExtRead& read) {
return filter_ranges(read.possible_pore_regions,
[&](PosRange) { return true; });
}});
split_finders.push_back(
{"PORE_ADAPTER", [&](const ExtRead& read) { return read.possible_pore_regions; }});

return split_finders;
}
Expand Down
7 changes: 3 additions & 4 deletions dorado/splitter/RNAReadSplitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@ class RNAReadSplitter : public ReadSplitter {
//TODO consider precomputing and reusing ranges with high signal
struct ExtRead {
SimplexReadPtr read;
splitter::PosRanges possible_pore_regions;
splitter::SampleRanges<int16_t> possible_pore_regions;
};

using SplitFinderF = std::function<splitter::PosRanges(const ExtRead&)>;
using SplitFinderF = std::function<splitter::SampleRanges<int16_t>(const ExtRead&)>;

ExtRead create_ext_read(SimplexReadPtr r) const;
std::vector<splitter::PosRange> possible_pore_regions(const ExtRead& read) const;

std::vector<SimplexReadPtr> subreads(SimplexReadPtr read,
const splitter::PosRanges& spacers) const;
const splitter::SampleRanges<int16_t>& spacers) const;

std::vector<std::pair<std::string, SplitFinderF>> build_split_finders() const;

Expand Down
22 changes: 19 additions & 3 deletions dorado/splitter/ReadSplitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,18 @@ struct RNASplitSettings {
struct DuplexSplitSettings {
bool enabled = true;
bool simplex_mode = false;
float pore_thr = 2.2f;
uint64_t pore_cl_dist = 4000; // TODO maybe use frequency * 1sec here?
float pore_thr = 2.4f;
uint64_t pore_cl_dist = 500; // in samples
//maximal 'open pore' region to consider (bp)
uint64_t max_pore_region = 500;
//only use position with signal maximal as a tentative open pore
bool use_argmax = true;
//number of bases to check quality (starting with pore region start)
int qscore_check_span = 5;
//only take fixed number of candidates with maximal signal
int top_candidates = 10;
//filter tentative open pore regions with mean qscore higher than threshold
float mean_qscore_thr = 10.;
//usually template read region to the left of potential spacer region
uint64_t strand_end_flank = 1200;
//trim potentially erroneous (and/or PCR adapter) bases at end of query
Expand All @@ -40,7 +48,9 @@ struct DuplexSplitSettings {
float relaxed_flank_err = 0.275f;
int adapter_edist = 4;
int relaxed_adapter_edist = 8;
uint64_t pore_adapter_range = 100; //bp
//bp from end of tentative pore to end of adapter
//(~ max pore-adapter dist + adapter length)
uint64_t pore_adapter_span = 50;
//in bases
uint64_t expect_adapter_prefix = 200;
//in samples
Expand All @@ -54,6 +64,12 @@ struct DuplexSplitSettings {
//Sequence below corresponds to the current 'head' adapter 'AATGTACTTCGTTCAGTTACGTATTGCT'
// with 4bp clipped from the beginning (24bp left)
std::string adapter = "TACTTCGTTCAGTTACGTATTGCT";

explicit DuplexSplitSettings(bool pA_scaling) {
if (pA_scaling) {
pore_thr = 2.8f;
}
}
};

class ReadSplitter {
Expand Down
4 changes: 2 additions & 2 deletions dorado/splitter/splitter_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ __attribute__((optimize("O0")))

SimplexReadPtr subread(const SimplexRead& read,
std::optional<PosRange> seq_range,
PosRange signal_range) {
std::pair<uint64_t, uint64_t> signal_range) {
//TODO support mods
//NB: currently doesn't support mods
//assert(read.mod_base_info == nullptr && read.base_mod_probs.empty());
Expand Down Expand Up @@ -96,7 +96,7 @@ PosRanges merge_ranges(const PosRanges& ranges, uint64_t merge_dist) {
if (merged.empty() || r.first > merged.back().second + merge_dist) {
merged.push_back(r);
} else {
merged.back().second = r.second;
merged.back().second = std::max(r.second, merged.back().second);
}
}
return merged;
Expand Down
40 changes: 33 additions & 7 deletions dorado/splitter/splitter_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
#include <ATen/core/TensorBody.h>

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <limits>
#include <optional>
#include <utility>
#include <vector>

namespace dorado::splitter {
Expand All @@ -30,25 +33,48 @@ SimplexReadPtr subread(const SimplexRead& read,
PosRange signal_range);

template <typename T>
std::vector<std::pair<uint64_t, uint64_t>> detect_pore_signal(const at::Tensor& signal,
T threshold,
uint64_t cluster_dist,
uint64_t ignore_prefix) {
std::vector<std::pair<uint64_t, uint64_t>> ans;
struct SampleRange {
//inclusive
uint64_t start_sample;
//exclusive
uint64_t end_sample;
uint64_t argmax_sample;
T max_val;

SampleRange(uint64_t start, uint64_t end, uint64_t argmax, T max)
: start_sample(start), end_sample(end), argmax_sample(argmax), max_val(max) {}
};

template <typename T>
using SampleRanges = std::vector<SampleRange<T>>;

template <typename T>
SampleRanges<T> detect_pore_signal(const at::Tensor& signal,
T threshold,
uint64_t cluster_dist,
uint64_t ignore_prefix) {
SampleRanges<T> ans;
auto pore_a = signal.accessor<T, 1>();
int64_t cl_start = -1;
int64_t cl_end = -1;

T cl_max = std::numeric_limits<T>::min();
int64_t cl_argmax = -1;
for (auto i = ignore_prefix; i < uint64_t(pore_a.size(0)); i++) {
if (pore_a[i] > threshold) {
//check if we need to start new cluster
if (cl_end == -1 || i > cl_end + cluster_dist) {
//report previous cluster
if (cl_end != -1) {
assert(cl_start != -1);
ans.push_back({cl_start, cl_end});
ans.push_back(SampleRange(cl_start, cl_end, cl_argmax, cl_max));
}
cl_start = i;
cl_max = std::numeric_limits<T>::min();
}
if (pore_a[i] >= cl_max) {
cl_max = pore_a[i];
cl_argmax = i;
}
cl_end = i + 1;
}
Expand All @@ -57,7 +83,7 @@ std::vector<std::pair<uint64_t, uint64_t>> detect_pore_signal(const at::Tensor&
if (cl_end != -1) {
assert(cl_start != -1);
assert(cl_start < pore_a.size(0) && cl_end <= pore_a.size(0));
ans.push_back({cl_start, cl_end});
ans.push_back(SampleRange(cl_start, cl_end, cl_argmax, cl_max));
}

return ans;
Expand Down
Loading

0 comments on commit e9f060c

Please sign in to comment.