diff --git a/dorado/cli/aligner.cpp b/dorado/cli/aligner.cpp index a1637f78..35999ffd 100644 --- a/dorado/cli/aligner.cpp +++ b/dorado/cli/aligner.cpp @@ -53,7 +53,14 @@ int aligner(int argc, char* argv[]) { .help("maximum number of reads to process (for debugging, 0=unlimited).") .default_value(0) .scan<'i', int>(); - parser.visible.add_argument("-v", "--verbose").default_value(false).implicit_value(true); + int verbosity = 0; + parser.visible.add_argument("-v", "--verbose") + .default_value(false) + .implicit_value(true) + .nargs(0) + .action([&](const auto&) { ++verbosity; }) + .append(); + cli::add_minimap2_arguments(parser, Aligner::dflt_options); try { @@ -67,7 +74,7 @@ int aligner(int argc, char* argv[]) { if (parser.visible.get("--verbose")) { mm_verbose = 3; - utils::SetDebugLogging(); + utils::SetVerboseLogging(static_cast(verbosity)); } auto index(parser.visible.get("index")); diff --git a/dorado/cli/basecaller.cpp b/dorado/cli/basecaller.cpp index de4c60f8..55d1ef10 100644 --- a/dorado/cli/basecaller.cpp +++ b/dorado/cli/basecaller.cpp @@ -150,7 +150,7 @@ void setup(std::vector args, if (estimate_poly_a) { current_sink_node = pipeline_desc.add_node( {current_sink_node}, std::thread::hardware_concurrency(), - PolyACalculator::get_model_type(model_name)); + is_rna_model(model_config)); } if (!barcode_kits.empty()) { current_sink_node = pipeline_desc.add_node( @@ -264,7 +264,13 @@ int basecaller(int argc, char* argv[]) { parser.visible.add_argument("data").help("the data directory or file (POD5/FAST5 format)."); - parser.visible.add_argument("-v", "--verbose").default_value(false).implicit_value(true); + int verbosity = 0; + parser.visible.add_argument("-v", "--verbose") + .default_value(false) + .implicit_value(true) + .nargs(0) + .action([&](const auto&) { ++verbosity; }) + .append(); parser.visible.add_argument("-x", "--device") .help("device string in format \"cuda:0,...,N\", \"cuda:all\", \"metal\", \"cpu\" " @@ -389,7 +395,7 @@ int basecaller(int argc, char* argv[]) { std::vector args(argv, argv + argc); if (parser.visible.get("--verbose")) { - utils::SetDebugLogging(); + utils::SetVerboseLogging(static_cast(verbosity)); } auto model = parser.visible.get("model"); diff --git a/dorado/cli/demux.cpp b/dorado/cli/demux.cpp index 2fd2bd64..f2d00f1a 100644 --- a/dorado/cli/demux.cpp +++ b/dorado/cli/demux.cpp @@ -69,7 +69,13 @@ int demuxer(int argc, char* argv[]) { parser.add_argument("-l", "--read-ids") .help("A file with a newline-delimited list of reads to demux.") .default_value(std::string("")); - parser.add_argument("-v", "--verbose").default_value(false).implicit_value(true); + int verbosity = 0; + parser.add_argument("-v", "--verbose") + .default_value(false) + .implicit_value(true) + .nargs(0) + .action([&](const auto&) { ++verbosity; }) + .append(); parser.add_argument("--emit-fastq") .help("Output in fastq format. Default is BAM.") .default_value(false) @@ -99,7 +105,7 @@ int demuxer(int argc, char* argv[]) { } if (parser.get("--verbose")) { - utils::SetDebugLogging(); + utils::SetVerboseLogging(static_cast(verbosity)); } auto reads(parser.get>("reads")); diff --git a/dorado/cli/download.cpp b/dorado/cli/download.cpp index 76464819..4d1aa8db 100644 --- a/dorado/cli/download.cpp +++ b/dorado/cli/download.cpp @@ -47,8 +47,6 @@ int download(int argc, char* argv[]) { argparse::ArgumentParser parser("dorado", DORADO_VERSION, argparse::default_arguments::help); - parser.add_argument("-v", "--verbose").default_value(false).implicit_value(true); - parser.add_argument("--model").default_value(std::string("all")).help("the model to download"); parser.add_argument("--directory") @@ -62,6 +60,14 @@ int download(int argc, char* argv[]) { .default_value(false) .implicit_value(true); + int verbosity = 0; + parser.add_argument("-v", "--verbose") + .default_value(false) + .implicit_value(true) + .nargs(0) + .action([&](const auto&) { ++verbosity; }) + .append(); + try { parser.parse_args(argc, argv); } catch (const std::exception& e) { @@ -72,7 +78,7 @@ int download(int argc, char* argv[]) { } if (parser.get("--verbose")) { - utils::SetDebugLogging(); + utils::SetVerboseLogging(static_cast(verbosity)); } auto list = parser.get("--list"); diff --git a/dorado/cli/duplex.cpp b/dorado/cli/duplex.cpp index 2d376748..67962aa2 100644 --- a/dorado/cli/duplex.cpp +++ b/dorado/cli/duplex.cpp @@ -110,7 +110,13 @@ int duplex(int argc, char* argv[]) { .help("Path to reference for alignment.") .default_value(std::string("")); - parser.visible.add_argument("-v", "--verbose").default_value(false).implicit_value(true); + int verbosity = 0; + parser.visible.add_argument("-v", "--verbose") + .default_value(false) + .implicit_value(true) + .nargs(0) + .action([&](const auto&) { ++verbosity; }) + .append(); cli::add_minimap2_arguments(parser, Aligner::dflt_options); cli::add_internal_arguments(parser); @@ -133,7 +139,7 @@ int duplex(int argc, char* argv[]) { const bool basespace_duplex = (model.compare("basespace") == 0); std::vector args(argv, argv + argc); if (parser.visible.get("--verbose")) { - utils::SetDebugLogging(); + utils::SetVerboseLogging(static_cast(verbosity)); } std::map template_complement_map; auto read_list = utils::load_read_list(parser.visible.get("--read-ids")); diff --git a/dorado/cli/summary.cpp b/dorado/cli/summary.cpp index 99816b3c..8da957fb 100644 --- a/dorado/cli/summary.cpp +++ b/dorado/cli/summary.cpp @@ -50,7 +50,13 @@ int summary(int argc, char *argv[]) { argparse::ArgumentParser parser("dorado", DORADO_VERSION, argparse::default_arguments::help); parser.add_argument("reads").help("SAM/BAM file produced by dorado basecaller."); parser.add_argument("-s", "--separator").default_value(std::string("\t")); - parser.add_argument("-v", "--verbose").default_value(false).implicit_value(true); + int verbosity = 0; + parser.add_argument("-v", "--verbose") + .default_value(false) + .implicit_value(true) + .nargs(0) + .action([&](const auto &) { ++verbosity; }) + .append(); try { parser.parse_args(argc, argv); @@ -62,7 +68,7 @@ int summary(int argc, char *argv[]) { } if (parser.get("--verbose")) { - utils::SetDebugLogging(); + utils::SetVerboseLogging(static_cast(verbosity)); } std::vector header = { diff --git a/dorado/nn/CRFModelConfig.cpp b/dorado/nn/CRFModelConfig.cpp index b50fd894..d0f84476 100644 --- a/dorado/nn/CRFModelConfig.cpp +++ b/dorado/nn/CRFModelConfig.cpp @@ -9,6 +9,18 @@ namespace dorado { +SampleType get_model_type(const std::string &model_name) { + if (model_name.find("rna004") != std::string::npos) { + return SampleType::RNA004; + } else if (model_name.find("rna002") != std::string::npos) { + return SampleType::RNA002; + } else if (model_name.find("dna") != std::string::npos) { + return SampleType::DNA; + } else { + throw std::runtime_error("Could not determine model type for " + model_name); + } +} + CRFModelConfig load_crf_model_config(const std::filesystem::path &path) { const auto config_toml = toml::parse(path / "config.toml"); @@ -96,6 +108,8 @@ CRFModelConfig load_crf_model_config(const std::filesystem::path &path) { config.signal_norm_params.quantile_scaling = false; } + config.sample_type = get_model_type(model_name); + return config; } diff --git a/dorado/nn/CRFModelConfig.h b/dorado/nn/CRFModelConfig.h index f01a1444..fb360c19 100644 --- a/dorado/nn/CRFModelConfig.h +++ b/dorado/nn/CRFModelConfig.h @@ -15,6 +15,12 @@ struct SignalNormalisationParams { bool quantile_scaling = true; }; +enum SampleType { + DNA, + RNA002, + RNA004, +}; + // Values extracted from config.toml used in construction of the model module. struct CRFModelConfig { float qscale = 1.0f; @@ -42,6 +48,8 @@ struct CRFModelConfig { // Start position for mean Q-score calculation for // short reads. int32_t mean_qscore_start_pos = -1; + + SampleType sample_type; }; CRFModelConfig load_crf_model_config(const std::filesystem::path& path); diff --git a/dorado/read_pipeline/Pipelines.cpp b/dorado/read_pipeline/Pipelines.cpp index 84a1c1ee..8331edaf 100644 --- a/dorado/read_pipeline/Pipelines.cpp +++ b/dorado/read_pipeline/Pipelines.cpp @@ -62,7 +62,7 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc, auto scaler_node = pipeline_desc.add_node({basecaller_node}, model_config.signal_norm_params, - is_rna_model(model_config), scaler_node_threads); + model_config.sample_type, scaler_node_threads); // if we've been provided a source node, connect it to the start of our pipeline if (source_node_handle != PipelineDescriptor::InvalidNodeHandle) { @@ -130,8 +130,9 @@ void create_stereo_duplex_pipeline(PipelineDescriptor& pipeline_desc, {splitter_node}, std::move(runners), adjusted_simplex_overlap, kSimplexBatchTimeoutMS, model_name, 1000, "BasecallerNode", true, mean_qscore_start_pos); - auto scaler_node = pipeline_desc.add_node( - {basecaller_node}, model_config.signal_norm_params, false, scaler_node_threads); + auto scaler_node = + pipeline_desc.add_node({basecaller_node}, model_config.signal_norm_params, + SampleType::DNA, scaler_node_threads); // if we've been provided a source node, connect it to the start of our pipeline if (source_node_handle != PipelineDescriptor::InvalidNodeHandle) { diff --git a/dorado/read_pipeline/PolyACalculator.cpp b/dorado/read_pipeline/PolyACalculator.cpp index b774bb8c..c3ffb6cf 100644 --- a/dorado/read_pipeline/PolyACalculator.cpp +++ b/dorado/read_pipeline/PolyACalculator.cpp @@ -256,57 +256,11 @@ SignalAnchorInfo determine_signal_anchor_and_strand_cdna(const dorado::SimplexRe return result; } -// Since the adapter in RNA is still DNA, the basecall quality of the adapter is poor because we -// infer with a model trained on RNA data. So finding a match of the adapter sequence in the RNA sequence -// doesn't work well. Instead, the raw signal is traversed to find a point -// where there's a jump in the median signal value, which is indicative of the -// transition from the DNA adapter to the RNA signal. The polyA will start right -// at that juncture. This function returns a struct with the strand -// direction (which is always reverse for dRNA), the signal anchor and the number of bases -// to omit from the tail length estimation due to any adapter effects. -SignalAnchorInfo determine_signal_anchor_and_strand_drna( - const dorado::SimplexRead& read, - dorado::PolyACalculator::ModelType model_type) { - static const std::unordered_map kOffsetMap = { - {dorado::PolyACalculator::ModelType::RNA002, 5000}, - {dorado::PolyACalculator::ModelType::RNA004, 1000}}; - static const std::unordered_map kMaxSignalPosMap = { - {dorado::PolyACalculator::ModelType::RNA002, 10000}, - {dorado::PolyACalculator::ModelType::RNA004, 5000}}; - - const int kWindowSize = 250; - const int kStride = 50; - const int kOffset = kOffsetMap.at(model_type); - const int kMaxSignalPos = kMaxSignalPosMap.at(model_type); - - const float kMinMedianForRNASignal = 0.f; - const float kMinMedianDiff = 1.f; - - int bp = -1; - int signal_len = read.read_common.get_raw_data_samples(); - auto sig_fp32 = read.read_common.raw_data.to(torch::kFloat); - float last_median = 0.f; - for (int i = kOffset; i < std::min(signal_len / 2, kMaxSignalPos); i += kStride) { - auto slice = sig_fp32.slice(0, std::max(0, i - kWindowSize / 2), - std::min(signal_len, i + kWindowSize / 2)); - float median = slice.median().item(); - if (median > kMinMedianForRNASignal) { - float diff = median - last_median; - if (diff > kMinMedianDiff) { - bp = i; - break; - } - } - last_median = median; - } - - spdlog::debug("Approx break point {}", bp); - - if (bp > 0) { - return SignalAnchorInfo{false, bp, 0}; - } else { - return SignalAnchorInfo{false, -1, 0}; - } +// RNA polyA appears at the beginning of the strand. Since the adapter +// for RNA has been trimmed off already, the polyA search can begin +// from the start of the signal. +SignalAnchorInfo determine_signal_anchor_and_strand_drna(const dorado::SimplexRead& read) { + return SignalAnchorInfo{false, 0, 0}; } } // namespace @@ -330,7 +284,7 @@ void PolyACalculator::worker_thread() { // Determine the strand direction, approximate base space anchor for the tail, and whether // the final length needs to be adjusted depending on the adapter sequence. auto [fwd, signal_anchor, trailing_Ts] = - m_is_rna ? determine_signal_anchor_and_strand_drna(*read, m_model_type) + m_is_rna ? determine_signal_anchor_and_strand_drna(*read) : determine_signal_anchor_and_strand_cdna(*read); if (signal_anchor >= 0) { @@ -358,7 +312,7 @@ void PolyACalculator::worker_thread() { // Update debug stats. total_tail_lengths_called += num_bases; ++num_called; - if (spdlog::get_level() == spdlog::level::debug) { + if (spdlog::get_level() <= spdlog::level::debug) { std::lock_guard lock(m_mutex); tail_length_counts[num_bases]++; } @@ -378,13 +332,8 @@ void PolyACalculator::worker_thread() { } } -PolyACalculator::PolyACalculator(size_t num_worker_threads, - PolyACalculator::ModelType model_type, - size_t max_reads) - : MessageSink(max_reads), - m_num_worker_threads(num_worker_threads), - m_is_rna(model_type == ModelType::RNA004 || model_type == ModelType::RNA002), - m_model_type(model_type) { +PolyACalculator::PolyACalculator(size_t num_worker_threads, bool is_rna, size_t max_reads) + : MessageSink(max_reads), m_num_worker_threads(num_worker_threads), m_is_rna(is_rna) { start_threads(); } @@ -410,14 +359,14 @@ void PolyACalculator::terminate_impl() { // Visualize a distribution of the tail lengths called. static bool done = false; - if (!done && spdlog::get_level() == spdlog::level::debug) { + if (!done && (spdlog::get_level() <= spdlog::level::debug)) { int max_val = -1; for (auto [k, v] : tail_length_counts) { max_val = std::max(v, max_val); } int factor = std::max(1, 1 + max_val / 100); for (auto [k, v] : tail_length_counts) { - spdlog::debug("{} : {}", k, std::string(v / factor, '*')); + spdlog::debug("{:03d} : {}", k, std::string(v / factor, '*')); } done = true; } @@ -437,16 +386,4 @@ stats::NamedStats PolyACalculator::sample_stats() const { return stats; } -PolyACalculator::ModelType PolyACalculator::get_model_type(const std::string& model_name) { - if (model_name.find("rna004") != std::string::npos) { - return PolyACalculator::ModelType::RNA004; - } else if (model_name.find("rna002") != std::string::npos) { - return PolyACalculator::ModelType::RNA002; - } else if (model_name.find("dna") != std::string::npos) { - return PolyACalculator::ModelType::DNA; - } else { - throw std::runtime_error("Could not determine model type for " + model_name); - } -} - } // namespace dorado diff --git a/dorado/read_pipeline/PolyACalculator.h b/dorado/read_pipeline/PolyACalculator.h index d66a3f60..0431601a 100644 --- a/dorado/read_pipeline/PolyACalculator.h +++ b/dorado/read_pipeline/PolyACalculator.h @@ -15,19 +15,12 @@ namespace dorado { class PolyACalculator : public MessageSink { public: - enum ModelType { - DNA, - RNA002, - RNA004, - }; - - PolyACalculator(size_t num_worker_threads, ModelType model_type, size_t max_reads = 1000); + PolyACalculator(size_t num_worker_threads, bool is_rna, size_t max_reads = 1000); ~PolyACalculator() { terminate_impl(); } std::string get_name() const override { return "PolyACalculator"; } stats::NamedStats sample_stats() const override; void terminate(const FlushOptions& flush_options) override { terminate_impl(); }; void restart() override; - static ModelType get_model_type(const std::string& model_name); private: void start_threads(); @@ -38,7 +31,6 @@ class PolyACalculator : public MessageSink { std::vector> m_workers; size_t m_num_worker_threads = 0; const bool m_is_rna; - const ModelType m_model_type; std::atomic total_tail_lengths_called{0}; std::atomic num_called{0}; std::atomic num_not_called{0}; diff --git a/dorado/read_pipeline/ScalerNode.cpp b/dorado/read_pipeline/ScalerNode.cpp index c8d63cf6..6f6d0ce3 100644 --- a/dorado/read_pipeline/ScalerNode.cpp +++ b/dorado/read_pipeline/ScalerNode.cpp @@ -6,10 +6,13 @@ #include #include +#include +#include #include +#include #include -#define EPS 1e-9f +static constexpr float EPS = 1e-9f; using namespace std::chrono_literals; using Slice = torch::indexing::Slice; @@ -37,6 +40,53 @@ std::pair ScalerNode::med_mad(const torch::Tensor& x) { return {med.item(), mad.item()}; } +// This function returns the approximate position where the DNA adapter +// in a dRNA read ends. The adapter location is determined by looking +// at the median signal value over a sliding window on the raw signal. +// RNA002 and RNA004 have different offsets and thresholds for the +// sliding window heuristic. +int determine_rna_adapter_pos(const dorado::SimplexRead& read, dorado::SampleType model_type) { + assert(read.read_common.raw_data.dtype() == torch::kInt16); + static const std::unordered_map kOffsetMap = { + {dorado::SampleType::RNA002, 4000}, {dorado::SampleType::RNA004, 1000}}; + static const std::unordered_map kAdapterCutoff = { + {dorado::SampleType::RNA002, 550}, {dorado::SampleType::RNA004, 700}}; + + const int kWindowSize = 250; + const int kStride = 50; + const int16_t kMedianDiff = 125; + + const int kOffset = kOffsetMap.at(model_type); + const int16_t kMinMedianForRNASignal = kAdapterCutoff.at(model_type); + + int signal_len = read.read_common.get_raw_data_samples(); + const int16_t* signal = static_cast(read.read_common.raw_data.data_ptr()); + + // Check the median value change over 5 windows. + std::array medians = {0, 0, 0, 0, 0}; + int median_pos = 0; + int break_point = 0; + const int signal_start = kOffsetMap.at(model_type); + const int signal_end = static_cast(3 * signal_len / 4); + for (int i = signal_start; i < signal_end; i += kStride) { + auto slice = torch::from_blob(const_cast(&signal[i]), + {static_cast(std::min(kWindowSize, signal_len - i))}, + torch::TensorOptions().dtype(torch::kInt16)); + int16_t median = slice.median().item(); + medians[median_pos++ % medians.size()] = median; + auto minmax = std::minmax_element(medians.begin(), medians.end()); + int16_t min_median = *minmax.first; + int16_t max_median = *minmax.second; + if ((median_pos > medians.size()) && (max_median > kMinMedianForRNASignal) && + (max_median - min_median > kMedianDiff)) { + break_point = i; + break; + } + } + + return break_point; +} + void ScalerNode::worker_thread() { torch::InferenceMode inference_mode_guard; @@ -50,6 +100,15 @@ void ScalerNode::worker_thread() { auto read = std::get(std::move(message)); + bool is_rna = (m_model_type == SampleType::RNA002 || m_model_type == SampleType::RNA004); + // Trim adapter for RNA first before scaling. + int trim_start = 0; + if (is_rna) { + trim_start = determine_rna_adapter_pos(*read, m_model_type); + read->read_common.raw_data = + read->read_common.raw_data.index({Slice(trim_start, torch::indexing::None)}); + } + assert(read->read_common.raw_data.dtype() == torch::kInt16); const auto [shift, scale] = m_scaling_params.quantile_scaling ? normalisation(read->read_common.raw_data) @@ -68,20 +127,20 @@ void ScalerNode::worker_thread() { read->read_common.shift = read->scaling * (shift + read->offset); // Don't perform DNA trimming on RNA since it looks too different and we lose useful signal. - int trim_start = 0; - if (!m_is_rna) { + if (!is_rna) { // 8000 value may be changed in future. Currently this is found to work well. int max_samples = std::min(8000, static_cast(read->read_common.get_raw_data_samples() / 2)); trim_start = utils::trim( read->read_common.raw_data.index({Slice(torch::indexing::None, max_samples)})); + read->read_common.raw_data = + read->read_common.raw_data.index({Slice(trim_start, torch::indexing::None)}); } - read->read_common.raw_data = - read->read_common.raw_data.index({Slice(trim_start, torch::indexing::None)}); read->read_common.num_trimmed_samples = trim_start; - spdlog::debug("{} {} {} {}", read->read_common.read_id, shift, scale, trim_start); + spdlog::trace("ScalerNode: {} shift: {} scale: {} trim: {}", read->read_common.read_id, + shift, scale, trim_start); // Pass the read to the next node send_message_to_sink(std::move(read)); @@ -89,13 +148,13 @@ void ScalerNode::worker_thread() { } ScalerNode::ScalerNode(const SignalNormalisationParams& config, - bool is_rna, + SampleType model_type, int num_worker_threads, size_t max_reads) : MessageSink(max_reads), m_scaling_params(config), m_num_worker_threads(num_worker_threads), - m_is_rna(is_rna) { + m_model_type(model_type) { start_threads(); } diff --git a/dorado/read_pipeline/ScalerNode.h b/dorado/read_pipeline/ScalerNode.h index fda9654c..743faae4 100644 --- a/dorado/read_pipeline/ScalerNode.h +++ b/dorado/read_pipeline/ScalerNode.h @@ -16,7 +16,7 @@ namespace dorado { class ScalerNode : public MessageSink { public: ScalerNode(const SignalNormalisationParams& config, - bool is_rna, + SampleType model_type, int num_worker_threads = 5, size_t max_reads = 1000); ~ScalerNode() { terminate_impl(); } @@ -33,7 +33,7 @@ class ScalerNode : public MessageSink { std::atomic m_num_worker_threads; SignalNormalisationParams m_scaling_params; - const bool m_is_rna; + const SampleType m_model_type; std::pair med_mad(const torch::Tensor& x); std::pair normalisation(const torch::Tensor& x); diff --git a/dorado/utils/log_utils.cpp b/dorado/utils/log_utils.cpp index 88cbf3fc..ba22abeb 100644 --- a/dorado/utils/log_utils.cpp +++ b/dorado/utils/log_utils.cpp @@ -73,9 +73,13 @@ void InitLogging() { } } -void SetDebugLogging() { +void SetVerboseLogging(VerboseLogLevel level) { if (is_safe_to_log()) { - spdlog::set_level(spdlog::level::debug); + if (level >= VerboseLogLevel::TRACE) { + spdlog::set_level(spdlog::level::trace); + } else if (level <= VerboseLogLevel::DEBUG) { + spdlog::set_level(spdlog::level::debug); + } } } diff --git a/dorado/utils/log_utils.h b/dorado/utils/log_utils.h index d01b3aae..910fb273 100644 --- a/dorado/utils/log_utils.h +++ b/dorado/utils/log_utils.h @@ -5,6 +5,11 @@ namespace dorado::utils { // Initialises the default logger to point to stderr. void InitLogging(); -void SetDebugLogging(); +enum class VerboseLogLevel { + DEBUG = 1, + TRACE = 2, +}; + +void SetVerboseLogging(VerboseLogLevel level); } // namespace dorado::utils diff --git a/tests/NodeSmokeTest.cpp b/tests/NodeSmokeTest.cpp index 9bb453dc..479955cf 100644 --- a/tests/NodeSmokeTest.cpp +++ b/tests/NodeSmokeTest.cpp @@ -3,6 +3,7 @@ #include "decode/CPUDecoder.h" #include "models/models.h" #include "nn/CRFModel.h" +#include "nn/CRFModelConfig.h" #include "nn/ModBaseModel.h" #include "nn/ModBaseRunner.h" #include "nn/ModelRunner.h" @@ -163,9 +164,10 @@ TempDir download_model(const std::string& model) { DEFINE_TEST(NodeSmokeTestRead, "ScalerNode") { auto pipeline_restart = GENERATE(false, true); - auto is_rna = GENERATE(true, false); + auto model_type = GENERATE(dorado::SampleType::DNA, dorado::SampleType::RNA002, + dorado::SampleType::RNA004); CAPTURE(pipeline_restart); - CAPTURE(is_rna); + CAPTURE(model_type); set_pipeline_restart(pipeline_restart); @@ -179,7 +181,7 @@ DEFINE_TEST(NodeSmokeTestRead, "ScalerNode") { config.quantile_b = 0.9; config.shift_multiplier = 0.51; config.scale_multiplier = 0.53; - run_smoke_test(config, is_rna, 2); + run_smoke_test(config, model_type, 2); } DEFINE_TEST(NodeSmokeTestRead, "BasecallerNode") { @@ -370,8 +372,7 @@ TEST_CASE("BarcodeClassifierNode: test simple pipeline with fastq and sam files" DEFINE_TEST(NodeSmokeTestRead, "PolyACalculator") { auto pipeline_restart = GENERATE(false, true); - auto is_rna = GENERATE(dorado::PolyACalculator::ModelType::DNA, - dorado::PolyACalculator::ModelType::RNA004); + auto is_rna = GENERATE(false, true); CAPTURE(pipeline_restart); CAPTURE(is_rna); diff --git a/tests/PolyACalculatorTest.cpp b/tests/PolyACalculatorTest.cpp index 218f55e3..f8c9ac7a 100644 --- a/tests/PolyACalculatorTest.cpp +++ b/tests/PolyACalculatorTest.cpp @@ -5,6 +5,7 @@ #include "utils/sequence_utils.h" #include +#include #include #include @@ -21,20 +22,18 @@ using namespace dorado; struct TestCase { int estimated_bases = 0; std::string test_dir; - PolyACalculator::ModelType model_type; + bool is_rna; }; TEST_CASE("PolyACalculator: Test polyT tail estimation", TEST_GROUP) { - auto [gt, data, model_type] = - GENERATE(TestCase{92, "poly_a/r9_rev_cdna", PolyACalculator::ModelType::DNA}, - TestCase{31, "poly_a/r10_fwd_cdna", PolyACalculator::ModelType::DNA}, - TestCase{29, "poly_a/rna002", PolyACalculator::ModelType::RNA002}, - TestCase{64, "poly_a/rna004", PolyACalculator::ModelType::RNA004}); + auto [gt, data, is_rna] = GENERATE( + TestCase{92, "poly_a/r9_rev_cdna", false}, TestCase{31, "poly_a/r10_fwd_cdna", false}, + TestCase{28, "poly_a/rna002", true}, TestCase{67, "poly_a/rna004", true}); dorado::PipelineDescriptor pipeline_desc; std::vector messages; auto sink = pipeline_desc.add_node({}, 100, messages); - auto estimator = pipeline_desc.add_node({sink}, 2, model_type); + auto estimator = pipeline_desc.add_node({sink}, 2, is_rna); auto pipeline = dorado::Pipeline::create(std::move(pipeline_desc)); diff --git a/tests/data/poly_a/rna002/moves.bin b/tests/data/poly_a/rna002/moves.bin index 22493b27..6f741d44 100644 Binary files a/tests/data/poly_a/rna002/moves.bin and b/tests/data/poly_a/rna002/moves.bin differ diff --git a/tests/data/poly_a/rna002/seq.txt b/tests/data/poly_a/rna002/seq.txt index bc7cc35d..8031d19e 100644 --- a/tests/data/poly_a/rna002/seq.txt +++ b/tests/data/poly_a/rna002/seq.txt @@ -1 +1 @@ -GGTATCCATGGTTACGACCTGATTTCGAAAAACTGGTAGCCAGCTATCAGGCCGGAAGAGGTCACCATGCGCTACTCAACTTCAGGCGTTACCGGGCATGGGCGATGATGCTTTAATCTACGCCCTGAGCCGTTATTTACTCTGCCAACAACCGCAGGGCCTTTCAAAGTTGCGGTCACTGTCGTGGATGTCAGTTGATGCAGGCTGGCACGCATCCCGATTACTACACCGGCCCGAAAAGGAAAAAATACGCTGGGCGTTGATGCGGTACGTGAGGTCACCGAAAAAGCTGAATGAGCACGCACGCTTAGGTGGTGCGAAAGTCGTTTGGGTAACCGATGCTGCCTTACTATCGACGCCGCGGCTAACGCATTGCTGAAAACGCTCGAAGAGCCAGCAGAAACTGGTTTTCCTGGCTACCCGCGAGCCTGAACGTTTACTGGCAACATTACGTAGTCGTTGTCGGTTACCTTTTGCGCCGCCGGAACAGTACGCCGTGACCTGGCTTTCACGCGAAGTGACAATGTCACAGGATGCATTACTTCTGCCGCATTGCGCTTAAGCGCCGGTACGTCTGGCGCGGCACTGGCGTTGTTTCAGGGAGATAACTGGCAGGCTCGTGAAACATTTGTGTCAGGCGTTGGCTAGCGTGCCATCGGGCGACTGGTATTCGCTGCTGTGGCCCTTAATCATGAACAAGCTCCGGCGCGTTTACACTGGCTGGCAACGTTGCTGATGGATGCGCTAAAACGCCATTATGGTGCTGCGCAGGTGACCAATGTTGATGTGCCGGGCCTGTACGTCGAACTGGCAAACCATCTTTCTCCCTCGCGCCTGCAGGCTATACTGGGGGATGTTTTGCCACATTCGTGAACAGTTAATGTCTGTTACAGGCATCAACCGCGAGCTTTTCATCACCGATCTTTTGCGTTGAGCATTACCTGCATGGGCGTTGTGCTACCAGCTTCCATCTTGGCTGCAAAAAAAAAAATTCCTCCTCCTCTACTCCTATCATCCATCATCATCCCATCATCCATCATCCTCTTATTCC \ No newline at end of file +GGTATCCATGGTTACGACCTGATTTCGAAAAACTGGTAGCCAGCTATCAGGCCGGAAGAGGTCACCATGCGCTACTCAACTTCAGGCGTTACCGGGCATGGGCGATGATGCTTTAATCTACGCCCTGAGCCGTTATTTACTCTGCCAACAACCGCAGGGCCTTTCAAAGTTGCGGTCACTGTCGTGGATGTCAGTTGATGCAGGCTGGCACGCATCCCGATTACTACACCCTGACCCGAAAAGGAAAAAATACGCTGGGCGTTGATGCGGTACGTGAGGTCACCGAAAAAGCTGAATGAGCACGCACGCTTAGGTGGTGCGAAAGTCGTTTGGGTAACCGATGCTGCCTTACTATCGACGCCGCGGCTAACGCATTGCTGAAAACGCTCGAAGAGCCAGCAGAAACTGGTTTTTCCTGGCTACCCGCGAGCCTGAACGTTTACTGGCAACATTACGTAGTCGTTGTCGGTTACCTTCTGCGCCGCCGCCGGAACAGTACGCCGTGACCTGGCTTTCACGCGAAGTGACAATGTCACAGGATGCATTACTTCTGCCGCATTGCGCTTAAGCGCCGGTTTTTTGGCGCGGCACTGGCGTTGTTTCAGGGAGATAACTGGCAGGCTCGTGAAACATTGTGCGGCAGGCGTTGGCTAGCGTGCCATCGGGCGACTGGTATTCGCCGTTAGCGGCCCTTAATCATGAACAAGCTCCGGCGCGTTTACACTGGCTGGCAACGTTGCTGATGGATGCGCTAAAACGCCATTATGGTGCTGCGCAGGGGTGACCAATGTTGATGTGCCGGGCCTGTACGTCGAACTGGCAAACCATCTTTCTCCCTCGCGCCTGCAGGCTATACTGGGGGATGTTTTGCCACATTCGTGAACAGTTAATGTCTGTTACAGGCATCAACCGCGAGCTTCATCACCGATTTTGCGTTGAGCATTACCTGCATGGGCGTTGTGCTACCGGTTCCTCATCTTTAAGCTGCAAAAAAAAAATTC \ No newline at end of file diff --git a/tests/data/poly_a/rna002/signal.tensor b/tests/data/poly_a/rna002/signal.tensor index 9616f0cb..038c2a64 100644 Binary files a/tests/data/poly_a/rna002/signal.tensor and b/tests/data/poly_a/rna002/signal.tensor differ diff --git a/tests/data/poly_a/rna004/moves.bin b/tests/data/poly_a/rna004/moves.bin index 4c5bcbf6..42466a66 100644 Binary files a/tests/data/poly_a/rna004/moves.bin and b/tests/data/poly_a/rna004/moves.bin differ diff --git a/tests/data/poly_a/rna004/seq.txt b/tests/data/poly_a/rna004/seq.txt index de95ddc4..fe28fdc2 100644 --- a/tests/data/poly_a/rna004/seq.txt +++ b/tests/data/poly_a/rna004/seq.txt @@ -1 +1 @@ -AAATGTCTTTGCTTACATCATCACGTTTATCGGCGCTTTAGCGCCATCCAGGCTGCCTGGGCGGTTGATTATCCGCTCAACCGGAAGCCGACTGGTTGGGCAAAATCAAACGTATACGGTGCAAGAAGAGGATAAAAACAAGGCATTGCCCGACGTTTTGATACTGCGGCAATGTTGATCTGAAGCCAATAACACTATCGCCCCGGTGCCAAAACCTGGTACGACGATAACTATTCCTTCACAACTGTATTACCTGATGCACCGCGTACAGGGGATTATCGTTAACCTTGCAGAGCTGCGCCTTTATTATTATCCGCCGGGAGAAAATATTGTGCAGGTCTATCCAATAGGTATTGATTGCAGGGGCTGGAAACGCCGGTGATGGAAACGCGTGTTGGGCAGAAAATCCCTAACCCAACCTGGACGCCTACGGCAGGCATTCGTCAGCGCTGGAGCGTGGCATTAAATTACCGCCAGTCGTTCCTGCCGGACCAAATAACCCGCTAGGACGTTACGCACTGCGCCTCGCGCATGGTAATGGCGAAATACCTTCATCATGGTATCCAGTGCGCCGGACAGCGTCGGTTTGCGCGTCAGTTCAGGGTGTATTCGCATGAATGTCCGGATATTAAAGCCTTGTTCTCCAGCGTGCGGACGGGAACGCCGGTGAAAGTGATCAACGAACCGGTGAAATATTCCGTGAGCCTGCTCAGGATGCGTTATGTTGAAGTACATCGACACTATCGGCAGAAGAACAGCAGAACGTTCAGACAATGCCACACACTGCCAGCAGGCTTTACGCAATTTAAAGACAATAAGGCTGTAGATCAGAAGTTAGTCGATAAAGCGTTGTATCGTCGGGCAGGGTATCCGGTTTCGGTGAGCAGTGGAGCAACTCCCGCAGCCAGCTAATGCGCCTTCAGTAGAGTCAGCGCAGAATGGTGAACCAGAGCAAGGAATATGTTACGCGTGACGCAGTAGGCTGCAAAAAAAAACGATCCCACCCCCATCTACCCACCAATTTATTCCATATCAACCTGCCCCACATCCCTAACTCC \ No newline at end of file +AAATGTCTTTGCTTACATCATCACGTTTATCGGCGCTTTTAGCGCCATCCAGGCTGCCTGGGCGGTTGATTATCCGCTCAACCGGAAGCCGACTGGTTGGGCAAAATCAAACGTATACGGTGCAAGAAGGGGATAAAAACAAGGCATTGCCCGACGTTTTGATACTGCGGCAATGTTGATCTGAAGCCAATAACACTATCGCCCCGGTGCCAAAACCTGGTACGACGATAACTATTCCTTCCAACTGTATTACCTGATGCACCGCGTACAGGGGATTATCGTTAACCTTGCAGAGCTGCGCCTTTATTATTATCCGCCGGGAGAAAATATTGTGCAGGTCTATCCAATAGGTATTGATTGCAGGGGCTGGAAACGCCGGTGATGGAAACGCGTGTTGGGCAGAAAATCCCTAACCCAACCTGGACGCCTACGGCAGGCATTCGTCAGCGCTGGAGCGTGGCATTAAATTACCGCCAGTCGTTCCTGCCGGACCAAATAACCCGCTAGGACGTTACGCACTGCGCCTCGCGCATGGTAATGGCGAAATACCTTCATCATGGTATCCAGTGCGCCGGACAGCGTCGGTTTGCGCGTCAGTTCAGGGTGTATTCGCATGAATGTCCGGATATTAAAGCCTTGTTCTCCAGCGTGCGGACGGGAACGCCGGTGAAAGTGATCAACGAACCGGTGAAATATTCCGCGGAGCCTAACGGGATGCGTTATGTTGAAGTACATCGACCACTATCGGCAGAAGAACAGCAGAACGTTCAGACAATGCCACACACTGCCAGCAGGCTTTACGCAATTTAAAGACAATAAGGCTGTAGATCAGAAGTTAGTCGATAAAGCGTTGTATCGTCGGGCAGGGTATCCGGTTTCGGTGAGCAGTGGAGCAACTCCCGCAGCCAGCAATGCGCCTTCAGTAGAGTCAGCGCAGAATGGTGAACCAGAGCAAGGAATATGTTACGCGTGACGCAGTAGGCTGCAAAAAAAAAAAAA \ No newline at end of file diff --git a/tests/data/poly_a/rna004/signal.tensor b/tests/data/poly_a/rna004/signal.tensor index 370d2533..a93d417c 100644 Binary files a/tests/data/poly_a/rna004/signal.tensor and b/tests/data/poly_a/rna004/signal.tensor differ