Skip to content

Commit

Permalink
Merge branch 'smalton/DOR-535-compat-lib' into 'master'
Browse files Browse the repository at this point in the history
DOR-535: Compatibility lib

Closes DOR-535

See merge request machine-learning/dorado!818
  • Loading branch information
malton-ont committed Jan 19, 2024
2 parents c0aa751 + 52022b3 commit e48bfea
Show file tree
Hide file tree
Showing 13 changed files with 84 additions and 59 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ add_subdirectory(${DORADO_3RD_PARTY_SOURCE}/date EXCLUDE_FROM_ALL)

enable_testing()

add_subdirectory(dorado/compat)
add_subdirectory(dorado/utils)
add_subdirectory(dorado/models)
add_subdirectory(dorado/basecall)
Expand Down Expand Up @@ -368,6 +369,7 @@ if(NOT DORADO_DISABLE_DORADO)
dorado_lib
dorado_io_lib
dorado_models_lib
dorado_compat
minimap2
)

Expand Down
1 change: 1 addition & 0 deletions dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include <htslib/sam.h>
#include <spdlog/spdlog.h>
#include <torch/utils.h>

#include <algorithm>
#include <cstdlib>
Expand Down
1 change: 1 addition & 0 deletions dorado/cli/duplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include <htslib/sam.h>
#include <spdlog/spdlog.h>
#include <torch/utils.h>

#include <cstdlib>
#include <exception>
Expand Down
11 changes: 11 additions & 0 deletions dorado/compat/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
add_library(dorado_compat
compat_utils.cpp
compat_utils.h
)

target_include_directories(dorado_compat
INTERFACE
${CMAKE_CURRENT_SOURCE_DIR}/..
)

enable_warnings_as_errors(dorado_compat)
11 changes: 11 additions & 0 deletions dorado/utils/compat_utils.cpp → dorado/compat/compat_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,16 @@ char* strptime(const char* s, const char* f, tm* tm) {
return (char*)(s + input.tellg());
}

int setenv(const char* name, const char* value, int overwrite) {
if (!overwrite) {
size_t envsize = 0;
int errcode = getenv_s(&envsize, NULL, 0, name);
if (errcode || envsize) {
return errcode;
}
}
return _putenv_s(name, value);
}

} // namespace dorado::utils
#endif // _WIN32
11 changes: 1 addition & 10 deletions dorado/utils/compat_utils.h → dorado/compat/compat_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,7 @@ namespace dorado::utils {
char* strptime(const char* s, const char* f, tm* tm);

// A simple wrapper for setenv, since windows doesn't have it.
inline int setenv(const char* name, const char* value, int overwrite) {
if (!overwrite) {
size_t envsize = 0;
int errcode = getenv_s(&envsize, NULL, 0, name);
if (errcode || envsize) {
return errcode;
}
}
return _putenv_s(name, value);
}
int setenv(const char* name, const char* value, int overwrite);

} // namespace dorado::utils

Expand Down
1 change: 0 additions & 1 deletion dorado/data_loader/DataLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "models/models.h"
#include "read_pipeline/ReadPipeline.h"
#include "utils/PostCondition.h"
#include "utils/compat_utils.h"
#include "utils/time_utils.h"
#include "utils/types.h"
#include "vbz_plugin_user_utils.h"
Expand Down
2 changes: 1 addition & 1 deletion dorado/main.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "Version.h"
#include "cli/cli.h"
#include "utils/compat_utils.h"
#include "compat/compat_utils.h"

