Skip to content

Commit

Permalink
Merge branch 'jdaw/rna-no-trim' into 'master'
Browse files Browse the repository at this point in the history
Control RNA adapter trimming

See merge request machine-learning/dorado!782
  • Loading branch information
tijyojwad committed Dec 21, 2023
2 parents a510d53 + e87a8e9 commit e42761c
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 8 deletions.
9 changes: 5 additions & 4 deletions dorado/api/pipeline_creation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc,
std::vector<modbase::RunnerPtr>&& modbase_runners,
size_t overlap,
uint32_t mean_qscore_start_pos,
bool trim_adapter,
int scaler_node_threads,
bool enable_read_splitter,
int splitter_node_threads,
Expand Down Expand Up @@ -57,9 +58,9 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc,
current_node_handle = rna_splitter_node;
}

auto scaler_node =
pipeline_desc.add_node<ScalerNode>({}, model_config.signal_norm_params,
model_config.sample_type, scaler_node_threads, 1000);
auto scaler_node = pipeline_desc.add_node<ScalerNode>({}, model_config.signal_norm_params,
model_config.sample_type, trim_adapter,
scaler_node_threads, 1000);
if (current_node_handle != PipelineDescriptor::InvalidNodeHandle) {
pipeline_desc.add_node_sink(current_node_handle, scaler_node);
} else {
Expand Down Expand Up @@ -174,7 +175,7 @@ void create_stereo_duplex_pipeline(PipelineDescriptor& pipeline_desc,
model_name, 1000, "BasecallerNode", mean_qscore_start_pos);

auto scaler_node = pipeline_desc.add_node<ScalerNode>(
{basecaller_node}, model_config.signal_norm_params, basecall::SampleType::DNA,
{basecaller_node}, model_config.signal_norm_params, basecall::SampleType::DNA, false,
scaler_node_threads, 1000);

// if we've been provided a source node, connect it to the start of our pipeline
Expand Down
1 change: 1 addition & 0 deletions dorado/api/pipeline_creation.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc,
std::vector<modbase::RunnerPtr>&& modbase_runners,
size_t overlap,
uint32_t mean_qscore_start_pos,
bool trim_adapter,
int scaler_node_threads,
bool enable_read_splitter,
int splitter_node_threads,
Expand Down
2 changes: 1 addition & 1 deletion dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ void setup(std::vector<std::string> args,
}
pipelines::create_simplex_pipeline(
pipeline_desc, std::move(runners), std::move(remora_runners), overlap,
mean_qscore_start_pos, thread_allocations.scaler_node_threads,
mean_qscore_start_pos, !adapter_no_trim, thread_allocations.scaler_node_threads,
true /* Enable read splitting */, thread_allocations.splitter_node_threads,
thread_allocations.remora_threads, current_sink_node,
PipelineDescriptor::InvalidNodeHandle);
Expand Down
6 changes: 4 additions & 2 deletions dorado/read_pipeline/ScalerNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void ScalerNode::worker_thread() {
bool is_rna = (m_model_type == SampleType::RNA002 || m_model_type == SampleType::RNA004);
// Trim adapter for RNA first before scaling.
int trim_start = 0;
if (is_rna) {
if (is_rna && m_trim_adapter) {
trim_start = determine_rna_adapter_pos(*read, m_model_type);
read->read_common.raw_data =
read->read_common.raw_data.index({Slice(trim_start, at::indexing::None)});
Expand Down Expand Up @@ -202,12 +202,14 @@ void ScalerNode::worker_thread() {

ScalerNode::ScalerNode(const SignalNormalisationParams& config,
SampleType model_type,
bool trim_adapter,
int num_worker_threads,
size_t max_reads)
: MessageSink(max_reads),
m_num_worker_threads(num_worker_threads),
m_scaling_params(config),
m_model_type(model_type) {
m_model_type(model_type),
m_trim_adapter(trim_adapter) {
start_threads();
}

Expand Down
2 changes: 2 additions & 0 deletions dorado/read_pipeline/ScalerNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ScalerNode : public MessageSink {
public:
ScalerNode(const basecall::SignalNormalisationParams& config,
basecall::SampleType model_type,
bool trim_adapter,
int num_worker_threads,
size_t max_reads);
~ScalerNode() { terminate_impl(); }
Expand All @@ -34,6 +35,7 @@ class ScalerNode : public MessageSink {

basecall::SignalNormalisationParams m_scaling_params;
const basecall::SampleType m_model_type;
const bool m_trim_adapter;

std::pair<float, float> med_mad(const at::Tensor& x);
std::pair<float, float> normalisation(const at::Tensor& x);
Expand Down
3 changes: 2 additions & 1 deletion tests/NodeSmokeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ DEFINE_TEST(NodeSmokeTestRead, "ScalerNode") {
auto model_type =
GENERATE(dorado::basecall::SampleType::DNA, dorado::basecall::SampleType::RNA002,
dorado::basecall::SampleType::RNA004);
auto trim_adapter = GENERATE(true, false);
CAPTURE(pipeline_restart);
CAPTURE(model_type);

Expand All @@ -187,7 +188,7 @@ DEFINE_TEST(NodeSmokeTestRead, "ScalerNode") {
config.quantile.quantile_b = 0.9f;
config.quantile.shift_multiplier = 0.51f;
config.quantile.scale_multiplier = 0.53f;
run_smoke_test<dorado::ScalerNode>(config, model_type, 2, 1000);
run_smoke_test<dorado::ScalerNode>(config, model_type, trim_adapter, 2, 1000);
}

DEFINE_TEST(NodeSmokeTestRead, "BasecallerNode") {
Expand Down

0 comments on commit e42761c

Please sign in to comment.