Skip to content

Commit

Permalink
Merge branch 'smalton/DOR-364-sample-sheet-alias' into 'master'
Browse files Browse the repository at this point in the history
DOR-364: Sample sheet aliasing

Closes DOR-364

See merge request machine-learning/dorado!661
  • Loading branch information
malton-ont committed Oct 26, 2023
2 parents f6bf232 + 4698eeb commit 616b951
Show file tree
Hide file tree
Showing 29 changed files with 285 additions and 130 deletions.
15 changes: 10 additions & 5 deletions dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,16 @@ void setup(std::vector<std::string> args,
num_devices, !remora_runners.empty() ? num_remora_threads : 0, enable_aligner,
!barcode_kits.empty());

std::unique_ptr<const utils::SampleSheet> sample_sheet;
BarcodingInfo::FilterSet allowed_barcodes;
if (!barcode_sample_sheet.empty()) {
sample_sheet = std::make_unique<const utils::SampleSheet>(barcode_sample_sheet, false);
allowed_barcodes = sample_sheet->get_barcode_values();
}

SamHdrPtr hdr(sam_hdr_init());
cli::add_pg_hdr(hdr.get(), args);
utils::add_rg_hdr(hdr.get(), read_groups, barcode_kits);
utils::add_rg_hdr(hdr.get(), read_groups, barcode_kits, sample_sheet.get());

PipelineDescriptor pipeline_desc;
auto hts_writer = pipeline_desc.add_node<HtsWriter>(
Expand All @@ -139,18 +146,16 @@ void setup(std::vector<std::string> args,
}
current_sink_node = pipeline_desc.add_node<ReadToBamType>(
{current_sink_node}, emit_moves, thread_allocations.read_converter_threads,
methylation_threshold_pct);
methylation_threshold_pct, std::move(sample_sheet), 1000);
if (estimate_poly_a) {
current_sink_node = pipeline_desc.add_node<PolyACalculator>(
{current_sink_node}, std::thread::hardware_concurrency(),
PolyACalculator::get_model_type(model_name));
}
if (!barcode_kits.empty()) {
utils::SampleSheet sample_sheet(barcode_sample_sheet);
BarcodingInfo::FilterSet allowed_barcodes = sample_sheet.get_barcode_values();
current_sink_node = pipeline_desc.add_node<BarcodeClassifierNode>(
{current_sink_node}, thread_allocations.barcoder_threads, barcode_kits,
barcode_both_ends, barcode_no_trim, allowed_barcodes);
barcode_both_ends, barcode_no_trim, std::move(allowed_barcodes));
}
current_sink_node = pipeline_desc.add_node<ReadFilterNode>(
{current_sink_node}, min_qscore, default_parameters.min_sequence_length,
Expand Down
21 changes: 14 additions & 7 deletions dorado/cli/demux.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,23 +131,30 @@ int demuxer(int argc, char* argv[]) {
}

HtsReader reader(reads[0], read_list);
auto header = sam_hdr_dup(reader.header);
add_pg_hdr(header);
auto header = SamHdrPtr(sam_hdr_dup(reader.header));
add_pg_hdr(header.get());

auto barcode_sample_sheet = parser.get<std::string>("--sample-sheet");
std::unique_ptr<const utils::SampleSheet> sample_sheet;
BarcodingInfo::FilterSet allowed_barcodes;
if (!barcode_sample_sheet.empty()) {
sample_sheet = std::make_unique<const utils::SampleSheet>(barcode_sample_sheet, true);
allowed_barcodes = sample_sheet->get_barcode_values();
}

PipelineDescriptor pipeline_desc;
auto demux_writer = pipeline_desc.add_node<BarcodeDemuxerNode>(
{}, output_dir, demux_writer_threads, 0, parser.get<bool>("--emit-fastq"));
{}, output_dir, demux_writer_threads, 0, parser.get<bool>("--emit-fastq"),
std::move(sample_sheet));

if (parser.is_used("--kit-name")) {
std::vector<std::string> kit_names;
if (auto names = parser.present<std::vector<std::string>>("--kit-name")) {
kit_names = std::move(*names);
}
utils::SampleSheet sample_sheet(parser.get<std::string>("--sample-sheet"));
BarcodingInfo::FilterSet allowed_barcodes = sample_sheet.get_barcode_values();
auto demux = pipeline_desc.add_node<BarcodeClassifierNode>(
{demux_writer}, demux_threads, kit_names, parser.get<bool>("--barcode-both-ends"),
parser.get<bool>("--no-trim"), allowed_barcodes);
parser.get<bool>("--no-trim"), std::move(allowed_barcodes));
}