#include <minimap.h>
#include <spdlog/cfg/env.h>
Expand Down
4 changes: 2 additions & 2 deletions dorado/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ add_library(dorado_utils
barcode_kits.h
basecaller_utils.cpp
basecaller_utils.h
compat_utils.cpp
compat_utils.h
dev_utils.cpp
dev_utils.h
driver_query.cpp
Expand Down Expand Up @@ -41,6 +39,7 @@ add_library(dorado_utils
thread_utils.h
time_utils.cpp
time_utils.h
torch_utils.cpp
torch_utils.h
trim.cpp
trim.h
Expand Down Expand Up @@ -95,6 +94,7 @@ target_link_libraries(dorado_utils
edlib
spdlog::spdlog
PRIVATE
dorado_compat
minimap2
OpenSSL::SSL
htslib
Expand Down
49 changes: 49 additions & 0 deletions dorado/utils/torch_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include "torch_utils.h"

#include "compat/compat_utils.h"

#if DORADO_CUDA_BUILD
#include <c10/cuda/CUDACachingAllocator.h>
#endif
#include <torch/torch.h>

namespace dorado::utils {

void make_torch_deterministic() {
#if DORADO_CUDA_BUILD
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
torch::globalContext().setDeterministicCuDNN(true);
torch::globalContext().setBenchmarkCuDNN(false);
#endif

#if TORCH_VERSION_MAJOR > 1 || TORCH_VERSION_MINOR >= 11
torch::globalContext().setDeterministicAlgorithms(true, false);
#else
torch::globalContext().setDeterministicAlgorithms(true);
#endif
}

void set_torch_allocator_max_split_size() {
#if DORADO_CUDA_BUILD && TORCH_VERSION_MAJOR >= 2

// Do not re-use smaller chunks of large buffers
// This prevents small allocations from reusing large sections of cached allocated memory
// which can lead to OoM errors when the original large allocation is needed again
auto max_split_size_mb = 25;
std::string settings = "max_split_size_mb:" + std::to_string(max_split_size_mb);

const char *pytorch_cuda_alloc_conf = std::getenv("PYTORCH_CUDA_ALLOC_CONF");
if (pytorch_cuda_alloc_conf != nullptr) {
std::string_view str(pytorch_cuda_alloc_conf);
if (str.find("max_split_size_mb") != std::string::npos) {
// user has set this via env_var - let torch parse and use their value
return;
}
settings += std::string(",") + pytorch_cuda_alloc_conf;
}

c10::cuda::CUDACachingAllocator::setAllocatorSettings(settings);
#endif
}

} // namespace dorado::utils
45 changes: 2 additions & 43 deletions dorado/utils/torch_utils.h
Original file line number Diff line number Diff line change
@@ -1,49 +1,8 @@
#pragma once

#include "compat_utils.h"

#if DORADO_CUDA_BUILD
#include <c10/cuda/CUDACachingAllocator.h>
#endif
#include <torch/torch.h>

namespace dorado::utils {

inline void make_torch_deterministic() {
#if DORADO_CUDA_BUILD
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
torch::globalContext().setDeterministicCuDNN(true);
torch::globalContext().setBenchmarkCuDNN(false);
#endif

#if TORCH_VERSION_MAJOR > 1 || TORCH_VERSION_MINOR >= 11
torch::globalContext().setDeterministicAlgorithms(true, false);
#else
torch::globalContext().setDeterministicAlgorithms(true);
#endif
}

inline void set_torch_allocator_max_split_size() {
#if DORADO_CUDA_BUILD && TORCH_VERSION_MAJOR >= 2

// Do not re-use smaller chunks of large buffers
// This prevents small allocations from reusing large sections of cached allocated memory
// which can lead to OoM errors when the original large allocation is needed again
auto max_split_size_mb = 25;
std::string settings = "max_split_size_mb:" + std::to_string(max_split_size_mb);

const char *pytorch_cuda_alloc_conf = std::getenv("PYTORCH_CUDA_ALLOC_CONF");
if (pytorch_cuda_alloc_conf != nullptr) {
std::string_view str(pytorch_cuda_alloc_conf);
if (str.find("max_split_size_mb") != std::string::npos) {
// user has set this via env_var - let torch parse and use their value
return;
}
settings += std::string(",") + pytorch_cuda_alloc_conf;
}

c10::cuda::CUDACachingAllocator::setAllocatorSettings(settings);
#endif
}
void make_torch_deterministic();
void set_torch_allocator_max_split_size();

} // namespace dorado::utils
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ foreach(TEST_BIN dorado_tests dorado_smoke_tests)
dorado_models_lib
dorado_basecall
dorado_modbase
dorado_compat
minimap2
${ZLIB_LIBRARIES}
${POD5_LIBRARIES}
Expand Down
4 changes: 2 additions & 2 deletions tests/main.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#define CATCH_CONFIG_RUNNER

#include "utils/compat_utils.h"
#include "compat/compat_utils.h"
#include "utils/torch_utils.h"

#include <catch2/catch.hpp>
#include <nvtx3/nvtx3.hpp>
#include <torch/torch.h>
#include <torch/utils.h>

#include <clocale>

Expand Down

0 comments on commit e48bfea

Please sign in to comment.