diff --git a/dorado/cli/basecaller.cpp b/dorado/cli/basecaller.cpp index 934c537f..de4c60f8 100644 --- a/dorado/cli/basecaller.cpp +++ b/dorado/cli/basecaller.cpp @@ -123,9 +123,16 @@ void setup(std::vector args, num_devices, !remora_runners.empty() ? num_remora_threads : 0, enable_aligner, !barcode_kits.empty()); + std::unique_ptr sample_sheet; + BarcodingInfo::FilterSet allowed_barcodes; + if (!barcode_sample_sheet.empty()) { + sample_sheet = std::make_unique(barcode_sample_sheet, false); + allowed_barcodes = sample_sheet->get_barcode_values(); + } + SamHdrPtr hdr(sam_hdr_init()); cli::add_pg_hdr(hdr.get(), args); - utils::add_rg_hdr(hdr.get(), read_groups, barcode_kits); + utils::add_rg_hdr(hdr.get(), read_groups, barcode_kits, sample_sheet.get()); PipelineDescriptor pipeline_desc; auto hts_writer = pipeline_desc.add_node( @@ -139,18 +146,16 @@ void setup(std::vector args, } current_sink_node = pipeline_desc.add_node( {current_sink_node}, emit_moves, thread_allocations.read_converter_threads, - methylation_threshold_pct); + methylation_threshold_pct, std::move(sample_sheet), 1000); if (estimate_poly_a) { current_sink_node = pipeline_desc.add_node( {current_sink_node}, std::thread::hardware_concurrency(), PolyACalculator::get_model_type(model_name)); } if (!barcode_kits.empty()) { - utils::SampleSheet sample_sheet(barcode_sample_sheet); - BarcodingInfo::FilterSet allowed_barcodes = sample_sheet.get_barcode_values(); current_sink_node = pipeline_desc.add_node( {current_sink_node}, thread_allocations.barcoder_threads, barcode_kits, - barcode_both_ends, barcode_no_trim, allowed_barcodes); + barcode_both_ends, barcode_no_trim, std::move(allowed_barcodes)); } current_sink_node = pipeline_desc.add_node( {current_sink_node}, min_qscore, default_parameters.min_sequence_length, diff --git a/dorado/cli/demux.cpp b/dorado/cli/demux.cpp index 68a6c997..2fd2bd64 100644 --- a/dorado/cli/demux.cpp +++ b/dorado/cli/demux.cpp @@ -131,23 +131,30 @@ int demuxer(int argc, char* argv[]) { } HtsReader reader(reads[0], read_list); - auto header = sam_hdr_dup(reader.header); - add_pg_hdr(header); + auto header = SamHdrPtr(sam_hdr_dup(reader.header)); + add_pg_hdr(header.get()); + + auto barcode_sample_sheet = parser.get("--sample-sheet"); + std::unique_ptr sample_sheet; + BarcodingInfo::FilterSet allowed_barcodes; + if (!barcode_sample_sheet.empty()) { + sample_sheet = std::make_unique(barcode_sample_sheet, true); + allowed_barcodes = sample_sheet->get_barcode_values(); + } PipelineDescriptor pipeline_desc; auto demux_writer = pipeline_desc.add_node( - {}, output_dir, demux_writer_threads, 0, parser.get("--emit-fastq")); + {}, output_dir, demux_writer_threads, 0, parser.get("--emit-fastq"), + std::move(sample_sheet)); if (parser.is_used("--kit-name")) { std::vector kit_names; if (auto names = parser.present>("--kit-name")) { kit_names = std::move(*names); } - utils::SampleSheet sample_sheet(parser.get("--sample-sheet")); - BarcodingInfo::FilterSet allowed_barcodes = sample_sheet.get_barcode_values(); auto demux = pipeline_desc.add_node( {demux_writer}, demux_threads, kit_names, parser.get("--barcode-both-ends"), - parser.get("--no-trim"), allowed_barcodes); + parser.get("--no-trim"), std::move(allowed_barcodes)); } // Create the Pipeline from our description. @@ -162,7 +169,7 @@ int demuxer(int argc, char* argv[]) { // rather than the pipeline framework. auto& demux_writer_ref = dynamic_cast(pipeline->get_node_ref(demux_writer)); - demux_writer_ref.set_header(header); + demux_writer_ref.set_header(header.get()); // Set up stats counting std::vector stats_callables; diff --git a/dorado/cli/duplex.cpp b/dorado/cli/duplex.cpp index 48f65423..2d376748 100644 --- a/dorado/cli/duplex.cpp +++ b/dorado/cli/duplex.cpp @@ -12,6 +12,7 @@ #include "read_pipeline/ProgressTracker.h" #include "read_pipeline/ReadFilterNode.h" #include "read_pipeline/ReadToBamTypeNode.h" +#include "utils/SampleSheet.h" #include "utils/bam_utils.h" #include "utils/basecaller_utils.h" #include "utils/duplex_utils.h" @@ -193,8 +194,8 @@ int duplex(int argc, char* argv[]) { pipeline_desc.add_node_sink(aligner, hts_writer); converted_reads_sink = aligner; } - auto read_converter = - pipeline_desc.add_node({converted_reads_sink}, emit_moves, 2); + auto read_converter = pipeline_desc.add_node( + {converted_reads_sink}, emit_moves, 2, 0, nullptr, 1000); auto duplex_read_tagger = pipeline_desc.add_node({read_converter}); // The minimum sequence length is set to 5 to avoid issues with duplex node printing very short sequences for mismatched pairs. std::unordered_set read_ids_to_filter; @@ -287,7 +288,7 @@ int duplex(int argc, char* argv[]) { read_groups.merge( DataLoader::load_read_groups(reads, duplex_rg_name, recursive_file_loading)); std::vector barcode_kits; - utils::add_rg_hdr(hdr.get(), read_groups, barcode_kits); + utils::add_rg_hdr(hdr.get(), read_groups, barcode_kits, nullptr); int batch_size(parser.visible.get("-b")); int chunk_size(parser.visible.get("-c")); diff --git a/dorado/data_loader/DataLoader.cpp b/dorado/data_loader/DataLoader.cpp index e335a214..2c90d79c 100644 --- a/dorado/data_loader/DataLoader.cpp +++ b/dorado/data_loader/DataLoader.cpp @@ -89,7 +89,16 @@ void string_reader(HighFive::Attribute& attribute, std::string& target_str) { if (eol_pos < target_str.size()) { target_str.resize(eol_pos); } -}; +} + +std::string get_string_attribute(const HighFive::Group& group, const std::string& attr_name) { + std::string attribute_string; + if (group.hasAttribute(attr_name)) { + HighFive::Attribute attribute = group.getAttribute(attr_name); + string_reader(attribute, attribute_string); + } + return attribute_string; +} } // namespace namespace dorado { @@ -156,6 +165,8 @@ SimplexReadPtr process_pod5_read( new_read->start_sample = read_data.start_sample; new_read->end_sample = read_data.start_sample + read_data.num_samples; new_read->read_common.flowcell_id = run_info_data->flow_cell_id; + new_read->read_common.position_id = run_info_data->sequencer_position; + new_read->read_common.experiment_id = run_info_data->experiment_name; new_read->read_common.is_duplex = false; // Determine the time sorted predecessor of the read @@ -466,6 +477,8 @@ std::unordered_map DataLoader::load_read_groups( std::string device_id = run_info_data->system_name; std::string run_id = run_info_data->acquisition_id; std::string sample_id = run_info_data->sample_id; + std::string position_id = run_info_data->sequencer_position; + std::string experiment_id = run_info_data->experiment_name; if (pod5_free_run_info(run_info_data) != POD5_OK) { spdlog::error("Failed to free run info"); @@ -478,7 +491,9 @@ std::unordered_map DataLoader::load_read_groups( flowcell_id, device_id, utils::get_string_timestamp_from_unix_time(exp_start_time_ms), - sample_id}; + sample_id, + position_id, + experiment_id}; } if (pod5_close_and_free_reader(file) != POD5_OK) { spdlog::error("Failed to close and free POD5 reader"); @@ -799,9 +814,11 @@ void DataLoader::load_fast5_reads_from_file(const std::string& path) { std::string fast5_filename = std::filesystem::path(path).filename().string(); HighFive::Group tracking_id_group = read.getGroup("tracking_id"); - HighFive::Attribute exp_start_time_attr = tracking_id_group.getAttribute("exp_start_time"); - std::string exp_start_time; - string_reader(exp_start_time_attr, exp_start_time); + std::string exp_start_time = get_string_attribute(tracking_id_group, "exp_start_time"); + std::string flow_cell_id = get_string_attribute(tracking_id_group, "flow_cell_id"); + std::string device_id = get_string_attribute(tracking_id_group, "device_id"); + std::string group_protocol_id = + get_string_attribute(tracking_id_group, "group_protocol_id"); auto start_time_str = utils::adjust_time(exp_start_time, static_cast(start_time / sampling_rate)); @@ -820,6 +837,9 @@ void DataLoader::load_fast5_reads_from_file(const std::string& path) { new_read->read_common.attributes.channel_number = channel_number; new_read->read_common.attributes.start_time = start_time_str; new_read->read_common.attributes.fast5_filename = fast5_filename; + new_read->read_common.flowcell_id = flow_cell_id; + new_read->read_common.position_id = device_id; + new_read->read_common.experiment_id = group_protocol_id; new_read->read_common.is_duplex = false; if (!m_allowed_read_ids || (m_allowed_read_ids->find(new_read->read_common.read_id) != diff --git a/dorado/demux/BarcodeClassifier.cpp b/dorado/demux/BarcodeClassifier.cpp index ba432639..b211ec09 100644 --- a/dorado/demux/BarcodeClassifier.cpp +++ b/dorado/demux/BarcodeClassifier.cpp @@ -10,7 +10,7 @@ #include #include -#include +#include #include #include #include @@ -386,7 +386,6 @@ std::vector BarcodeClassifier::calculate_adapter_score_double_ends if (!barcode_is_permitted(allowed_barcodes, adapter_name)) { continue; } - spdlog::debug("Checking barcode {}", adapter_name); auto top_mask_score = extract_mask_score(adapter, top_mask, mask_config, "top window"); diff --git a/dorado/demux/BarcodeClassifier.h b/dorado/demux/BarcodeClassifier.h index 99dad4f1..9df9cf81 100644 --- a/dorado/demux/BarcodeClassifier.h +++ b/dorado/demux/BarcodeClassifier.h @@ -1,5 +1,4 @@ #pragma once -#include "read_pipeline/ReadPipeline.h" #include "utils/stats.h" #include "utils/types.h" diff --git a/dorado/demux/BarcodeClassifierSelector.cpp b/dorado/demux/BarcodeClassifierSelector.cpp index 06405c47..5a189af7 100644 --- a/dorado/demux/BarcodeClassifierSelector.cpp +++ b/dorado/demux/BarcodeClassifierSelector.cpp @@ -2,6 +2,8 @@ #include "BarcodeClassifier.h" +#include + namespace dorado::demux { std::shared_ptr BarcodeClassifierSelector::get_barcoder( diff --git a/dorado/read_pipeline/BarcodeClassifierNode.cpp b/dorado/read_pipeline/BarcodeClassifierNode.cpp index 3d4383b7..cf3eb9fb 100644 --- a/dorado/read_pipeline/BarcodeClassifierNode.cpp +++ b/dorado/read_pipeline/BarcodeClassifierNode.cpp @@ -1,6 +1,7 @@ #include "BarcodeClassifierNode.h" #include "demux/BarcodeClassifier.h" +#include "utils/SampleSheet.h" #include "utils/bam_utils.h" #include "utils/barcode_kits.h" #include "utils/trim.h" @@ -40,11 +41,13 @@ BarcodeClassifierNode::BarcodeClassifierNode(int threads, const std::vector& kit_names, bool barcode_both_ends, bool no_trim, - const BarcodingInfo::FilterSet& allowed_barcodes) + BarcodingInfo::FilterSet allowed_barcodes) : MessageSink(10000), m_threads(threads), - m_default_barcoding_info( - create_barcoding_info(kit_names, barcode_both_ends, !no_trim, allowed_barcodes)) { + m_default_barcoding_info(create_barcoding_info(kit_names, + barcode_both_ends, + !no_trim, + std::move(allowed_barcodes))) { start_threads(); } diff --git a/dorado/read_pipeline/BarcodeClassifierNode.h b/dorado/read_pipeline/BarcodeClassifierNode.h index a91cbdcb..a905e9f6 100644 --- a/dorado/read_pipeline/BarcodeClassifierNode.h +++ b/dorado/read_pipeline/BarcodeClassifierNode.h @@ -5,6 +5,7 @@ #include "utils/types.h" #include +#include #include #include #include @@ -21,7 +22,7 @@ class BarcodeClassifierNode : public MessageSink { const std::vector& kit_name, bool barcode_both_ends, bool no_trim, - const BarcodingInfo::FilterSet& allowed_barcodes); + BarcodingInfo::FilterSet allowed_barcodes); BarcodeClassifierNode(int threads); ~BarcodeClassifierNode(); std::string get_name() const override { return "BarcodeClassifierNode"; } diff --git a/dorado/read_pipeline/BarcodeDemuxerNode.cpp b/dorado/read_pipeline/BarcodeDemuxerNode.cpp index 45dce7c1..0200d79b 100644 --- a/dorado/read_pipeline/BarcodeDemuxerNode.cpp +++ b/dorado/read_pipeline/BarcodeDemuxerNode.cpp @@ -1,6 +1,7 @@ #include "BarcodeDemuxerNode.h" #include "read_pipeline/ReadPipeline.h" +#include "utils/SampleSheet.h" #include #include @@ -14,12 +15,14 @@ namespace dorado { BarcodeDemuxerNode::BarcodeDemuxerNode(const std::string& output_dir, size_t htslib_threads, size_t num_reads, - bool write_fastq) + bool write_fastq, + std::unique_ptr sample_sheet) : MessageSink(10000), m_output_dir(output_dir), m_htslib_threads(htslib_threads), m_num_reads_expected(num_reads), - m_write_fastq(write_fastq) { + m_write_fastq(write_fastq), + m_sample_sheet(std::move(sample_sheet)) { std::filesystem::create_directories(m_output_dir); start_threads(); } @@ -67,6 +70,15 @@ int BarcodeDemuxerNode::write(bam1_t* const record) { if (bam_tag) { bc = std::string(bam_aux2Z(bam_tag)); } + + if (m_sample_sheet) { + // experiment id and position id are not stored in the bam record, so we can't recover them to use here + auto alias = m_sample_sheet->get_alias("", "", "", bc); + if (!alias.empty()) { + bc = alias; + bam_aux_update_str(record, "BC", bc.size() + 1, bc.c_str()); + } + } // Check of existence of file for that barcode. auto res = m_files.find(bc); htsFile* file = nullptr; diff --git a/dorado/read_pipeline/BarcodeDemuxerNode.h b/dorado/read_pipeline/BarcodeDemuxerNode.h index b5c5f4af..17008b84 100644 --- a/dorado/read_pipeline/BarcodeDemuxerNode.h +++ b/dorado/read_pipeline/BarcodeDemuxerNode.h @@ -14,12 +14,16 @@ namespace dorado { +namespace utils { +class SampleSheet; +} class BarcodeDemuxerNode : public MessageSink { public: BarcodeDemuxerNode(const std::string& output_dir, size_t htslib_threads, size_t num_reads, - bool write_fastq); + bool write_fastq, + std::unique_ptr sample_sheet); ~BarcodeDemuxerNode(); std::string get_name() const override { return "BarcodeDemuxerNode"; } stats::NamedStats sample_stats() const override; @@ -42,6 +46,7 @@ class BarcodeDemuxerNode : public MessageSink { int write(bam1_t* record); size_t m_num_reads_expected; bool m_write_fastq{false}; + std::unique_ptr m_sample_sheet; }; } // namespace dorado diff --git a/dorado/read_pipeline/ReadPipeline.h b/dorado/read_pipeline/ReadPipeline.h index 23b53241..2425a2f3 100644 --- a/dorado/read_pipeline/ReadPipeline.h +++ b/dorado/read_pipeline/ReadPipeline.h @@ -41,8 +41,10 @@ class ReadCommon { std::vector moves; // Move table std::vector base_mod_probs; // Modified base probabilities std::string run_id; // Run ID - used in read group - std::string flowcell_id; // Flowcell ID - used in read group - std::string model_name; // Read group + std::string flowcell_id; // Flowcell ID - used in read group and for sample sheet aliasing + std::string position_id; // Position ID - used for sample sheet aliasing + std::string experiment_id; // Experiment ID - used for sample sheet aliasing + std::string model_name; // Read group dorado::details::Attributes attributes; diff --git a/dorado/read_pipeline/ReadToBamTypeNode.cpp b/dorado/read_pipeline/ReadToBamTypeNode.cpp index c3465fb4..6d725edb 100644 --- a/dorado/read_pipeline/ReadToBamTypeNode.cpp +++ b/dorado/read_pipeline/ReadToBamTypeNode.cpp @@ -1,5 +1,7 @@ #include "ReadToBamTypeNode.h" +#include "utils/SampleSheet.h" + #include #include @@ -24,6 +26,17 @@ void ReadToBamType::worker_thread() { if (!read_common_data.is_duplex) { is_duplex_parent = std::get(message)->is_duplex_parent; } + + // alias barcode if present + if (m_sample_sheet && !read_common_data.barcode.empty()) { + auto alias = m_sample_sheet->get_alias( + read_common_data.flowcell_id, read_common_data.position_id, + read_common_data.experiment_id, read_common_data.barcode); + if (!alias.empty()) { + read_common_data.barcode = alias; + } + } + auto alns = read_common_data.extract_sam_lines(m_emit_moves, m_modbase_threshold, is_duplex_parent); for (auto& aln : alns) { @@ -35,15 +48,19 @@ void ReadToBamType::worker_thread() { ReadToBamType::ReadToBamType(bool emit_moves, size_t num_worker_threads, float modbase_threshold_frac, + std::unique_ptr sample_sheet, size_t max_reads) : MessageSink(max_reads), m_num_worker_threads(num_worker_threads), m_emit_moves(emit_moves), m_modbase_threshold( - static_cast(std::min(modbase_threshold_frac * 256.0f, 255.0f))) { + static_cast(std::min(modbase_threshold_frac * 256.0f, 255.0f))), + m_sample_sheet(std::move(sample_sheet)) { start_threads(); } +ReadToBamType::~ReadToBamType() { terminate_impl(); } + void ReadToBamType::start_threads() { for (size_t i = 0; i < m_num_worker_threads; i++) { m_workers.push_back( diff --git a/dorado/read_pipeline/ReadToBamTypeNode.h b/dorado/read_pipeline/ReadToBamTypeNode.h index f3d2be67..d4c3001e 100644 --- a/dorado/read_pipeline/ReadToBamTypeNode.h +++ b/dorado/read_pipeline/ReadToBamTypeNode.h @@ -11,13 +11,18 @@ namespace dorado { +namespace utils { +class SampleSheet; +} + class ReadToBamType : public MessageSink { public: ReadToBamType(bool emit_moves, size_t num_worker_threads, - float modbase_threshold_frac = 0, - size_t max_reads = 1000); - ~ReadToBamType() { terminate_impl(); } + float modbase_threshold_frac, + std::unique_ptr sample_sheet, + size_t max_reads); + ~ReadToBamType(); std::string get_name() const override { return "ReadToBamType"; } void terminate(const FlushOptions& flush_options) override { terminate_impl(); }; void restart() override; @@ -33,6 +38,7 @@ class ReadToBamType : public MessageSink { bool m_emit_moves; uint8_t m_modbase_threshold; + std::unique_ptr m_sample_sheet; }; } // namespace dorado diff --git a/dorado/read_pipeline/StereoDuplexEncoderNode.cpp b/dorado/read_pipeline/StereoDuplexEncoderNode.cpp index 7835be7a..25032efa 100644 --- a/dorado/read_pipeline/StereoDuplexEncoderNode.cpp +++ b/dorado/read_pipeline/StereoDuplexEncoderNode.cpp @@ -270,6 +270,9 @@ DuplexReadPtr StereoDuplexEncoderNode::stereo_encode(const ReadPair& read_pair) read->read_common.raw_data = tmp; // use the encoded signal read->read_common.is_duplex = true; read->read_common.run_id = template_read.read_common.run_id; + read->read_common.flowcell_id = template_read.read_common.flowcell_id; + read->read_common.position_id = template_read.read_common.position_id; + read->read_common.experiment_id = template_read.read_common.experiment_id; edlibFreeAlignResult(result); diff --git a/dorado/read_pipeline/read_utils.cpp b/dorado/read_pipeline/read_utils.cpp index 27918427..37ff8840 100644 --- a/dorado/read_pipeline/read_utils.cpp +++ b/dorado/read_pipeline/read_utils.cpp @@ -21,6 +21,9 @@ SimplexReadPtr shallow_copy_read(const SimplexRead& read) { copy->read_common.qstring = read.read_common.qstring; copy->read_common.moves = read.read_common.moves; copy->read_common.run_id = read.read_common.run_id; + copy->read_common.flowcell_id = read.read_common.flowcell_id; + copy->read_common.position_id = read.read_common.position_id; + copy->read_common.experiment_id = read.read_common.experiment_id; copy->read_common.model_name = read.read_common.model_name; copy->read_common.base_mod_probs = read.read_common.base_mod_probs; diff --git a/dorado/utils/SampleSheet.cpp b/dorado/utils/SampleSheet.cpp index dec6e024..2e5e934c 100644 --- a/dorado/utils/SampleSheet.cpp +++ b/dorado/utils/SampleSheet.cpp @@ -52,6 +52,11 @@ bool is_alias_forbidden(const std::string& input) { return true; } + // Unclassified + if (input == "unclassified") { + return true; + } + return false; } @@ -77,6 +82,8 @@ bool get_line(std::istream& input, namespace dorado::utils { +SampleSheet::SampleSheet() : m_skip_index_matching(false) {} + SampleSheet::SampleSheet(const std::string& filename, bool skip_index_matching) : m_filename(filename), m_skip_index_matching(skip_index_matching) { if (!filename.empty()) { @@ -147,6 +154,17 @@ void SampleSheet::load(std::istream& file_stream, const std::string& filename) { std::string("Unable to infer barcode aliases from sample sheet file: " + filename + " does not contain a unique mapping of barcode ids.")); } + + if (m_type == Type::barcode) { + std::unordered_set barcodes; + // Grab the barcode idx once so that we're not doing it repeatedly + const auto barcode_idx = m_col_indices.at("barcode"); + // Grab the barcodes + for (const auto& row : m_rows) { + barcodes.emplace(row[barcode_idx]); + } + m_allowed_barcodes = std::move(barcodes); + } } // check if we can generate a unique alias without the flowcell/position information @@ -191,9 +209,15 @@ std::string SampleSheet::get_alias(const std::string& flow_cell_id, return ""; } + std::string_view barcode_only(barcode); + if (auto pos = barcode_only.find('_'); pos != std::string::npos) { + // trim off the kit name + barcode_only = barcode_only.substr(pos + 1); + } + for (const auto& row : m_rows) { if (match_index(row, flow_cell_id, position_id, experiment_id) && - get(row, "barcode") == barcode) { + get(row, "barcode") == barcode_only) { return get(row, "alias"); } } @@ -202,27 +226,14 @@ std::string SampleSheet::get_alias(const std::string& flow_cell_id, return ""; } -BarcodingInfo::FilterSet SampleSheet::get_barcode_values() const { - std::unordered_set barcodes; - - switch (m_type) { - case Type::barcode: { - // Grab the barcode idx once so that we're not doing it repeatedly - const auto barcode_idx = m_col_indices.at("barcode"); +BarcodingInfo::FilterSet SampleSheet::get_barcode_values() const { return m_allowed_barcodes; } - // Grab the barcodes - for (const auto& row : m_rows) { - barcodes.emplace(row[barcode_idx]); - } - break; - } - case Type::none: - [[fallthrough]]; - default: - return std::nullopt; +bool SampleSheet::barcode_is_permitted(const std::string& barcode_name) const { + if (!m_allowed_barcodes.has_value()) { + return true; } - return barcodes; + return m_allowed_barcodes->count(barcode_name) != 0; } void SampleSheet::validate_headers(const std::vector& col_names, @@ -304,18 +315,43 @@ void SampleSheet::validate_alias(const Row& row, const std::string& key) const { bool SampleSheet::check_index(const std::string& flow_cell_id, const std::string& position_id) const { - return m_skip_index_matching || ((m_index[FLOW_CELL_ID] == !flow_cell_id.empty()) && - (m_index[POSITION_ID] == !position_id.empty())); + if (m_skip_index_matching) { + return true; + } + + bool ok = m_index.any(); // one of the indicies must be set + if (m_index[FLOW_CELL_ID]) { + // if we're expecting a flow cell id, we must provide one + ok &= !flow_cell_id.empty(); + } + if (m_index[POSITION_ID]) { + // if we're expecting a position id, we must provide one + ok &= !position_id.empty(); + } + return ok; } bool SampleSheet::match_index(const Row& row, const std::string& flow_cell_id, const std::string& position_id, const std::string& experiment_id) const { - return m_skip_index_matching || - ((!m_index[FLOW_CELL_ID] || get(row, "flow_cell_id") == flow_cell_id) && - (!m_index[POSITION_ID] || get(row, "position_id") == position_id) && - (get(row, "experiment_id") == experiment_id)); + if (m_skip_index_matching) { + return true; + } + + if (get(row, "experiment_id") != experiment_id) { + return false; + } + + if (m_index[FLOW_CELL_ID] && (get(row, "flow_cell_id") != flow_cell_id)) { + return false; + } + + if (m_index[POSITION_ID] && (get(row, "position_id") != position_id)) { + return false; + } + + return true; } std::string SampleSheet::get(const Row& row, const std::string& key) const { diff --git a/dorado/utils/SampleSheet.h b/dorado/utils/SampleSheet.h index 2c335e6d..3623a762 100644 --- a/dorado/utils/SampleSheet.h +++ b/dorado/utils/SampleSheet.h @@ -6,6 +6,7 @@ #include #include #include +#include #include namespace dorado::utils { @@ -18,36 +19,37 @@ class SampleSheet { POSITION_ID, }; + SampleSheet(); + // Calls load on the passed filename, if it is not empty. // If skip_index_matching is true the lookup by flowcell/experiment id will be skipped when fetching an alias. // In this case, the constructor will throw if the sample sheet contains entries for more that one flow_cell_id, // position_id or experiment_id, or if any barcode is re-used. - explicit SampleSheet(const std::string & filename = std::string(), - bool skip_index_matching = false); + explicit SampleSheet(const std::string& filename, bool skip_index_matching); // load a sample sheet from a file. Throws a std::runtime_error for any failure condition. - void load(const std::string & filename); + void load(const std::string& filename); // (Testability) load a sample sheet from the given stream. // Throws a std::runtime_error for any failure condition using the filename in the // error message. - void load(std::istream & input_stream, const std::string & filename); + void load(std::istream& input_stream, const std::string& filename); // Return the sample sheet filename. - const std::string & get_filename() const { return m_filename; } + const std::string& get_filename() const { return m_filename; } // Return the sample sheet type based on the column headers it contains. Type get_type() const { return m_type; } - bool contains_column(const std::string & column) const { return m_col_indices.count(column); } + bool contains_column(const std::string& column) const { return m_col_indices.count(column); } // For a given flow_cell_id, position_id, experiment_id and barcode, get the named alias. // Returns an empty string if one does not exist in the loaded sample sheet, or if the sample // sheet is not of type "barcode". - std::string get_alias(const std::string & flow_cell_id, - const std::string & position_id, - const std::string & experiment_id, - const std::string & barcode) const; + std::string get_alias(const std::string& flow_cell_id, + const std::string& position_id, + const std::string& experiment_id, + const std::string& barcode) const; /** * Get all of the barcodes that are present in the sample sheet. @@ -55,6 +57,11 @@ class SampleSheet { */ BarcodingInfo::FilterSet get_barcode_values() const; + /** + * Check whether the a list of allowed barcodes is set and, if so, whether the provided barcode is in it. + */ + bool barcode_is_permitted(const std::string& barcode_name) const; + private: using Row = std::vector; std::string m_filename; @@ -63,16 +70,17 @@ class SampleSheet { std::unordered_map m_col_indices; std::vector m_rows; bool m_skip_index_matching; - - void validate_headers(const std::vector & col_names, const std::string & filename); - bool check_index(const std::string & flow_cell_id, const std::string & position_id) const; - bool match_index(const Row & row, - const std::string & flow_cell_id, - const std::string & position_id, - const std::string & experiment_id) const; - std::string get(const Row & row, const std::string & key) const; - void validate_text(const Row & row, const std::string & key) const; - void validate_alias(const Row & row, const std::string & key) const; + BarcodingInfo::FilterSet m_allowed_barcodes; + + void validate_headers(const std::vector& col_names, const std::string& filename); + bool check_index(const std::string& flow_cell_id, const std::string& position_id) const; + bool match_index(const Row& row, + const std::string& flow_cell_id, + const std::string& position_id, + const std::string& experiment_id) const; + std::string get(const Row& row, const std::string& key) const; + void validate_text(const Row& row, const std::string& key) const; + void validate_alias(const Row& row, const std::string& key) const; bool is_barcode_mapping_unique() const; }; @@ -97,7 +105,7 @@ enum class EolFileFormat { // identifiers decorated with "_eol" suffix to avoid * the NL character, unless opened for binary reading. But for our purposes we are not interested in * the underlying truth of the file format just the style in which it appears in the stream. */ -EolFileFormat get_eol_file_format(std::istream & input); +EolFileFormat get_eol_file_format(std::istream& input); } // namespace details diff --git a/dorado/utils/bam_utils.cpp b/dorado/utils/bam_utils.cpp index 83151adb..b4bd3c3d 100644 --- a/dorado/utils/bam_utils.cpp +++ b/dorado/utils/bam_utils.cpp @@ -1,5 +1,6 @@ #include "bam_utils.h" +#include "SampleSheet.h" #include "barcode_kits.h" #include "sequence_utils.h" @@ -17,7 +18,8 @@ namespace dorado::utils { void add_rg_hdr(sam_hdr_t* hdr, const std::unordered_map& read_groups, - const std::vector& barcode_kits) { + const std::vector& barcode_kits, + const utils::SampleSheet* const sample_sheet) { const auto& barcode_kit_infos = barcode_kits::get_kit_infos(); const auto& barcode_sequences = barcode_kits::get_barcodes(); @@ -65,9 +67,24 @@ void add_rg_hdr(sam_hdr_t* hdr, const auto& kit_info = kit_iter->second; for (const auto& barcode_name : kit_info.barcodes) { const auto additional_tags = "\tBC:" + barcode_sequences.at(barcode_name); + const auto normalized_barcode_name = barcode_kits::normalize_barcode_name(barcode_name); for (const auto& read_group : read_groups) { - auto id = read_group.first + '_' + - barcode_kits::generate_standard_barcode_name(kit_name, barcode_name); + std::string alias; + auto id = read_group.first + '_'; + if (sample_sheet) { + if (!sample_sheet->barcode_is_permitted(normalized_barcode_name)) { + continue; + } + + alias = sample_sheet->get_alias( + read_group.second.flowcell_id, read_group.second.position_id, + read_group.second.experiment_id, normalized_barcode_name); + } + if (!alias.empty()) { + id += alias; + } else { + id += barcode_kits::generate_standard_barcode_name(kit_name, barcode_name); + } const std::string read_group_tags = to_string(read_group.second); emit_read_group(read_group_tags, id, additional_tags); } diff --git a/dorado/utils/bam_utils.h b/dorado/utils/bam_utils.h index 1e4c50c9..ac2d7953 100644 --- a/dorado/utils/bam_utils.h +++ b/dorado/utils/bam_utils.h @@ -10,6 +10,8 @@ struct sam_hdr_t; namespace dorado::utils { +class SampleSheet; + using sq_t = std::vector>; struct AlignmentOps { @@ -23,7 +25,8 @@ struct AlignmentOps { void add_rg_hdr(sam_hdr_t* hdr, const std::unordered_map& read_groups, - const std::vector& barcode_kits); + const std::vector& barcode_kits, + const utils::SampleSheet* const sample_sheet); void add_sq_hdr(sam_hdr_t* hdr, const sq_t& seqs); diff --git a/dorado/utils/types.cpp b/dorado/utils/types.cpp index ac98466f..2f3c32cb 100644 --- a/dorado/utils/types.cpp +++ b/dorado/utils/types.cpp @@ -9,12 +9,13 @@ std::shared_ptr create_barcoding_info( const std::vector& kit_names, bool barcode_both_ends, bool trim_barcode, - const BarcodingInfo::FilterSet& allowed_barcodes) { + BarcodingInfo::FilterSet allowed_barcodes) { if (kit_names.empty()) { return {}; } - auto result = BarcodingInfo{kit_names[0], barcode_both_ends, trim_barcode, allowed_barcodes}; + auto result = BarcodingInfo{kit_names[0], barcode_both_ends, trim_barcode, + std::move(allowed_barcodes)}; return std::make_shared(std::move(result)); } diff --git a/dorado/utils/types.h b/dorado/utils/types.h index b67ef3e0..7096e5fa 100644 --- a/dorado/utils/types.h +++ b/dorado/utils/types.h @@ -25,7 +25,7 @@ std::shared_ptr create_barcoding_info( const std::vector &kit_names, bool barcode_both_ends, bool trim_barcode, - const BarcodingInfo::FilterSet &allowed_barcodes); + BarcodingInfo::FilterSet allowed_barcodes); struct ReadGroup { std::string run_id; @@ -34,6 +34,8 @@ struct ReadGroup { std::string device_id; std::string exp_start_time; std::string sample_id; + std::string position_id; + std::string experiment_id; }; struct BamDestructor { diff --git a/tests/BamUtilsTest.cpp b/tests/BamUtilsTest.cpp index 9e41e716..2117ce8f 100644 --- a/tests/BamUtilsTest.cpp +++ b/tests/BamUtilsTest.cpp @@ -64,7 +64,7 @@ TEST_CASE("BamUtilsTest: add_rg_hdr read group headers", TEST_GROUP) { SECTION("No read groups generate no headers") { dorado::SamHdrPtr sam_header(sam_hdr_init()); CHECK(sam_hdr_count_lines(sam_header.get(), "RG") == 0); - dorado::utils::add_rg_hdr(sam_header.get(), {}, {}); + dorado::utils::add_rg_hdr(sam_header.get(), {}, {}, nullptr); CHECK(sam_hdr_count_lines(sam_header.get(), "RG") == 0); } @@ -77,7 +77,7 @@ TEST_CASE("BamUtilsTest: add_rg_hdr read group headers", TEST_GROUP) { SECTION("Read groups") { dorado::SamHdrPtr sam_header(sam_hdr_init()); - dorado::utils::add_rg_hdr(sam_header.get(), read_groups, {}); + dorado::utils::add_rg_hdr(sam_header.get(), read_groups, {}, nullptr); // Check the IDs of the groups are all there. CHECK(sam_hdr_count_lines(sam_header.get(), "RG") == read_groups.size()); @@ -97,7 +97,7 @@ TEST_CASE("BamUtilsTest: add_rg_hdr read group headers", TEST_GROUP) { SECTION("Read groups with barcodes") { dorado::SamHdrPtr sam_header(sam_hdr_init()); - dorado::utils::add_rg_hdr(sam_header.get(), read_groups, barcode_kits); + dorado::utils::add_rg_hdr(sam_header.get(), read_groups, barcode_kits, nullptr); // Check the IDs of the groups are all there. size_t total_barcodes = 0; @@ -130,7 +130,7 @@ TEST_CASE("BamUtilsTest: add_rg_hdr read group headers", TEST_GROUP) { SECTION("Read groups with unknown barcode kit") { dorado::SamHdrPtr sam_header(sam_hdr_init()); - CHECK_THROWS(dorado::utils::add_rg_hdr(sam_header.get(), read_groups, {"blah"})); + CHECK_THROWS(dorado::utils::add_rg_hdr(sam_header.get(), read_groups, {"blah"}, nullptr)); } } diff --git a/tests/BarcodeDemuxerNodeTest.cpp b/tests/BarcodeDemuxerNodeTest.cpp index f5093671..93a20bd3 100644 --- a/tests/BarcodeDemuxerNodeTest.cpp +++ b/tests/BarcodeDemuxerNodeTest.cpp @@ -3,6 +3,7 @@ #include "MessageSinkUtils.h" #include "TestUtils.h" #include "read_pipeline/HtsReader.h" +#include "utils/SampleSheet.h" #include "utils/bam_utils.h" #include "utils/sequence_utils.h" #include "utils/types.h" @@ -46,8 +47,8 @@ TEST_CASE("BarcodeDemuxerNode: check correct output files are created", TEST_GRO // the pipeline object is closed. This needs to be looked at. // TODO: Address open file issue on windows. dorado::PipelineDescriptor pipeline_desc; - auto demuxer = - pipeline_desc.add_node({}, tmp_dir.string(), 8, 10, false); + auto demuxer = pipeline_desc.add_node({}, tmp_dir.string(), 8, 10, + false, nullptr); auto pipeline = dorado::Pipeline::create(std::move(pipeline_desc)); diff --git a/tests/NodeSmokeTest.cpp b/tests/NodeSmokeTest.cpp index 7da59539..9bb453dc 100644 --- a/tests/NodeSmokeTest.cpp +++ b/tests/NodeSmokeTest.cpp @@ -14,6 +14,7 @@ #include "read_pipeline/ReadFilterNode.h" #include "read_pipeline/ReadToBamTypeNode.h" #include "read_pipeline/ScalerNode.h" +#include "utils/SampleSheet.h" #include "utils/parameters.h" #if DORADO_GPU_BUILD @@ -327,8 +328,8 @@ DEFINE_TEST(NodeSmokeTestBam, "ReadToBamType") { set_pipeline_restart(pipeline_restart); - run_smoke_test(emit_moves, 2, - dorado::utils::default_parameters.methylation_threshold); + run_smoke_test( + emit_moves, 2, dorado::utils::default_parameters.methylation_threshold, nullptr, 1000); } DEFINE_TEST(NodeSmokeTestRead, "BarcodeClassifierNode") { diff --git a/tests/SampleSheetTests.cpp b/tests/SampleSheetTests.cpp index 91cff220..959115b7 100644 --- a/tests/SampleSheetTests.cpp +++ b/tests/SampleSheetTests.cpp @@ -29,8 +29,7 @@ TEST_CASE(CUT_TAG " load valid no-barcode sample sheet", CUT_TAG) { // Test that all the alias functions return empty strings std::string alias; - REQUIRE_NOTHROW(alias = sample_sheet.get_alias("FA026858", "pos_id", "sequencing_20200522", - "barcode10")); + REQUIRE_NOTHROW(alias = sample_sheet.get_alias("PAO25751", "pos_id", "", "barcode10")); CHECK(alias == ""); } @@ -42,28 +41,23 @@ TEST_CASE(CUT_TAG " load valid single barcode sample sheet", CUT_TAG) { // Test first entry loads correctly std::string alias; - REQUIRE_NOTHROW( - alias = sample_sheet.get_alias("FA026858", "", "sequencing_20200522", "barcode01")); + REQUIRE_NOTHROW(alias = sample_sheet.get_alias("PAO25751", "", "", "barcode01")); CHECK(alias == "patient_id_5"); // Test last entry loads correctly - REQUIRE_NOTHROW( - alias = sample_sheet.get_alias("FA026858", "", "sequencing_20200522", "barcode08")); + REQUIRE_NOTHROW(alias = sample_sheet.get_alias("PAO25751", "", "", "barcode08")); CHECK(alias == "patient_id_4"); - // TODO: is this what we want? - // Test that asking for position_id when it's not there stops you getting an alias - REQUIRE_NOTHROW(alias = sample_sheet.get_alias("FA026858", "pos_id", "sequencing_20200522", - "barcode01")); - CHECK(alias == ""); + // Test that providing position_id when it's not in the sample sheet doesn't stop you getting an alias + REQUIRE_NOTHROW(alias = sample_sheet.get_alias("PAO25751", "pos_id", "", "barcode01")); + CHECK(alias == "patient_id_5"); // Test that asking for neither position_id or flowcell_id stops you getting an alias - REQUIRE_NOTHROW(alias = sample_sheet.get_alias("", "", "sequencing_20200522", "barcode01")); + REQUIRE_NOTHROW(alias = sample_sheet.get_alias("", "", "", "barcode01")); CHECK(alias == ""); // Test non-existent entry - REQUIRE_NOTHROW( - alias = sample_sheet.get_alias("FA026858", "", "sequencing_20200522", "barcode10")); + REQUIRE_NOTHROW(alias = sample_sheet.get_alias("PAO25751", "", "", "barcode10")); CHECK(alias == ""); } @@ -102,7 +96,8 @@ TEST_CASE(CUT_TAG " load sample sheet cross platform (parameterised)", CUT_TAG) CAPTURE(eol_chars); const std::string HEADER_LINE{"flow_cell_id,kit,sample_id,experiment_id,barcode,alias,type"}; const std::string RECORD_LINE{ - "FA026858,SQK-RBK004,barcoding_run,sequencing_20200522,barcode01,patient_id_5,test_" + "PAO25751,SQK-RBK004,barcoding_run,,barcode01," + "patient_id_5,test_" "sample"}; dorado::utils::SampleSheet sample_sheet{}; std::stringstream input_file{HEADER_LINE + eol_chars + RECORD_LINE + eol_chars}; @@ -112,8 +107,7 @@ TEST_CASE(CUT_TAG " load sample sheet cross platform (parameterised)", CUT_TAG) REQUIRE(sample_sheet.get_type() == dorado::utils::SampleSheet::Type::barcode); std::string alias; - REQUIRE_NOTHROW( - alias = sample_sheet.get_alias("FA026858", "", "sequencing_20200522", "barcode01")); + REQUIRE_NOTHROW(alias = sample_sheet.get_alias("PAO25751", "", "", "barcode01")); CHECK(alias == "patient_id_5"); } diff --git a/tests/data/barcode_demux/sample_sheet.csv b/tests/data/barcode_demux/sample_sheet.csv index d73594cb..5dfb5ae8 100644 --- a/tests/data/barcode_demux/sample_sheet.csv +++ b/tests/data/barcode_demux/sample_sheet.csv @@ -1,2 +1,2 @@ flow_cell_id,kit,sample_id,experiment_id,barcode,alias,type -PAO25751,SQK-RBK114-96,no_sample,not_set,barcode01,patient_id_1,test_sample \ No newline at end of file +PAO25751,SQK-RBK114-96,no_sample,,barcode01,patient_id_1,test_sample \ No newline at end of file diff --git a/tests/data/sample_sheets/single_barcode.csv b/tests/data/sample_sheets/single_barcode.csv index 37496d91..823eb166 100644 --- a/tests/data/sample_sheets/single_barcode.csv +++ b/tests/data/sample_sheets/single_barcode.csv @@ -1,9 +1,9 @@ flow_cell_id,kit,sample_id,experiment_id,barcode,alias,type -FA026858,SQK-RBK004,barcoding_run,sequencing_20200522,barcode01,patient_id_5,test_sample -FA026858,SQK-RBK004,barcoding_run,sequencing_20200522,barcode02,patient_id_6,test_sample -FA026858,SQK-RBK004,barcoding_run,sequencing_20200522,barcode03,patient_id_7,test_sample -FA026858,SQK-RBK004,barcoding_run,sequencing_20200522,barcode04,patient_id_8,test_sample -FA026858,SQK-RBK004,barcoding_run,sequencing_20200522,barcode05,patient_id_1,test_sample -FA026858,SQK-RBK004,barcoding_run,sequencing_20200522,barcode06,patient_id_2,test_sample -FA026858,SQK-RBK004,barcoding_run,sequencing_20200522,barcode07,patient_id_3,test_sample -FA026858,SQK-RBK004,barcoding_run,sequencing_20200522,barcode08,patient_id_4,test_sample +PAO25751,SQK-RBK114-96,barcoding_run,,barcode01,patient_id_5,test_sample +PAO25751,SQK-RBK114-96,barcoding_run,,barcode02,patient_id_6,test_sample +PAO25751,SQK-RBK114-96,barcoding_run,,barcode03,patient_id_7,test_sample +PAO25751,SQK-RBK114-96,barcoding_run,,barcode04,patient_id_8,test_sample +PAO25751,SQK-RBK114-96,barcoding_run,,barcode05,patient_id_1,test_sample +PAO25751,SQK-RBK114-96,barcoding_run,,barcode06,patient_id_2,test_sample +PAO25751,SQK-RBK114-96,barcoding_run,,barcode07,patient_id_3,test_sample +PAO25751,SQK-RBK114-96,barcoding_run,,barcode08,patient_id_4,test_sample diff --git a/tests/test_simple_basecaller_execution.sh b/tests/test_simple_basecaller_execution.sh index 572dd989..81b3f937 100755 --- a/tests/test_simple_basecaller_execution.sh +++ b/tests/test_simple_basecaller_execution.sh @@ -180,23 +180,28 @@ fi echo dorado basecaller barcoding read groups test_barcoding_read_groups() ( - expected_read_groups_barcode01=$1 - expected_read_groups_barcode04=$2 - expected_read_groups_unclassified=$3 - sample_sheet=$4 + while (( "$#" >= 2 )); do + barcode=$1 + export expected_read_groups_${barcode}=$2 + shift 2 + done + sample_sheet=$1 output_name=read_group_test${sample_sheet:+_sample_sheet} $dorado_bin basecaller -b ${batch} --kit-name SQK-RBK114-96 ${sample_sheet:+--sample-sheet ${sample_sheet}} ${model_5k} $data_dir/barcode_demux/read_group_test > $output_dir/${output_name}.bam samtools quickcheck -u $output_dir/${output_name}.bam split_dir=$output_dir/${output_name} mkdir $split_dir samtools split -u $split_dir/unknown.bam -f "$split_dir/rg_%!.bam" $output_dir/${output_name}.bam - # There should be 4 reads with BC01, 3 with BC04, and 2 unclassified groups. for bam in $split_dir/rg_*.bam; do if [[ $bam =~ "_SQK-RBK114-96_" ]]; then # Arrangement is |_|, so trim the kit from the prefix and the .bam from the suffix. barcode=${bam#*_SQK-RBK114-96_} barcode=${barcode%.bam*} + elif [[ $bam =~ "_${model_name_5k}_" ]]; then + # Arrangement is ||, so trim the model from the prefix and the .bam from the suffix. + barcode=${bam#*_${model_name_5k}_} + barcode=${barcode%.bam*} else barcode="unclassified" fi @@ -217,8 +222,10 @@ test_barcoding_read_groups() ( fi ) -test_barcoding_read_groups 4 3 2 -test_barcoding_read_groups 4 0 5 $data_dir/barcode_demux/sample_sheet.csv +# There should be 4 reads with BC01, 3 with BC04, and 2 unclassified groups. +test_barcoding_read_groups barcode01 4 barcode04 3 unclassified 2 +# There should be 4 reads with BC01 aliased to patient_id_1, and 5 unclassified groups. +test_barcoding_read_groups patient_id_1 4 unclassified 5 $data_dir/barcode_demux/sample_sheet.csv # Test demux only on a pre-classified BAM file $dorado_bin demux --no-classify --output-dir "$output_dir/demux_only_test/" $output_dir/read_group_test.bam