Skip to content

Commit

Permalink
Merge branch 'smalton/auto-download-cleanup' into 'master'
Browse files Browse the repository at this point in the history
Auto download cleanup

See merge request machine-learning/dorado!755
  • Loading branch information
malton-ont committed Dec 7, 2023
2 parents c552351 + 80703a3 commit d6e2a80
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 34 deletions.
7 changes: 4 additions & 3 deletions dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,9 +458,8 @@ int basecaller(int argc, char* argv[]) {
auto ways = {model_selection.has_mods_variant(), !mod_bases.empty(), !mod_bases_models.empty()};
if (std::count(ways.begin(), ways.end(), true) > 1) {
spdlog::error(
"only one of --modified-bases, --modified-bases-models, or modified models set "
"via models argument can be used at once",
model_arg);
"Only one of --modified-bases, --modified-bases-models, or modified models set "
"via models argument can be used at once");
std::exit(EXIT_FAILURE);
};

Expand Down Expand Up @@ -578,6 +577,7 @@ int basecaller(int argc, char* argv[]) {
temp_download_paths = model_finder.downloaded_models();
} catch (std::exception& e) {
spdlog::error(e.what());
utils::clean_temporary_models(model_finder.downloaded_models());
std::exit(EXIT_FAILURE);
}
}
Expand Down Expand Up @@ -605,6 +605,7 @@ int basecaller(int argc, char* argv[]) {
parser.visible.get<bool>("--estimate-poly-a"), model_selection);
} catch (const std::exception& e) {
spdlog::error("{}", e.what());
utils::clean_temporary_models(temp_download_paths);
return 1;
}

Expand Down
61 changes: 30 additions & 31 deletions dorado/cli/duplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,16 @@ ModelFinder get_model_finder(const std::string& model_arg,
// Get the model name
const auto model_path = std::filesystem::canonical(std::filesystem::path(model_arg));
const auto model_name = model_path.filename().string();
try {
// Try to find the model
const auto model_info = ModelFinder::get_simplex_model_info(model_name);

// Pass the model's ModelVariant (e.g. HAC) in here so everything matches
const auto inferred_selection = ModelSelection{
models::to_string(model_info.simplex.variant), model_info.simplex, {}};

// Return the ModelFinder which hasn't needed to inspect any data
return ModelFinder{model_info.chemistry, inferred_selection, false};
} catch (const std::exception& e) {
spdlog::error(e.what());
std::exit(EXIT_FAILURE);
}

// Try to find the model
const auto model_info = ModelFinder::get_simplex_model_info(model_name);

// Pass the model's ModelVariant (e.g. HAC) in here so everything matches
const auto inferred_selection = ModelSelection{
models::to_string(model_info.simplex.variant), model_info.simplex, {}};

// Return the ModelFinder which hasn't needed to inspect any data
return ModelFinder{model_info.chemistry, inferred_selection, false};
}

// Model complex given, inspect data to find chemistry.
Expand All @@ -104,11 +100,9 @@ DuplexModels load_models(const std::string& model_arg,
auto ways = {inferred_selection.has_mods_variant(), !mod_bases.empty(),
!mod_bases_models.empty()};
if (std::count(ways.begin(), ways.end(), true) > 1) {
spdlog::error(
"only one of --modified-bases, --modified-bases-models, or modified models set "
"via models argument can be used at once",
model_arg);
std::exit(EXIT_FAILURE);
throw std::runtime_error(
"Only one of --modified-bases, --modified-bases-models, or modified models set "
"via models argument can be used at once");
};

if (inferred_selection.model.variant == ModelVariant::FAST) {
Expand Down Expand Up @@ -136,21 +130,26 @@ DuplexModels load_models(const std::string& model_arg,
if (!mod_bases.empty()) {
std::transform(
mod_bases.begin(), mod_bases.end(), std::back_inserter(mods_model_paths),
[&model_arg](std::string m) {
[&model_arg](const std::string& m) {
return std::filesystem::path(models::get_modification_model(model_arg, m));
});
} else if (!mod_bases_models.empty()) {
const auto split = utils::split(mod_bases_models, ',');
std::transform(split.begin(), split.end(), std::back_inserter(mods_model_paths),
[&](std::string m) { return std::filesystem::path(m); });
[&](const std::string& m) { return std::filesystem::path(m); });
}

} else {
model_path = model_finder.fetch_simplex_model();
stereo_model_path = model_finder.fetch_stereo_model();
mods_model_paths = inferred_selection.has_mods_variant()
? model_finder.fetch_mods_models()
: std::vector<std::filesystem::path>{};
try {
model_path = model_finder.fetch_simplex_model();
stereo_model_path = model_finder.fetch_stereo_model();
mods_model_paths = inferred_selection.has_mods_variant()
? model_finder.fetch_mods_models()
: std::vector<std::filesystem::path>{};
} catch (const std::exception&) {
utils::clean_temporary_models(model_finder.downloaded_models());
throw;
}
}

const auto model_name = model_finder.get_simplex_model_name();
Expand Down Expand Up @@ -280,16 +279,13 @@ int duplex(int argc, char* argv[]) {
cli::add_minimap2_arguments(parser, alignment::dflt_options);
cli::add_internal_arguments(parser);

std::set<fs::path> temp_model_paths;
try {
cli::parse(parser, argc, argv);

auto device(parser.visible.get<std::string>("-x"));
auto model(parser.visible.get<std::string>("model"));

if (model.find("fast") != std::string::npos) {
spdlog::warn("Fast models are currently not recommended for duplex basecalling.");
}

auto reads(parser.visible.get<std::string>("reads"));
std::string pairs_file = parser.visible.get<std::string>("--pairs");
auto threads = static_cast<size_t>(parser.visible.get<int>("--threads"));
Expand Down Expand Up @@ -450,6 +446,8 @@ int duplex(int argc, char* argv[]) {
load_models(model, mod_bases, mod_bases_models, reads, recursive_file_loading,
skip_model_compatibility_check);

temp_model_paths = models.temp_paths;

// create modbase runners first so basecall runners can pick batch sizes based on available memory
auto mod_base_runners = create_modbase_runners(
models.mods_model_paths, device, default_parameters.mod_base_runners_per_caller,
Expand Down Expand Up @@ -556,7 +554,7 @@ int duplex(int argc, char* argv[]) {
loader.load_reads(reads, parser.visible.get<bool>("--recursive"),
ReadOrder::BY_CHANNEL);

utils::clean_temporary_models(models.temp_paths);
utils::clean_temporary_models(temp_model_paths);
}

// Wait for the pipeline to complete. When it does, we collect
Expand All @@ -576,6 +574,7 @@ int duplex(int argc, char* argv[]) {
: std::optional<std::regex>(dump_stats_filter));
}
} catch (const std::exception& e) {
utils::clean_temporary_models(temp_model_paths);
spdlog::error(e.what());
return 1;
}
Expand Down

0 comments on commit d6e2a80

Please sign in to comment.