Skip to content

Commit

Permalink
Merge branch 'mvella/hemimethylation' into 'master'
Browse files Browse the repository at this point in the history
Hemi-Methylation support

See merge request machine-learning/dorado!741
  • Loading branch information
tijyojwad committed Dec 1, 2023
2 parents d871eb1 + eb2c235 commit b7d4b38
Show file tree
Hide file tree
Showing 13 changed files with 650 additions and 155 deletions.
9 changes: 4 additions & 5 deletions dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,16 @@ void setup(std::vector<std::string> args,
const bool enable_aligner = !ref.empty();

// create modbase runners first so basecall runners can pick batch sizes based on available memory
auto remora_runners = create_modbase_runners(
remora_models, device, default_parameters.remora_runners_per_caller, remora_batch_size);
auto remora_runners = create_modbase_runners(remora_models, device,
default_parameters.mod_base_runners_per_caller,
remora_batch_size);

auto [runners, num_devices] = create_basecall_runners(model_config, device, num_runners, 0,
batch_size, chunk_size, 1.f, false);

auto read_groups = DataLoader::load_read_groups(data_path, model_name, modbase_model_names,
recursive_file_loading);

bool duplex = false;

const auto thread_allocations = utils::default_thread_allocations(
int(num_devices), !remora_runners.empty() ? int(num_remora_threads) : 0, enable_aligner,
!barcode_kits.empty());
Expand Down Expand Up @@ -251,7 +250,7 @@ void setup(std::vector<std::string> args,
}

std::vector<dorado::stats::StatsCallable> stats_callables;
ProgressTracker tracker(int(num_reads), duplex);
ProgressTracker tracker(int(num_reads), false);
stats_callables.push_back(
[&tracker](const stats::NamedStats& stats) { tracker.update_progress_bar(stats); });
constexpr auto kStatsPeriod = 100ms;
Expand Down
117 changes: 93 additions & 24 deletions dorado/cli/duplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "data_loader/ModelFinder.h"
#include "models/models.h"
#include "nn/CRFModelConfig.h"
#include "nn/ModBaseRunner.h"
#include "nn/Runners.h"
#include "read_pipeline/AlignerNode.h"
#include "read_pipeline/BaseSpaceDuplexCallerNode.h"
Expand All @@ -21,6 +22,7 @@
#include "utils/log_utils.h"
#include "utils/parameters.h"
#include "utils/stats.h"
#include "utils/string_utils.h"
#include "utils/sys_stats.h"
#include "utils/torch_utils.h"
#include "utils/types.h"
Expand All @@ -33,6 +35,8 @@
#include <thread>
#include <unordered_set>

namespace fs = std::filesystem;

namespace dorado {

namespace {
Expand All @@ -45,17 +49,30 @@ struct DuplexModels {
CRFModelConfig stereo_model_config;
std::string stereo_model_name;

std::vector<fs::path> mods_model_paths;

std::set<std::filesystem::path> temp_paths{};
};

DuplexModels load_models(const std::string& model_arg,
const std::vector<std::string>& mod_bases,
const std::string& mod_bases_models,
const std::string& reads,
const bool recursive_file_loading,
const bool skip_model_compatibility_check) {
const ModelSelection model_selection = cli::parse_model_argument(model_arg);

auto ways = {model_selection.has_mods_variant(), !mod_bases.empty(), !mod_bases_models.empty()};
if (std::count(ways.begin(), ways.end(), true) > 1) {
spdlog::error(
"only one of --modified-bases, --modified-bases-models, or modified models set "
"via models argument can be used at once",
model_arg);
std::exit(EXIT_FAILURE);
};

if (model_selection.is_path()) {
const auto model_path = std::filesystem::canonical(std::filesystem::path(model_arg));
const auto model_path = fs::canonical(fs::path(model_arg));
const auto model_name = model_path.filename().string();
const auto model_config = load_crf_model_config(model_path);

Expand All @@ -75,23 +92,30 @@ DuplexModels load_models(const std::string& model_arg,
throw std::runtime_error(err.str());
}
const auto stereo_model_name = utils::get_stereo_model_name(model_arg, data_sample_rate);
const auto stereo_model_path =
model_path.parent_path() / std::filesystem::path(stereo_model_name);
const auto stereo_model_path = model_path.parent_path() / fs::path(stereo_model_name);

if (!std::filesystem::exists(stereo_model_path)) {
if (!fs::exists(stereo_model_path)) {
if (!models::download_models(model_path.parent_path().u8string(), stereo_model_name)) {
throw std::runtime_error("Failed to download model: " + stereo_model_name);
}
}
const auto stereo_model_config = load_crf_model_config(stereo_model_path);

return DuplexModels{model_path, model_name, model_config,
stereo_model_path, stereo_model_config, stereo_model_name};
}
std::vector<fs::path> mods_model_paths;
if (!mod_bases.empty()) {
std::transform(mod_bases.begin(), mod_bases.end(), std::back_inserter(mods_model_paths),
[&model_arg](std::string m) {
return fs::path(models::get_modification_model(model_arg, m));
});
} else if (!mod_bases_models.empty()) {
const auto split = utils::split(mod_bases_models, ',');
std::transform(split.begin(), split.end(), std::back_inserter(mods_model_paths),
[&](std::string m) { return fs::path(m); });
}

if (model_selection.has_mods_variant()) {
spdlog::error("Modified bases models are not supported for duplex");
std::exit(EXIT_FAILURE);
return DuplexModels{model_path, model_name, model_config,
stereo_model_path, stereo_model_config, stereo_model_name,
mods_model_paths};
}

auto model_finder = cli::model_finder(model_selection, reads, recursive_file_loading, true);
Expand All @@ -104,13 +128,14 @@ DuplexModels load_models(const std::string& model_arg,
const auto stereo_model_name = stereo_model_path.filename().string();
const auto stereo_model_config = load_crf_model_config(stereo_model_path);

return DuplexModels{model_path,
model_name,
model_config,
stereo_model_path,
stereo_model_config,
stereo_model_name,
model_finder.downloaded_models()};
const std::vector<fs::path> mods_model_paths = model_selection.has_mods_variant()
? model_finder.fetch_mods_models()
: std::vector<fs::path>{};

return DuplexModels{model_path, model_name,
model_config, stereo_model_path,
stereo_model_config, stereo_model_name,
mods_model_paths, model_finder.downloaded_models()};
}
} // namespace

Expand Down Expand Up @@ -203,6 +228,28 @@ int duplex(int argc, char* argv[]) {
.action([&](const auto&) { ++verbosity; })
.append();

parser.visible.add_argument("--modified-bases")
.nargs(argparse::nargs_pattern::at_least_one)
.action([](const std::string& value) {
const auto& mods = models::modified_model_variants();
if (std::find(mods.begin(), mods.end(), value) == mods.end()) {
spdlog::error("'{}' is not a supported modification please select from {}",
value, utils::join(mods, ", "));
std::exit(EXIT_FAILURE);
}
return value;
});

parser.visible.add_argument("--modified-bases-models")
.default_value(std::string())
.help("a comma separated list of modified base models");

parser.visible.add_argument("--modified-bases-threshold")
.default_value(default_parameters.methylation_threshold)
.scan<'f', float>()
.help("the minimum predicted methylation probability for a modified base to be emitted "
"in an all-context model, [0, 1]");

cli::add_minimap2_arguments(parser, alignment::dflt_options);
cli::add_internal_arguments(parser);

Expand All @@ -226,6 +273,10 @@ int duplex(int argc, char* argv[]) {
if (parser.visible.get<bool>("--verbose")) {
utils::SetVerboseLogging(static_cast<dorado::utils::VerboseLogLevel>(verbosity));
}

auto mod_bases = parser.visible.get<std::vector<std::string>>("--modified-bases");
auto mod_bases_models = parser.visible.get<std::string>("--modified-bases-models");

std::map<std::string, std::string> template_complement_map;
auto read_list = utils::load_read_list(parser.visible.get<std::string>("--read-ids"));

Expand Down Expand Up @@ -331,6 +382,11 @@ int duplex(int argc, char* argv[]) {
return 1; // Exit with an error code
}

if (!mod_bases.empty() || !mod_bases_models.empty()) {
spdlog::error("Basespace duplex does not support modbase models");
return EXIT_FAILURE;
}

spdlog::info("> Loading reads");
auto read_map = read_bam(reads, read_list_from_pairs);

Expand All @@ -355,15 +411,25 @@ int duplex(int argc, char* argv[]) {
kStatsPeriod, stats_reporters, stats_callables, max_stats_records);
} else { // Execute a Stereo Duplex pipeline.

if (!DataLoader::is_read_data_present(reads, recursive_file_loading)) {
std::string err = "No POD5 or FAST5 data found in path: " + reads;
throw std::runtime_error(err);
}

const bool skip_model_compatibility_check =
parser.hidden.get<bool>("--skip-model-compatibility-check");

const DuplexModels models = load_models(model, reads, recursive_file_loading,
skip_model_compatibility_check);
const DuplexModels models =
load_models(model, mod_bases, mod_bases_models, reads, recursive_file_loading,
skip_model_compatibility_check);

if (!DataLoader::is_read_data_present(reads, recursive_file_loading)) {
std::string err = "No POD5 or FAST5 data found in path: " + reads;
throw std::runtime_error(err);
// create modbase runners first so basecall runners can pick batch sizes based on available memory
auto mod_base_runners = create_modbase_runners(
models.mods_model_paths, device, default_parameters.mod_base_runners_per_caller,
default_parameters.remora_batchsize);

if (!mod_base_runners.empty() && output_mode == HtsWriter::OutputMode::FASTQ) {
throw std::runtime_error("Modified base models cannot be used with FASTQ output");
}

// Write read group info to header.
Expand Down Expand Up @@ -429,9 +495,12 @@ int duplex(int argc, char* argv[]) {
throw std::runtime_error("Mean q-score start position cannot be < 0");
}
}

pipelines::create_stereo_duplex_pipeline(
pipeline_desc, std::move(runners), std::move(stereo_runners), overlap,
mean_qscore_start_pos, int(num_devices * 2), int(num_devices),
pipeline_desc, std::move(runners), std::move(stereo_runners),
std::move(mod_base_runners), overlap, mean_qscore_start_pos,
int(num_devices * 2), int(num_devices),
int(default_parameters.remora_threads * num_devices),
std::move(pairing_parameters), read_filter_node,
PipelineDescriptor::InvalidNodeHandle);

Expand Down

0 comments on commit b7d4b38

Please sign in to comment.