Skip to content

Commit

Permalink
Merge branch 'smalton/DOR-577-server-polya' into 'master'
Browse files Browse the repository at this point in the history
[DOR-577] PolyA calculator node refactor

Closes DOR-577

See merge request machine-learning/dorado!865
  • Loading branch information
malton-ont committed Apr 3, 2024
2 parents c8d0231 + f567056 commit 39f3c65
Show file tree
Hide file tree
Showing 26 changed files with 515 additions and 410 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ add_library(dorado_lib
dorado/read_pipeline/BaseSpaceDuplexCallerNode.cpp
dorado/read_pipeline/BaseSpaceDuplexCallerNode.h
dorado/read_pipeline/ClientInfo.h
dorado/read_pipeline/DefaultClientInfo.cpp
dorado/read_pipeline/DefaultClientInfo.h
dorado/read_pipeline/DuplexReadTaggingNode.cpp
dorado/read_pipeline/DuplexReadTaggingNode.h
Expand Down
16 changes: 11 additions & 5 deletions dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "read_pipeline/AdapterDetectorNode.h"
#include "read_pipeline/AlignerNode.h"
#include "read_pipeline/BarcodeClassifierNode.h"
#include "read_pipeline/DefaultClientInfo.h"
#include "read_pipeline/HtsReader.h"
#include "read_pipeline/HtsWriter.h"
#include "read_pipeline/PolyACalculatorNode.h"
Expand Down Expand Up @@ -119,7 +120,7 @@ void setup(std::vector<std::string> args,
const std::optional<std::string>& custom_primer_file,
argparse::ArgumentParser& resume_parser,
bool estimate_poly_a,
const std::string* const polya_config,
const std::string& polya_config,
const ModelSelection& model_selection) {
const auto model_config = basecall::load_crf_model_config(model_path);
const std::string model_name = models::extract_model_name_from_path(model_path);
Expand Down Expand Up @@ -221,8 +222,7 @@ void setup(std::vector<std::string> args,
methylation_threshold_pct, std::move(sample_sheet), 1000);
if (estimate_poly_a) {
current_sink_node = pipeline_desc.add_node<PolyACalculatorNode>(
{current_sink_node}, std::thread::hardware_concurrency(),
is_rna_model(model_config), 1000, polya_config);
{current_sink_node}, std::thread::hardware_concurrency(), 1000);
}
if (adapter_trimming_enabled) {
current_sink_node = pipeline_desc.add_node<AdapterDetectorNode>(
Expand Down Expand Up @@ -322,6 +322,12 @@ void setup(std::vector<std::string> args,
DataLoader loader(*pipeline, "cpu", thread_allocations.loader_threads, max_reads, read_list,
reads_already_processed);

DefaultClientInfo::PolyTailSettings polytail_settings{estimate_poly_a,
is_rna_model(model_config), polya_config};
auto default_client_info = std::make_shared<DefaultClientInfo>(polytail_settings);
auto func = [default_client_info](ReadCommon& read) { read.client_info = default_client_info; };
loader.add_read_initialiser(func);

// Run pipeline.
loader.load_reads(data_path, recursive_file_loading, ReadOrder::UNRESTRICTED);

Expand Down Expand Up @@ -682,8 +688,8 @@ int basecaller(int argc, char* argv[]) {
parser.visible.get<bool>("--barcode-both-ends"), no_trim_barcodes, no_trim_adapters,
no_trim_primers, parser.visible.get<std::string>("--sample-sheet"),
std::move(custom_kit), std::move(custom_barcode_seqs), std::move(custom_primer_file),
resume_parser, parser.visible.get<bool>("--estimate-poly-a"),
polya_config.empty() ? nullptr : &polya_config, model_selection);
resume_parser, parser.visible.get<bool>("--estimate-poly-a"), polya_config,
model_selection);
} catch (const std::exception& e) {
spdlog::error("{}", e.what());
utils::clean_temporary_models(temp_download_paths);
Expand Down
11 changes: 11 additions & 0 deletions dorado/cli/duplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "models/models.h"
#include "read_pipeline/AlignerNode.h"
#include "read_pipeline/BaseSpaceDuplexCallerNode.h"
#include "read_pipeline/DefaultClientInfo.h"
#include "read_pipeline/DuplexReadTaggingNode.h"
#include "read_pipeline/HtsWriter.h"
#include "read_pipeline/ProgressTracker.h"
Expand Down Expand Up @@ -406,6 +407,11 @@ int duplex(int argc, char* argv[]) {

constexpr auto kStatsPeriod = 100ms;

auto default_client_info = std::make_shared<DefaultClientInfo>();
auto client_info_init_func = [default_client_info](ReadCommon& read) {
read.client_info = default_client_info;
};

if (basespace_duplex) { // Execute a Basespace duplex pipeline.
if (pairs_file.empty()) {
spdlog::error("The --pairs argument is required for the basespace model.");
Expand All @@ -420,6 +426,10 @@ int duplex(int argc, char* argv[]) {
spdlog::info("> Loading reads");
auto read_map = read_bam(reads, read_list_from_pairs);

for (auto& [key, read] : read_map) {
client_info_init_func(read->read_common);
}

spdlog::info("> Starting Basespace Duplex Pipeline");
threads = threads == 0 ? std::thread::hardware_concurrency() : threads;

Expand Down Expand Up @@ -541,6 +551,7 @@ int duplex(int argc, char* argv[]) {
hts_file.set_and_write_header(hdr.get());

DataLoader loader(*pipeline, "cpu", num_devices, 0, std::move(read_list), {});
loader.add_read_initialiser(client_info_init_func);

stats_sampler = std::make_unique<dorado::stats::StatsSampler>(
kStatsPeriod, stats_reporters, stats_callables, max_stats_records);
Expand Down
9 changes: 9 additions & 0 deletions dorado/data_loader/DataLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ void DataLoader::load_pod5_reads_from_file_by_read_ids(const std::string& path,

for (auto& v : futures) {
auto read = v.get();
initialise_read(read->read_common);
check_read(read);
m_pipeline.push_message(std::move(read));
m_loaded_read_count++;
Expand Down Expand Up @@ -914,6 +915,7 @@ void DataLoader::load_pod5_reads_from_file(const std::string& path) {

for (auto& v : futures) {
auto read = v.get();
initialise_read(read->read_common);
check_read(read);
m_pipeline.push_message(std::move(read));
m_loaded_read_count++;
Expand Down Expand Up @@ -1021,12 +1023,19 @@ void DataLoader::load_fast5_reads_from_file(const std::string& path) {

if (!m_allowed_read_ids || (m_allowed_read_ids->find(new_read->read_common.read_id) !=
m_allowed_read_ids->end())) {
initialise_read(new_read->read_common);
m_pipeline.push_message(std::move(new_read));
m_loaded_read_count++;
}
}
}

void DataLoader::initialise_read(ReadCommon& read_common) const {
for (auto initialiser : m_read_initialisers) {
initialiser(read_common);
}
}

void DataLoader::check_read(const SimplexReadPtr& read) {
if (read->read_common.chemistry == models::Chemistry::UNKNOWN &&
m_log_unknown_chemistry.exchange(false)) {
Expand Down
10 changes: 10 additions & 0 deletions dorado/data_loader/DataLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <array>
#include <filesystem>
#include <functional>
#include <map>
#include <memory>
#include <optional>
Expand Down Expand Up @@ -70,13 +71,20 @@ class DataLoader {
uint32_t read_number;
};

using ReadInitialiserF = std::function<void(ReadCommon&)>;
void add_read_initialiser(ReadInitialiserF func) {
m_read_initialisers.push_back(std::move(func));
}

private:
void load_fast5_reads_from_file(const std::string& path);
void load_pod5_reads_from_file(const std::string& path);
void load_pod5_reads_from_file_by_read_ids(const std::string& path,
const std::vector<ReadID>& read_ids);
void load_read_channels(std::filesystem::path data_path, bool recursive_file_loading);

void initialise_read(ReadCommon& read) const;

Pipeline& m_pipeline; // Where should the loaded reads go?
std::atomic<size_t> m_loaded_read_count{0};
std::string m_device;
Expand All @@ -90,6 +98,8 @@ class DataLoader {
std::unordered_map<std::string, size_t> m_read_id_to_index;
int m_max_channel{0};

std::vector<ReadInitialiserF> m_read_initialisers;

// Issue warnings if read is potentially problematic
inline void check_read(const SimplexReadPtr& read);
// A flag to warn only once if the data chemsitry is known
Expand Down
1 change: 1 addition & 0 deletions dorado/poly_tail/dna_poly_tail_calculator.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "dna_poly_tail_calculator.h"

#include "read_pipeline/messages.h"
#include "utils/math_utils.h"
#include "utils/sequence_utils.h"

Expand Down
1 change: 1 addition & 0 deletions dorado/poly_tail/plasmid_poly_tail_calculator.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "plasmid_poly_tail_calculator.h"

#include "read_pipeline/messages.h"
#include "utils/sequence_utils.h"

#include <edlib.h>
Expand Down
3 changes: 2 additions & 1 deletion dorado/poly_tail/poly_tail_calculator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "dna_poly_tail_calculator.h"
#include "plasmid_poly_tail_calculator.h"
#include "poly_tail_config.h"
#include "read_pipeline/messages.h"
#include "rna_poly_tail_calculator.h"
#include "utils/sequence_utils.h"

Expand Down Expand Up @@ -233,7 +234,7 @@ int PolyTailCalculator::calculate_num_bases(const SimplexRead& read,

std::unique_ptr<PolyTailCalculator> PolyTailCalculatorFactory::create(
bool is_rna,
const std::string* const config_file) {
const std::string& config_file) {
auto config = prepare_config(config_file);
if (is_rna) {
return std::make_unique<RNAPolyTailCalculator>(std::move(config));
Expand Down
9 changes: 6 additions & 3 deletions dorado/poly_tail/poly_tail_calculator.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
#pragma once

#include "poly_tail_config.h"
#include "read_pipeline/messages.h"

#include <memory>
#include <string>
#include <utility>
#include <vector>

namespace dorado {
class SimplexRead;
}

namespace dorado::poly_tail {

struct SignalAnchorInfo {
Expand Down Expand Up @@ -67,8 +71,7 @@ class PolyTailCalculator {

class PolyTailCalculatorFactory {
public:
static std::unique_ptr<PolyTailCalculator> create(bool is_rna,
const std::string* const config_file);
static std::unique_ptr<PolyTailCalculator> create(bool is_rna, const std::string& config_file);
};

} // namespace dorado::poly_tail
8 changes: 4 additions & 4 deletions dorado/poly_tail/poly_tail_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ PolyTailConfig prepare_config(std::istream& is) {
return config;
}

PolyTailConfig prepare_config(const std::string* const config_file) {
if (config_file != nullptr) {
std::ifstream file(*config_file); // Open the file for reading
PolyTailConfig prepare_config(const std::string& config_file) {
if (!config_file.empty()) {
std::ifstream file(config_file); // Open the file for reading
if (!file.is_open()) {
throw std::runtime_error("Failed to open file " + *config_file);
throw std::runtime_error("Failed to open file " + config_file);
}

// Read the file contents into a string
Expand Down
2 changes: 1 addition & 1 deletion dorado/poly_tail/poly_tail_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct PolyTailConfig {
// Prepare the PolyA configuration struct. If a configuration
// file is available, parse it to extract parameters. Otherwise
// prepare the default configuration.
PolyTailConfig prepare_config(const std::string* const config_file);
PolyTailConfig prepare_config(const std::string& config_file);

// Overloaded function that parses the configuration passed
// in as an input stream.
Expand Down
1 change: 1 addition & 0 deletions dorado/poly_tail/rna_poly_tail_calculator.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "rna_poly_tail_calculator.h"

#include "read_pipeline/messages.h"
#include "utils/math_utils.h"

#include <algorithm>
Expand Down
8 changes: 8 additions & 0 deletions dorado/read_pipeline/ClientInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
#include "alignment/Minimap2Options.h"

#include <cstdint>
#include <memory>
#include <string>

namespace dorado {

namespace poly_tail {
class PolyTailCalculator;
}

// TODO replace this explicit dependency on an alignment struct with type
// erasure (possibly by using an inversion of control container as we do in
// basecall_server)
Expand All @@ -20,6 +25,9 @@ class ClientInfo {
virtual ~ClientInfo() = default;

virtual const AlignmentInfo& alignment_info() const = 0;
virtual const std::unique_ptr<const poly_tail::PolyTailCalculator>& poly_a_calculator()
const = 0;

virtual int32_t client_id() const = 0;
virtual bool is_disconnected() const = 0;
};
Expand Down
14 changes: 14 additions & 0 deletions dorado/read_pipeline/DefaultClientInfo.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include "DefaultClientInfo.h"

namespace dorado {

const AlignmentInfo DefaultClientInfo::empty_alignment_info{};

DefaultClientInfo::DefaultClientInfo(const PolyTailSettings& polytail_settings)
: m_poly_a_calculator(polytail_settings.active
? poly_tail::PolyTailCalculatorFactory::create(
polytail_settings.is_rna,
polytail_settings.config_file)
: nullptr) {}

} // namespace dorado
18 changes: 17 additions & 1 deletion dorado/read_pipeline/DefaultClientInfo.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
#pragma once

#include "ClientInfo.h"
#include "poly_tail/poly_tail_calculator.h"

namespace dorado {

class DefaultClientInfo final : public ClientInfo {
inline static const AlignmentInfo empty_alignment_info{};
static const AlignmentInfo empty_alignment_info;
const std::unique_ptr<const poly_tail::PolyTailCalculator> m_poly_a_calculator;

public:
struct PolyTailSettings {
bool active{false};
bool is_rna{false};
std::string config_file{};
};

DefaultClientInfo() = default;
DefaultClientInfo(const PolyTailSettings& polytail_settings);
~DefaultClientInfo() = default;

const AlignmentInfo& alignment_info() const override { return empty_alignment_info; }
const std::unique_ptr<const poly_tail::PolyTailCalculator>& poly_a_calculator() const override {
return m_poly_a_calculator;
};

int32_t client_id() const override { return -1; }
bool is_disconnected() const override { return false; }
};
Expand Down
27 changes: 16 additions & 11 deletions dorado/read_pipeline/PolyACalculatorNode.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "PolyACalculatorNode.h"

#include "ClientInfo.h"
#include "poly_tail/poly_tail_calculator.h"
#include "utils/math_utils.h"
#include "utils/sequence_utils.h"

Expand All @@ -23,18 +25,25 @@ void PolyACalculatorNode::input_thread_fn() {
// If this message isn't a read, we'll get a bad_variant_access exception.
auto read = std::get<SimplexReadPtr>(std::move(message));

auto signal_info = m_calculator->determine_signal_anchor_and_strand(*read);
const auto& calculator = read->read_common.client_info->poly_a_calculator();
if (!calculator) {
send_message_to_sink(std::move(read));
num_not_called++;
continue;
}

auto signal_info = calculator->determine_signal_anchor_and_strand(*read);

if (signal_info.signal_anchor >= 0) {
int num_bases = m_calculator->calculate_num_bases(*read, signal_info);
int num_bases = calculator->calculate_num_bases(*read, signal_info);
if (signal_info.split_tail) {
auto split_bases = std::max(
0, m_calculator->calculate_num_bases(*read, {signal_info.is_fwd_strand, 0,
0, signal_info.split_tail}));
0, calculator->calculate_num_bases(*read, {signal_info.is_fwd_strand, 0, 0,
signal_info.split_tail}));
num_bases += split_bases;
}

if (num_bases > 0 && num_bases < m_calculator->max_tail_length()) {
if (num_bases > 0 && num_bases < calculator->max_tail_length()) {
// Update debug stats.
total_tail_lengths_called += num_bases;
++num_called;
Expand All @@ -55,12 +64,8 @@ void PolyACalculatorNode::input_thread_fn() {
}
}

PolyACalculatorNode::PolyACalculatorNode(size_t num_worker_threads,
bool is_rna,
size_t max_reads,
const std::string* const config_file)
: MessageSink(max_reads, static_cast<int>(num_worker_threads)),
m_calculator(poly_tail::PolyTailCalculatorFactory::create(is_rna, config_file)) {
PolyACalculatorNode::PolyACalculatorNode(size_t num_worker_threads, size_t max_reads)
: MessageSink(max_reads, static_cast<int>(num_worker_threads)) {
start_input_processing(&PolyACalculatorNode::input_thread_fn, this);
}

Expand Down
Loading

0 comments on commit 39f3c65

Please sign in to comment.