Skip to content

Commit

Permalink
Merge branch 'jdaw/fix-correct-alignments' into 'release-v0.7'
Browse files Browse the repository at this point in the history
Fix sub-par alignments in dorado correct

See merge request machine-learning/dorado!1054
  • Loading branch information
vellamike committed Jun 12, 2024
2 parents d0df79c + d956314 commit 3b51c1b
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 11 deletions.
9 changes: 6 additions & 3 deletions dorado/alignment/Minimap2Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,12 @@ void Minimap2Index::set_mapping_options(const Minimap2MappingOptions& mapping_op
}
}

// Equivalent to "--cap-kalloc 100m --cap-sw-mem 50m"
m_mapping_options->cap_kalloc = 100'000'000;
m_mapping_options->max_sw_mat = 50'000'000;
// Either use the default value for cap_kalloc and max_sw_mat defined in the dorado
// options initialization, or if it's set to nullopt use the minimap2 library default.
m_mapping_options->cap_kalloc =
mapping_options.cap_kalloc.value_or(m_mapping_options->cap_kalloc);
m_mapping_options->max_sw_mat =
mapping_options.max_sw_mat.value_or(m_mapping_options->max_sw_mat);
}

std::shared_ptr<mm_idx_t> Minimap2Index::load_initial_index(const std::string& index_file,
Expand Down
3 changes: 3 additions & 0 deletions dorado/alignment/Minimap2Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ struct Minimap2MappingOptions {
std::optional<std::string> cs;
std::optional<std::string> dual;
std::optional<uint64_t> mini_batch_size;
// Equivalent to "--cap-kalloc 100m --cap-sw-mem 50m"
std::optional<int64_t> cap_kalloc = 100'000'000;
std::optional<int64_t> max_sw_mat = 50'000'000;
};

inline bool operator<(const Minimap2MappingOptions& l, const Minimap2MappingOptions& r) {
Expand Down
8 changes: 7 additions & 1 deletion dorado/correct/features.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,13 @@ std::vector<WindowFeatures> extract_features(std::vector<std::vector<OverlapWind
}
// Sort the filtered overlaps by accuracy score
std::sort(overlap_windows.begin(), overlap_windows.end(),
[](const OverlapWindow& a, const OverlapWindow& b) {
[&alignments](const OverlapWindow& a, const OverlapWindow& b) {
if (std::fabs(a.accuracy - b.accuracy) < 1e-10) {
const auto& a_qname = alignments.qnames[a.overlap_idx];
const auto& b_qname = alignments.qnames[b.overlap_idx];
return std::lexicographical_compare(a_qname.begin(), a_qname.end(),
b_qname.begin(), b_qname.end());
}
return a.accuracy > b.accuracy;
});
}
Expand Down
4 changes: 2 additions & 2 deletions dorado/correct/infer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ namespace dorado::correction {

int calculate_batch_size(const std::string& device, float memory_fraction) {
// These sizes are currently hard coded for version 1 model.
const float model_mem = 1.f; // GB
const float per_sample_mem = 0.9f; // GB
const float model_mem = 1.f; // GB
const float per_sample_mem = 1.f; // GB
float usable_memory = 0.f;
if (device == "cpu") {
#if DORADO_METAL_BUILD
Expand Down
2 changes: 1 addition & 1 deletion dorado/read_pipeline/CorrectionNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ void CorrectionNode::infer_fn(const std::string& device_str, int mtx_idx, int ba
at::from_blob(lengths.data(), {(int)lengths.size()},
at::TensorOptions().dtype(torch::kInt32).device(torch::kCPU));
auto batched_bases = collate<int>(bases_batch, (int)11, torch::kInt32);
auto batched_quals = collate<float>(quals_batch, 0.f, torch::kFloat32);
auto batched_quals = collate<float>(quals_batch, 1.f, torch::kFloat32);

std::unique_lock<std::mutex> lock(m_gpu_mutexes[mtx_idx]);
std::vector<torch::jit::IValue> inputs;
Expand Down
7 changes: 3 additions & 4 deletions dorado/read_pipeline/ErrorCorrectionMapperNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ void ErrorCorrectionMapperNode::extract_alignments(const mm_reg1_t* reg,
if (m_read_mutex.find(tname) == m_read_mutex.end()) {
m_read_mutex.emplace(tname, std::make_unique<std::mutex>());
CorrectionAlignments new_aln;
new_aln.read_name = tname;
m_correction_records.emplace(tname, std::move(new_aln));
m_processed_queries_per_target.emplace(tname, std::unordered_set<std::string>());
}
Expand Down Expand Up @@ -115,10 +116,6 @@ void ErrorCorrectionMapperNode::extract_alignments(const mm_reg1_t* reg,
continue;
}

if (alignments.read_name.empty()) {
alignments.read_name = tname;
}

alignments.qnames.push_back(qname);

alignments.mm2_cigars.push_back(std::move(cigar));
Expand Down Expand Up @@ -255,6 +252,8 @@ ErrorCorrectionMapperNode::ErrorCorrectionMapperNode(const std::string& index_fi
options.occ_dist = 200;
options.cs = "short";
options.dual = "yes";
options.cap_kalloc = std::nullopt;
options.max_sw_mat = std::nullopt;

m_index = std::make_shared<alignment::Minimap2Index>();
if (!m_index->initialise(options)) {
Expand Down

0 comments on commit 3b51c1b

Please sign in to comment.