Skip to content

Commit

Permalink
Merge branch 'jdaw/support-twist-barcodes' into 'master'
Browse files Browse the repository at this point in the history
[DOR-327] Support for Twist barcodes

See merge request machine-learning/dorado!788
  • Loading branch information
tijyojwad committed Jan 5, 2024
2 parents 899c4a2 + 0032b37 commit 7506d44
Show file tree
Hide file tree
Showing 12 changed files with 342 additions and 55 deletions.
19 changes: 10 additions & 9 deletions dorado/demux/BarcodeClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,11 @@ std::unordered_map<std::string, dorado::barcode_kits::KitInfo> process_custom_ki
const std::optional<std::string>& custom_kit) {
std::unordered_map<std::string, dorado::barcode_kits::KitInfo> kit_map;
if (custom_kit) {
auto [kit_name, kit_info] = demux::parse_custom_arrangement(*custom_kit);
kit_map[kit_name] = kit_info;
auto custom_arrangement = demux::parse_custom_arrangement(*custom_kit);
if (custom_arrangement) {
const auto& [kit_name, kit_info] = *custom_arrangement;
kit_map[kit_name] = kit_info;
}
}
return kit_map;
}
Expand Down Expand Up @@ -208,17 +211,15 @@ std::vector<BarcodeClassifier::BarcodeCandidateKit> BarcodeClassifier::generate_
const std::vector<std::string>& kit_names) {
std::vector<BarcodeCandidateKit> candidates_list;

const auto& kit_info_map = barcode_kits::get_kit_infos();

std::vector<std::string> final_kit_names;
if (!m_custom_kit.empty()) {
for (auto& [kit_name, _] : m_custom_kit) {
final_kit_names.push_back(kit_name);
}
} else if (kit_names.empty()) {
for (auto& [kit_name, _] : kit_info_map) {
final_kit_names.push_back(kit_name);
}
throw std::runtime_error(
"Either custom kit must include kit arrangement or a kit name needs to be passed "
"in.");
} else {
final_kit_names = kit_names;
}
Expand Down Expand Up @@ -613,11 +614,11 @@ BarcodeScoreResult BarcodeClassifier::find_best_barcode(
auto best_bottom_score = std::max_element(
scores.begin(), scores.end(),
[](const auto& l, const auto& r) { return l.bottom_score < r.bottom_score; });
spdlog::trace("Check double ends: top bc {}, bottom bc {}", best_top_score->barcode_name,
best_bottom_score->barcode_name);
if ((best_top_score->score > m_scoring_params.min_soft_barcode_threshold) &&
(best_bottom_score->score > m_scoring_params.min_soft_barcode_threshold) &&
(best_top_score->barcode_name != best_bottom_score->barcode_name)) {
spdlog::trace("Two ends confidently predict different BCs: top bc {}, bottom bc {}",
best_top_score->barcode_name, best_bottom_score->barcode_name);
return UNCLASSIFIED;
}
}
Expand Down
20 changes: 13 additions & 7 deletions dorado/demux/BarcodeClassifierSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@ namespace dorado::demux {

std::shared_ptr<const BarcodeClassifier> BarcodeClassifierSelector::get_barcoder(
const BarcodingInfo& barcode_kit_info) {
assert(!barcode_kit_info.kit_name.empty());
if (barcode_kit_info.kit_name.empty() && !barcode_kit_info.custom_kit.has_value()) {
throw std::runtime_error("Either kit name or custom kit file must be specified!");
}
const auto kit_id = barcode_kit_info.kit_name.empty() ? *barcode_kit_info.custom_kit
: barcode_kit_info.kit_name;
std::lock_guard<std::mutex> lock(m_mutex);
if (!m_barcoder_lut.count(barcode_kit_info.kit_name)) {
m_barcoder_lut.emplace(barcode_kit_info.kit_name,
std::make_shared<const BarcodeClassifier>(
std::vector<std::string>{barcode_kit_info.kit_name},
barcode_kit_info.custom_kit, barcode_kit_info.custom_seqs));
if (!m_barcoder_lut.count(kit_id)) {
m_barcoder_lut.emplace(
kit_id, std::make_shared<const BarcodeClassifier>(
barcode_kit_info.kit_name.empty()
? std::vector<std::string>{}
: std::vector<std::string>{barcode_kit_info.kit_name},
barcode_kit_info.custom_kit, barcode_kit_info.custom_seqs));
}
return m_barcoder_lut.at(barcode_kit_info.kit_name);
return m_barcoder_lut.at(kit_id);
}

} // namespace dorado::demux
7 changes: 3 additions & 4 deletions dorado/demux/parse_custom_kit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,12 @@ bool check_normalized_id_pattern(const std::string& pattern) {
return true;
}

std::pair<std::string, barcode_kits::KitInfo> parse_custom_arrangement(
std::optional<std::pair<std::string, barcode_kits::KitInfo>> parse_custom_arrangement(
const std::string& arrangement_file) {
const toml::value config_toml = toml::parse(arrangement_file);

if (!config_toml.contains("arrangement")) {
throw std::runtime_error(
"Custom barcode arrangement file must have [arrangement] section.");
return std::nullopt;
}

barcode_kits::KitInfo new_kit;
Expand Down Expand Up @@ -108,7 +107,7 @@ std::pair<std::string, barcode_kits::KitInfo> parse_custom_arrangement(
(barcode1_pattern != barcode2_pattern);
}

return {kit_name, new_kit};
return std::make_pair(kit_name, new_kit);
}

BarcodeKitScoringParams parse_scoring_params(const std::string& arrangement_file) {
Expand Down
3 changes: 2 additions & 1 deletion dorado/demux/parse_custom_kit.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "utils/barcode_kits.h"

#include <optional>
#include <string>
#include <unordered_map>

Expand All @@ -15,7 +16,7 @@ struct BarcodeKitScoringParams {
float min_barcode_score_dist = 0.25f;
};

std::pair<std::string, barcode_kits::KitInfo> parse_custom_arrangement(
std::optional<std::pair<std::string, barcode_kits::KitInfo>> parse_custom_arrangement(
const std::string& arrangement_file);

std::unordered_map<std::string, std::string> parse_custom_barcode_sequences(
Expand Down
12 changes: 11 additions & 1 deletion dorado/read_pipeline/BarcodeClassifierNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ BarcodeClassifierNode::BarcodeClassifierNode(int threads,
if (m_default_barcoding_info->kit_name.empty()) {
spdlog::debug("Barcode with new kit from {}", *m_default_barcoding_info->custom_kit);
} else {
spdlog::info("Barcode for {}", m_default_barcoding_info->kit_name);
spdlog::debug("Barcode for {}", m_default_barcoding_info->kit_name);
}
start_threads();
}
Expand Down Expand Up @@ -136,6 +136,10 @@ void BarcodeClassifierNode::barcode(BamPtr& read) {
auto bc = generate_barcode_string(bc_res);
bam_aux_append(irecord, "BC", 'Z', int(bc.length() + 1), (uint8_t*)bc.c_str());
m_num_records++;
{
std::lock_guard lock(m_barcode_count_mutex);
m_barcode_count[bc]++;
}

if (m_default_barcoding_info->trim) {
int seqlen = irecord->core.l_qseq;
Expand Down Expand Up @@ -171,6 +175,12 @@ void BarcodeClassifierNode::barcode(SimplexRead& read) {
stats::NamedStats BarcodeClassifierNode::sample_stats() const {
auto stats = stats::from_obj(m_work_queue);
stats["num_barcodes_demuxed"] = m_num_records.load();
{
for (const auto& [bc_name, bc_count] : m_barcode_count) {
std::string key = "bc." + bc_name;
stats[key] = static_cast<float>(bc_count);
}
}
return stats;
}

Expand Down
7 changes: 7 additions & 0 deletions dorado/read_pipeline/BarcodeClassifierNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
#include "utils/types.h"

#include <atomic>
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <string>
#include <string_view>
Expand Down Expand Up @@ -46,6 +48,11 @@ class BarcodeClassifierNode : public MessageSink {
void barcode(SimplexRead& read);

void terminate_impl();

// Track how many reads were classified as each barcode for debugging
// purposes.
std::map<std::string, std::atomic<size_t>> m_barcode_count;
std::mutex m_barcode_count_mutex;
};

} // namespace dorado
3 changes: 2 additions & 1 deletion dorado/read_pipeline/HtsReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ void HtsReader::read(Pipeline& pipeline, int max_reads) {
}
}
pipeline.push_message(BamPtr(bam_dup1(record.get())));
if (max_reads > 0 && ++num_reads >= max_reads) {
++num_reads;
if (max_reads > 0 && num_reads >= max_reads) {
break;
}
if (num_reads % 50000 == 0) {
Expand Down
29 changes: 29 additions & 0 deletions dorado/read_pipeline/ProgressTracker.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "utils/stats.h"
#include "utils/string_utils.h"
#include "utils/tty_utils.h"

#ifdef WIN32
Expand Down Expand Up @@ -69,6 +70,21 @@ class ProgressTracker {
rate_str << std::scientific << m_num_barcodes_demuxed / (duration / 1000.0);
spdlog::info("> {} reads demuxed @ classifications/s: {}", m_num_barcodes_demuxed,
rate_str.str());
// Report how many reads were classified into each
// barcode.
if (spdlog::get_level() <= spdlog::level::debug) {
spdlog::debug("Barcode distribution :");
size_t unclassified = 0;
size_t total = 0;
for (const auto& [bc_name, bc_count] : m_barcode_count) {
spdlog::debug("{} : {}", bc_name, bc_count);
total += bc_count;
if (bc_name == "unclassified") {
unclassified += bc_count;
}
}
spdlog::debug("Classified rate {}%", (1.f - float(unclassified) / total) * 100.f);
}
}
}

Expand Down Expand Up @@ -133,6 +149,17 @@ class ProgressTracker {
std::cerr << "\r> Output records written: " << m_num_simplex_reads_written;
std::cerr << "\r";
}

// Collect per barcode stats.
if (m_num_barcodes_demuxed > 0 && (spdlog::get_level() <= spdlog::level::debug)) {
for (const auto& [stat, val] : stats) {
const std::string prefix = "BarcodeClassifierNode.bc.";
if (utils::starts_with(stat, prefix)) {
auto bc_name = stat.substr(prefix.length());
m_barcode_count[bc_name] = static_cast<int>(val);
}
}
}
}

private:
Expand All @@ -150,6 +177,8 @@ class ProgressTracker {

int m_num_reads_expected;

std::map<std::string, size_t> m_barcode_count;

std::chrono::time_point<std::chrono::system_clock> m_initialization_time;
std::chrono::time_point<std::chrono::system_clock> m_end_time;

Expand Down

0 comments on commit 7506d44

Please sign in to comment.