Skip to content

Commit

Permalink
Merge branch 'smalton/DOR-499-basecall-deps' into 'master'
Browse files Browse the repository at this point in the history
DOR-499: Break basecall<->models library dependency

Closes DOR-499

See merge request machine-learning/dorado!784
  • Loading branch information
tijyojwad committed Dec 21, 2023
2 parents e42761c + 17999f6 commit 1893d69
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 57 deletions.
1 change: 0 additions & 1 deletion dorado/basecall/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ target_link_libraries(dorado_basecall
${TORCH_LIBRARIES}
dorado_utils
PRIVATE
dorado_models_lib
${KOI_LIBRARIES}
spdlog::spdlog
)
Expand Down
44 changes: 28 additions & 16 deletions dorado/basecall/CRFModelConfig.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
#include "CRFModelConfig.h"

#include "models/models.h"

#include <spdlog/spdlog.h>
#include <toml.hpp>
#include <toml/value.hpp>

#include <cstddef>
#include <set>
Expand All @@ -31,6 +28,27 @@ SublayerType sublayer_type(const toml::value &segment) {
return mapping_iter->second;
}

// the mean Q-score of short reads are artificially lowered because of
// some lower quality bases at the beginning of the read. to correct for
// that, mean Q-score calculation should ignore the first few bases. The
// number of bases to ignore is dependent on the model.
uint32_t get_mean_qscore_start_pos_by_model_name(const std::string &model_name) {
static const std::unordered_map<std::string, uint16_t> mean_qscore_start_pos_by_model = {
// To add model specific start positions for older models,
// create an entry keyed by model name with the value as
// the desired start position.
// e.g. {"dna_r10.4.1_e8.2_5khz_400bps_fast@v4.2.0", 10}
};

auto iter = mean_qscore_start_pos_by_model.find(model_name);
if (iter != mean_qscore_start_pos_by_model.end()) {
return iter->second;
} else {
// Assume start position of 60 as default.
return 60;
}
}

} // namespace
namespace dorado::basecall {

Expand Down Expand Up @@ -217,6 +235,13 @@ CRFModelConfig load_crf_model_config(const std::filesystem::path &path) {
config.qscale = toml::find<float>(qscore, "scale");
if (qscore.contains("mean_qscore_start_pos")) {
config.mean_qscore_start_pos = toml::find<int32_t>(qscore, "mean_qscore_start_pos");
} else {
// If information is not present in the config, find start position by model name.
std::string model_name = config.model_path.filename().string();
config.mean_qscore_start_pos = get_mean_qscore_start_pos_by_model_name(model_name);
}
if (config.mean_qscore_start_pos < 0) {
throw std::runtime_error("Mean q-score start position cannot be < 0");
}
} else {
spdlog::debug("> no qscore calibration found");
Expand Down Expand Up @@ -307,19 +332,6 @@ CRFModelConfig load_crf_model_config(const std::filesystem::path &path) {
return config;
}

int32_t get_model_mean_qscore_start_pos(const CRFModelConfig &model_config) {
int32_t mean_qscore_start_pos = model_config.mean_qscore_start_pos;
if (mean_qscore_start_pos < 0) {
// If unsuccessful, find start position by model name.
std::string model_name = model_config.model_path.filename().string();
mean_qscore_start_pos = models::get_mean_qscore_start_pos_by_model_name(model_name);
}
if (mean_qscore_start_pos < 0) {
throw std::runtime_error("Mean q-score start position cannot be < 0");
}
return mean_qscore_start_pos;
}

bool is_rna_model(const CRFModelConfig &model_config) {
auto path = std::filesystem::canonical(model_config.model_path);
auto filename = path.filename();
Expand Down
2 changes: 0 additions & 2 deletions dorado/basecall/CRFModelConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ struct CRFModelConfig {

CRFModelConfig load_crf_model_config(const std::filesystem::path& path);

int32_t get_model_mean_qscore_start_pos(const CRFModelConfig& model_config);

bool is_rna_model(const CRFModelConfig& model_config);

} // namespace dorado::basecall
7 changes: 1 addition & 6 deletions dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,7 @@ void setup(std::vector<std::string> args,
std::unordered_set<std::string>{}, thread_allocations.read_filter_threads);

auto mean_qscore_start_pos = model_config.mean_qscore_start_pos;
if (mean_qscore_start_pos < 0) {
mean_qscore_start_pos = models::get_mean_qscore_start_pos_by_model_name(model_name);
if (mean_qscore_start_pos < 0) {
throw std::runtime_error("Mean q-score start position cannot be < 0");
}
}

pipelines::create_simplex_pipeline(
pipeline_desc, std::move(runners), std::move(remora_runners), overlap,
mean_qscore_start_pos, !adapter_no_trim, thread_allocations.scaler_node_threads,
Expand Down
7 changes: 0 additions & 7 deletions dorado/cli/duplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,13 +512,6 @@ int duplex(int argc, char* argv[]) {
}

auto mean_qscore_start_pos = models.model_config.mean_qscore_start_pos;
if (mean_qscore_start_pos < 0) {
mean_qscore_start_pos =
models::get_mean_qscore_start_pos_by_model_name(models.stereo_model_name);
if (mean_qscore_start_pos < 0) {
throw std::runtime_error("Mean q-score start position cannot be < 0");
}
}

pipelines::create_stereo_duplex_pipeline(
pipeline_desc, std::move(runners), std::move(stereo_runners),
Expand Down
20 changes: 1 addition & 19 deletions dorado/models/models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,14 +770,6 @@ const std::vector<ModelInfo> models = {

} // namespace modified

const std::unordered_map<std::string, uint16_t> mean_qscore_start_pos_by_model = {

// To add model specific start positions for older models,
// create an entry keyed by model name with the value as
// the desired start position.
// e.g. {"dna_r10.4.1_e8.2_5khz_400bps_fast@v4.2.0", 10}
};

std::string calculate_checksum(std::string_view data) {
// Hash the data.
std::array<unsigned char, SHA256_DIGEST_LENGTH> hash{};
Expand Down Expand Up @@ -1090,17 +1082,7 @@ SamplingRate get_sample_rate_by_model_name(const std::string& model_name) {
return iter->second.sampling_rate;
} else {
// This can only happen if a model_info.chemistry not in chemistries which should be impossible.
throw std::logic_error("Couldn't find chemsitry: " + to_string(model_info.chemistry));
}
}

uint32_t get_mean_qscore_start_pos_by_model_name(const std::string& model_name) {
auto iter = mean_qscore_start_pos_by_model.find(model_name);
if (iter != mean_qscore_start_pos_by_model.end()) {
return iter->second;
} else {
// Assume start position of 60 as default.
return 60;
throw std::logic_error("Couldn't find chemistry: " + to_string(model_info.chemistry));
}
}

Expand Down
6 changes: 0 additions & 6 deletions dorado/models/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,6 @@ std::string get_modification_model(const std::string& simplex_model,
// get the sampling rate that the model is compatible with
SamplingRate get_sample_rate_by_model_name(const std::string& model_name);

// the mean Q-score of short reads are artificially lowered because of
// some lower quality bases at the beginning of the read. to correct for
// that, mean Q-score calculation should ignore the first few bases. The
// number of bases to ignore is dependent on the model.
uint32_t get_mean_qscore_start_pos_by_model_name(const std::string& model_name);

// Extract the model name from the model path.
std::string extract_model_name_from_path(const std::filesystem::path& model_path);

Expand Down

0 comments on commit 1893d69

Please sign in to comment.