// Create the Pipeline from our description.
Expand All @@ -162,7 +169,7 @@ int demuxer(int argc, char* argv[]) {
// rather than the pipeline framework.
auto& demux_writer_ref =
dynamic_cast<BarcodeDemuxerNode&>(pipeline->get_node_ref(demux_writer));
demux_writer_ref.set_header(header);
demux_writer_ref.set_header(header.get());

// Set up stats counting
std::vector<dorado::stats::StatsCallable> stats_callables;
Expand Down
7 changes: 4 additions & 3 deletions dorado/cli/duplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "read_pipeline/ProgressTracker.h"
#include "read_pipeline/ReadFilterNode.h"
#include "read_pipeline/ReadToBamTypeNode.h"
#include "utils/SampleSheet.h"
#include "utils/bam_utils.h"
#include "utils/basecaller_utils.h"
#include "utils/duplex_utils.h"
Expand Down Expand Up @@ -193,8 +194,8 @@ int duplex(int argc, char* argv[]) {
pipeline_desc.add_node_sink(aligner, hts_writer);
converted_reads_sink = aligner;
}
auto read_converter =
pipeline_desc.add_node<ReadToBamType>({converted_reads_sink}, emit_moves, 2);
auto read_converter = pipeline_desc.add_node<ReadToBamType>(
{converted_reads_sink}, emit_moves, 2, 0, nullptr, 1000);
auto duplex_read_tagger = pipeline_desc.add_node<DuplexReadTaggingNode>({read_converter});
// The minimum sequence length is set to 5 to avoid issues with duplex node printing very short sequences for mismatched pairs.
std::unordered_set<std::string> read_ids_to_filter;
Expand Down Expand Up @@ -287,7 +288,7 @@ int duplex(int argc, char* argv[]) {
read_groups.merge(
DataLoader::load_read_groups(reads, duplex_rg_name, recursive_file_loading));
std::vector<std::string> barcode_kits;
utils::add_rg_hdr(hdr.get(), read_groups, barcode_kits);
utils::add_rg_hdr(hdr.get(), read_groups, barcode_kits, nullptr);

int batch_size(parser.visible.get<int>("-b"));
int chunk_size(parser.visible.get<int>("-c"));
Expand Down
30 changes: 25 additions & 5 deletions dorado/data_loader/DataLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,16 @@ void string_reader(HighFive::Attribute& attribute, std::string& target_str) {
if (eol_pos < target_str.size()) {
target_str.resize(eol_pos);
}
};
}

std::string get_string_attribute(const HighFive::Group& group, const std::string& attr_name) {
std::string attribute_string;
if (group.hasAttribute(attr_name)) {
HighFive::Attribute attribute = group.getAttribute(attr_name);
string_reader(attribute, attribute_string);
}
return attribute_string;
}
} // namespace

