From 4340df7bfbfad24ab9c461dc40f7fc77f85ef733 Mon Sep 17 00:00:00 2001 From: Richard Harris Date: Fri, 8 Mar 2024 17:17:42 +0000 Subject: [PATCH] DOR-600 fix issue where using simplex-only model complex and --modified-bases{-models}... --- dorado/cli/basecaller.cpp | 36 +++++++++---------- dorado/cli/duplex.cpp | 18 ++++------ dorado/data_loader/ModelFinder.cpp | 29 +++++++++++++++ dorado/data_loader/ModelFinder.h | 6 ++++ dorado/models/models.cpp | 17 ++++----- dorado/models/models.h | 6 ++-- .../test_simple_auto_basecaller_execution.sh | 10 +++--- 7 files changed, 75 insertions(+), 47 deletions(-) diff --git a/dorado/cli/basecaller.cpp b/dorado/cli/basecaller.cpp index 6610b43d..3cba885c 100644 --- a/dorado/cli/basecaller.cpp +++ b/dorado/cli/basecaller.cpp @@ -462,14 +462,6 @@ int basecaller(int argc, char* argv[]) { const ModelSelection model_selection = cli::parse_model_argument(model_arg); - auto ways = {model_selection.has_mods_variant(), !mod_bases.empty(), !mod_bases_models.empty()}; - if (std::count(ways.begin(), ways.end(), true) > 1) { - spdlog::error( - "Only one of --modified-bases, --modified-bases-models, or modified models set " - "via models argument can be used at once"); - std::exit(EXIT_FAILURE); - }; - auto methylation_threshold = parser.visible.get("--modified-bases-threshold"); if (methylation_threshold < 0.f || methylation_threshold > 1.f) { spdlog::error("--modified-bases-threshold must be between 0 and 1."); @@ -564,30 +556,34 @@ int basecaller(int argc, char* argv[]) { custom_primer_file = parser.visible.get("--primer-sequences"); } + // Assert that only one of --modified-bases, --modified-bases-models or mods model complex is set + auto ways = {model_selection.has_mods_variant(), !mod_bases.empty(), !mod_bases_models.empty()}; + if (std::count(ways.begin(), ways.end(), true) > 1) { + spdlog::error( + "Only one of --modified-bases, --modified-bases-models, or modified models set " + "via models argument can be used at once"); + std::exit(EXIT_FAILURE); + }; + fs::path model_path; std::vector mods_model_paths; std::set temp_download_paths; if (model_selection.is_path()) { model_path = fs::path(model_arg); - - if (mod_bases.size() > 0) { - std::transform(mod_bases.begin(), mod_bases.end(), std::back_inserter(mods_model_paths), - [&model_arg](std::string m) { - return fs::path(models::get_modification_model(model_arg, m)); - }); - } else if (mod_bases_models.size() > 0) { - const auto split = utils::split(mod_bases_models, ','); - std::transform(split.begin(), split.end(), std::back_inserter(mods_model_paths), - [&](std::string m) { return fs::path(m); }); - } - + mods_model_paths = + dorado::get_non_complex_mods_models(model_path, mod_bases, mod_bases_models); } else { auto model_finder = cli::model_finder(model_selection, data, recursive, true); try { model_path = model_finder.fetch_simplex_model(); if (model_selection.has_mods_variant()) { + // Get mods models from complex - we assert above that there's only one method mods_model_paths = model_finder.fetch_mods_models(); + } else { + // Get mods models from args + mods_model_paths = dorado::get_non_complex_mods_models(model_path, mod_bases, + mod_bases_models); } temp_download_paths = model_finder.downloaded_models(); } catch (std::exception& e) { diff --git a/dorado/cli/duplex.cpp b/dorado/cli/duplex.cpp index cf0c833c..5838f3c1 100644 --- a/dorado/cli/duplex.cpp +++ b/dorado/cli/duplex.cpp @@ -76,6 +76,7 @@ ModelFinder get_model_finder(const std::string& model_arg, const auto model_info = ModelFinder::get_simplex_model_info(model_name); // Pass the model's ModelVariant (e.g. HAC) in here so everything matches + // There are no mods variants if model_arg is a path const auto inferred_selection = ModelSelection{ models::to_string(model_info.simplex.variant), model_info.simplex, {}}; @@ -128,25 +129,18 @@ DuplexModels load_models(const std::string& model_arg, check_sampling_rates_compatible(model_name, reads, model_config.sample_rate, recursive_file_loading); } - if (!mod_bases.empty()) { - std::transform( - mod_bases.begin(), mod_bases.end(), std::back_inserter(mods_model_paths), - [&model_arg](const std::string& m) { - return std::filesystem::path(models::get_modification_model(model_arg, m)); - }); - } else if (!mod_bases_models.empty()) { - const auto split = utils::split(mod_bases_models, ','); - std::transform(split.begin(), split.end(), std::back_inserter(mods_model_paths), - [&](const std::string& m) { return std::filesystem::path(m); }); - } + mods_model_paths = + dorado::get_non_complex_mods_models(model_path, mod_bases, mod_bases_models); } else { try { model_path = model_finder.fetch_simplex_model(); stereo_model_path = model_finder.fetch_stereo_model(); + // Either get the mods from the model complex or resolve from --modified-bases args mods_model_paths = inferred_selection.has_mods_variant() ? model_finder.fetch_mods_models() - : std::vector{}; + : dorado::get_non_complex_mods_models(model_path, mod_bases, + mod_bases_models); } catch (const std::exception&) { utils::clean_temporary_models(model_finder.downloaded_models()); throw; diff --git a/dorado/data_loader/ModelFinder.cpp b/dorado/data_loader/ModelFinder.cpp index c8ccefc1..77586823 100644 --- a/dorado/data_loader/ModelFinder.cpp +++ b/dorado/data_loader/ModelFinder.cpp @@ -280,4 +280,33 @@ void check_sampling_rates_compatible(const std::string& model_name, } } +std::vector get_non_complex_mods_models( + const std::filesystem::path& simplex_model_path, + const std::vector& mod_bases, + const std::string& mod_bases_models) { + if (!mod_bases.empty() && !mod_bases_models.empty()) { + throw std::runtime_error( + "CLI arguments --modified-bases and --modified-bases-models are mutually " + "exclusive"); + } + + std::vector mods_model_paths; + + if (!mod_bases.empty()) { + // Foreach --modified-bases get the modified model of that type matched to the simplex model + std::transform(mod_bases.begin(), mod_bases.end(), std::back_inserter(mods_model_paths), + [&simplex_model_path](const std::string& m) { + return std::filesystem::path( + models::get_modification_model(simplex_model_path, m)); + }); + } else if (!mod_bases_models.empty()) { + // Foreach --modified-bases-models get a path + const auto split = utils::split(mod_bases_models, ','); + std::transform(split.begin(), split.end(), std::back_inserter(mods_model_paths), + [&](const std::string& m) { return std::filesystem::path(m); }); + } + + return mods_model_paths; +} + } // namespace dorado \ No newline at end of file diff --git a/dorado/data_loader/ModelFinder.h b/dorado/data_loader/ModelFinder.h index 6137179e..c1e8e22a 100644 --- a/dorado/data_loader/ModelFinder.h +++ b/dorado/data_loader/ModelFinder.h @@ -106,4 +106,10 @@ void check_sampling_rates_compatible(const std::string& model_name, const int config_sample_rate, const bool recursive_file_loading); +// Get modified models set using `--modified-bases` or `--modified-bases-models` cli args +std::vector get_non_complex_mods_models( + const std::filesystem::path& simplex_model_path, + const std::vector& mod_bases, + const std::string& mod_bases_models); + } // namespace dorado \ No newline at end of file diff --git a/dorado/models/models.cpp b/dorado/models/models.cpp index 6f9fc57d..04c4c8a7 100644 --- a/dorado/models/models.cpp +++ b/dorado/models/models.cpp @@ -830,18 +830,18 @@ ModelInfo get_simplex_model_info(const std::string& model_name) { return matches.back(); } -std::string get_modification_model(const std::string& simplex_model, +std::string get_modification_model(const std::filesystem::path& simplex_model_path, const std::string& modification) { std::string modification_model{""}; - auto simplex_path = fs::path(simplex_model); - if (!fs::exists(simplex_path)) { - throw std::runtime_error{"unknown simplex model " + simplex_model}; + if (!fs::exists(simplex_model_path)) { + throw std::runtime_error{ + "Cannot find modification model for '" + modification + + "' reason: simplex model doesn't exist at: " + simplex_model_path.u8string()}; } - simplex_path = fs::canonical(simplex_path); - auto model_dir = simplex_path.parent_path(); - auto simplex_name = simplex_path.filename().u8string(); + auto model_dir = simplex_model_path.parent_path(); + auto simplex_name = simplex_model_path.filename().u8string(); if (is_valid_model(simplex_name)) { std::string mods_prefix = simplex_name + "_" + modification + "@v"; @@ -854,7 +854,8 @@ std::string get_modification_model(const std::string& simplex_model, } } } else { - throw std::runtime_error{"unknown simplex model " + simplex_name}; + throw std::runtime_error{"Cannot find modification model for '" + modification + + "' reason: unknown simplex model " + simplex_name}; } if (modification_model.empty()) { diff --git a/dorado/models/models.h b/dorado/models/models.h index 33a1a4cd..202fa4d4 100644 --- a/dorado/models/models.h +++ b/dorado/models/models.h @@ -51,9 +51,9 @@ bool download_models(const std::string& target_directory, const std::string& sel ModelInfo get_simplex_model_info(const std::string& model_name); // finds the matching modification model for a given modification i.e. 5mCG and a simplex model -// is the matching modification model is not found in the same model directory as the simplex -// model then it is downloaded. -std::string get_modification_model(const std::string& simplex_model, +// if the modification model is not found in the same model directory as the simplex +// model then it is downloaded into the same directory. +std::string get_modification_model(const std::filesystem::path& simplex_model, const std::string& modification); // get the sampling rate that the model is compatible with diff --git a/tests/test_simple_auto_basecaller_execution.sh b/tests/test_simple_auto_basecaller_execution.sh index 6df19667..6c014b11 100755 --- a/tests/test_simple_auto_basecaller_execution.sh +++ b/tests/test_simple_auto_basecaller_execution.sh @@ -58,10 +58,15 @@ set -e echo dorado summary test stage $dorado_bin summary $output_dir/calls.bam +echo dorado basecaller mixed model complex and --modified-bases +$dorado_bin basecaller $model_complex $pod5_dir -b ${batch} --modified-bases 5mCG_5hmCG -vv > $output_dir/calls.bam +samtools view -h $output_dir/calls.bam | grep "ML:B:C," +samtools view -h $output_dir/calls.bam | grep "MM:Z:C+h" +samtools view -h $output_dir/calls.bam | grep "MN:i:" + echo redirecting stderr to stdout: check output is still valid $dorado_bin basecaller $model_complex,5mCG_5hmCG $pod5_dir -b ${batch} --emit-moves > $output_dir/calls.bam 2>&1 samtools quickcheck -u $output_dir/calls.bam -samtools view $output_dir/calls.bam > $output_dir/calls.sam echo dorado aligner test stage $dorado_bin basecaller $model_complex $pod5_dir -b ${batch} --emit-fastq > $output_dir/ref.fq @@ -69,8 +74,6 @@ $dorado_bin aligner $output_dir/ref.fq $output_dir/calls.sam > $output_dir/calls $dorado_bin basecaller $model_complex,5mCG_5hmCG $pod5_dir -b ${batch} | $dorado_bin aligner $output_dir/ref.fq > $output_dir/calls.bam $dorado_bin basecaller $model_complex,5mCG_5hmCG $pod5_dir -b ${batch} --reference $output_dir/ref.fq > $output_dir/calls.bam samtools quickcheck -u $output_dir/calls.bam -samtools view -h $output_dir/calls.bam > $output_dir/calls.sam - if ! uname -r | grep -q tegra; then echo dorado duplex basespace test stage @@ -95,7 +98,6 @@ if ! uname -r | grep -q tegra; then fi fi - set +e if ls .temp_dorado_model-* ; then echo ".temp_dorado_models not cleaned"