Skip to content

Commit

Permalink
Merge branch 'smalton/DOR-477-basecall-refactor' into 'master'
Browse files Browse the repository at this point in the history
DOR-477: Basecall refactor

Closes DOR-477

See merge request machine-learning/dorado!766
  • Loading branch information
malton-ont committed Dec 14, 2023
2 parents 6ed81c5 + d109ab2 commit aae47b1
Show file tree
Hide file tree
Showing 55 changed files with 516 additions and 391 deletions.
53 changes: 8 additions & 45 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ enable_testing()

add_subdirectory(dorado/utils)
add_subdirectory(dorado/models)
add_subdirectory(dorado/basecall)
add_subdirectory(dorado/modbase)

if("${CUDA_VERSION}" STREQUAL "")
set(CUDA_VERSION ${CUDAToolkit_VERSION})
Expand All @@ -179,21 +181,10 @@ set(LIB_SOURCE_FILES
dorado/alignment/Minimap2Index.h
dorado/alignment/Minimap2IndexSupportTypes.h
dorado/alignment/Minimap2Options.h
dorado/nn/CRFModel.h
dorado/nn/CRFModel.cpp
dorado/nn/CRFModelConfig.h
dorado/nn/CRFModelConfig.cpp
dorado/nn/ModelRunner.h
dorado/nn/ModBaseModel.cpp
dorado/nn/ModBaseModel.h
dorado/nn/ModBaseModelConfig.cpp
dorado/nn/ModBaseModelConfig.h
dorado/nn/ModBaseRunner.cpp
dorado/nn/ModBaseRunner.h
dorado/nn/Runners.cpp
dorado/nn/Runners.h
dorado/read_pipeline/Pipelines.cpp
dorado/read_pipeline/Pipelines.h
dorado/api/runner_creation.cpp
dorado/api/runner_creation.h
dorado/api/pipeline_creation.cpp
dorado/api/pipeline_creation.h
dorado/read_pipeline/FakeDataLoader.cpp
dorado/read_pipeline/FakeDataLoader.h
dorado/read_pipeline/ReadPipeline.cpp
Expand Down Expand Up @@ -263,36 +254,8 @@ set(LIB_SOURCE_FILES
dorado/demux/Trimmer.h
dorado/demux/parse_custom_kit.cpp
dorado/demux/parse_custom_kit.h
dorado/decode/beam_search.cpp
dorado/decode/beam_search.h
dorado/decode/CPUDecoder.cpp
dorado/decode/CPUDecoder.h
dorado/modbase/modbase_encoder.cpp
dorado/modbase/modbase_encoder.h
dorado/modbase/modbase_scaler.cpp
dorado/modbase/modbase_scaler.h
dorado/modbase/ModBaseContext.cpp
dorado/modbase/ModBaseContext.h
dorado/modbase/MotifMatcher.cpp
dorado/modbase/MotifMatcher.h
)

if (DORADO_GPU_BUILD)
if(APPLE)
list(APPEND LIB_SOURCE_FILES
dorado/nn/MetalCRFModel.h
dorado/nn/MetalCRFModel.cpp
)
else()
list(APPEND LIB_SOURCE_FILES
dorado/decode/GPUDecoder.cpp
dorado/decode/GPUDecoder.h
dorado/nn/CudaCRFModel.h
dorado/nn/CudaCRFModel.cpp
)
endif()
endif()

add_library(dorado_lib ${LIB_SOURCE_FILES})

enable_warnings_as_errors(dorado_lib)
Expand All @@ -312,7 +275,6 @@ target_include_directories(dorado_lib
PUBLIC
${CMAKE_CURRENT_BINARY_DIR}/dorado
${CMAKE_CURRENT_SOURCE_DIR}/dorado
${KOI_INCLUDE}
${POD5_INCLUDE}
)
# 3rdparty libs should be considered SYSTEM headers
Expand Down Expand Up @@ -348,8 +310,9 @@ target_link_libraries(dorado_lib
vbz_hdf_plugin
edlib
dorado_utils
dorado_basecall
dorado_modbase
PRIVATE
${KOI_LIBRARIES}
minimap2
)

Expand Down
2 changes: 1 addition & 1 deletion cmake/Metal.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ find_library(APPLE_FWK_METAL Metal REQUIRED)
find_library(IOKIT IOKit REQUIRED)

set(AIR_FILES)
set(METAL_SOURCES dorado/nn/metal/nn.metal)
set(METAL_SOURCES dorado/basecall/metal/nn.metal)

if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
set(XCRUN_SDK iphoneos)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#include "Pipelines.h"

#include "BasecallerNode.h"
#include "ModBaseCallerNode.h"
#include "PairingNode.h"
#include "ReadSplitNode.h"
#include "ScalerNode.h"
#include "StereoDuplexEncoderNode.h"
#include "nn/CRFModelConfig.h"
#include "nn/ModBaseRunner.h"
#include "nn/ModelRunner.h"
#include "pipeline_creation.h"

#include "basecall/CRFModelConfig.h"
#include "basecall/ModelRunner.h"
#include "modbase/ModBaseRunner.h"
#include "read_pipeline/BasecallerNode.h"
#include "read_pipeline/ModBaseCallerNode.h"
#include "read_pipeline/PairingNode.h"
#include "read_pipeline/ReadSplitNode.h"
#include "read_pipeline/ScalerNode.h"
#include "read_pipeline/StereoDuplexEncoderNode.h"
#include "splitter/DuplexReadSplitter.h"
#include "splitter/RNAReadSplitter.h"

Expand All @@ -17,8 +17,8 @@
namespace dorado::pipelines {

void create_simplex_pipeline(PipelineDescriptor& pipeline_desc,
std::vector<dorado::Runner>&& runners,
std::vector<std::unique_ptr<dorado::ModBaseRunner>>&& modbase_runners,
std::vector<basecall::RunnerPtr>&& runners,
std::vector<modbase::RunnerPtr>&& modbase_runners,
size_t overlap,
uint32_t mean_qscore_start_pos,
int scaler_node_threads,
Expand Down Expand Up @@ -76,7 +76,7 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc,
// For DNA, read splitting happens after basecall.
if (enable_read_splitter && !is_rna) {
splitter::DuplexSplitSettings splitter_settings(model_config.signal_norm_params.strategy ==
ScalingStrategy::PA);
basecall::ScalingStrategy::PA);
splitter_settings.simplex_mode = true;
auto dna_splitter = std::make_unique<const splitter::DuplexReadSplitter>(splitter_settings);
auto dna_splitter_node = pipeline_desc.add_node<ReadSplitNode>({}, std::move(dna_splitter),
Expand Down Expand Up @@ -105,19 +105,18 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc,
}
}

void create_stereo_duplex_pipeline(
PipelineDescriptor& pipeline_desc,
std::vector<dorado::Runner>&& runners,
std::vector<dorado::Runner>&& stereo_runners,
std::vector<std::unique_ptr<dorado::ModBaseRunner>>&& modbase_runners,
size_t overlap,
uint32_t mean_qscore_start_pos,
int scaler_node_threads,
int splitter_node_threads,
int modbase_node_threads,
PairingParameters pairing_parameters,
NodeHandle sink_node_handle,
NodeHandle source_node_handle) {
void create_stereo_duplex_pipeline(PipelineDescriptor& pipeline_desc,
std::vector<basecall::RunnerPtr>&& runners,
std::vector<basecall::RunnerPtr>&& stereo_runners,
std::vector<modbase::RunnerPtr>&& modbase_runners,
size_t overlap,
uint32_t mean_qscore_start_pos,
int scaler_node_threads,
int splitter_node_threads,
int modbase_node_threads,
PairingParameters pairing_parameters,
NodeHandle sink_node_handle,
NodeHandle source_node_handle) {
const auto& model_config = runners.front()->config();
const auto& stereo_model_config = stereo_runners.front()->config();
std::string model_name =
Expand Down Expand Up @@ -162,7 +161,7 @@ void create_stereo_duplex_pipeline(
// as a passthrough, meaning it won't perform any splitting operations and
// will just pass data through.
splitter::DuplexSplitSettings splitter_settings(model_config.signal_norm_params.strategy ==
ScalingStrategy::PA);
basecall::ScalingStrategy::PA);
auto duplex_splitter = std::make_unique<const splitter::DuplexReadSplitter>(splitter_settings);
auto splitter_node = pipeline_desc.add_node<ReadSplitNode>(
{pairing_node}, std::move(duplex_splitter), splitter_node_threads, 1000);
Expand All @@ -174,9 +173,9 @@ void create_stereo_duplex_pipeline(
{splitter_node}, std::move(runners), adjusted_simplex_overlap, kSimplexBatchTimeoutMS,
model_name, 1000, "BasecallerNode", mean_qscore_start_pos);

auto scaler_node =
pipeline_desc.add_node<ScalerNode>({basecaller_node}, model_config.signal_norm_params,
SampleType::DNA, scaler_node_threads, 1000);
auto scaler_node = pipeline_desc.add_node<ScalerNode>(
{basecaller_node}, model_config.signal_norm_params, basecall::SampleType::DNA,
scaler_node_threads, 1000);

// if we've been provided a source node, connect it to the start of our pipeline
if (source_node_handle != PipelineDescriptor::InvalidNodeHandle) {
Expand Down
41 changes: 23 additions & 18 deletions dorado/read_pipeline/Pipelines.h → dorado/api/pipeline_creation.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include "ReadPipeline.h"
#include "read_pipeline/ReadPipeline.h"

#include <cstdint>
#include <map>
Expand All @@ -11,10 +11,16 @@

namespace dorado {

class ModBaseRunner;
namespace basecall {
class ModelRunnerBase;
using RunnerPtr = std::unique_ptr<ModelRunnerBase>;
} // namespace basecall

namespace modbase {
class ModBaseRunner;
using RunnerPtr = std::unique_ptr<ModBaseRunner>;
} // namespace modbase

using Runner = std::shared_ptr<ModelRunnerBase>;
using PairingParameters = std::variant<DuplexPairingParameters, std::map<std::string, std::string>>;

namespace pipelines {
Expand All @@ -23,8 +29,8 @@ namespace pipelines {
/// If source_node_handle is valid, set this to be the source of the simplex pipeline
/// If sink_node_handle is valid, set this to be the sink of the simplex pipeline
void create_simplex_pipeline(PipelineDescriptor& pipeline_desc,
std::vector<dorado::Runner>&& runners,
std::vector<std::unique_ptr<dorado::ModBaseRunner>>&& modbase_runners,
std::vector<basecall::RunnerPtr>&& runners,
std::vector<modbase::RunnerPtr>&& modbase_runners,
size_t overlap,
uint32_t mean_qscore_start_pos,
int scaler_node_threads,
Expand All @@ -37,19 +43,18 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc,
/// Create a duplex basecall pipeline description
/// If source_node_handle is valid, set this to be the source of the simplex pipeline
/// If sink_node_handle is valid, set this to be the sink of the simplex pipeline
void create_stereo_duplex_pipeline(
PipelineDescriptor& pipeline_desc,
std::vector<dorado::Runner>&& runners,
std::vector<dorado::Runner>&& stereo_runners,
std::vector<std::unique_ptr<dorado::ModBaseRunner>>&& modbase_runners,
size_t overlap,
uint32_t mean_qscore_start_pos,
int scaler_node_threads,
int splitter_node_threads,
int modbase_node_threads,
PairingParameters pairing_parameters,
NodeHandle sink_node_handle,
NodeHandle source_node_handle);
void create_stereo_duplex_pipeline(PipelineDescriptor& pipeline_desc,
std::vector<basecall::RunnerPtr>&& runners,
std::vector<basecall::RunnerPtr>&& stereo_runners,
std::vector<modbase::RunnerPtr>&& modbase_runners,
size_t overlap,
uint32_t mean_qscore_start_pos,
int scaler_node_threads,
int splitter_node_threads,
int modbase_node_threads,
PairingParameters pairing_parameters,
NodeHandle sink_node_handle,
NodeHandle source_node_handle);

} // namespace pipelines

Expand Down

0 comments on commit aae47b1

Please sign in to comment.