diff --git a/dorado/cli/basecaller.cpp b/dorado/cli/basecaller.cpp index 29bb95b4..4603b4e8 100644 --- a/dorado/cli/basecaller.cpp +++ b/dorado/cli/basecaller.cpp @@ -458,9 +458,8 @@ int basecaller(int argc, char* argv[]) { auto ways = {model_selection.has_mods_variant(), !mod_bases.empty(), !mod_bases_models.empty()}; if (std::count(ways.begin(), ways.end(), true) > 1) { spdlog::error( - "only one of --modified-bases, --modified-bases-models, or modified models set " - "via models argument can be used at once", - model_arg); + "Only one of --modified-bases, --modified-bases-models, or modified models set " + "via models argument can be used at once"); std::exit(EXIT_FAILURE); }; @@ -578,6 +577,7 @@ int basecaller(int argc, char* argv[]) { temp_download_paths = model_finder.downloaded_models(); } catch (std::exception& e) { spdlog::error(e.what()); + utils::clean_temporary_models(model_finder.downloaded_models()); std::exit(EXIT_FAILURE); } } @@ -605,6 +605,7 @@ int basecaller(int argc, char* argv[]) { parser.visible.get("--estimate-poly-a"), model_selection); } catch (const std::exception& e) { spdlog::error("{}", e.what()); + utils::clean_temporary_models(temp_download_paths); return 1; } diff --git a/dorado/cli/duplex.cpp b/dorado/cli/duplex.cpp index b9a4b30a..ffd14ad4 100644 --- a/dorado/cli/duplex.cpp +++ b/dorado/cli/duplex.cpp @@ -70,20 +70,16 @@ ModelFinder get_model_finder(const std::string& model_arg, // Get the model name const auto model_path = std::filesystem::canonical(std::filesystem::path(model_arg)); const auto model_name = model_path.filename().string(); - try { - // Try to find the model - const auto model_info = ModelFinder::get_simplex_model_info(model_name); - - // Pass the model's ModelVariant (e.g. HAC) in here so everything matches - const auto inferred_selection = ModelSelection{ - models::to_string(model_info.simplex.variant), model_info.simplex, {}}; - - // Return the ModelFinder which hasn't needed to inspect any data - return ModelFinder{model_info.chemistry, inferred_selection, false}; - } catch (const std::exception& e) { - spdlog::error(e.what()); - std::exit(EXIT_FAILURE); - } + + // Try to find the model + const auto model_info = ModelFinder::get_simplex_model_info(model_name); + + // Pass the model's ModelVariant (e.g. HAC) in here so everything matches + const auto inferred_selection = ModelSelection{ + models::to_string(model_info.simplex.variant), model_info.simplex, {}}; + + // Return the ModelFinder which hasn't needed to inspect any data + return ModelFinder{model_info.chemistry, inferred_selection, false}; } // Model complex given, inspect data to find chemistry. @@ -104,11 +100,9 @@ DuplexModels load_models(const std::string& model_arg, auto ways = {inferred_selection.has_mods_variant(), !mod_bases.empty(), !mod_bases_models.empty()}; if (std::count(ways.begin(), ways.end(), true) > 1) { - spdlog::error( - "only one of --modified-bases, --modified-bases-models, or modified models set " - "via models argument can be used at once", - model_arg); - std::exit(EXIT_FAILURE); + throw std::runtime_error( + "Only one of --modified-bases, --modified-bases-models, or modified models set " + "via models argument can be used at once"); }; if (inferred_selection.model.variant == ModelVariant::FAST) { @@ -136,21 +130,26 @@ DuplexModels load_models(const std::string& model_arg, if (!mod_bases.empty()) { std::transform( mod_bases.begin(), mod_bases.end(), std::back_inserter(mods_model_paths), - [&model_arg](std::string m) { + [&model_arg](const std::string& m) { return std::filesystem::path(models::get_modification_model(model_arg, m)); }); } else if (!mod_bases_models.empty()) { const auto split = utils::split(mod_bases_models, ','); std::transform(split.begin(), split.end(), std::back_inserter(mods_model_paths), - [&](std::string m) { return std::filesystem::path(m); }); + [&](const std::string& m) { return std::filesystem::path(m); }); } } else { - model_path = model_finder.fetch_simplex_model(); - stereo_model_path = model_finder.fetch_stereo_model(); - mods_model_paths = inferred_selection.has_mods_variant() - ? model_finder.fetch_mods_models() - : std::vector{}; + try { + model_path = model_finder.fetch_simplex_model(); + stereo_model_path = model_finder.fetch_stereo_model(); + mods_model_paths = inferred_selection.has_mods_variant() + ? model_finder.fetch_mods_models() + : std::vector{}; + } catch (const std::exception&) { + utils::clean_temporary_models(model_finder.downloaded_models()); + throw; + } } const auto model_name = model_finder.get_simplex_model_name(); @@ -280,16 +279,13 @@ int duplex(int argc, char* argv[]) { cli::add_minimap2_arguments(parser, alignment::dflt_options); cli::add_internal_arguments(parser); + std::set temp_model_paths; try { cli::parse(parser, argc, argv); auto device(parser.visible.get("-x")); auto model(parser.visible.get("model")); - if (model.find("fast") != std::string::npos) { - spdlog::warn("Fast models are currently not recommended for duplex basecalling."); - } - auto reads(parser.visible.get("reads")); std::string pairs_file = parser.visible.get("--pairs"); auto threads = static_cast(parser.visible.get("--threads")); @@ -450,6 +446,8 @@ int duplex(int argc, char* argv[]) { load_models(model, mod_bases, mod_bases_models, reads, recursive_file_loading, skip_model_compatibility_check); + temp_model_paths = models.temp_paths; + // create modbase runners first so basecall runners can pick batch sizes based on available memory auto mod_base_runners = create_modbase_runners( models.mods_model_paths, device, default_parameters.mod_base_runners_per_caller, @@ -556,7 +554,7 @@ int duplex(int argc, char* argv[]) { loader.load_reads(reads, parser.visible.get("--recursive"), ReadOrder::BY_CHANNEL); - utils::clean_temporary_models(models.temp_paths); + utils::clean_temporary_models(temp_model_paths); } // Wait for the pipeline to complete. When it does, we collect @@ -576,6 +574,7 @@ int duplex(int argc, char* argv[]) { : std::optional(dump_stats_filter)); } } catch (const std::exception& e) { + utils::clean_temporary_models(temp_model_paths); spdlog::error(e.what()); return 1; }