diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 1f09159e6..0793c8061 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -93,6 +93,7 @@ function(make_torchcodec_libraries CpuDeviceInterface.cpp SingleStreamDecoder.cpp Encoder.cpp + ValidationUtils.cpp ) if(ENABLE_CUDA) diff --git a/src/torchcodec/_core/ValidationUtils.cpp b/src/torchcodec/_core/ValidationUtils.cpp new file mode 100644 index 000000000..fae3dd940 --- /dev/null +++ b/src/torchcodec/_core/ValidationUtils.cpp @@ -0,0 +1,35 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/_core/ValidationUtils.h" +#include +#include "c10/util/Exception.h" + +namespace facebook::torchcodec { + +int validateInt64ToInt(int64_t value, const std::string& parameterName) { + TORCH_CHECK( + value >= std::numeric_limits::min() && + value <= std::numeric_limits::max(), + parameterName, + "=", + value, + " is out of range for int type."); + + return static_cast(value); +} + +std::optional validateOptionalInt64ToInt( + const std::optional& value, + const std::string& parameterName) { + if (value.has_value()) { + return validateInt64ToInt(value.value(), parameterName); + } else { + return std::nullopt; + } +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/ValidationUtils.h b/src/torchcodec/_core/ValidationUtils.h new file mode 100644 index 000000000..ce2d11256 --- /dev/null +++ b/src/torchcodec/_core/ValidationUtils.h @@ -0,0 +1,21 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +namespace facebook::torchcodec { + +int validateInt64ToInt(int64_t value, const std::string& parameterName); + +std::optional validateOptionalInt64ToInt( + const std::optional& value, + const std::string& parameterName); + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index a5127e73f..c646ed54a 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -13,6 +13,7 @@ #include "src/torchcodec/_core/AVIOTensorContext.h" #include "src/torchcodec/_core/Encoder.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" +#include "src/torchcodec/_core/ValidationUtils.h" namespace facebook::torchcodec { @@ -164,16 +165,6 @@ std::string mapToJson(const std::map& metadataMap) { return ss.str(); } -int validateSampleRate(int64_t sampleRate) { - TORCH_CHECK( - sampleRate <= std::numeric_limits::max(), - "sample_rate=", - sampleRate, - " is too large to be cast to an int."); - - return static_cast(sampleRate); -} - } // namespace // ============================== @@ -413,14 +404,17 @@ void encode_audio_to_file( std::optional bit_rate = std::nullopt, std::optional num_channels = std::nullopt, std::optional desired_sample_rate = std::nullopt) { - // TODO Fix implicit int conversion: - // https://github.com/pytorch/torchcodec/issues/679 AudioStreamOptions audioStreamOptions; - audioStreamOptions.bitRate = bit_rate; - audioStreamOptions.numChannels = num_channels; - audioStreamOptions.sampleRate = desired_sample_rate; + audioStreamOptions.bitRate = validateOptionalInt64ToInt(bit_rate, "bit_rate"); + audioStreamOptions.numChannels = + validateOptionalInt64ToInt(num_channels, "num_channels"); + audioStreamOptions.sampleRate = + validateOptionalInt64ToInt(desired_sample_rate, "desired_sample_rate"); AudioEncoder( - samples, validateSampleRate(sample_rate), file_name, audioStreamOptions) + samples, + validateInt64ToInt(sample_rate, "sample_rate"), + file_name, + audioStreamOptions) .encode(); } @@ -432,15 +426,15 @@ at::Tensor encode_audio_to_tensor( std::optional num_channels = std::nullopt, std::optional desired_sample_rate = std::nullopt) { auto avioContextHolder = std::make_unique(); - // TODO Fix implicit int conversion: - // https://github.com/pytorch/torchcodec/issues/679 AudioStreamOptions audioStreamOptions; - audioStreamOptions.bitRate = bit_rate; - audioStreamOptions.numChannels = num_channels; - audioStreamOptions.sampleRate = desired_sample_rate; + audioStreamOptions.bitRate = validateOptionalInt64ToInt(bit_rate, "bit_rate"); + audioStreamOptions.numChannels = + validateOptionalInt64ToInt(num_channels, "num_channels"); + audioStreamOptions.sampleRate = + validateOptionalInt64ToInt(desired_sample_rate, "desired_sample_rate"); return AudioEncoder( samples, - validateSampleRate(sample_rate), + validateInt64ToInt(sample_rate, "sample_rate"), format, std::move(avioContextHolder), audioStreamOptions) diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index e74e5574f..72969f7a9 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -13,6 +13,7 @@ #include "src/torchcodec/_core/Encoder.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" #include "src/torchcodec/_core/StreamOptions.h" +#include "src/torchcodec/_core/ValidationUtils.h" namespace py = pybind11; @@ -55,20 +56,19 @@ void encode_audio_to_file_like( auto samples = torch::from_blob( reinterpret_cast(data_ptr), shape, tensor_options); - // TODO Fix implicit int conversion: - // https://github.com/pytorch/torchcodec/issues/679 - // same for sample_rate parameter below AudioStreamOptions audioStreamOptions; - audioStreamOptions.bitRate = bit_rate; - audioStreamOptions.numChannels = num_channels; - audioStreamOptions.sampleRate = desired_sample_rate; + audioStreamOptions.bitRate = validateOptionalInt64ToInt(bit_rate, "bit_rate"); + audioStreamOptions.numChannels = + validateOptionalInt64ToInt(num_channels, "num_channels"); + audioStreamOptions.sampleRate = + validateOptionalInt64ToInt(desired_sample_rate, "desired_sample_rate"); auto avioContextHolder = std::make_unique(file_like, /*isForWriting=*/true); AudioEncoder encoder( samples, - static_cast(sample_rate), + validateInt64ToInt(sample_rate, "sample_rate"), format, std::move(avioContextHolder), audioStreamOptions);