diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 918520ec4..df8aaae38 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -203,6 +203,18 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( frames = allocateEmptyHWCTensor(height, width, options.device, numFrames); } +bool VideoDecoder::SwsContextKey::operator==( + const VideoDecoder::SwsContextKey& other) { + return decodedWidth == other.decodedWidth && decodedHeight == decodedHeight && + decodedFormat == other.decodedFormat && + outputWidth == other.outputWidth && outputHeight == other.outputHeight; +} + +bool VideoDecoder::SwsContextKey::operator!=( + const VideoDecoder::SwsContextKey& other) { + return !(*this == other); +} + VideoDecoder::VideoDecoder() {} void VideoDecoder::initializeDecoder() { @@ -1340,7 +1352,11 @@ int VideoDecoder::convertFrameToBufferUsingSwsScale( int expectedOutputHeight = outputTensor.sizes()[0]; int expectedOutputWidth = outputTensor.sizes()[1]; auto curFrameSwsContextKey = SwsContextKey{ - frame->width, frame->height, frameFormat, expectedOutputWidth, expectedOutputHeight}; + frame->width, + frame->height, + frameFormat, + expectedOutputWidth, + expectedOutputHeight}; if (activeStream.swsContext.get() == nullptr || activeStream.swsContextKey != curFrameSwsContextKey) { SwsContext* swsContext = sws_getContext( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index aefeb1fc3..8ae7cc176 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -323,8 +323,8 @@ class VideoDecoder { AVPixelFormat decodedFormat; int outputWidth; int outputHeight; - bool operator==(const SwsContextKey&) const = default; - bool operator!=(const SwsContextKey&) const = default; + bool operator==(const SwsContextKey&); + bool operator!=(const SwsContextKey&); }; // Stores information for each stream. struct StreamInfo {