Skip to content

Commit

Permalink
Merge branch 'cherry-pick-862872a2' into 'release-v0.6.0'
Browse files Browse the repository at this point in the history
Merge branch 'DOR-553_RG_HDR_FOR_CUSTOM_BARCODES' into 'release-v0.6.0'

See merge request machine-learning/dorado!910
  • Loading branch information
tijyojwad committed Mar 26, 2024
2 parents bab07f5 + 7e25c9e commit 9dc052d
Show file tree
Hide file tree
Showing 18 changed files with 243 additions and 185 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,10 @@ add_library(dorado_lib
dorado/demux/BarcodeClassifier.h
dorado/demux/BarcodeClassifierSelector.cpp
dorado/demux/BarcodeClassifierSelector.h
dorado/demux/parse_custom_sequences.cpp
dorado/demux/parse_custom_sequences.h
dorado/demux/Trimmer.cpp
dorado/demux/Trimmer.h
dorado/demux/parse_custom_kit.cpp
dorado/demux/parse_custom_kit.h
dorado/poly_tail/dna_poly_tail_calculator.cpp
dorado/poly_tail/dna_poly_tail_calculator.h
dorado/poly_tail/plasmid_poly_tail_calculator.cpp
Expand Down
51 changes: 45 additions & 6 deletions dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "utils/fs_utils.h"
#include "utils/log_utils.h"
#include "utils/parameters.h"
#include "utils/parse_custom_kit.h"
#include "utils/stats.h"
#include "utils/string_utils.h"
#include "utils/sys_stats.h"
Expand All @@ -47,6 +48,32 @@

namespace dorado {

namespace {

const barcode_kits::KitInfo& get_barcode_kit_info(const std::string& kit_name) {
const auto kit_info = barcode_kits::get_kit_info(kit_name);
if (!kit_info) {
spdlog::error(
"{} is not a valid barcode kit name. Please run the help "
"command to find out available barcode kits.",
kit_name);
std::exit(EXIT_FAILURE);
}
return *kit_info;
}

std::pair<std::string, barcode_kits::KitInfo> get_custom_barcode_kit_info(
const std::string& custom_kit_file) {
auto custom_kit_info = barcode_kits::parse_custom_arrangement(custom_kit_file);
if (!custom_kit_info) {
spdlog::error("Unable to load custom barcode arrangement file: {}", custom_kit_file);
std::exit(EXIT_FAILURE);
}
return *custom_kit_info;
}

} // namespace

using dorado::utils::default_parameters;
using OutputMode = dorado::utils::HtsFile::OutputMode;
using namespace std::chrono_literals;
Expand Down Expand Up @@ -77,7 +104,7 @@ void setup(std::vector<std::string> args,
const std::string& dump_stats_file,
const std::string& dump_stats_filter,
const std::string& resume_from_file,
const std::vector<std::string>& barcode_kits,
const std::string& barcode_kit,
bool barcode_both_ends,
bool barcode_no_trim,
bool adapter_no_trim,
Expand Down Expand Up @@ -135,7 +162,7 @@ void setup(std::vector<std::string> args,
recursive_file_loading);

const bool adapter_trimming_enabled = (!adapter_no_trim || !primer_no_trim);
const bool barcode_enabled = !barcode_kits.empty() || custom_kit;
const bool barcode_enabled = !barcode_kit.empty() || custom_kit;
const auto thread_allocations = utils::default_thread_allocations(
int(num_devices), !remora_runners.empty() ? int(num_remora_threads) : 0, enable_aligner,
barcode_enabled, adapter_trimming_enabled);
Expand All @@ -149,7 +176,17 @@ void setup(std::vector<std::string> args,

SamHdrPtr hdr(sam_hdr_init());
cli::add_pg_hdr(hdr.get(), args);
utils::add_rg_hdr(hdr.get(), read_groups, barcode_kits, sample_sheet.get());
if (custom_kit) {
auto [kit_name, kit_info] = get_custom_barcode_kit_info(*custom_kit);
utils::add_rg_headers_with_barcode_kit(hdr.get(), read_groups, kit_name, kit_info,
sample_sheet.get());
} else if (!barcode_kit.empty()) {
const auto kit_info = get_barcode_kit_info(barcode_kit);
utils::add_rg_headers_with_barcode_kit(hdr.get(), read_groups, barcode_kit, kit_info,
sample_sheet.get());
} else {
utils::add_rg_headers(hdr.get(), read_groups);
}

utils::HtsFile hts_file("-", output_mode, thread_allocations.writer_threads);

Expand Down Expand Up @@ -178,8 +215,9 @@ void setup(std::vector<std::string> args,
!primer_no_trim, std::move(custom_primer_file));
}
if (barcode_enabled) {
std::vector<std::string> kit_as_vector{barcode_kit};
current_sink_node = pipeline_desc.add_node<BarcodeClassifierNode>(
{current_sink_node}, thread_allocations.barcoder_threads, barcode_kits,
{current_sink_node}, thread_allocations.barcoder_threads, kit_as_vector,
barcode_both_ends, barcode_no_trim, std::move(allowed_barcodes),
std::move(custom_kit), std::move(custom_barcode_file));
}
Expand Down Expand Up @@ -405,7 +443,8 @@ int basecaller(int argc, char* argv[]) {

parser.visible.add_argument("--kit-name")
.help("Enable barcoding with the provided kit name. Choose from: " +
dorado::barcode_kits::barcode_kits_list_str() + ".");
dorado::barcode_kits::barcode_kits_list_str() + ".")
.default_value(std::string{});
parser.visible.add_argument("--barcode-both-ends")
.help("Require both ends of a read to be barcoded for a double ended barcode.")
.default_value(false)
Expand Down Expand Up @@ -624,7 +663,7 @@ int basecaller(int argc, char* argv[]) {
parser.hidden.get<std::string>("--dump_stats_file"),
parser.hidden.get<std::string>("--dump_stats_filter"),
parser.visible.get<std::string>("--resume-from"),
parser.visible.get<std::vector<std::string>>("--kit-name"),
parser.visible.get<std::string>("--kit-name"),
parser.visible.get<bool>("--barcode-both-ends"), no_trim_barcodes, no_trim_adapters,
no_trim_primers, parser.visible.get<std::string>("--sample-sheet"),
std::move(custom_kit), std::move(custom_barcode_seqs), std::move(custom_primer_file),
Expand Down
3 changes: 1 addition & 2 deletions dorado/cli/duplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,7 @@ int duplex(int argc, char* argv[]) {
recursive_file_loading);
read_groups.merge(DataLoader::load_read_groups(reads, duplex_rg_name, "",
recursive_file_loading));
std::vector<std::string> barcode_kits;
utils::add_rg_hdr(hdr.get(), read_groups, barcode_kits, nullptr);
utils::add_rg_headers(hdr.get(), read_groups);

int batch_size(parser.visible.get<int>("-b"));
int chunk_size(parser.visible.get<int>("-c"));
Expand Down
3 changes: 2 additions & 1 deletion dorado/demux/AdapterDetector.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#include "AdapterDetector.h"

#include "parse_custom_kit.h"
#include "parse_custom_sequences.h"
#include "utils/alignment_utils.h"
#include "utils/parse_custom_kit.h"
#include "utils/sequence_utils.h"
#include "utils/types.h"

Expand Down
7 changes: 4 additions & 3 deletions dorado/demux/BarcodeClassifier.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#include "BarcodeClassifier.h"

#include "parse_custom_kit.h"
#include "parse_custom_sequences.h"
#include "utils/alignment_utils.h"
#include "utils/barcode_kits.h"
#include "utils/parse_custom_kit.h"
#include "utils/sequence_utils.h"
#include "utils/types.h"

Expand Down Expand Up @@ -124,7 +125,7 @@ std::unordered_map<std::string, dorado::barcode_kits::KitInfo> process_custom_ki
const std::optional<std::string>& custom_kit) {
std::unordered_map<std::string, dorado::barcode_kits::KitInfo> kit_map;
if (custom_kit) {
auto custom_arrangement = demux::parse_custom_arrangement(*custom_kit);
auto custom_arrangement = dorado::barcode_kits::parse_custom_arrangement(*custom_kit);
if (custom_arrangement) {
const auto& [kit_name, kit_info] = *custom_arrangement;
kit_map[kit_name] = kit_info;
Expand All @@ -151,7 +152,7 @@ dorado::barcode_kits::BarcodeKitScoringParams set_scoring_params(
if (custom_kit) {
// If a custom kit is passed, parse it for any scoring
// params that need to override the default params.
return dorado::demux::parse_scoring_params(*custom_kit, params);
return barcode_kits::parse_scoring_params(*custom_kit, params);
} else {
return params;
}
Expand Down
4 changes: 2 additions & 2 deletions dorado/demux/BarcodeClassifier.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once
#include "parse_custom_kit.h"
#include "utils/barcode_kits.h"
#include "utils/parse_custom_kit.h"
#include "utils/stats.h"
#include "utils/types.h"

Expand Down Expand Up @@ -31,7 +31,7 @@ class BarcodeClassifier {
private:
const std::unordered_map<std::string, dorado::barcode_kits::KitInfo> m_custom_kit;
const std::unordered_map<std::string, std::string> m_custom_seqs;
const dorado::barcode_kits::BarcodeKitScoringParams m_scoring_params;
const barcode_kits::BarcodeKitScoringParams m_scoring_params;
const std::vector<BarcodeCandidateKit> m_barcode_candidates;

std::vector<BarcodeCandidateKit> generate_candidates(const std::vector<std::string>& kit_names);
Expand Down
31 changes: 31 additions & 0 deletions dorado/demux/parse_custom_sequences.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "parse_custom_sequences.h"

#include "utils/bam_utils.h"
#include "utils/types.h"

#include <htslib/sam.h>

namespace dorado::demux {

std::unordered_map<std::string, std::string> parse_custom_sequences(
const std::string& sequences_file) {
dorado::HtsFilePtr file(hts_open(sequences_file.c_str(), "r"));
BamPtr record;
record.reset(bam_init1());

std::unordered_map<std::string, std::string> sequences;

int sam_ret_val = 0;
while ((sam_ret_val = sam_read1(file.get(), nullptr, record.get())) != -1) {
if (sam_ret_val < -1) {
throw std::runtime_error("Failed to parse custom sequence file " + sequences_file);
}
std::string qname = bam_get_qname(record.get());
std::string seq = utils::extract_sequence(record.get());
sequences[qname] = seq;
}

return sequences;
}

} // namespace dorado::demux
11 changes: 11 additions & 0 deletions dorado/demux/parse_custom_sequences.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include <string>
#include <unordered_map>

namespace dorado::demux {

std::unordered_map<std::string, std::string> parse_custom_sequences(
const std::string& sequences_file);

} // namespace dorado::demux
4 changes: 3 additions & 1 deletion dorado/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ add_library(dorado_utils
module_utils.h
parameters.cpp
parameters.h
parse_custom_kit.cpp
parse_custom_kit.h
PostCondition.h
SampleSheet.cpp
SampleSheet.h
Expand Down Expand Up @@ -59,7 +61,7 @@ add_library(dorado_utils
types.h
uuid_utils.cpp
uuid_utils.h
)
)

if (DORADO_GPU_BUILD)
if(APPLE)
Expand Down

0 comments on commit 9dc052d

Please sign in to comment.