diff --git a/dorado/cli/basecaller.cpp b/dorado/cli/basecaller.cpp index 0b79fddd..3ad80bcf 100644 --- a/dorado/cli/basecaller.cpp +++ b/dorado/cli/basecaller.cpp @@ -55,7 +55,8 @@ const barcode_kits::KitInfo& get_barcode_kit_info(const std::string& kit_name) { if (!kit_info) { spdlog::error( "{} is not a valid barcode kit name. Please run the help " - "command to find out available barcode kits."); + "command to find out available barcode kits.", + kit_name); std::exit(EXIT_FAILURE); } return *kit_info; @@ -103,7 +104,7 @@ void setup(std::vector args, const std::string& dump_stats_file, const std::string& dump_stats_filter, const std::string& resume_from_file, - const std::vector& barcode_kits, + const std::string& barcode_kit, bool barcode_both_ends, bool barcode_no_trim, bool adapter_no_trim, @@ -161,7 +162,7 @@ void setup(std::vector args, recursive_file_loading); const bool adapter_trimming_enabled = (!adapter_no_trim || !primer_no_trim); - const bool barcode_enabled = !barcode_kits.empty() || custom_kit; + const bool barcode_enabled = !barcode_kit.empty() || custom_kit; const auto thread_allocations = utils::default_thread_allocations( int(num_devices), !remora_runners.empty() ? int(num_remora_threads) : 0, enable_aligner, barcode_enabled, adapter_trimming_enabled); @@ -179,10 +180,12 @@ void setup(std::vector args, auto [kit_name, kit_info] = get_custom_barcode_kit_info(*custom_kit); utils::add_rg_headers_with_barcode_kit(hdr.get(), read_groups, kit_name, kit_info, sample_sheet.get()); - } else { - const auto kit_info = get_barcode_kit_info(barcode_kits[0]); - utils::add_rg_headers_with_barcode_kit(hdr.get(), read_groups, barcode_kits[0], kit_info, + } else if (!barcode_kit.empty()) { + const auto kit_info = get_barcode_kit_info(barcode_kit); + utils::add_rg_headers_with_barcode_kit(hdr.get(), read_groups, barcode_kit, kit_info, sample_sheet.get()); + } else { + utils::add_rg_headers(hdr.get(), read_groups); } utils::HtsFile hts_file("-", output_mode, thread_allocations.writer_threads); @@ -212,8 +215,9 @@ void setup(std::vector args, !primer_no_trim, std::move(custom_primer_file)); } if (barcode_enabled) { + std::vector kit_as_vector{barcode_kit}; current_sink_node = pipeline_desc.add_node( - {current_sink_node}, thread_allocations.barcoder_threads, barcode_kits, + {current_sink_node}, thread_allocations.barcoder_threads, kit_as_vector, barcode_both_ends, barcode_no_trim, std::move(allowed_barcodes), std::move(custom_kit), std::move(custom_barcode_file)); } @@ -434,7 +438,8 @@ int basecaller(int argc, char* argv[]) { parser.visible.add_argument("--kit-name") .help("Enable barcoding with the provided kit name. Choose from: " + - dorado::barcode_kits::barcode_kits_list_str() + "."); + dorado::barcode_kits::barcode_kits_list_str() + ".") + .default_value(std::string{}); parser.visible.add_argument("--barcode-both-ends") .help("Require both ends of a read to be barcoded for a double ended barcode.") .default_value(false) @@ -640,7 +645,6 @@ int basecaller(int argc, char* argv[]) { spdlog::info("> Creating basecall pipeline"); try { - /* clang format off */ setup(args, model_path, data, mods_model_paths, parser.visible.get("-x"), parser.visible.get("--reference"), parser.visible.get("-c"), parser.visible.get("-o"), parser.visible.get("-b"), @@ -654,13 +658,12 @@ int basecaller(int argc, char* argv[]) { parser.hidden.get("--dump_stats_file"), parser.hidden.get("--dump_stats_filter"), parser.visible.get("--resume-from"), - parser.visible.get>("--kit-name"), + parser.visible.get("--kit-name"), parser.visible.get("--barcode-both-ends"), no_trim_barcodes, no_trim_adapters, no_trim_primers, parser.visible.get("--sample-sheet"), std::move(custom_kit), std::move(custom_barcode_seqs), std::move(custom_primer_file), resume_parser, parser.visible.get("--estimate-poly-a"), polya_config.empty() ? nullptr : &polya_config, model_selection); - /* clang format on */ } catch (const std::exception& e) { spdlog::error("{}", e.what()); utils::clean_temporary_models(temp_download_paths); diff --git a/tests/symbol_test.cpp b/tests/symbol_test.cpp index 5c7ae870..620b9cd1 100644 --- a/tests/symbol_test.cpp +++ b/tests/symbol_test.cpp @@ -7,7 +7,6 @@ #include "api/runner_creation.h" #include "basecall/CRFModelConfig.h" #include "basecall/ModelRunner.h" -#include "demux/parse_custom_kit.h" #include "demux/parse_custom_sequences.h" #include "modbase/ModBaseModelConfig.h" #include "modbase/ModBaseRunner.h" @@ -23,6 +22,7 @@ #include "utils/barcode_kits.h" #include "utils/gpu_monitor.h" #include "utils/parameters.h" +#include "utils/parse_custom_kit.h" #include "utils/sequence_utils.h" #include "utils/string_utils.h" #include "utils/time_utils.h"