namespace dorado {
Expand Down Expand Up @@ -156,6 +165,8 @@ SimplexReadPtr process_pod5_read(
new_read->start_sample = read_data.start_sample;
new_read->end_sample = read_data.start_sample + read_data.num_samples;
new_read->read_common.flowcell_id = run_info_data->flow_cell_id;
new_read->read_common.position_id = run_info_data->sequencer_position;
new_read->read_common.experiment_id = run_info_data->experiment_name;
new_read->read_common.is_duplex = false;

// Determine the time sorted predecessor of the read
Expand Down Expand Up @@ -466,6 +477,8 @@ std::unordered_map<std::string, ReadGroup> DataLoader::load_read_groups(
std::string device_id = run_info_data->system_name;
std::string run_id = run_info_data->acquisition_id;
std::string sample_id = run_info_data->sample_id;
std::string position_id = run_info_data->sequencer_position;
std::string experiment_id = run_info_data->experiment_name;

if (pod5_free_run_info(run_info_data) != POD5_OK) {
spdlog::error("Failed to free run info");
Expand All @@ -478,7 +491,9 @@ std::unordered_map<std::string, ReadGroup> DataLoader::load_read_groups(
flowcell_id,
device_id,
utils::get_string_timestamp_from_unix_time(exp_start_time_ms),
sample_id};
sample_id,
position_id,
experiment_id};
}
if (pod5_close_and_free_reader(file) != POD5_OK) {
spdlog::error("Failed to close and free POD5 reader");
Expand Down Expand Up @@ -799,9 +814,11 @@ void DataLoader::load_fast5_reads_from_file(const std::string& path) {
std::string fast5_filename = std::filesystem::path(path).filename().string();

HighFive::Group tracking_id_group = read.getGroup("tracking_id");
HighFive::Attribute exp_start_time_attr = tracking_id_group.getAttribute("exp_start_time");
std::string exp_start_time;
string_reader(exp_start_time_attr, exp_start_time);
std::string exp_start_time = get_string_attribute(tracking_id_group, "exp_start_time");
std::string flow_cell_id = get_string_attribute(tracking_id_group, "flow_cell_id");
std::string device_id = get_string_attribute(tracking_id_group, "device_id");
std::string group_protocol_id =
get_string_attribute(tracking_id_group, "group_protocol_id");

auto start_time_str = utils::adjust_time(exp_start_time,
static_cast<uint32_t>(start_time / sampling_rate));
Expand All @@ -820,6 +837,9 @@ void DataLoader::load_fast5_reads_from_file(const std::string& path) {
new_read->read_common.attributes.channel_number = channel_number;
new_read->read_common.attributes.start_time = start_time_str;
new_read->read_common.attributes.fast5_filename = fast5_filename;
new_read->read_common.flowcell_id = flow_cell_id;
new_read->read_common.position_id = device_id;
new_read->read_common.experiment_id = group_protocol_id;
new_read->read_common.is_duplex = false;

if (!m_allowed_read_ids || (m_allowed_read_ids->find(new_read->read_common.read_id) !=
Expand Down
3 changes: 1 addition & 2 deletions dorado/demux/BarcodeClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <spdlog/spdlog.h>

#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
#include <string_view>
#include <tuple>
Expand Down Expand Up @@ -386,7 +386,6 @@ std::vector<ScoreResults> BarcodeClassifier::calculate_adapter_score_double_ends
if (!barcode_is_permitted(allowed_barcodes, adapter_name)) {
continue;
}

spdlog::debug("Checking barcode {}", adapter_name);

auto top_mask_score = extract_mask_score(adapter, top_mask, mask_config, "top window");
Expand Down
1 change: 0 additions & 1 deletion dorado/demux/BarcodeClassifier.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#pragma once
#include "read_pipeline/ReadPipeline.h"
#include "utils/stats.h"
#include "utils/types.h"

Expand Down
2 changes: 2 additions & 0 deletions dorado/demux/BarcodeClassifierSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include "BarcodeClassifier.h"

#include <cassert>

namespace dorado::demux {

std::shared_ptr<const BarcodeClassifier> BarcodeClassifierSelector::get_barcoder(
Expand Down
9 changes: 6 additions & 3 deletions dorado/read_pipeline/BarcodeClassifierNode.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "BarcodeClassifierNode.h"

#include "demux/BarcodeClassifier.h"
#include "utils/SampleSheet.h"
#include "utils/bam_utils.h"
#include "utils/barcode_kits.h"
#include "utils/trim.h"
Expand Down Expand Up @@ -40,11 +41,13 @@ BarcodeClassifierNode::BarcodeClassifierNode(int threads,
const std::vector<std::string>& kit_names,
bool barcode_both_ends,
bool no_trim,
const BarcodingInfo::FilterSet& allowed_barcodes)
BarcodingInfo::FilterSet allowed_barcodes)
: MessageSink(10000),
m_threads(threads),
m_default_barcoding_info(
create_barcoding_info(kit_names, barcode_both_ends, !no_trim, allowed_barcodes)) {
m_default_barcoding_info(create_barcoding_info(kit_names,
barcode_both_ends,
!no_trim,
std::move(allowed_barcodes))) {
start_threads();
}

Expand Down
3 changes: 2 additions & 1 deletion dorado/read_pipeline/BarcodeClassifierNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "utils/types.h"

#include <atomic>
#include <memory>
#include <string>
#include <string_view>
#include <vector>
Expand All @@ -21,7 +22,7 @@ class BarcodeClassifierNode : public MessageSink {
const std::vector<std::string>& kit_name,
bool barcode_both_ends,
bool no_trim,
const BarcodingInfo::FilterSet& allowed_barcodes);
BarcodingInfo::FilterSet allowed_barcodes);
BarcodeClassifierNode(int threads);
~BarcodeClassifierNode();
std::string get_name() const override { return "BarcodeClassifierNode"; }
Expand Down
16 changes: 14 additions & 2 deletions dorado/read_pipeline/BarcodeDemuxerNode.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "BarcodeDemuxerNode.h"

#include "read_pipeline/ReadPipeline.h"
#include "utils/SampleSheet.h"

#include <htslib/bgzf.h>
#include <htslib/sam.h>
Expand All @@ -14,12 +15,14 @@ namespace dorado {
BarcodeDemuxerNode::BarcodeDemuxerNode(const std::string& output_dir,
size_t htslib_threads,
size_t num_reads,
bool write_fastq)
bool write_fastq,
std::unique_ptr<const utils::SampleSheet> sample_sheet)
: MessageSink(10000),
m_output_dir(output_dir),
m_htslib_threads(htslib_threads),
m_num_reads_expected(num_reads),
m_write_fastq(write_fastq) {
m_write_fastq(write_fastq),
m_sample_sheet(std::move(sample_sheet)) {
std::filesystem::create_directories(m_output_dir);
start_threads();
}
Expand Down Expand Up @@ -67,6 +70,15 @@ int BarcodeDemuxerNode::write(bam1_t* const record) {
if (bam_tag) {
bc = std::string(bam_aux2Z(bam_tag));
}

if (m_sample_sheet) {
// experiment id and position id are not stored in the bam record, so we can't recover them to use here
auto alias = m_sample_sheet->get_alias("", "", "", bc);
if (!alias.empty()) {
bc = alias;
bam_aux_update_str(record, "BC", bc.size() + 1, bc.c_str());
}
}
// Check of existence of file for that barcode.
auto res = m_files.find(bc);
htsFile* file = nullptr;
Expand Down
7 changes: 6 additions & 1 deletion dorado/read_pipeline/BarcodeDemuxerNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@

namespace dorado {

namespace utils {
class SampleSheet;
}
class BarcodeDemuxerNode : public MessageSink {
public:
BarcodeDemuxerNode(const std::string& output_dir,
size_t htslib_threads,
size_t num_reads,
bool write_fastq);
bool write_fastq,
std::unique_ptr<const utils::SampleSheet> sample_sheet);
~BarcodeDemuxerNode();
std::string get_name() const override { return "BarcodeDemuxerNode"; }
stats::NamedStats sample_stats() const override;
Expand All @@ -42,6 +46,7 @@ class BarcodeDemuxerNode : public MessageSink {
int write(bam1_t* record);
size_t m_num_reads_expected;
bool m_write_fastq{false};
std::unique_ptr<const utils::SampleSheet> m_sample_sheet;
};

} // namespace dorado
6 changes: 4 additions & 2 deletions dorado/read_pipeline/ReadPipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ class ReadCommon {
std::vector<uint8_t> moves; // Move table
std::vector<uint8_t> base_mod_probs; // Modified base probabilities
std::string run_id; // Run ID - used in read group
std::string flowcell_id; // Flowcell ID - used in read group
std::string model_name; // Read group
std::string flowcell_id; // Flowcell ID - used in read group and for sample sheet aliasing
std::string position_id; // Position ID - used for sample sheet aliasing
std::string experiment_id; // Experiment ID - used for sample sheet aliasing
std::string model_name; // Read group

dorado::details::Attributes attributes;

Expand Down
19 changes: 18 additions & 1 deletion dorado/read_pipeline/ReadToBamTypeNode.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "ReadToBamTypeNode.h"

#include "utils/SampleSheet.h"

#include <spdlog/spdlog.h>

#include <algorithm>
Expand All @@ -24,6 +26,17 @@ void ReadToBamType::worker_thread() {
if (!read_common_data.is_duplex) {
is_duplex_parent = std::get<SimplexReadPtr>(message)->is_duplex_parent;
}

// alias barcode if present
if (m_sample_sheet && !read_common_data.barcode.empty()) {
auto alias = m_sample_sheet->get_alias(
read_common_data.flowcell_id, read_common_data.position_id,
read_common_data.experiment_id, read_common_data.barcode);
if (!alias.empty()) {
read_common_data.barcode = alias;
}
}

auto alns = read_common_data.extract_sam_lines(m_emit_moves, m_modbase_threshold,
is_duplex_parent);
for (auto& aln : alns) {
Expand All @@ -35,15 +48,19 @@ void ReadToBamType::worker_thread() {
ReadToBamType::ReadToBamType(bool emit_moves,
size_t num_worker_threads,
float modbase_threshold_frac,
std::unique_ptr<const utils::SampleSheet> sample_sheet,
size_t max_reads)
: MessageSink(max_reads),
m_num_worker_threads(num_worker_threads),
m_emit_moves(emit_moves),
m_modbase_threshold(
static_cast<uint8_t>(std::min(modbase_threshold_frac * 256.0f, 255.0f))) {
static_cast<uint8_t>(std::min(modbase_threshold_frac * 256.0f, 255.0f))),
m_sample_sheet(std::move(sample_sheet)) {
start_threads();
}

ReadToBamType::~ReadToBamType() { terminate_impl(); }

void ReadToBamType::start_threads() {
for (size_t i = 0; i < m_num_worker_threads; i++) {
m_workers.push_back(
Expand Down

0 comments on commit 616b951

Please sign in to comment.