diff --git a/CMakeLists.txt b/CMakeLists.txt index 6d34f45a..64d6091e 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -161,6 +161,8 @@ enable_testing() add_subdirectory(dorado/utils) add_subdirectory(dorado/models) +add_subdirectory(dorado/basecall) +add_subdirectory(dorado/modbase) if("${CUDA_VERSION}" STREQUAL "") set(CUDA_VERSION ${CUDAToolkit_VERSION}) @@ -179,21 +181,10 @@ set(LIB_SOURCE_FILES dorado/alignment/Minimap2Index.h dorado/alignment/Minimap2IndexSupportTypes.h dorado/alignment/Minimap2Options.h - dorado/nn/CRFModel.h - dorado/nn/CRFModel.cpp - dorado/nn/CRFModelConfig.h - dorado/nn/CRFModelConfig.cpp - dorado/nn/ModelRunner.h - dorado/nn/ModBaseModel.cpp - dorado/nn/ModBaseModel.h - dorado/nn/ModBaseModelConfig.cpp - dorado/nn/ModBaseModelConfig.h - dorado/nn/ModBaseRunner.cpp - dorado/nn/ModBaseRunner.h - dorado/nn/Runners.cpp - dorado/nn/Runners.h - dorado/read_pipeline/Pipelines.cpp - dorado/read_pipeline/Pipelines.h + dorado/api/runner_creation.cpp + dorado/api/runner_creation.h + dorado/api/pipeline_creation.cpp + dorado/api/pipeline_creation.h dorado/read_pipeline/FakeDataLoader.cpp dorado/read_pipeline/FakeDataLoader.h dorado/read_pipeline/ReadPipeline.cpp @@ -263,36 +254,8 @@ set(LIB_SOURCE_FILES dorado/demux/Trimmer.h dorado/demux/parse_custom_kit.cpp dorado/demux/parse_custom_kit.h - dorado/decode/beam_search.cpp - dorado/decode/beam_search.h - dorado/decode/CPUDecoder.cpp - dorado/decode/CPUDecoder.h - dorado/modbase/modbase_encoder.cpp - dorado/modbase/modbase_encoder.h - dorado/modbase/modbase_scaler.cpp - dorado/modbase/modbase_scaler.h - dorado/modbase/ModBaseContext.cpp - dorado/modbase/ModBaseContext.h - dorado/modbase/MotifMatcher.cpp - dorado/modbase/MotifMatcher.h ) -if (DORADO_GPU_BUILD) - if(APPLE) - list(APPEND LIB_SOURCE_FILES - dorado/nn/MetalCRFModel.h - dorado/nn/MetalCRFModel.cpp - ) - else() - list(APPEND LIB_SOURCE_FILES - dorado/decode/GPUDecoder.cpp - dorado/decode/GPUDecoder.h - dorado/nn/CudaCRFModel.h - dorado/nn/CudaCRFModel.cpp - ) - endif() -endif() - add_library(dorado_lib ${LIB_SOURCE_FILES}) enable_warnings_as_errors(dorado_lib) @@ -312,7 +275,6 @@ target_include_directories(dorado_lib PUBLIC ${CMAKE_CURRENT_BINARY_DIR}/dorado ${CMAKE_CURRENT_SOURCE_DIR}/dorado - ${KOI_INCLUDE} ${POD5_INCLUDE} ) # 3rdparty libs should be considered SYSTEM headers @@ -348,8 +310,9 @@ target_link_libraries(dorado_lib vbz_hdf_plugin edlib dorado_utils + dorado_basecall + dorado_modbase PRIVATE - ${KOI_LIBRARIES} minimap2 ) diff --git a/cmake/Metal.cmake b/cmake/Metal.cmake index caf02e91..212fe673 100644 --- a/cmake/Metal.cmake +++ b/cmake/Metal.cmake @@ -5,7 +5,7 @@ find_library(APPLE_FWK_METAL Metal REQUIRED) find_library(IOKIT IOKit REQUIRED) set(AIR_FILES) -set(METAL_SOURCES dorado/nn/metal/nn.metal) +set(METAL_SOURCES dorado/basecall/metal/nn.metal) if (CMAKE_SYSTEM_NAME STREQUAL "iOS") set(XCRUN_SDK iphoneos) diff --git a/dorado/read_pipeline/Pipelines.cpp b/dorado/api/pipeline_creation.cpp similarity index 82% rename from dorado/read_pipeline/Pipelines.cpp rename to dorado/api/pipeline_creation.cpp index d1fb622f..b7a20896 100644 --- a/dorado/read_pipeline/Pipelines.cpp +++ b/dorado/api/pipeline_creation.cpp @@ -1,14 +1,14 @@ -#include "Pipelines.h" - -#include "BasecallerNode.h" -#include "ModBaseCallerNode.h" -#include "PairingNode.h" -#include "ReadSplitNode.h" -#include "ScalerNode.h" -#include "StereoDuplexEncoderNode.h" -#include "nn/CRFModelConfig.h" -#include "nn/ModBaseRunner.h" -#include "nn/ModelRunner.h" +#include "pipeline_creation.h" + +#include "basecall/CRFModelConfig.h" +#include "basecall/ModelRunner.h" +#include "modbase/ModBaseRunner.h" +#include "read_pipeline/BasecallerNode.h" +#include "read_pipeline/ModBaseCallerNode.h" +#include "read_pipeline/PairingNode.h" +#include "read_pipeline/ReadSplitNode.h" +#include "read_pipeline/ScalerNode.h" +#include "read_pipeline/StereoDuplexEncoderNode.h" #include "splitter/DuplexReadSplitter.h" #include "splitter/RNAReadSplitter.h" @@ -17,8 +17,8 @@ namespace dorado::pipelines { void create_simplex_pipeline(PipelineDescriptor& pipeline_desc, - std::vector&& runners, - std::vector>&& modbase_runners, + std::vector&& runners, + std::vector&& modbase_runners, size_t overlap, uint32_t mean_qscore_start_pos, int scaler_node_threads, @@ -76,7 +76,7 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc, // For DNA, read splitting happens after basecall. if (enable_read_splitter && !is_rna) { splitter::DuplexSplitSettings splitter_settings(model_config.signal_norm_params.strategy == - ScalingStrategy::PA); + basecall::ScalingStrategy::PA); splitter_settings.simplex_mode = true; auto dna_splitter = std::make_unique(splitter_settings); auto dna_splitter_node = pipeline_desc.add_node({}, std::move(dna_splitter), @@ -105,19 +105,18 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc, } } -void create_stereo_duplex_pipeline( - PipelineDescriptor& pipeline_desc, - std::vector&& runners, - std::vector&& stereo_runners, - std::vector>&& modbase_runners, - size_t overlap, - uint32_t mean_qscore_start_pos, - int scaler_node_threads, - int splitter_node_threads, - int modbase_node_threads, - PairingParameters pairing_parameters, - NodeHandle sink_node_handle, - NodeHandle source_node_handle) { +void create_stereo_duplex_pipeline(PipelineDescriptor& pipeline_desc, + std::vector&& runners, + std::vector&& stereo_runners, + std::vector&& modbase_runners, + size_t overlap, + uint32_t mean_qscore_start_pos, + int scaler_node_threads, + int splitter_node_threads, + int modbase_node_threads, + PairingParameters pairing_parameters, + NodeHandle sink_node_handle, + NodeHandle source_node_handle) { const auto& model_config = runners.front()->config(); const auto& stereo_model_config = stereo_runners.front()->config(); std::string model_name = @@ -162,7 +161,7 @@ void create_stereo_duplex_pipeline( // as a passthrough, meaning it won't perform any splitting operations and // will just pass data through. splitter::DuplexSplitSettings splitter_settings(model_config.signal_norm_params.strategy == - ScalingStrategy::PA); + basecall::ScalingStrategy::PA); auto duplex_splitter = std::make_unique(splitter_settings); auto splitter_node = pipeline_desc.add_node( {pairing_node}, std::move(duplex_splitter), splitter_node_threads, 1000); @@ -174,9 +173,9 @@ void create_stereo_duplex_pipeline( {splitter_node}, std::move(runners), adjusted_simplex_overlap, kSimplexBatchTimeoutMS, model_name, 1000, "BasecallerNode", mean_qscore_start_pos); - auto scaler_node = - pipeline_desc.add_node({basecaller_node}, model_config.signal_norm_params, - SampleType::DNA, scaler_node_threads, 1000); + auto scaler_node = pipeline_desc.add_node( + {basecaller_node}, model_config.signal_norm_params, basecall::SampleType::DNA, + scaler_node_threads, 1000); // if we've been provided a source node, connect it to the start of our pipeline if (source_node_handle != PipelineDescriptor::InvalidNodeHandle) { diff --git a/dorado/read_pipeline/Pipelines.h b/dorado/api/pipeline_creation.h similarity index 51% rename from dorado/read_pipeline/Pipelines.h rename to dorado/api/pipeline_creation.h index f4a27211..ab6dbe45 100644 --- a/dorado/read_pipeline/Pipelines.h +++ b/dorado/api/pipeline_creation.h @@ -1,6 +1,6 @@ #pragma once -#include "ReadPipeline.h" +#include "read_pipeline/ReadPipeline.h" #include #include @@ -11,10 +11,16 @@ namespace dorado { -class ModBaseRunner; +namespace basecall { class ModelRunnerBase; +using RunnerPtr = std::unique_ptr; +} // namespace basecall + +namespace modbase { +class ModBaseRunner; +using RunnerPtr = std::unique_ptr; +} // namespace modbase -using Runner = std::shared_ptr; using PairingParameters = std::variant>; namespace pipelines { @@ -23,8 +29,8 @@ namespace pipelines { /// If source_node_handle is valid, set this to be the source of the simplex pipeline /// If sink_node_handle is valid, set this to be the sink of the simplex pipeline void create_simplex_pipeline(PipelineDescriptor& pipeline_desc, - std::vector&& runners, - std::vector>&& modbase_runners, + std::vector&& runners, + std::vector&& modbase_runners, size_t overlap, uint32_t mean_qscore_start_pos, int scaler_node_threads, @@ -37,19 +43,18 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc, /// Create a duplex basecall pipeline description /// If source_node_handle is valid, set this to be the source of the simplex pipeline /// If sink_node_handle is valid, set this to be the sink of the simplex pipeline -void create_stereo_duplex_pipeline( - PipelineDescriptor& pipeline_desc, - std::vector&& runners, - std::vector&& stereo_runners, - std::vector>&& modbase_runners, - size_t overlap, - uint32_t mean_qscore_start_pos, - int scaler_node_threads, - int splitter_node_threads, - int modbase_node_threads, - PairingParameters pairing_parameters, - NodeHandle sink_node_handle, - NodeHandle source_node_handle); +void create_stereo_duplex_pipeline(PipelineDescriptor& pipeline_desc, + std::vector&& runners, + std::vector&& stereo_runners, + std::vector&& modbase_runners, + size_t overlap, + uint32_t mean_qscore_start_pos, + int scaler_node_threads, + int splitter_node_threads, + int modbase_node_threads, + PairingParameters pairing_parameters, + NodeHandle sink_node_handle, + NodeHandle source_node_handle); } // namespace pipelines diff --git a/dorado/nn/Runners.cpp b/dorado/api/runner_creation.cpp similarity index 74% rename from dorado/nn/Runners.cpp rename to dorado/api/runner_creation.cpp index 94ed34f6..80224986 100644 --- a/dorado/nn/Runners.cpp +++ b/dorado/api/runner_creation.cpp @@ -1,15 +1,13 @@ -#include "Runners.h" +#include "runner_creation.h" -#include "ModBaseRunner.h" -#include "ModelRunner.h" -#include "decode/CPUDecoder.h" -#include "nn/CRFModel.h" +#include "basecall/crf_utils.h" +#include "basecall/decode/CPUDecoder.h" #if DORADO_GPU_BUILD #ifdef __APPLE__ -#include "nn/MetalCRFModel.h" +#include "basecall/MetalCRFModel.h" #else -#include "nn/CudaCRFModel.h" +#include "basecall/CudaCRFModel.h" #include "utils/cuda_utils.h" #endif #endif // DORADO_GPU_BUILD @@ -21,8 +19,8 @@ namespace dorado { -std::pair, size_t> create_basecall_runners( - const dorado::CRFModelConfig& model_config, +std::pair, size_t> create_basecall_runners( + const basecall::CRFModelConfig& model_config, const std::string& device, size_t num_gpu_runners, size_t num_cpu_runners, @@ -34,7 +32,7 @@ std::pair, size_t> create_basecall_runners( (void)guard_gpus; #endif - std::vector runners; + std::vector runners; // Default is 1 device. CUDA path may alter this. size_t num_devices = 1; @@ -48,22 +46,23 @@ std::pair, size_t> create_basecall_runners( batch_size = 128; } if (num_cpu_runners == 0) { - num_cpu_runners = auto_calculate_num_runners(model_config, batch_size, memory_fraction); + num_cpu_runners = + basecall::auto_calculate_num_runners(model_config, batch_size, memory_fraction); } spdlog::debug("- CPU calling: set batch size to {}, num_cpu_runners to {}", batch_size, num_cpu_runners); for (size_t i = 0; i < num_cpu_runners; i++) { - runners.push_back(std::make_shared>( + runners.push_back(std::make_unique>( model_config, device, int(chunk_size), int(batch_size))); } } #if DORADO_GPU_BUILD #ifdef __APPLE__ else if (device == "metal") { - auto caller = dorado::create_metal_caller(model_config, int(chunk_size), int(batch_size)); + auto caller = basecall::create_metal_caller(model_config, int(chunk_size), int(batch_size)); for (size_t i = 0; i < num_gpu_runners; i++) { - runners.push_back(std::make_shared(caller)); + runners.push_back(std::make_unique(caller)); } if (batch_size == 0) { spdlog::info(" - set batch size to {}", runners.back()->batch_size()); @@ -84,11 +83,11 @@ std::pair, size_t> create_basecall_runners( } cxxpool::thread_pool pool{num_devices}; - std::vector> callers; - std::vector>> futures; + std::vector> callers; + std::vector>> futures; for (auto device_string : devices) { - futures.push_back(pool.push(dorado::create_cuda_caller, model_config, int(chunk_size), + futures.push_back(pool.push(basecall::create_cuda_caller, model_config, int(chunk_size), int(batch_size), device_string, memory_fraction, guard_gpus)); } @@ -99,7 +98,7 @@ std::pair, size_t> create_basecall_runners( for (size_t j = 0; j < num_devices; j++) { for (size_t i = 0; i < num_gpu_runners; i++) { - runners.push_back(std::make_shared(callers[j])); + runners.push_back(std::make_unique(callers[j])); } if (batch_size == 0) { spdlog::info(" - set batch size for {} to {}", devices[j], @@ -121,7 +120,7 @@ std::pair, size_t> create_basecall_runners( auto model_stride = runners.front()->model_stride(); #endif auto adjusted_chunk_size = runners.front()->chunk_size(); - assert(std::all_of(runners.begin(), runners.end(), [&](auto runner) { + assert(std::all_of(runners.begin(), runners.end(), [&](const auto& runner) { return runner->model_stride() == model_stride && runner->chunk_size() == adjusted_chunk_size; })); @@ -132,10 +131,10 @@ std::pair, size_t> create_basecall_runners( chunk_size = adjusted_chunk_size; } - return {runners, num_devices}; + return {std::move(runners), num_devices}; } -std::vector> create_modbase_runners( +std::vector create_modbase_runners( const std::vector& remora_models, const std::string& device, size_t remora_runners_per_caller, @@ -145,7 +144,7 @@ std::vector> create_modbase_runners( } // generate model callers before nodes or it affects the speed calculations - std::vector> remora_runners; + std::vector remora_runners; std::vector modbase_devices; int remora_callers = 1; @@ -168,10 +167,10 @@ std::vector> create_modbase_runners( #endif // DORADO_GPU_BUILD for (const auto& device_string : modbase_devices) { for (int i = 0; i < remora_callers; ++i) { - auto caller = dorado::create_modbase_caller(remora_models, int(remora_batch_size), - device_string); + auto caller = modbase::create_modbase_caller(remora_models, int(remora_batch_size), + device_string); for (size_t j = 0; j < remora_runners_per_caller; j++) { - remora_runners.push_back(std::make_unique(caller)); + remora_runners.push_back(std::make_unique(caller)); } } }; diff --git a/dorado/nn/Runners.h b/dorado/api/runner_creation.h similarity index 65% rename from dorado/nn/Runners.h rename to dorado/api/runner_creation.h index 9276fdf2..b9d09074 100644 --- a/dorado/nn/Runners.h +++ b/dorado/api/runner_creation.h @@ -1,5 +1,8 @@ #pragma once +#include "basecall/ModelRunner.h" +#include "modbase/ModBaseRunner.h" + #include #include #include @@ -8,14 +11,12 @@ namespace dorado { +namespace basecall { struct CRFModelConfig; -class ModelRunnerBase; -class ModBaseRunner; - -using Runner = std::shared_ptr; +} // namespace basecall -std::pair, size_t> create_basecall_runners( - const dorado::CRFModelConfig& model_config, +std::pair, size_t> create_basecall_runners( + const basecall::CRFModelConfig& model_config, const std::string& device, size_t num_gpu_runners, size_t num_cpu_runners, @@ -24,7 +25,7 @@ std::pair, size_t> create_basecall_runners( float memory_fraction, bool guard_gpus); -std::vector> create_modbase_runners( +std::vector create_modbase_runners( const std::vector& remora_models, const std::string& device, size_t remora_runners_per_caller, diff --git a/dorado/basecall/CMakeLists.txt b/dorado/basecall/CMakeLists.txt new file mode 100644 index 00000000..88bf5df9 --- /dev/null +++ b/dorado/basecall/CMakeLists.txt @@ -0,0 +1,65 @@ +add_library(dorado_basecall STATIC + crf_utils.h + crf_utils.cpp + CRFModel.h + CRFModel.cpp + CRFModelConfig.h + CRFModelConfig.cpp + ModelRunner.h + decode/beam_search.cpp + decode/beam_search.h + decode/CPUDecoder.cpp + decode/CPUDecoder.h +) + +if (DORADO_GPU_BUILD) + if(APPLE) + target_sources(dorado_basecall + PRIVATE + MetalCRFModel.h + MetalCRFModel.cpp + ) + else() + target_sources(dorado_basecall + PRIVATE + CudaCRFModel.h + CudaCRFModel.cpp + decode/GPUDecoder.cpp + decode/GPUDecoder.h + ) + endif() +endif() + +target_include_directories(dorado_basecall + SYSTEM + PUBLIC + ${DORADO_3RD_PARTY_SOURCE}/toml11 + PRIVATE + ${DORADO_3RD_PARTY_SOURCE}/NVTX/c/include +) + + +target_link_libraries(dorado_basecall + PUBLIC + ${TORCH_LIBRARIES} + dorado_utils + PRIVATE + dorado_models_lib + ${KOI_LIBRARIES} + spdlog::spdlog +) + +target_include_directories(dorado_basecall + PRIVATE + ${KOI_INCLUDE} +) + +enable_warnings_as_errors(dorado_basecall) + +if (DORADO_ENABLE_PCH) + target_precompile_headers(dorado_basecall REUSE_FROM dorado_utils) + # these are defined publicly by minimap2, which we're not using. + # we need to define them here so that the environment matches the one + # used when building the PCH + target_compile_definitions(dorado_basecall PRIVATE PTW32_CLEANUP_C PTW32_STATIC_LIB) +endif() diff --git a/dorado/nn/CRFModel.cpp b/dorado/basecall/CRFModel.cpp similarity index 92% rename from dorado/nn/CRFModel.cpp rename to dorado/basecall/CRFModel.cpp index 2dd7db46..0d9e3a7a 100644 --- a/dorado/nn/CRFModel.cpp +++ b/dorado/basecall/CRFModel.cpp @@ -1,9 +1,9 @@ #include "CRFModel.h" #include "CRFModelConfig.h" +#include "crf_utils.h" #include "utils/gpu_profiling.h" #include "utils/math_utils.h" -#include "utils/memory_utils.h" #include "utils/module_utils.h" #include "utils/tensor_utils.h" @@ -12,7 +12,7 @@ #if DORADO_GPU_BUILD && !defined(__APPLE__) #define USE_KOI 1 -#include "../utils/cuda_utils.h" +#include "utils/cuda_utils.h" #include #include @@ -42,14 +42,17 @@ using namespace torch::nn; namespace F = torch::nn::functional; using Slice = torch::indexing::Slice; -#if USE_KOI +namespace dorado::basecall { + +namespace { -KoiActivation get_koi_activation(dorado::Activation act) { - if (act == dorado::Activation::SWISH) { +#if USE_KOI +KoiActivation get_koi_activation(Activation act) { + if (act == Activation::SWISH) { return KOI_SWISH; - } else if (act == dorado::Activation::SWISH_CLAMP) { + } else if (act == Activation::SWISH_CLAMP) { return KOI_SWISH_CLAMP; - } else if (act == dorado::Activation::TANH) { + } else if (act == Activation::TANH) { return KOI_TANH; } else { throw std::logic_error("Unrecognised activation function id."); @@ -91,24 +94,24 @@ KoiActivation get_koi_activation(dorado::Activation act) { enum class TensorLayout { NTC, TNC, CUTLASS_TNC_F16, CUTLASS_TNC_I8, CUBLAS_TN2C }; // TODO: These should really be part of Koi -static bool koi_can_use_cutlass() { +bool koi_can_use_cutlass() { cudaDeviceProp *prop = at::cuda::getCurrentDeviceProperties(); return ((prop->major == 8 || prop->major == 9) && prop->minor == 0); } -static bool koi_can_use_quantised_lstm() { +bool koi_can_use_quantised_lstm() { cudaDeviceProp *prop = at::cuda::getCurrentDeviceProperties(); // DP4A is supported on Pascal and later, except for TX2 (sm_62). return (prop->major > 6) || (prop->major == 6 && prop->minor != 2); } -static TensorLayout get_koi_lstm_input_layout(int layer_size, dorado::Activation activation) { +TensorLayout get_koi_lstm_input_layout(int layer_size, Activation activation) { TensorLayout layout = TensorLayout::CUBLAS_TN2C; if (koi_can_use_quantised_lstm() && (layer_size == 96 || layer_size == 128)) { layout = TensorLayout::NTC; } else if (koi_can_use_cutlass() && layer_size <= 1024 && layer_size > 128 && (layer_size % 128) == 0) { - layout = (activation == dorado::Activation::TANH) ? TensorLayout::CUTLASS_TNC_I8 - : TensorLayout::CUTLASS_TNC_F16; + layout = (activation == Activation::TANH) ? TensorLayout::CUTLASS_TNC_I8 + : TensorLayout::CUTLASS_TNC_F16; } // Apply override (Cutlass override can only be applied if conditions are met) @@ -273,14 +276,13 @@ class WorkingMemory { #endif // if USE_KOI -namespace { template ModuleHolder populate_model(Model &&model, const std::filesystem::path &path, const at::TensorOptions &options, bool decomposition, bool linear_layer_bias) { - auto state_dict = dorado::load_crf_model_weights(path, decomposition, linear_layer_bias); + auto state_dict = load_crf_model_weights(path, decomposition, linear_layer_bias); model->load_state_dict(state_dict); model->to(options.dtype_opt().value().toScalarType()); model->to(options.device_opt().value()); @@ -290,9 +292,8 @@ ModuleHolder populate_model(Model &&model, auto holder = ModuleHolder(module); return holder; } -} // namespace -namespace dorado { +} // namespace namespace nn { @@ -906,45 +907,6 @@ TORCH_MODULE(CRFModel); } // namespace nn -std::vector load_crf_model_weights(const std::filesystem::path &dir, - bool decomposition, - bool linear_layer_bias) { - auto tensors = std::vector{ - - "0.conv.weight.tensor", "0.conv.bias.tensor", - - "1.conv.weight.tensor", "1.conv.bias.tensor", - - "2.conv.weight.tensor", "2.conv.bias.tensor", - - "4.rnn.weight_ih_l0.tensor", "4.rnn.weight_hh_l0.tensor", - "4.rnn.bias_ih_l0.tensor", "4.rnn.bias_hh_l0.tensor", - - "5.rnn.weight_ih_l0.tensor", "5.rnn.weight_hh_l0.tensor", - "5.rnn.bias_ih_l0.tensor", "5.rnn.bias_hh_l0.tensor", - - "6.rnn.weight_ih_l0.tensor", "6.rnn.weight_hh_l0.tensor", - "6.rnn.bias_ih_l0.tensor", "6.rnn.bias_hh_l0.tensor", - - "7.rnn.weight_ih_l0.tensor", "7.rnn.weight_hh_l0.tensor", - "7.rnn.bias_ih_l0.tensor", "7.rnn.bias_hh_l0.tensor", - - "8.rnn.weight_ih_l0.tensor", "8.rnn.weight_hh_l0.tensor", - "8.rnn.bias_ih_l0.tensor", "8.rnn.bias_hh_l0.tensor", - - "9.linear.weight.tensor"}; - - if (linear_layer_bias) { - tensors.push_back("9.linear.bias.tensor"); - } - - if (decomposition) { - tensors.push_back("10.linear.weight.tensor"); - } - - return utils::load_tensors(dir, tensors); -} - ModuleHolder load_crf_model(const CRFModelConfig &model_config, const at::TensorOptions &options) { auto model = nn::CRFModel(model_config); @@ -952,30 +914,4 @@ ModuleHolder load_crf_model(const CRFModelConfig &model_config, model_config.out_features.has_value(), model_config.bias); } -size_t auto_calculate_num_runners(const CRFModelConfig &model_config, - size_t batch_size, - float memory_fraction) { - auto model_name = std::filesystem::canonical(model_config.model_path).filename().string(); - - // very hand-wavy determination - // these numbers were determined empirically by running 1, 2, 4 and 8 runners for each model - auto required_ram_per_runner_GB = 0.f; - if (model_name.find("_fast@v") != std::string::npos) { - required_ram_per_runner_GB = 1.5; - } else if (model_name.find("_hac@v") != std::string::npos) { - required_ram_per_runner_GB = 4.5; - } else if (model_name.find("_sup@v") != std::string::npos) { - required_ram_per_runner_GB = 12.5; - } else { - return 1; - } - - // numbers were determined with a batch_size of 128, assume this just scales - required_ram_per_runner_GB *= batch_size / 128.f; - - auto free_ram_GB = utils::available_host_memory_GB() * memory_fraction; - auto num_runners = static_cast(free_ram_GB / required_ram_per_runner_GB); - return std::clamp(num_runners, size_t(1), std::size_t(std::thread::hardware_concurrency())); -} - -} // namespace dorado +} // namespace dorado::basecall diff --git a/dorado/basecall/CRFModel.h b/dorado/basecall/CRFModel.h new file mode 100644 index 00000000..e7045a32 --- /dev/null +++ b/dorado/basecall/CRFModel.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace dorado::basecall { + +struct CRFModelConfig; + +torch::nn::ModuleHolder load_crf_model(const CRFModelConfig& model_config, + const at::TensorOptions& options); + +} // namespace dorado::basecall diff --git a/dorado/nn/CRFModelConfig.cpp b/dorado/basecall/CRFModelConfig.cpp similarity index 98% rename from dorado/nn/CRFModelConfig.cpp rename to dorado/basecall/CRFModelConfig.cpp index 35fc389e..ea787e42 100644 --- a/dorado/nn/CRFModelConfig.cpp +++ b/dorado/basecall/CRFModelConfig.cpp @@ -12,6 +12,8 @@ #include #include +namespace { + enum SublayerType { CLAMP, CONVOLUTION, LINEAR, LINEAR_CRF_ENCODER, LSTM, PERMUTE, UNRECOGNISED }; static const std::unordered_map sublayer_map = { {"clamp", SublayerType::CLAMP}, {"convolution", SublayerType::CONVOLUTION}, @@ -29,7 +31,8 @@ SublayerType sublayer_type(const toml::value &segment) { return mapping_iter->second; } -namespace dorado { +} // namespace +namespace dorado::basecall { // Parse the config to determine if there are any clamp layers bool has_clamp(const std::vector &sublayers) { @@ -157,7 +160,7 @@ SampleType get_model_type(const std::string &model_name) { std::string SignalNormalisationParams::to_string() const { std::string str = "SignalNormalisationParams {"; - str += " strategy:" + dorado::to_string(strategy); + str += " strategy:" + dorado::basecall::to_string(strategy); if (strategy == ScalingStrategy::QUANTILE) { str += quantile.to_string(); } else if (strategy == ScalingStrategy::PA && standarisation.standardise) { @@ -173,7 +176,7 @@ std::string ConvParams::to_string() const { str += " size:" + std::to_string(size); str += " winlen:" + std::to_string(winlen); str += " stride:" + std::to_string(stride); - str += " activation:" + dorado::to_string(activation); + str += " activation:" + dorado::basecall::to_string(activation); str += "}"; return str; }; @@ -362,4 +365,4 @@ ScalingStrategy scaling_strategy_from_string(const std::string &strategy) { throw std::runtime_error("Unknown scaling strategy: `" + strategy + "`"); } -} // namespace dorado +} // namespace dorado::basecall diff --git a/dorado/nn/CRFModelConfig.h b/dorado/basecall/CRFModelConfig.h similarity index 97% rename from dorado/nn/CRFModelConfig.h rename to dorado/basecall/CRFModelConfig.h index e48d0e79..d6165c3f 100644 --- a/dorado/nn/CRFModelConfig.h +++ b/dorado/basecall/CRFModelConfig.h @@ -1,13 +1,11 @@ #pragma once -#include "utils/math_utils.h" - #include #include #include #include -namespace dorado { +namespace dorado::basecall { enum class Activation { SWISH, SWISH_CLAMP, TANH }; std::string to_string(const Activation& activation); @@ -114,4 +112,4 @@ int32_t get_model_mean_qscore_start_pos(const CRFModelConfig& model_config); bool is_rna_model(const CRFModelConfig& model_config); -} // namespace dorado +} // namespace dorado::basecall diff --git a/dorado/nn/CudaCRFModel.cpp b/dorado/basecall/CudaCRFModel.cpp similarity index 94% rename from dorado/nn/CudaCRFModel.cpp rename to dorado/basecall/CudaCRFModel.cpp index 7b430475..09bbc3a7 100644 --- a/dorado/nn/CudaCRFModel.cpp +++ b/dorado/basecall/CudaCRFModel.cpp @@ -19,7 +19,7 @@ using namespace std::chrono_literals; -namespace dorado { +namespace dorado::basecall { class CudaCaller { public: @@ -32,16 +32,16 @@ class CudaCaller { : m_config(model_config), m_device(device), m_exclusive_gpu_access(exclusive_gpu_access) { - m_decoder_options = DecoderOptions(); + m_decoder_options = decode::DecoderOptions(); m_decoder_options.q_shift = model_config.qbias; m_decoder_options.q_scale = model_config.qscale; - m_decoder = std::make_unique(model_config.clamp ? 5.f : 0.f); + m_decoder = std::make_unique(model_config.clamp ? 5.f : 0.f); m_num_input_features = model_config.num_features; // adjust chunk size to be a multiple of the stride m_out_chunk_size = chunk_size / model_config.stride; m_in_chunk_size = m_out_chunk_size * model_config.stride; - m_options = at::TensorOptions().dtype(GPUDecoder::dtype).device(device); + m_options = at::TensorOptions().dtype(decode::GPUDecoder::dtype).device(device); assert(m_options.device().is_cuda()); at::InferenceMode guard; @@ -95,7 +95,7 @@ class CudaCaller { return 64; } - int determine_batch_size(const dorado::CRFModelConfig &model_config, + int determine_batch_size(const CRFModelConfig &model_config, int chunk_size_in, float memory_limit_fraction, bool run_benchmark) { @@ -155,7 +155,7 @@ class CudaCaller { // Determine size of working memory for decoder divided by (batch_size * chunk_size) // Decoder needs roughly (beam_width * 4) + num_states + 10 extra bytes // where num_states = 4^(state_len+1) - // See `dorado::GPUDecoder::gpu_part()`, block beginning with `if (!initialized) {` + // See `dorado::basecall::decode::GPUDecoder::gpu_part()`, block beginning with `if (!initialized) {` // for more details. int64_t decode_bytes_per_chunk_timestep = 10 + m_decoder_options.beam_width * 4 + (1ull << (model_config.state_len * 2 + 2)); @@ -230,15 +230,15 @@ class CudaCaller { bool done{false}; }; - std::vector call_chunks(at::Tensor &input, - at::Tensor &output, - int num_chunks, - c10::cuda::CUDAStream stream) { + std::vector call_chunks(at::Tensor &input, + at::Tensor &output, + int num_chunks, + c10::cuda::CUDAStream stream) { NVTX3_FUNC_RANGE(); c10::cuda::CUDAStreamGuard stream_guard(stream); if (num_chunks == 0) { - return std::vector(); + return std::vector(); } auto task = std::make_shared(input.to(m_options.device())); @@ -373,8 +373,8 @@ class CudaCaller { const CRFModelConfig m_config; std::string m_device; at::TensorOptions m_options; - std::unique_ptr m_decoder; - DecoderOptions m_decoder_options; + std::unique_ptr m_decoder; + decode::DecoderOptions m_decoder_options; torch::nn::ModuleHolder m_module{nullptr}; std::atomic m_terminate{false}; std::deque> m_input_queue; @@ -416,7 +416,7 @@ void CudaModelRunner::accept_chunk(int chunk_idx, const at::Tensor &chunk) { m_input.index_put_({chunk_idx, torch::indexing::Ellipsis}, chunk); } -std::vector CudaModelRunner::call_chunks(int num_chunks) { +std::vector CudaModelRunner::call_chunks(int num_chunks) { ++m_num_batches_called; stats::Timer timer; auto decoded_chunks = m_caller->call_chunks(m_input, m_output, num_chunks, m_stream); @@ -448,4 +448,4 @@ stats::NamedStats CudaModelRunner::sample_stats() const { return stats; } -} // namespace dorado +} // namespace dorado::basecall diff --git a/dorado/nn/CudaCRFModel.h b/dorado/basecall/CudaCRFModel.h similarity index 89% rename from dorado/nn/CudaCRFModel.h rename to dorado/basecall/CudaCRFModel.h index c216164f..ce00ebfc 100644 --- a/dorado/nn/CudaCRFModel.h +++ b/dorado/basecall/CudaCRFModel.h @@ -1,7 +1,7 @@ #pragma once +#include "CRFModel.h" #include "ModelRunner.h" -#include "nn/CRFModel.h" #include #include @@ -10,7 +10,7 @@ #include #include -namespace dorado { +namespace dorado::basecall { struct CRFModelConfig; class CudaCaller; @@ -26,7 +26,7 @@ class CudaModelRunner final : public ModelRunnerBase { public: explicit CudaModelRunner(std::shared_ptr caller); void accept_chunk(int chunk_idx, const at::Tensor& chunk) final; - std::vector call_chunks(int num_chunks) final; + std::vector call_chunks(int num_chunks) final; const CRFModelConfig& config() const final; size_t model_stride() const final; size_t chunk_size() const final; @@ -46,4 +46,4 @@ class CudaModelRunner final : public ModelRunnerBase { std::atomic m_num_batches_called = 0; }; -} // namespace dorado +} // namespace dorado::basecall diff --git a/dorado/nn/MetalCRFModel.cpp b/dorado/basecall/MetalCRFModel.cpp similarity index 98% rename from dorado/nn/MetalCRFModel.cpp rename to dorado/basecall/MetalCRFModel.cpp index f09f1ef4..34952b1d 100644 --- a/dorado/nn/MetalCRFModel.cpp +++ b/dorado/basecall/MetalCRFModel.cpp @@ -4,6 +4,7 @@ #include "MetalCRFModel.h" #include "CRFModelConfig.h" +#include "crf_utils.h" #include "decode/beam_search.h" #include "utils/math_utils.h" #include "utils/metal_utils.h" @@ -64,7 +65,7 @@ bool finishCommandBuffer(std::string_view label, MTL::CommandBuffer *cb, int try } // namespace -namespace dorado { +namespace dorado::basecall { namespace nn { @@ -617,7 +618,7 @@ class MetalCaller { m_device = get_mtl_device(); - m_decoder_options = DecoderOptions(); + m_decoder_options = decode::DecoderOptions(); m_decoder_options.q_shift = model_config.qbias; m_decoder_options.q_scale = model_config.qscale; @@ -821,14 +822,14 @@ class MetalCaller { } struct NNTask { - NNTask(at::Tensor *input_, int num_chunks_, std::vector *out_chunks_) + NNTask(at::Tensor *input_, int num_chunks_, std::vector *out_chunks_) : input(input_), out_chunks(out_chunks_), num_chunks(num_chunks_) {} at::Tensor *input; std::mutex mut; std::condition_variable cv; bool ready{false}; - std::vector *out_chunks; + std::vector *out_chunks; int num_chunks; int decode_chunks_started{0}; int decode_chunks_finished{0}; @@ -836,7 +837,9 @@ class MetalCaller { uint64_t decode_complete_event_id = static_cast(0); }; - void call_chunks(at::Tensor &input, int num_chunks, std::vector &out_chunks) { + void call_chunks(at::Tensor &input, + int num_chunks, + std::vector &out_chunks) { if (num_chunks == 0) { return; } @@ -977,7 +980,7 @@ class MetalCaller { const int out_buf_idx = chunk_idx / m_out_batch_size; const int buf_chunk_idx = chunk_idx % m_out_batch_size; - auto [sequence, qstring, moves] = beam_search_decode( + auto [sequence, qstring, moves] = decode::beam_search_decode( m_scores_int8.at(out_buf_idx).index({Slice(), buf_chunk_idx}), m_bwd.at(out_buf_idx)[buf_chunk_idx], m_posts_int16.at(out_buf_idx)[buf_chunk_idx], m_decoder_options.beam_width, @@ -985,7 +988,7 @@ class MetalCaller { m_decoder_options.q_shift, m_decoder_options.q_scale, score_scale); (*task->out_chunks)[chunk_idx] = - DecodedChunk{std::move(sequence), std::move(qstring), std::move(moves)}; + decode::DecodedChunk{std::move(sequence), std::move(qstring), std::move(moves)}; // Wake the waiting thread which called `call_chunks()` if we're done decoding std::unique_lock task_lock(task->mut); @@ -1037,7 +1040,7 @@ class MetalCaller { std::mutex m_decode_lock; std::condition_variable m_decode_cv; std::vector> m_decode_threads; - DecoderOptions m_decoder_options; + decode::DecoderOptions m_decoder_options; nn::MetalModel m_model{nullptr}; NS::SharedPtr m_device; NS::SharedPtr m_bwd_scan_cps, m_fwd_scan_add_softmax_cps; @@ -1084,9 +1087,9 @@ void MetalModelRunner::accept_chunk(int chunk_idx, const at::Tensor &chunk) { } } -std::vector MetalModelRunner::call_chunks(int num_chunks) { +std::vector MetalModelRunner::call_chunks(int num_chunks) { ++m_num_batches_called; - std::vector out_chunks(num_chunks); + std::vector out_chunks(num_chunks); m_caller->call_chunks(m_input, num_chunks, out_chunks); return out_chunks; } @@ -1105,4 +1108,4 @@ stats::NamedStats MetalModelRunner::sample_stats() const { return stats; } -} // namespace dorado +} // namespace dorado::basecall diff --git a/dorado/nn/MetalCRFModel.h b/dorado/basecall/MetalCRFModel.h similarity index 89% rename from dorado/nn/MetalCRFModel.h rename to dorado/basecall/MetalCRFModel.h index 84b5dbd9..927faef5 100644 --- a/dorado/nn/MetalCRFModel.h +++ b/dorado/basecall/MetalCRFModel.h @@ -10,7 +10,7 @@ #include #include -namespace dorado { +namespace dorado::basecall { struct CRFModelConfig; class MetalCaller; @@ -23,7 +23,7 @@ class MetalModelRunner final : public ModelRunnerBase { public: explicit MetalModelRunner(std::shared_ptr caller); void accept_chunk(int chunk_idx, const at::Tensor& chunk) final; - std::vector call_chunks(int num_chunks) final; + std::vector call_chunks(int num_chunks) final; const CRFModelConfig& config() const final; size_t model_stride() const final; size_t chunk_size() const final; @@ -41,4 +41,4 @@ class MetalModelRunner final : public ModelRunnerBase { std::atomic m_num_batches_called = 0; }; -} // namespace dorado +} // namespace dorado::basecall diff --git a/dorado/nn/ModelRunner.h b/dorado/basecall/ModelRunner.h similarity index 87% rename from dorado/nn/ModelRunner.h rename to dorado/basecall/ModelRunner.h index 625398fc..a9dee625 100644 --- a/dorado/nn/ModelRunner.h +++ b/dorado/basecall/ModelRunner.h @@ -10,12 +10,13 @@ #include #include -namespace dorado { +namespace dorado::basecall { class ModelRunnerBase { public: + virtual ~ModelRunnerBase() = default; virtual void accept_chunk(int chunk_idx, const at::Tensor &chunk) = 0; - virtual std::vector call_chunks(int num_chunks) = 0; + virtual std::vector call_chunks(int num_chunks) = 0; virtual const CRFModelConfig &config() const = 0; virtual size_t model_stride() const = 0; virtual size_t chunk_size() const = 0; @@ -26,7 +27,7 @@ class ModelRunnerBase { virtual stats::NamedStats sample_stats() const = 0; }; -using Runner = std::shared_ptr; +using RunnerPtr = std::unique_ptr; template class ModelRunner final : public ModelRunnerBase { @@ -36,7 +37,7 @@ class ModelRunner final : public ModelRunnerBase { int chunk_size, int batch_size); void accept_chunk(int chunk_idx, const at::Tensor &chunk) final; - std::vector call_chunks(int num_chunks) final; + std::vector call_chunks(int num_chunks) final; const CRFModelConfig &config() const final { return m_config; }; size_t model_stride() const final { return m_config.stride; } size_t chunk_size() const final { return m_input.size(2); } @@ -51,7 +52,7 @@ class ModelRunner final : public ModelRunnerBase { at::Tensor m_input; at::TensorOptions m_options; std::unique_ptr m_decoder; - DecoderOptions m_decoder_options; + decode::DecoderOptions m_decoder_options; torch::nn::ModuleHolder m_module{nullptr}; // Performance monitoring stats. @@ -66,7 +67,7 @@ ModelRunner::ModelRunner(const CRFModelConfig &model_config, int chunk_size, int batch_size) : m_config(model_config) { - m_decoder_options = DecoderOptions(); + m_decoder_options = decode::DecoderOptions(); m_decoder_options.q_shift = model_config.qbias; m_decoder_options.q_scale = model_config.qscale; m_decoder = std::make_unique(); @@ -82,7 +83,7 @@ ModelRunner::ModelRunner(const CRFModelConfig &model_config, } template -std::vector ModelRunner::call_chunks(int num_chunks) { +std::vector ModelRunner::call_chunks(int num_chunks) { at::InferenceMode guard; dorado::stats::Timer timer; auto scores = m_module->forward(m_input.to(m_options.device_opt().value())); @@ -109,4 +110,4 @@ stats::NamedStats ModelRunner::sample_stats() const { return stats; } -} // namespace dorado +} // namespace dorado::basecall diff --git a/dorado/basecall/crf_utils.cpp b/dorado/basecall/crf_utils.cpp new file mode 100644 index 00000000..47ed3fd9 --- /dev/null +++ b/dorado/basecall/crf_utils.cpp @@ -0,0 +1,76 @@ +#include "crf_utils.h" + +#include "CRFModelConfig.h" +#include "utils/memory_utils.h" +#include "utils/tensor_utils.h" + +#include +#include + +namespace dorado::basecall { +std::vector load_crf_model_weights(const std::filesystem::path &dir, + bool decomposition, + bool linear_layer_bias) { + auto tensors = std::vector{ + + "0.conv.weight.tensor", "0.conv.bias.tensor", + + "1.conv.weight.tensor", "1.conv.bias.tensor", + + "2.conv.weight.tensor", "2.conv.bias.tensor", + + "4.rnn.weight_ih_l0.tensor", "4.rnn.weight_hh_l0.tensor", + "4.rnn.bias_ih_l0.tensor", "4.rnn.bias_hh_l0.tensor", + + "5.rnn.weight_ih_l0.tensor", "5.rnn.weight_hh_l0.tensor", + "5.rnn.bias_ih_l0.tensor", "5.rnn.bias_hh_l0.tensor", + + "6.rnn.weight_ih_l0.tensor", "6.rnn.weight_hh_l0.tensor", + "6.rnn.bias_ih_l0.tensor", "6.rnn.bias_hh_l0.tensor", + + "7.rnn.weight_ih_l0.tensor", "7.rnn.weight_hh_l0.tensor", + "7.rnn.bias_ih_l0.tensor", "7.rnn.bias_hh_l0.tensor", + + "8.rnn.weight_ih_l0.tensor", "8.rnn.weight_hh_l0.tensor", + "8.rnn.bias_ih_l0.tensor", "8.rnn.bias_hh_l0.tensor", + + "9.linear.weight.tensor"}; + + if (linear_layer_bias) { + tensors.push_back("9.linear.bias.tensor"); + } + + if (decomposition) { + tensors.push_back("10.linear.weight.tensor"); + } + + return utils::load_tensors(dir, tensors); +} + +size_t auto_calculate_num_runners(const CRFModelConfig &model_config, + size_t batch_size, + float memory_fraction) { + auto model_name = std::filesystem::canonical(model_config.model_path).filename().string(); + + // very hand-wavy determination + // these numbers were determined empirically by running 1, 2, 4 and 8 runners for each model + auto required_ram_per_runner_GB = 0.f; + if (model_name.find("_fast@v") != std::string::npos) { + required_ram_per_runner_GB = 1.5; + } else if (model_name.find("_hac@v") != std::string::npos) { + required_ram_per_runner_GB = 4.5; + } else if (model_name.find("_sup@v") != std::string::npos) { + required_ram_per_runner_GB = 12.5; + } else { + return 1; + } + + // numbers were determined with a batch_size of 128, assume this just scales + required_ram_per_runner_GB *= batch_size / 128.f; + + auto free_ram_GB = utils::available_host_memory_GB() * memory_fraction; + auto num_runners = static_cast(free_ram_GB / required_ram_per_runner_GB); + return std::clamp(num_runners, size_t(1), std::size_t(std::thread::hardware_concurrency())); +} + +} // namespace dorado::basecall diff --git a/dorado/nn/CRFModel.h b/dorado/basecall/crf_utils.h similarity index 64% rename from dorado/nn/CRFModel.h rename to dorado/basecall/crf_utils.h index 03bddd75..57881699 100644 --- a/dorado/nn/CRFModel.h +++ b/dorado/basecall/crf_utils.h @@ -1,11 +1,11 @@ #pragma once -#include +#include #include #include -namespace dorado { +namespace dorado::basecall { struct CRFModelConfig; @@ -13,11 +13,8 @@ std::vector load_crf_model_weights(const std::filesystem::path& dir, bool decomposition, bool bias); -torch::nn::ModuleHolder load_crf_model(const CRFModelConfig& model_config, - const at::TensorOptions& options); - size_t auto_calculate_num_runners(const CRFModelConfig& model_config, size_t batch_size, float memory_fraction); -} // namespace dorado +} // namespace dorado::basecall diff --git a/dorado/decode/CPUDecoder.cpp b/dorado/basecall/decode/CPUDecoder.cpp similarity index 98% rename from dorado/decode/CPUDecoder.cpp rename to dorado/basecall/decode/CPUDecoder.cpp index f95d37c6..f405c003 100644 --- a/dorado/decode/CPUDecoder.cpp +++ b/dorado/basecall/decode/CPUDecoder.cpp @@ -86,7 +86,7 @@ at::Tensor backward_scores(const at::Tensor& scores, const float fixed_stay_scor } // namespace -namespace dorado { +namespace dorado::basecall::decode { std::vector CPUDecoder::beam_search(const at::Tensor& scores, const int num_chunks, @@ -144,4 +144,4 @@ std::vector CPUDecoder::beam_search(const at::Tensor& scores, return chunk_results; } -} // namespace dorado +} // namespace dorado::basecall::decode diff --git a/dorado/decode/CPUDecoder.h b/dorado/basecall/decode/CPUDecoder.h similarity index 83% rename from dorado/decode/CPUDecoder.h rename to dorado/basecall/decode/CPUDecoder.h index b40ad4f7..b2011d5c 100644 --- a/dorado/decode/CPUDecoder.h +++ b/dorado/basecall/decode/CPUDecoder.h @@ -4,7 +4,7 @@ #include -namespace dorado { +namespace dorado::basecall::decode { class CPUDecoder final : Decoder { public: @@ -14,4 +14,4 @@ class CPUDecoder final : Decoder { constexpr static at::ScalarType dtype = at::ScalarType::Float; }; -} // namespace dorado +} // namespace dorado::basecall::decode diff --git a/dorado/decode/Decoder.h b/dorado/basecall/decode/Decoder.h similarity index 89% rename from dorado/decode/Decoder.h rename to dorado/basecall/decode/Decoder.h index 3c354769..09ef76bf 100644 --- a/dorado/decode/Decoder.h +++ b/dorado/basecall/decode/Decoder.h @@ -5,7 +5,7 @@ #include #include -namespace dorado { +namespace dorado::basecall::decode { struct DecodedChunk { std::string sequence; @@ -30,4 +30,4 @@ class Decoder { const DecoderOptions& options) = 0; }; -} // namespace dorado +} // namespace dorado::basecall::decode diff --git a/dorado/decode/GPUDecoder.cpp b/dorado/basecall/decode/GPUDecoder.cpp similarity index 97% rename from dorado/decode/GPUDecoder.cpp rename to dorado/basecall/decode/GPUDecoder.cpp index 7718f3f0..7561371b 100644 --- a/dorado/decode/GPUDecoder.cpp +++ b/dorado/basecall/decode/GPUDecoder.cpp @@ -1,8 +1,8 @@ #include "GPUDecoder.h" -#include "../utils/cuda_utils.h" -#include "../utils/gpu_profiling.h" #include "Decoder.h" +#include "utils/cuda_utils.h" +#include "utils/gpu_profiling.h" #include #include @@ -11,7 +11,7 @@ extern "C" { #include "koi.h" } -namespace dorado { +namespace dorado::basecall::decode { at::Tensor GPUDecoder::gpu_part(at::Tensor scores, DecoderOptions options) { c10::cuda::CUDAGuard device_guard(scores.device()); @@ -113,4 +113,4 @@ std::vector GPUDecoder::beam_search(const at::Tensor &scores, return cpu_part(gpu_part(scores, options)); } -} // namespace dorado +} // namespace dorado::basecall::decode diff --git a/dorado/decode/GPUDecoder.h b/dorado/basecall/decode/GPUDecoder.h similarity index 91% rename from dorado/decode/GPUDecoder.h rename to dorado/basecall/decode/GPUDecoder.h index bf7f90bb..46ff4d28 100644 --- a/dorado/decode/GPUDecoder.h +++ b/dorado/basecall/decode/GPUDecoder.h @@ -4,7 +4,7 @@ #include -namespace dorado { +namespace dorado::basecall::decode { class GPUDecoder final : Decoder { public: @@ -24,4 +24,4 @@ class GPUDecoder final : Decoder { float m_score_clamp_val; }; -} // namespace dorado +} // namespace dorado::basecall::decode diff --git a/dorado/decode/beam_search.cpp b/dorado/basecall/decode/beam_search.cpp similarity index 99% rename from dorado/decode/beam_search.cpp rename to dorado/basecall/decode/beam_search.cpp index b7393c5f..9dad798b 100644 --- a/dorado/decode/beam_search.cpp +++ b/dorado/basecall/decode/beam_search.cpp @@ -117,6 +117,8 @@ uint32_t crc32c(uint32_t crc, uint32_t new_bits) { } // anonymous namespace +namespace dorado::basecall::decode { + template float beam_search(const T* const scores, size_t scores_block_stride, @@ -586,3 +588,5 @@ std::tuple> beam_search_decode( return {std::move(sequence), std::move(qstring), std::move(moves)}; } + +} // namespace dorado::basecall::decode diff --git a/dorado/decode/beam_search.h b/dorado/basecall/decode/beam_search.h similarity index 86% rename from dorado/decode/beam_search.h rename to dorado/basecall/decode/beam_search.h index 9b38d8f3..5561a280 100644 --- a/dorado/decode/beam_search.h +++ b/dorado/basecall/decode/beam_search.h @@ -8,6 +8,7 @@ #include #include +namespace dorado::basecall::decode { std::tuple> beam_search_decode( const at::Tensor& scores_t, const at::Tensor& back_guides_t, @@ -18,3 +19,4 @@ std::tuple> beam_search_decode( float q_shift, float q_scale, float byte_score_scale); +} // namespace dorado::basecall::decode diff --git a/dorado/nn/metal/nn.metal b/dorado/basecall/metal/nn.metal similarity index 100% rename from dorado/nn/metal/nn.metal rename to dorado/basecall/metal/nn.metal diff --git a/dorado/cli/basecaller.cpp b/dorado/cli/basecaller.cpp index 63a5634a..9edc07b3 100644 --- a/dorado/cli/basecaller.cpp +++ b/dorado/cli/basecaller.cpp @@ -1,18 +1,17 @@ #include "Version.h" +#include "api/pipeline_creation.h" +#include "api/runner_creation.h" +#include "basecall/CRFModelConfig.h" #include "cli/cli_utils.h" #include "data_loader/DataLoader.h" #include "data_loader/ModelFinder.h" #include "models/kits.h" #include "models/models.h" -#include "nn/CRFModelConfig.h" -#include "nn/ModBaseRunner.h" -#include "nn/Runners.h" #include "read_pipeline/AdapterDetectorNode.h" #include "read_pipeline/AlignerNode.h" #include "read_pipeline/BarcodeClassifierNode.h" #include "read_pipeline/HtsReader.h" #include "read_pipeline/HtsWriter.h" -#include "read_pipeline/Pipelines.h" #include "read_pipeline/PolyACalculator.h" #include "read_pipeline/ProgressTracker.h" #include "read_pipeline/ReadFilterNode.h" @@ -86,7 +85,7 @@ void setup(std::vector args, argparse::ArgumentParser& resume_parser, bool estimate_poly_a, const ModelSelection& model_selection) { - const auto model_config = load_crf_model_config(model_path); + const auto model_config = basecall::load_crf_model_config(model_path); const std::string model_name = models::extract_model_name_from_path(model_path); const std::string modbase_model_names = models::extract_model_names_from_paths(remora_models); diff --git a/dorado/cli/duplex.cpp b/dorado/cli/duplex.cpp index ffd14ad4..241bd28a 100644 --- a/dorado/cli/duplex.cpp +++ b/dorado/cli/duplex.cpp @@ -1,18 +1,17 @@ #include "Version.h" +#include "api/pipeline_creation.h" +#include "api/runner_creation.h" +#include "basecall/CRFModelConfig.h" #include "cli/cli_utils.h" #include "data_loader/DataLoader.h" #include "data_loader/ModelFinder.h" #include "models/kits.h" #include "models/metadata.h" #include "models/models.h" -#include "nn/CRFModelConfig.h" -#include "nn/ModBaseRunner.h" -#include "nn/Runners.h" #include "read_pipeline/AlignerNode.h" #include "read_pipeline/BaseSpaceDuplexCallerNode.h" #include "read_pipeline/DuplexReadTaggingNode.h" #include "read_pipeline/HtsWriter.h" -#include "read_pipeline/Pipelines.h" #include "read_pipeline/ProgressTracker.h" #include "read_pipeline/ReadFilterNode.h" #include "read_pipeline/ReadToBamTypeNode.h" @@ -48,10 +47,10 @@ namespace { struct DuplexModels { std::filesystem::path model_path; std::string model_name; - CRFModelConfig model_config; + basecall::CRFModelConfig model_config; std::filesystem::path stereo_model; - CRFModelConfig stereo_model_config; + basecall::CRFModelConfig stereo_model_config; std::string stereo_model_name; std::vector mods_model_paths; @@ -122,7 +121,7 @@ DuplexModels load_models(const std::string& model_arg, } if (!skip_model_compatibility_check) { - const auto model_config = load_crf_model_config(model_path); + const auto model_config = basecall::load_crf_model_config(model_path); const auto model_name = model_path.filename().string(); check_sampling_rates_compatible(model_name, reads, model_config.sample_rate, recursive_file_loading); @@ -153,10 +152,10 @@ DuplexModels load_models(const std::string& model_arg, } const auto model_name = model_finder.get_simplex_model_name(); - const auto model_config = load_crf_model_config(model_path); + const auto model_config = basecall::load_crf_model_config(model_path); const auto stereo_model_name = stereo_model_path.filename().string(); - const auto stereo_model_config = load_crf_model_config(stereo_model_path); + const auto stereo_model_config = basecall::load_crf_model_config(stereo_model_path); return DuplexModels{model_path, model_name, model_config, stereo_model_path, @@ -488,7 +487,7 @@ int duplex(int argc, char* argv[]) { create_basecall_runners(models.model_config, device, num_runners, 0, batch_size, chunk_size, 0.9f, true); - std::vector stereo_runners; + std::vector stereo_runners; // The fraction argument for GPU memory allocates the fraction of the // _remaining_ memory to the caller. So, we allocate all of the available // memory after simplex caller has been instantiated to the duplex caller. diff --git a/dorado/modbase/CMakeLists.txt b/dorado/modbase/CMakeLists.txt new file mode 100644 index 00000000..27a4e9a9 --- /dev/null +++ b/dorado/modbase/CMakeLists.txt @@ -0,0 +1,42 @@ +add_library(dorado_modbase STATIC + ModBaseContext.cpp + ModBaseContext.h + ModbaseEncoder.cpp + ModbaseEncoder.h + ModBaseModel.cpp + ModBaseModel.h + ModBaseModelConfig.cpp + ModBaseModelConfig.h + ModBaseRunner.cpp + ModBaseRunner.h + ModbaseScaler.cpp + ModbaseScaler.h + MotifMatcher.cpp + MotifMatcher.h +) + +target_include_directories(dorado_modbase + SYSTEM + PUBLIC + PRIVATE + ${DORADO_3RD_PARTY_SOURCE}/NVTX/c/include + ${DORADO_3RD_PARTY_SOURCE}/toml11 +) + +target_link_libraries(dorado_modbase + PUBLIC + ${TORCH_LIBRARIES} + PRIVATE + dorado_utils + spdlog::spdlog +) + +enable_warnings_as_errors(dorado_modbase) + +if (DORADO_ENABLE_PCH) + target_precompile_headers(dorado_modbase REUSE_FROM dorado_utils) + # these are defined publicly by minimap2, which we're not using. + # we need to define them here so that the environment matches the one + # used when building the PCH + target_compile_definitions(dorado_modbase PRIVATE PTW32_CLEANUP_C PTW32_STATIC_LIB) +endif() diff --git a/dorado/modbase/ModBaseContext.cpp b/dorado/modbase/ModBaseContext.cpp index b7845be5..c38a579b 100644 --- a/dorado/modbase/ModBaseContext.cpp +++ b/dorado/modbase/ModBaseContext.cpp @@ -5,14 +5,16 @@ #include -namespace dorado::utils { +namespace dorado::modbase { ModBaseContext::ModBaseContext() {} ModBaseContext::~ModBaseContext() {} -const std::string& ModBaseContext::motif(char base) const { return m_motifs[base_to_int(base)]; } +const std::string& ModBaseContext::motif(char base) const { + return m_motifs[utils::base_to_int(base)]; +} -size_t ModBaseContext::motif_offset(char base) const { return m_offsets[base_to_int(base)]; } +size_t ModBaseContext::motif_offset(char base) const { return m_offsets[utils::base_to_int(base)]; } void ModBaseContext::set_context(std::string motif, size_t offset) { if (motif.size() < 2) { @@ -20,7 +22,7 @@ void ModBaseContext::set_context(std::string motif, size_t offset) { return; } char base = motif.at(offset); - auto index = base_to_int(base); + auto index = utils::base_to_int(base); m_motif_matchers[index] = std::make_unique(motif, offset); m_motifs[index] = std::move(motif); m_offsets[index] = offset; @@ -99,7 +101,7 @@ void ModBaseContext::update_mask(std::vector& mask, // A cardinal base. current_cardinal = modbase_alphabet[channel_idx][0]; } else { - if (!m_motifs[base_to_int(current_cardinal)].empty()) { + if (!m_motifs[utils::base_to_int(current_cardinal)].empty()) { // This cardinal base has a context associated with modifications, so the mask should // not be updated, regardless of the threshold. continue; @@ -115,4 +117,4 @@ void ModBaseContext::update_mask(std::vector& mask, } } -} // namespace dorado::utils \ No newline at end of file +} // namespace dorado::modbase \ No newline at end of file diff --git a/dorado/modbase/ModBaseContext.h b/dorado/modbase/ModBaseContext.h index c806091f..75e06166 100644 --- a/dorado/modbase/ModBaseContext.h +++ b/dorado/modbase/ModBaseContext.h @@ -24,12 +24,9 @@ * be "CXT:XG:_:_". */ -namespace dorado { +namespace dorado::modbase { struct ModBaseModelConfig; class MotifMatcher; -} // namespace dorado - -namespace dorado::utils { class ModBaseContext { public: @@ -97,4 +94,4 @@ class ModBaseContext { std::array, 4> m_motif_matchers; }; -} // namespace dorado::utils +} // namespace dorado::modbase diff --git a/dorado/nn/ModBaseModel.cpp b/dorado/modbase/ModBaseModel.cpp similarity index 99% rename from dorado/nn/ModBaseModel.cpp rename to dorado/modbase/ModBaseModel.cpp index de1a03be..1ccae387 100644 --- a/dorado/nn/ModBaseModel.cpp +++ b/dorado/modbase/ModBaseModel.cpp @@ -28,7 +28,7 @@ ModuleHolder populate_model(Model&& model, } } // namespace -namespace dorado { +namespace dorado::modbase { namespace nn { @@ -257,4 +257,4 @@ ModuleHolder load_modbase_model(const std::filesystem::path& model_pa throw std::runtime_error("Unknown model type in config file."); } -} // namespace dorado +} // namespace dorado::modbase diff --git a/dorado/nn/ModBaseModel.h b/dorado/modbase/ModBaseModel.h similarity index 78% rename from dorado/nn/ModBaseModel.h rename to dorado/modbase/ModBaseModel.h index b1ce4f80..a8c6111f 100644 --- a/dorado/nn/ModBaseModel.h +++ b/dorado/modbase/ModBaseModel.h @@ -4,10 +4,10 @@ #include -namespace dorado { +namespace dorado::modbase { torch::nn::ModuleHolder load_modbase_model( const std::filesystem::path& model_path, at::TensorOptions options); -} // namespace dorado +} // namespace dorado::modbase diff --git a/dorado/nn/ModBaseModelConfig.cpp b/dorado/modbase/ModBaseModelConfig.cpp similarity index 98% rename from dorado/nn/ModBaseModelConfig.cpp rename to dorado/modbase/ModBaseModelConfig.cpp index 711bae41..1111c7fd 100644 --- a/dorado/nn/ModBaseModelConfig.cpp +++ b/dorado/modbase/ModBaseModelConfig.cpp @@ -6,7 +6,7 @@ #include -namespace dorado { +namespace dorado::modbase { ModBaseModelConfig load_modbase_model_config(const std::filesystem::path& model_path) { ModBaseModelConfig config; @@ -130,4 +130,4 @@ ModBaseInfo get_modbase_info( return result; } -} // namespace dorado +} // namespace dorado::modbase diff --git a/dorado/nn/ModBaseModelConfig.h b/dorado/modbase/ModBaseModelConfig.h similarity index 96% rename from dorado/nn/ModBaseModelConfig.h rename to dorado/modbase/ModBaseModelConfig.h index cd7d9a22..84fabbdc 100644 --- a/dorado/nn/ModBaseModelConfig.h +++ b/dorado/modbase/ModBaseModelConfig.h @@ -6,7 +6,7 @@ #include #include -namespace dorado { +namespace dorado::modbase { struct ModBaseModelConfig { std::vector mod_long_names; ///< The long names of the modified bases. @@ -32,4 +32,4 @@ ModBaseModelConfig load_modbase_model_config(const std::filesystem::path & model ModBaseInfo get_modbase_info( const std::vector> & base_mod_params); -} // namespace dorado +} // namespace dorado::modbase diff --git a/dorado/nn/ModBaseRunner.cpp b/dorado/modbase/ModBaseRunner.cpp similarity index 99% rename from dorado/nn/ModBaseRunner.cpp rename to dorado/modbase/ModBaseRunner.cpp index d5faa603..21307d55 100644 --- a/dorado/nn/ModBaseRunner.cpp +++ b/dorado/modbase/ModBaseRunner.cpp @@ -2,8 +2,8 @@ #include "ModBaseModel.h" #include "ModBaseModelConfig.h" -#include "modbase/MotifMatcher.h" -#include "modbase/modbase_scaler.h" +#include "ModbaseScaler.h" +#include "MotifMatcher.h" #include "utils/sequence_utils.h" #include "utils/stats.h" #include "utils/tensor_utils.h" @@ -22,7 +22,7 @@ using namespace std::chrono_literals; -namespace dorado { +namespace dorado::modbase { class ModBaseCaller { public: @@ -363,4 +363,4 @@ stats::NamedStats ModBaseRunner::sample_stats() const { return stats; } -} // namespace dorado +} // namespace dorado::modbase diff --git a/dorado/nn/ModBaseRunner.h b/dorado/modbase/ModBaseRunner.h similarity index 93% rename from dorado/nn/ModBaseRunner.h rename to dorado/modbase/ModBaseRunner.h index db84bb78..63ee9466 100644 --- a/dorado/nn/ModBaseRunner.h +++ b/dorado/modbase/ModBaseRunner.h @@ -11,7 +11,7 @@ #include #include -namespace dorado { +namespace dorado::modbase { struct ModBaseModelConfig; class ModBaseCaller; @@ -52,4 +52,6 @@ class ModBaseRunner { std::atomic m_num_batches_called = 0; }; -} // namespace dorado +using RunnerPtr = std::unique_ptr; + +} // namespace dorado::modbase diff --git a/dorado/modbase/modbase_encoder.cpp b/dorado/modbase/ModbaseEncoder.cpp similarity index 99% rename from dorado/modbase/modbase_encoder.cpp rename to dorado/modbase/ModbaseEncoder.cpp index 5a14f916..8b2728c0 100644 --- a/dorado/modbase/modbase_encoder.cpp +++ b/dorado/modbase/ModbaseEncoder.cpp @@ -1,4 +1,4 @@ -#include "modbase_encoder.h" +#include "ModbaseEncoder.h" #include "utils/sequence_utils.h" #include "utils/simd.h" @@ -9,7 +9,7 @@ #include #include -namespace dorado { +namespace dorado::modbase { ModBaseEncoder::ModBaseEncoder(size_t block_stride, size_t context_samples, @@ -275,4 +275,4 @@ std::vector ModBaseEncoder::encode_kmer(const std::vector& seq, m_kmer_len); } -} // namespace dorado +} // namespace dorado::modbase diff --git a/dorado/modbase/modbase_encoder.h b/dorado/modbase/ModbaseEncoder.h similarity index 97% rename from dorado/modbase/modbase_encoder.h rename to dorado/modbase/ModbaseEncoder.h index 3360aab3..6d5e7dc7 100644 --- a/dorado/modbase/modbase_encoder.h +++ b/dorado/modbase/ModbaseEncoder.h @@ -4,7 +4,7 @@ #include #include -namespace dorado { +namespace dorado::modbase { class ModBaseEncoder { private: @@ -61,4 +61,4 @@ class ModBaseEncoder { Context get_context(size_t seq_pos) const; }; -} // namespace dorado +} // namespace dorado::modbase diff --git a/dorado/modbase/modbase_scaler.cpp b/dorado/modbase/ModbaseScaler.cpp similarity index 97% rename from dorado/modbase/modbase_scaler.cpp rename to dorado/modbase/ModbaseScaler.cpp index 1b5a0d18..c8c3971a 100644 --- a/dorado/modbase/modbase_scaler.cpp +++ b/dorado/modbase/ModbaseScaler.cpp @@ -1,4 +1,4 @@ -#include "modbase_scaler.h" +#include "ModbaseScaler.h" #include "utils/math_utils.h" @@ -8,7 +8,7 @@ #include #include -namespace dorado { +namespace dorado::modbase { ModBaseScaler::ModBaseScaler(const std::vector& kmer_levels, size_t kmer_len, @@ -98,4 +98,4 @@ std::pair ModBaseScaler::calc_offset_scale( return {new_offset, new_scale}; } -} // namespace dorado +} // namespace dorado::modbase diff --git a/dorado/modbase/modbase_scaler.h b/dorado/modbase/ModbaseScaler.h similarity index 97% rename from dorado/modbase/modbase_scaler.h rename to dorado/modbase/ModbaseScaler.h index eb3db861..2c8e1304 100644 --- a/dorado/modbase/modbase_scaler.h +++ b/dorado/modbase/ModbaseScaler.h @@ -7,7 +7,7 @@ #include #include -namespace dorado { +namespace dorado::modbase { /// Calculates new scaling values for improved modified base detection class ModBaseScaler { @@ -64,4 +64,4 @@ class ModBaseScaler { ModBaseScaler(const std::vector& kmer_levels, size_t kmer_len, size_t centre_index); }; -} // namespace dorado +} // namespace dorado::modbase diff --git a/dorado/modbase/MotifMatcher.cpp b/dorado/modbase/MotifMatcher.cpp index 0911c802..ec601293 100644 --- a/dorado/modbase/MotifMatcher.cpp +++ b/dorado/modbase/MotifMatcher.cpp @@ -1,6 +1,6 @@ #include "MotifMatcher.h" -#include "nn/ModBaseModelConfig.h" +#include "ModBaseModelConfig.h" #include @@ -41,7 +41,7 @@ std::string expand_motif_regex(const std::string& motif) { } // namespace -namespace dorado { +namespace dorado::modbase { MotifMatcher::MotifMatcher(const ModBaseModelConfig& model_config) : MotifMatcher(model_config.motif, model_config.motif_offset) {} @@ -68,4 +68,4 @@ std::vector MotifMatcher::get_motif_hits(std::string_view seq) const { return context_hits; } -} // namespace dorado +} // namespace dorado::modbase diff --git a/dorado/modbase/MotifMatcher.h b/dorado/modbase/MotifMatcher.h index f203f3cf..0501a34b 100644 --- a/dorado/modbase/MotifMatcher.h +++ b/dorado/modbase/MotifMatcher.h @@ -4,7 +4,7 @@ #include #include -namespace dorado { +namespace dorado::modbase { struct ModBaseModelConfig; class MotifMatcher { @@ -19,4 +19,4 @@ class MotifMatcher { const size_t m_motif_offset; }; -} // namespace dorado +} // namespace dorado::modbase diff --git a/dorado/read_pipeline/BasecallerNode.cpp b/dorado/read_pipeline/BasecallerNode.cpp index e29bcaed..c39882f6 100644 --- a/dorado/read_pipeline/BasecallerNode.cpp +++ b/dorado/read_pipeline/BasecallerNode.cpp @@ -1,6 +1,7 @@ #include "BasecallerNode.h" -#include "decode/CPUDecoder.h" +#include "basecall/ModelRunner.h" +#include "basecall/decode/CPUDecoder.h" #include "stitch.h" #include "utils/stats.h" @@ -116,7 +117,7 @@ void BasecallerNode::input_worker_thread() { void BasecallerNode::basecall_current_batch(int worker_id) { NVTX3_FUNC_RANGE(); - auto model_runner = m_model_runners[worker_id]; + auto &model_runner = m_model_runners[worker_id]; dorado::stats::Timer timer; auto decode_results = model_runner->call_chunks(int(m_batched_chunks[worker_id].size())); m_call_chunks_ms += timer.GetElapsedMS(); @@ -288,7 +289,7 @@ void BasecallerNode::basecall_worker_thread(int worker_id) { namespace { // Calculates the input queue size. -size_t CalcMaxChunksIn(const std::vector &model_runners) { +size_t CalcMaxChunksIn(const std::vector &model_runners) { // Allow 2 batches per model runner on the chunks_in queue size_t max_chunks_in = 0; // Allows optimal batch size to be used for every GPU @@ -300,7 +301,7 @@ size_t CalcMaxChunksIn(const std::vector &model_runners) { } // namespace -BasecallerNode::BasecallerNode(std::vector model_runners, +BasecallerNode::BasecallerNode(std::vector model_runners, size_t overlap, int batch_timeout_ms, std::string model_name, diff --git a/dorado/read_pipeline/BasecallerNode.h b/dorado/read_pipeline/BasecallerNode.h index 9a3236e5..5191c2cb 100644 --- a/dorado/read_pipeline/BasecallerNode.h +++ b/dorado/read_pipeline/BasecallerNode.h @@ -1,6 +1,5 @@ #pragma once -#include "../nn/ModelRunner.h" #include "ReadPipeline.h" #include "utils/AsyncQueue.h" #include "utils/stats.h" @@ -14,13 +13,18 @@ namespace dorado { +namespace basecall { +class ModelRunnerBase; +using RunnerPtr = std::unique_ptr; +} // namespace basecall + class BasecallerNode : public MessageSink { struct BasecallingRead; struct BasecallingChunk; public: // Chunk size and overlap are in raw samples - BasecallerNode(std::vector model_runners, + BasecallerNode(std::vector model_runners, size_t overlap, int batch_timeout_ms, std::string model_name, @@ -46,7 +50,7 @@ class BasecallerNode : public MessageSink { void working_reads_manager(); // Vector of model runners (each with their own GPU access etc) - std::vector m_model_runners; + std::vector m_model_runners; // Chunk length size_t m_chunk_size; // Minimum overlap between two adjacent chunks in a read. Overlap is used to reduce edge effects and improve accuracy. diff --git a/dorado/read_pipeline/ModBaseCallerNode.cpp b/dorado/read_pipeline/ModBaseCallerNode.cpp index 974d340f..ae56520f 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.cpp +++ b/dorado/read_pipeline/ModBaseCallerNode.cpp @@ -1,9 +1,9 @@ #include "ModBaseCallerNode.h" #include "modbase/ModBaseContext.h" -#include "modbase/modbase_encoder.h" -#include "nn/ModBaseModelConfig.h" -#include "nn/ModBaseRunner.h" +#include "modbase/ModBaseModelConfig.h" +#include "modbase/ModBaseRunner.h" +#include "modbase/ModbaseEncoder.h" #include "utils/math_utils.h" #include "utils/sequence_utils.h" #include "utils/stats.h" @@ -50,7 +50,7 @@ struct ModBaseCallerNode::WorkingRead { num_modbase_chunks_called; // Number of modbase chunks which have been scored }; -ModBaseCallerNode::ModBaseCallerNode(std::vector> model_runners, +ModBaseCallerNode::ModBaseCallerNode(std::vector model_runners, size_t remora_threads, size_t block_stride, size_t max_reads) @@ -125,9 +125,9 @@ void ModBaseCallerNode::restart() { } void ModBaseCallerNode::init_modbase_info() { - std::vector> base_mod_params; + std::vector> base_mod_params; auto& runner = m_runners[0]; - utils::ModBaseContext context_handler; + modbase::ModBaseContext context_handler; for (size_t caller_id = 0; caller_id < runner->num_callers(); ++caller_id) { const auto& params = runner->caller_params(caller_id); if (!params.motif.empty()) { @@ -137,7 +137,7 @@ void ModBaseCallerNode::init_modbase_info() { m_num_states += params.base_mod_count; } - auto result = get_modbase_info(base_mod_params); + auto result = modbase::get_modbase_info(base_mod_params); m_mod_base_info = std::make_shared( std::move(result.alphabet), std::move(result.long_names), context_handler.encode()); @@ -229,8 +229,8 @@ void ModBaseCallerNode::duplex_mod_call(Message&& message) { auto context_samples = (params.context_before + params.context_after); // One-hot encodes the kmer at each signal step for input into the network - ModBaseEncoder encoder(m_block_stride, context_samples, params.bases_before, - params.bases_after); + modbase::ModBaseEncoder encoder(m_block_stride, context_samples, + params.bases_before, params.bases_after); encoder.init(sequence_ints, seq_to_sig_map); auto context_hits = runner->get_motif_hits(caller_id, new_seq); @@ -355,8 +355,8 @@ void ModBaseCallerNode::simplex_mod_call(Message&& message) { auto context_samples = (params.context_before + params.context_after); // One-hot encodes the kmer at each signal step for input into the network - ModBaseEncoder encoder(m_block_stride, context_samples, params.bases_before, - params.bases_after); + modbase::ModBaseEncoder encoder(m_block_stride, context_samples, params.bases_before, + params.bases_after); encoder.init(sequence_ints, seq_to_sig_map); auto context_hits = runner->get_motif_hits(caller_id, read->read_common.seq); diff --git a/dorado/read_pipeline/ModBaseCallerNode.h b/dorado/read_pipeline/ModBaseCallerNode.h index a97d032f..b3dfe9e4 100644 --- a/dorado/read_pipeline/ModBaseCallerNode.h +++ b/dorado/read_pipeline/ModBaseCallerNode.h @@ -16,14 +16,17 @@ namespace dorado { +namespace modbase { class ModBaseRunner; +using RunnerPtr = std::unique_ptr; +} // namespace modbase class ModBaseCallerNode : public MessageSink { struct RemoraChunk; struct WorkingRead; public: - ModBaseCallerNode(std::vector> model_runners, + ModBaseCallerNode(std::vector model_runners, size_t remora_threads, size_t block_stride, size_t max_reads); @@ -56,7 +59,7 @@ class ModBaseCallerNode : public MessageSink { // Worker thread, processes chunk results back into the reads void output_worker_thread(); - std::vector> m_runners; + std::vector m_runners; size_t m_num_input_workers = 0; size_t m_block_stride; size_t m_batch_size; diff --git a/dorado/read_pipeline/ReadPipeline.cpp b/dorado/read_pipeline/ReadPipeline.cpp index 42b7f668..d79371a0 100644 --- a/dorado/read_pipeline/ReadPipeline.cpp +++ b/dorado/read_pipeline/ReadPipeline.cpp @@ -158,7 +158,7 @@ void ReadCommon::generate_modbase_tags(bam1_t *aln, uint8_t threshold) const { // Create a mask indicating which bases are modified. std::unordered_map base_has_context = { {'A', false}, {'C', false}, {'G', false}, {'T', false}}; - utils::ModBaseContext context_handler; + modbase::ModBaseContext context_handler; if (!mod_base_info->context.empty()) { if (!context_handler.decode(mod_base_info->context)) { throw std::runtime_error("Invalid base modification context string."); diff --git a/dorado/read_pipeline/ScalerNode.cpp b/dorado/read_pipeline/ScalerNode.cpp index 6fa77e79..e42eed82 100644 --- a/dorado/read_pipeline/ScalerNode.cpp +++ b/dorado/read_pipeline/ScalerNode.cpp @@ -1,6 +1,6 @@ #include "ScalerNode.h" -#include "nn/CRFModelConfig.h" +#include "basecall/CRFModelConfig.h" #include "utils/tensor_utils.h" #include "utils/trim.h" @@ -20,6 +20,9 @@ using namespace std::chrono_literals; using Slice = at::indexing::Slice; namespace dorado { +using SampleType = basecall::SampleType; +using ScalingStrategy = basecall::ScalingStrategy; +using SignalNormalisationParams = basecall::SignalNormalisationParams; std::pair ScalerNode::normalisation(const at::Tensor& x) { // Calculate shift and scale factors for normalisation. diff --git a/dorado/read_pipeline/ScalerNode.h b/dorado/read_pipeline/ScalerNode.h index 64f65c11..e4cea5ef 100644 --- a/dorado/read_pipeline/ScalerNode.h +++ b/dorado/read_pipeline/ScalerNode.h @@ -1,6 +1,6 @@ #pragma once #include "ReadPipeline.h" -#include "nn/CRFModelConfig.h" +#include "basecall/CRFModelConfig.h" #include "utils/stats.h" #include @@ -15,8 +15,8 @@ namespace dorado { class ScalerNode : public MessageSink { public: - ScalerNode(const SignalNormalisationParams& config, - SampleType model_type, + ScalerNode(const basecall::SignalNormalisationParams& config, + basecall::SampleType model_type, int num_worker_threads, size_t max_reads); ~ScalerNode() { terminate_impl(); } @@ -32,8 +32,8 @@ class ScalerNode : public MessageSink { std::vector> m_worker_threads; std::atomic m_num_worker_threads; - SignalNormalisationParams m_scaling_params; - const SampleType m_model_type; + basecall::SignalNormalisationParams m_scaling_params; + const basecall::SampleType m_model_type; std::pair med_mad(const at::Tensor& x); std::pair normalisation(const at::Tensor& x); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8d34f125..ec727dc0 100755 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -75,6 +75,8 @@ foreach(TEST_BIN dorado_tests dorado_smoke_tests) dorado_lib dorado_io_lib dorado_models_lib + dorado_basecall + dorado_modbase minimap2 ${ZLIB_LIBRARIES} ${POD5_LIBRARIES} diff --git a/tests/CRFModelConfigTest.cpp b/tests/CRFModelConfigTest.cpp index e6dc0342..1333f226 100644 --- a/tests/CRFModelConfigTest.cpp +++ b/tests/CRFModelConfigTest.cpp @@ -1,4 +1,4 @@ -#include "nn/CRFModelConfig.h" +#include "basecall/CRFModelConfig.h" #include "TestUtils.h" @@ -12,7 +12,7 @@ #define CUT_TAG "[CRFModelConfig]" -using namespace dorado; +using namespace dorado::basecall; namespace fs = std::filesystem; TEST_CASE(CUT_TAG ": test dna_r9.4.1 hac@v3.3 model load", CUT_TAG) { diff --git a/tests/ModBaseEncoderTest.cpp b/tests/ModBaseEncoderTest.cpp index e29754e2..feace473 100644 --- a/tests/ModBaseEncoderTest.cpp +++ b/tests/ModBaseEncoderTest.cpp @@ -1,4 +1,5 @@ -#include "modbase/modbase_encoder.h" +#include "modbase/ModbaseEncoder.h" + #include "utils/sequence_utils.h" #include @@ -15,7 +16,7 @@ TEST_CASE("Encode sequence for modified basecalling", TEST_GROUP) { auto seq_to_sig_map = dorado::utils::moves_to_map(moves, BLOCK_STRIDE, moves.size() * BLOCK_STRIDE, std::nullopt); - dorado::ModBaseEncoder encoder(BLOCK_STRIDE, SLICE_BLOCKS * BLOCK_STRIDE, 1, 1); + dorado::modbase::ModBaseEncoder encoder(BLOCK_STRIDE, SLICE_BLOCKS * BLOCK_STRIDE, 1, 1); encoder.init(seq_ints, seq_to_sig_map); auto slice0 = encoder.get_context(0); // The T in the NTA 3mer. diff --git a/tests/MotifMatcherTest.cpp b/tests/MotifMatcherTest.cpp index 77e380dc..ad9ed2bc 100644 --- a/tests/MotifMatcherTest.cpp +++ b/tests/MotifMatcherTest.cpp @@ -1,6 +1,6 @@ #include "modbase/MotifMatcher.h" -#include "nn/ModBaseModelConfig.h" +#include "modbase/ModBaseModelConfig.h" #include @@ -24,7 +24,7 @@ const std::string SEQ = "AACCGGTTACGTGGACTGACACTAAA"; } // namespace TEST_CASE(TEST_GROUP ": test motifs", TEST_GROUP) { - dorado::ModBaseModelConfig config; + dorado::modbase::ModBaseModelConfig config; auto [motif, motif_offset, expected_results] = GENERATE(table>({ // clang-format off @@ -42,7 +42,7 @@ TEST_CASE(TEST_GROUP ": test motifs", TEST_GROUP) { config.motif = motif; config.motif_offset = motif_offset; - dorado::MotifMatcher matcher(config); + dorado::modbase::MotifMatcher matcher(config); auto hits = matcher.get_motif_hits(SEQ); CHECK(hits == expected_results); } diff --git a/tests/NodeSmokeTest.cpp b/tests/NodeSmokeTest.cpp index 8f2cca06..ef86c218 100644 --- a/tests/NodeSmokeTest.cpp +++ b/tests/NodeSmokeTest.cpp @@ -1,12 +1,12 @@ #include "MessageSinkUtils.h" #include "TestUtils.h" -#include "decode/CPUDecoder.h" +#include "basecall/CRFModel.h" +#include "basecall/CRFModelConfig.h" +#include "basecall/ModelRunner.h" +#include "basecall/decode/CPUDecoder.h" +#include "modbase/ModBaseModel.h" +#include "modbase/ModBaseRunner.h" #include "models/models.h" -#include "nn/CRFModel.h" -#include "nn/CRFModelConfig.h" -#include "nn/ModBaseModel.h" -#include "nn/ModBaseRunner.h" -#include "nn/ModelRunner.h" #include "read_pipeline/AdapterDetectorNode.h" #include "read_pipeline/BarcodeClassifierNode.h" #include "read_pipeline/BasecallerNode.h" @@ -21,9 +21,9 @@ #if DORADO_GPU_BUILD #ifdef __APPLE__ -#include "nn/MetalCRFModel.h" +#include "basecall/MetalCRFModel.h" #else -#include "nn/CudaCRFModel.h" +#include "basecall/CudaCRFModel.h" #include "utils/cuda_utils.h" #endif #endif // DORADO_GPU_BUILD @@ -35,6 +35,10 @@ #include #include +#ifndef _WIN32 +#include +#endif + namespace fs = std::filesystem; namespace { @@ -165,8 +169,9 @@ TempDir download_model(const std::string& model) { DEFINE_TEST(NodeSmokeTestRead, "ScalerNode") { auto pipeline_restart = GENERATE(false, true); - auto model_type = GENERATE(dorado::SampleType::DNA, dorado::SampleType::RNA002, - dorado::SampleType::RNA004); + auto model_type = + GENERATE(dorado::basecall::SampleType::DNA, dorado::basecall::SampleType::RNA002, + dorado::basecall::SampleType::RNA004); CAPTURE(pipeline_restart); CAPTURE(model_type); @@ -177,8 +182,8 @@ DEFINE_TEST(NodeSmokeTestRead, "ScalerNode") { read->read_common.raw_data = read->read_common.raw_data.to(torch::kI16); }); - dorado::SignalNormalisationParams config; - config.strategy = dorado::ScalingStrategy::QUANTILE; + dorado::basecall::SignalNormalisationParams config; + config.strategy = dorado::basecall::ScalingStrategy::QUANTILE; config.quantile.quantile_a = 0.2f; config.quantile.quantile_b = 0.9f; config.quantile.shift_multiplier = 0.51f; @@ -202,20 +207,20 @@ DEFINE_TEST(NodeSmokeTestRead, "BasecallerNode") { const auto& default_params = dorado::utils::default_parameters; const auto model_dir = download_model(model_name); const auto model_path = (model_dir.m_path / model_name).string(); - auto model_config = dorado::load_crf_model_config(model_path); + auto model_config = dorado::basecall::load_crf_model_config(model_path); // Use a fixed batch size otherwise we slow down CI autobatchsizing. std::size_t batch_size = 128; // Create runners - std::vector runners; + std::vector runners; if (gpu) { #if DORADO_GPU_BUILD #ifdef __APPLE__ - auto caller = - dorado::create_metal_caller(model_config, default_params.chunksize, batch_size); + auto caller = dorado::basecall::create_metal_caller(model_config, default_params.chunksize, + batch_size); for (int i = 0; i < default_params.num_runners; i++) { - runners.push_back(std::make_shared(caller)); + runners.push_back(std::make_unique(caller)); } #else // __APPLE__ auto devices = dorado::utils::parse_cuda_device_string("cuda:all"); @@ -223,10 +228,10 @@ DEFINE_TEST(NodeSmokeTestRead, "BasecallerNode") { SKIP("No CUDA devices found"); } for (const auto& device : devices) { - auto caller = dorado::create_cuda_caller(model_config, default_params.chunksize, - int(batch_size), device, 1.f, true); + auto caller = dorado::basecall::create_cuda_caller( + model_config, default_params.chunksize, int(batch_size), device, 1.f, true); for (int i = 0; i < default_params.num_runners; i++) { - runners.push_back(std::make_shared(caller)); + runners.push_back(std::make_unique(caller)); } } #endif // __APPLE__ @@ -238,7 +243,8 @@ DEFINE_TEST(NodeSmokeTestRead, "BasecallerNode") { set_num_reads(5); set_expected_messages(5); batch_size = 8; - runners.push_back(std::make_shared>( + runners.push_back(std::make_unique< + dorado::basecall::ModelRunner>( model_config, "cpu", default_params.chunksize, int(batch_size))); } @@ -272,10 +278,10 @@ DEFINE_TEST(NodeSmokeTestRead, "ModBaseCallerNode") { const char model_name[] = "dna_r10.4.1_e8.2_400bps_fast@v4.2.0"; const auto model_dir = download_model(model_name); const std::size_t model_stride = - dorado::load_crf_model_config(model_dir.m_path / model_name).stride; + dorado::basecall::load_crf_model_config(model_dir.m_path / model_name).stride; // Create runners - std::vector> remora_runners; + std::vector remora_runners; std::vector modbase_devices; int batch_size = default_params.remora_batchsize; if (gpu) { @@ -299,10 +305,10 @@ DEFINE_TEST(NodeSmokeTestRead, "ModBaseCallerNode") { batch_size = 8; // reduce batch size so we're not doing work on empty entries } for (const auto& device_string : modbase_devices) { - auto caller = dorado::create_modbase_caller({remora_model, remora_model_6mA}, batch_size, - device_string); + auto caller = dorado::modbase::create_modbase_caller({remora_model, remora_model_6mA}, + batch_size, device_string); for (int i = 0; i < default_params.mod_base_runners_per_caller; i++) { - remora_runners.push_back(std::make_unique(caller)); + remora_runners.push_back(std::make_unique(caller)); } }