Skip to content

Commit

Permalink
Merge branch 'jdaw/rna-adapter-trimming' into 'master'
Browse files Browse the repository at this point in the history
RNA adapter trimming

Closes DOR-373

See merge request machine-learning/dorado!658
  • Loading branch information
tijyojwad committed Oct 30, 2023
2 parents bb2a133 + 820c223 commit 2dc1f03
Show file tree
Hide file tree
Showing 23 changed files with 184 additions and 127 deletions.
11 changes: 9 additions & 2 deletions dorado/cli/aligner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,14 @@ int aligner(int argc, char* argv[]) {
.help("maximum number of reads to process (for debugging, 0=unlimited).")
.default_value(0)
.scan<'i', int>();
parser.visible.add_argument("-v", "--verbose").default_value(false).implicit_value(true);
int verbosity = 0;
parser.visible.add_argument("-v", "--verbose")
.default_value(false)
.implicit_value(true)
.nargs(0)
.action([&](const auto&) { ++verbosity; })
.append();

cli::add_minimap2_arguments(parser, Aligner::dflt_options);

try {
Expand All @@ -67,7 +74,7 @@ int aligner(int argc, char* argv[]) {

if (parser.visible.get<bool>("--verbose")) {
mm_verbose = 3;
utils::SetDebugLogging();
utils::SetVerboseLogging(static_cast<dorado::utils::VerboseLogLevel>(verbosity));
}

auto index(parser.visible.get<std::string>("index"));
Expand Down
12 changes: 9 additions & 3 deletions dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void setup(std::vector<std::string> args,
if (estimate_poly_a) {
current_sink_node = pipeline_desc.add_node<PolyACalculator>(
{current_sink_node}, std::thread::hardware_concurrency(),
PolyACalculator::get_model_type(model_name));
is_rna_model(model_config));
}
if (!barcode_kits.empty()) {
current_sink_node = pipeline_desc.add_node<BarcodeClassifierNode>(
Expand Down Expand Up @@ -264,7 +264,13 @@ int basecaller(int argc, char* argv[]) {

parser.visible.add_argument("data").help("the data directory or file (POD5/FAST5 format).");

parser.visible.add_argument("-v", "--verbose").default_value(false).implicit_value(true);
int verbosity = 0;
parser.visible.add_argument("-v", "--verbose")
.default_value(false)
.implicit_value(true)
.nargs(0)
.action([&](const auto&) { ++verbosity; })
.append();

parser.visible.add_argument("-x", "--device")
.help("device string in format \"cuda:0,...,N\", \"cuda:all\", \"metal\", \"cpu\" "
Expand Down Expand Up @@ -389,7 +395,7 @@ int basecaller(int argc, char* argv[]) {
std::vector<std::string> args(argv, argv + argc);

if (parser.visible.get<bool>("--verbose")) {
utils::SetDebugLogging();
utils::SetVerboseLogging(static_cast<dorado::utils::VerboseLogLevel>(verbosity));
}

auto model = parser.visible.get<std::string>("model");
Expand Down
10 changes: 8 additions & 2 deletions dorado/cli/demux.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,13 @@ int demuxer(int argc, char* argv[]) {
parser.add_argument("-l", "--read-ids")
.help("A file with a newline-delimited list of reads to demux.")
.default_value(std::string(""));
parser.add_argument("-v", "--verbose").default_value(false).implicit_value(true);
int verbosity = 0;
parser.add_argument("-v", "--verbose")
.default_value(false)
.implicit_value(true)
.nargs(0)
.action([&](const auto&) { ++verbosity; })
.append();
parser.add_argument("--emit-fastq")
.help("Output in fastq format. Default is BAM.")
.default_value(false)
Expand Down Expand Up @@ -99,7 +105,7 @@ int demuxer(int argc, char* argv[]) {
}

if (parser.get<bool>("--verbose")) {
utils::SetDebugLogging();
utils::SetVerboseLogging(static_cast<dorado::utils::VerboseLogLevel>(verbosity));
}

auto reads(parser.get<std::vector<std::string>>("reads"));
Expand Down
12 changes: 9 additions & 3 deletions dorado/cli/download.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ int download(int argc, char* argv[]) {

argparse::ArgumentParser parser("dorado", DORADO_VERSION, argparse::default_arguments::help);

parser.add_argument("-v", "--verbose").default_value(false).implicit_value(true);

parser.add_argument("--model").default_value(std::string("all")).help("the model to download");

parser.add_argument("--directory")
Expand All @@ -62,6 +60,14 @@ int download(int argc, char* argv[]) {
.default_value(false)
.implicit_value(true);

int verbosity = 0;
parser.add_argument("-v", "--verbose")
.default_value(false)
.implicit_value(true)
.nargs(0)
.action([&](const auto&) { ++verbosity; })
.append();

try {
parser.parse_args(argc, argv);
} catch (const std::exception& e) {
Expand All @@ -72,7 +78,7 @@ int download(int argc, char* argv[]) {
}

if (parser.get<bool>("--verbose")) {
utils::SetDebugLogging();
utils::SetVerboseLogging(static_cast<dorado::utils::VerboseLogLevel>(verbosity));
}

auto list = parser.get<bool>("--list");
Expand Down
10 changes: 8 additions & 2 deletions dorado/cli/duplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,13 @@ int duplex(int argc, char* argv[]) {
.help("Path to reference for alignment.")
.default_value(std::string(""));

parser.visible.add_argument("-v", "--verbose").default_value(false).implicit_value(true);
int verbosity = 0;
parser.visible.add_argument("-v", "--verbose")
.default_value(false)
.implicit_value(true)
.nargs(0)
.action([&](const auto&) { ++verbosity; })
.append();

cli::add_minimap2_arguments(parser, Aligner::dflt_options);
cli::add_internal_arguments(parser);
Expand All @@ -133,7 +139,7 @@ int duplex(int argc, char* argv[]) {
const bool basespace_duplex = (model.compare("basespace") == 0);
std::vector<std::string> args(argv, argv + argc);
if (parser.visible.get<bool>("--verbose")) {
utils::SetDebugLogging();
utils::SetVerboseLogging(static_cast<dorado::utils::VerboseLogLevel>(verbosity));
}
std::map<std::string, std::string> template_complement_map;
auto read_list = utils::load_read_list(parser.visible.get<std::string>("--read-ids"));
Expand Down
10 changes: 8 additions & 2 deletions dorado/cli/summary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ int summary(int argc, char *argv[]) {
argparse::ArgumentParser parser("dorado", DORADO_VERSION, argparse::default_arguments::help);
parser.add_argument("reads").help("SAM/BAM file produced by dorado basecaller.");
parser.add_argument("-s", "--separator").default_value(std::string("\t"));
parser.add_argument("-v", "--verbose").default_value(false).implicit_value(true);
int verbosity = 0;
parser.add_argument("-v", "--verbose")
.default_value(false)
.implicit_value(true)
.nargs(0)
.action([&](const auto &) { ++verbosity; })
.append();

try {
parser.parse_args(argc, argv);
Expand All @@ -62,7 +68,7 @@ int summary(int argc, char *argv[]) {
}

if (parser.get<bool>("--verbose")) {
utils::SetDebugLogging();
utils::SetVerboseLogging(static_cast<dorado::utils::VerboseLogLevel>(verbosity));
}

std::vector<std::string> header = {
Expand Down
14 changes: 14 additions & 0 deletions dorado/nn/CRFModelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@

namespace dorado {

SampleType get_model_type(const std::string &model_name) {
if (model_name.find("rna004") != std::string::npos) {
return SampleType::RNA004;
} else if (model_name.find("rna002") != std::string::npos) {
return SampleType::RNA002;
} else if (model_name.find("dna") != std::string::npos) {
return SampleType::DNA;
} else {
throw std::runtime_error("Could not determine model type for " + model_name);
}
}

CRFModelConfig load_crf_model_config(const std::filesystem::path &path) {
const auto config_toml = toml::parse(path / "config.toml");

Expand Down Expand Up @@ -96,6 +108,8 @@ CRFModelConfig load_crf_model_config(const std::filesystem::path &path) {
config.signal_norm_params.quantile_scaling = false;
}

config.sample_type = get_model_type(model_name);

return config;
}

Expand Down
8 changes: 8 additions & 0 deletions dorado/nn/CRFModelConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ struct SignalNormalisationParams {
bool quantile_scaling = true;
};

enum SampleType {
DNA,
RNA002,
RNA004,
};

// Values extracted from config.toml used in construction of the model module.
struct CRFModelConfig {
float qscale = 1.0f;
Expand Down Expand Up @@ -42,6 +48,8 @@ struct CRFModelConfig {
// Start position for mean Q-score calculation for
// short reads.
int32_t mean_qscore_start_pos = -1;

SampleType sample_type;
};

CRFModelConfig load_crf_model_config(const std::filesystem::path& path);
Expand Down
7 changes: 4 additions & 3 deletions dorado/read_pipeline/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc,

auto scaler_node =
pipeline_desc.add_node<ScalerNode>({basecaller_node}, model_config.signal_norm_params,
is_rna_model(model_config), scaler_node_threads);
model_config.sample_type, scaler_node_threads);

// if we've been provided a source node, connect it to the start of our pipeline
if (source_node_handle != PipelineDescriptor::InvalidNodeHandle) {
Expand Down Expand Up @@ -130,8 +130,9 @@ void create_stereo_duplex_pipeline(PipelineDescriptor& pipeline_desc,
{splitter_node}, std::move(runners), adjusted_simplex_overlap, kSimplexBatchTimeoutMS,
model_name, 1000, "BasecallerNode", true, mean_qscore_start_pos);

auto scaler_node = pipeline_desc.add_node<ScalerNode>(
{basecaller_node}, model_config.signal_norm_params, false, scaler_node_threads);
auto scaler_node =
pipeline_desc.add_node<ScalerNode>({basecaller_node}, model_config.signal_norm_params,
SampleType::DNA, scaler_node_threads);

// if we've been provided a source node, connect it to the start of our pipeline
if (source_node_handle != PipelineDescriptor::InvalidNodeHandle) {
Expand Down
85 changes: 11 additions & 74 deletions dorado/read_pipeline/PolyACalculator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,57 +256,11 @@ SignalAnchorInfo determine_signal_anchor_and_strand_cdna(const dorado::SimplexRe
return result;
}

// Since the adapter in RNA is still DNA, the basecall quality of the adapter is poor because we
// infer with a model trained on RNA data. So finding a match of the adapter sequence in the RNA sequence
// doesn't work well. Instead, the raw signal is traversed to find a point
// where there's a jump in the median signal value, which is indicative of the
// transition from the DNA adapter to the RNA signal. The polyA will start right
// at that juncture. This function returns a struct with the strand
// direction (which is always reverse for dRNA), the signal anchor and the number of bases
// to omit from the tail length estimation due to any adapter effects.
SignalAnchorInfo determine_signal_anchor_and_strand_drna(
const dorado::SimplexRead& read,
dorado::PolyACalculator::ModelType model_type) {
static const std::unordered_map<dorado::PolyACalculator::ModelType, int> kOffsetMap = {
{dorado::PolyACalculator::ModelType::RNA002, 5000},
{dorado::PolyACalculator::ModelType::RNA004, 1000}};
static const std::unordered_map<dorado::PolyACalculator::ModelType, int> kMaxSignalPosMap = {
{dorado::PolyACalculator::ModelType::RNA002, 10000},
{dorado::PolyACalculator::ModelType::RNA004, 5000}};

const int kWindowSize = 250;
const int kStride = 50;
const int kOffset = kOffsetMap.at(model_type);
const int kMaxSignalPos = kMaxSignalPosMap.at(model_type);

const float kMinMedianForRNASignal = 0.f;
const float kMinMedianDiff = 1.f;

int bp = -1;
int signal_len = read.read_common.get_raw_data_samples();
auto sig_fp32 = read.read_common.raw_data.to(torch::kFloat);
float last_median = 0.f;
for (int i = kOffset; i < std::min(signal_len / 2, kMaxSignalPos); i += kStride) {
auto slice = sig_fp32.slice(0, std::max(0, i - kWindowSize / 2),
std::min(signal_len, i + kWindowSize / 2));
float median = slice.median().item<float>();
if (median > kMinMedianForRNASignal) {
float diff = median - last_median;
if (diff > kMinMedianDiff) {
bp = i;
break;
}
}
last_median = median;
}

spdlog::debug("Approx break point {}", bp);

if (bp > 0) {
return SignalAnchorInfo{false, bp, 0};
} else {
return SignalAnchorInfo{false, -1, 0};
}
// RNA polyA appears at the beginning of the strand. Since the adapter
// for RNA has been trimmed off already, the polyA search can begin
// from the start of the signal.
SignalAnchorInfo determine_signal_anchor_and_strand_drna(const dorado::SimplexRead& read) {
return SignalAnchorInfo{false, 0, 0};
}

} // namespace
Expand All @@ -330,7 +284,7 @@ void PolyACalculator::worker_thread() {
// Determine the strand direction, approximate base space anchor for the tail, and whether
// the final length needs to be adjusted depending on the adapter sequence.
auto [fwd, signal_anchor, trailing_Ts] =
m_is_rna ? determine_signal_anchor_and_strand_drna(*read, m_model_type)
m_is_rna ? determine_signal_anchor_and_strand_drna(*read)
: determine_signal_anchor_and_strand_cdna(*read);

if (signal_anchor >= 0) {
Expand Down Expand Up @@ -358,7 +312,7 @@ void PolyACalculator::worker_thread() {
// Update debug stats.
total_tail_lengths_called += num_bases;
++num_called;
if (spdlog::get_level() == spdlog::level::debug) {
if (spdlog::get_level() <= spdlog::level::debug) {
std::lock_guard<std::mutex> lock(m_mutex);
tail_length_counts[num_bases]++;
}
Expand All @@ -378,13 +332,8 @@ void PolyACalculator::worker_thread() {
}
}

PolyACalculator::PolyACalculator(size_t num_worker_threads,
PolyACalculator::ModelType model_type,
size_t max_reads)
: MessageSink(max_reads),
m_num_worker_threads(num_worker_threads),
m_is_rna(model_type == ModelType::RNA004 || model_type == ModelType::RNA002),
m_model_type(model_type) {
PolyACalculator::PolyACalculator(size_t num_worker_threads, bool is_rna, size_t max_reads)
: MessageSink(max_reads), m_num_worker_threads(num_worker_threads), m_is_rna(is_rna) {
start_threads();
}

Expand All @@ -410,14 +359,14 @@ void PolyACalculator::terminate_impl() {

// Visualize a distribution of the tail lengths called.
static bool done = false;
if (!done && spdlog::get_level() == spdlog::level::debug) {
if (!done && (spdlog::get_level() <= spdlog::level::debug)) {
int max_val = -1;
for (auto [k, v] : tail_length_counts) {
max_val = std::max(v, max_val);
}
int factor = std::max(1, 1 + max_val / 100);
for (auto [k, v] : tail_length_counts) {
spdlog::debug("{} : {}", k, std::string(v / factor, '*'));
spdlog::debug("{:03d} : {}", k, std::string(v / factor, '*'));
}
done = true;
}
Expand All @@ -437,16 +386,4 @@ stats::NamedStats PolyACalculator::sample_stats() const {
return stats;
}

PolyACalculator::ModelType PolyACalculator::get_model_type(const std::string& model_name) {
if (model_name.find("rna004") != std::string::npos) {
return PolyACalculator::ModelType::RNA004;
} else if (model_name.find("rna002") != std::string::npos) {
return PolyACalculator::ModelType::RNA002;
} else if (model_name.find("dna") != std::string::npos) {
return PolyACalculator::ModelType::DNA;
} else {
throw std::runtime_error("Could not determine model type for " + model_name);
}
}

} // namespace dorado
10 changes: 1 addition & 9 deletions dorado/read_pipeline/PolyACalculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,12 @@ namespace dorado {

class PolyACalculator : public MessageSink {
public:
enum ModelType {
DNA,
RNA002,
RNA004,
};

PolyACalculator(size_t num_worker_threads, ModelType model_type, size_t max_reads = 1000);
PolyACalculator(size_t num_worker_threads, bool is_rna, size_t max_reads = 1000);
~PolyACalculator() { terminate_impl(); }
std::string get_name() const override { return "PolyACalculator"; }
stats::NamedStats sample_stats() const override;
void terminate(const FlushOptions& flush_options) override { terminate_impl(); };
void restart() override;
static ModelType get_model_type(const std::string& model_name);

private:
void start_threads();
Expand All @@ -38,7 +31,6 @@ class PolyACalculator : public MessageSink {
std::vector<std::unique_ptr<std::thread>> m_workers;
size_t m_num_worker_threads = 0;
const bool m_is_rna;
const ModelType m_model_type;
std::atomic<size_t> total_tail_lengths_called{0};
std::atomic<int> num_called{0};
std::atomic<int> num_not_called{0};
Expand Down

0 comments on commit 2dc1f03

Please sign in to comment.