Skip to content

Commit

Permalink
Merge branch 'DOR-600-other_modbase_with_model_complex' into 'master'
Browse files Browse the repository at this point in the history
DOR-600 fix issue where using simplex-only model complex and --modified-bases{-models}...

Closes DOR-600

See merge request machine-learning/dorado!875
  • Loading branch information
HalfPhoton committed Mar 8, 2024
2 parents b31e5c8 + 4340df7 commit bdc05e3
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 47 deletions.
36 changes: 16 additions & 20 deletions dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,14 +462,6 @@ int basecaller(int argc, char* argv[]) {

const ModelSelection model_selection = cli::parse_model_argument(model_arg);

auto ways = {model_selection.has_mods_variant(), !mod_bases.empty(), !mod_bases_models.empty()};
if (std::count(ways.begin(), ways.end(), true) > 1) {
spdlog::error(
"Only one of --modified-bases, --modified-bases-models, or modified models set "
"via models argument can be used at once");
std::exit(EXIT_FAILURE);
};

auto methylation_threshold = parser.visible.get<float>("--modified-bases-threshold");
if (methylation_threshold < 0.f || methylation_threshold > 1.f) {
spdlog::error("--modified-bases-threshold must be between 0 and 1.");
Expand Down Expand Up @@ -564,30 +556,34 @@ int basecaller(int argc, char* argv[]) {
custom_primer_file = parser.visible.get<std::string>("--primer-sequences");
}

// Assert that only one of --modified-bases, --modified-bases-models or mods model complex is set
auto ways = {model_selection.has_mods_variant(), !mod_bases.empty(), !mod_bases_models.empty()};
if (std::count(ways.begin(), ways.end(), true) > 1) {
spdlog::error(
"Only one of --modified-bases, --modified-bases-models, or modified models set "
"via models argument can be used at once");
std::exit(EXIT_FAILURE);
};

fs::path model_path;
std::vector<fs::path> mods_model_paths;
std::set<fs::path> temp_download_paths;

if (model_selection.is_path()) {
model_path = fs::path(model_arg);

if (mod_bases.size() > 0) {
std::transform(mod_bases.begin(), mod_bases.end(), std::back_inserter(mods_model_paths),
[&model_arg](std::string m) {
return fs::path(models::get_modification_model(model_arg, m));
});
} else if (mod_bases_models.size() > 0) {
const auto split = utils::split(mod_bases_models, ',');
std::transform(split.begin(), split.end(), std::back_inserter(mods_model_paths),
[&](std::string m) { return fs::path(m); });
}

mods_model_paths =
dorado::get_non_complex_mods_models(model_path, mod_bases, mod_bases_models);
} else {
auto model_finder = cli::model_finder(model_selection, data, recursive, true);
try {
model_path = model_finder.fetch_simplex_model();
if (model_selection.has_mods_variant()) {
// Get mods models from complex - we assert above that there's only one method
mods_model_paths = model_finder.fetch_mods_models();
} else {
// Get mods models from args
mods_model_paths = dorado::get_non_complex_mods_models(model_path, mod_bases,
mod_bases_models);
}
temp_download_paths = model_finder.downloaded_models();
} catch (std::exception& e) {
Expand Down
18 changes: 6 additions & 12 deletions dorado/cli/duplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ ModelFinder get_model_finder(const std::string& model_arg,
const auto model_info = ModelFinder::get_simplex_model_info(model_name);

// Pass the model's ModelVariant (e.g. HAC) in here so everything matches
// There are no mods variants if model_arg is a path
const auto inferred_selection = ModelSelection{
models::to_string(model_info.simplex.variant), model_info.simplex, {}};

Expand Down Expand Up @@ -128,25 +129,18 @@ DuplexModels load_models(const std::string& model_arg,
check_sampling_rates_compatible(model_name, reads, model_config.sample_rate,
recursive_file_loading);
}
if (!mod_bases.empty()) {
std::transform(
mod_bases.begin(), mod_bases.end(), std::back_inserter(mods_model_paths),
[&model_arg](const std::string& m) {
return std::filesystem::path(models::get_modification_model(model_arg, m));
});
} else if (!mod_bases_models.empty()) {
const auto split = utils::split(mod_bases_models, ',');
std::transform(split.begin(), split.end(), std::back_inserter(mods_model_paths),
[&](const std::string& m) { return std::filesystem::path(m); });
}
mods_model_paths =
dorado::get_non_complex_mods_models(model_path, mod_bases, mod_bases_models);

} else {
try {
model_path = model_finder.fetch_simplex_model();
stereo_model_path = model_finder.fetch_stereo_model();
// Either get the mods from the model complex or resolve from --modified-bases args
mods_model_paths = inferred_selection.has_mods_variant()
? model_finder.fetch_mods_models()
: std::vector<std::filesystem::path>{};
: dorado::get_non_complex_mods_models(model_path, mod_bases,
mod_bases_models);
} catch (const std::exception&) {
utils::clean_temporary_models(model_finder.downloaded_models());
throw;
Expand Down
29 changes: 29 additions & 0 deletions dorado/data_loader/ModelFinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,33 @@ void check_sampling_rates_compatible(const std::string& model_name,
}
}

std::vector<std::filesystem::path> get_non_complex_mods_models(
const std::filesystem::path& simplex_model_path,
const std::vector<std::string>& mod_bases,
const std::string& mod_bases_models) {
if (!mod_bases.empty() && !mod_bases_models.empty()) {
throw std::runtime_error(
"CLI arguments --modified-bases and --modified-bases-models are mutually "
"exclusive");
}

std::vector<std::filesystem::path> mods_model_paths;

if (!mod_bases.empty()) {
// Foreach --modified-bases get the modified model of that type matched to the simplex model
std::transform(mod_bases.begin(), mod_bases.end(), std::back_inserter(mods_model_paths),
[&simplex_model_path](const std::string& m) {
return std::filesystem::path(
models::get_modification_model(simplex_model_path, m));
});
} else if (!mod_bases_models.empty()) {
// Foreach --modified-bases-models get a path
const auto split = utils::split(mod_bases_models, ',');
std::transform(split.begin(), split.end(), std::back_inserter(mods_model_paths),
[&](const std::string& m) { return std::filesystem::path(m); });
}

return mods_model_paths;
}

} // namespace dorado
6 changes: 6 additions & 0 deletions dorado/data_loader/ModelFinder.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,10 @@ void check_sampling_rates_compatible(const std::string& model_name,
const int config_sample_rate,
const bool recursive_file_loading);

// Get modified models set using `--modified-bases` or `--modified-bases-models` cli args
std::vector<std::filesystem::path> get_non_complex_mods_models(
const std::filesystem::path& simplex_model_path,
const std::vector<std::string>& mod_bases,
const std::string& mod_bases_models);

} // namespace dorado
17 changes: 9 additions & 8 deletions dorado/models/models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -830,18 +830,18 @@ ModelInfo get_simplex_model_info(const std::string& model_name) {
return matches.back();
}

std::string get_modification_model(const std::string& simplex_model,
std::string get_modification_model(const std::filesystem::path& simplex_model_path,
const std::string& modification) {
std::string modification_model{""};
auto simplex_path = fs::path(simplex_model);

if (!fs::exists(simplex_path)) {
throw std::runtime_error{"unknown simplex model " + simplex_model};
if (!fs::exists(simplex_model_path)) {
throw std::runtime_error{
"Cannot find modification model for '" + modification +
"' reason: simplex model doesn't exist at: " + simplex_model_path.u8string()};
}

simplex_path = fs::canonical(simplex_path);
auto model_dir = simplex_path.parent_path();
auto simplex_name = simplex_path.filename().u8string();
auto model_dir = simplex_model_path.parent_path();
auto simplex_name = simplex_model_path.filename().u8string();

if (is_valid_model(simplex_name)) {
std::string mods_prefix = simplex_name + "_" + modification + "@v";
Expand All @@ -854,7 +854,8 @@ std::string get_modification_model(const std::string& simplex_model,
}
}
} else {
throw std::runtime_error{"unknown simplex model " + simplex_name};
throw std::runtime_error{"Cannot find modification model for '" + modification +
"' reason: unknown simplex model " + simplex_name};
}

if (modification_model.empty()) {
Expand Down
6 changes: 3 additions & 3 deletions dorado/models/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ bool download_models(const std::string& target_directory, const std::string& sel
ModelInfo get_simplex_model_info(const std::string& model_name);

// finds the matching modification model for a given modification i.e. 5mCG and a simplex model
// is the matching modification model is not found in the same model directory as the simplex
// model then it is downloaded.
std::string get_modification_model(const std::string& simplex_model,
// if the modification model is not found in the same model directory as the simplex
// model then it is downloaded into the same directory.
std::string get_modification_model(const std::filesystem::path& simplex_model,
const std::string& modification);

// get the sampling rate that the model is compatible with
Expand Down
10 changes: 6 additions & 4 deletions tests/test_simple_auto_basecaller_execution.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,22 @@ set -e
echo dorado summary test stage
$dorado_bin summary $output_dir/calls.bam

echo dorado basecaller mixed model complex and --modified-bases
$dorado_bin basecaller $model_complex $pod5_dir -b ${batch} --modified-bases 5mCG_5hmCG -vv > $output_dir/calls.bam
samtools view -h $output_dir/calls.bam | grep "ML:B:C,"
samtools view -h $output_dir/calls.bam | grep "MM:Z:C+h"
samtools view -h $output_dir/calls.bam | grep "MN:i:"

echo redirecting stderr to stdout: check output is still valid
$dorado_bin basecaller $model_complex,5mCG_5hmCG $pod5_dir -b ${batch} --emit-moves > $output_dir/calls.bam 2>&1
samtools quickcheck -u $output_dir/calls.bam
samtools view $output_dir/calls.bam > $output_dir/calls.sam

echo dorado aligner test stage
$dorado_bin basecaller $model_complex $pod5_dir -b ${batch} --emit-fastq > $output_dir/ref.fq
$dorado_bin aligner $output_dir/ref.fq $output_dir/calls.sam > $output_dir/calls.bam
$dorado_bin basecaller $model_complex,5mCG_5hmCG $pod5_dir -b ${batch} | $dorado_bin aligner $output_dir/ref.fq > $output_dir/calls.bam
$dorado_bin basecaller $model_complex,5mCG_5hmCG $pod5_dir -b ${batch} --reference $output_dir/ref.fq > $output_dir/calls.bam
samtools quickcheck -u $output_dir/calls.bam
samtools view -h $output_dir/calls.bam > $output_dir/calls.sam


if ! uname -r | grep -q tegra; then
echo dorado duplex basespace test stage
Expand All @@ -95,7 +98,6 @@ if ! uname -r | grep -q tegra; then
fi
fi


set +e
if ls .temp_dorado_model-* ; then
echo ".temp_dorado_models not cleaned"
Expand Down

0 comments on commit bdc05e3

Please sign in to comment.