diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 72fbd12e9..918520ec4 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1339,7 +1339,10 @@ int VideoDecoder::convertFrameToBufferUsingSwsScale( int expectedOutputHeight = outputTensor.sizes()[0]; int expectedOutputWidth = outputTensor.sizes()[1]; - if (activeStream.swsContext.get() == nullptr) { + auto curFrameSwsContextKey = SwsContextKey{ + frame->width, frame->height, frameFormat, expectedOutputWidth, expectedOutputHeight}; + if (activeStream.swsContext.get() == nullptr || + activeStream.swsContextKey != curFrameSwsContextKey) { SwsContext* swsContext = sws_getContext( frame->width, frame->height, @@ -1373,6 +1376,7 @@ int VideoDecoder::convertFrameToBufferUsingSwsScale( brightness, contrast, saturation); + activeStream.swsContextKey = curFrameSwsContextKey; activeStream.swsContext.reset(swsContext); } SwsContext* swsContext = activeStream.swsContext.get(); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 637da359f..aefeb1fc3 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -317,6 +317,15 @@ class VideoDecoder { AVFilterContext* sourceContext = nullptr; AVFilterContext* sinkContext = nullptr; }; + struct SwsContextKey { + int decodedWidth; + int decodedHeight; + AVPixelFormat decodedFormat; + int outputWidth; + int outputHeight; + bool operator==(const SwsContextKey&) const = default; + bool operator!=(const SwsContextKey&) const = default; + }; // Stores information for each stream. struct StreamInfo { int streamIndex = -1; @@ -337,6 +346,7 @@ class VideoDecoder { ColorConversionLibrary colorConversionLibrary = FILTERGRAPH; std::vector keyFrames; std::vector allFrames; + SwsContextKey swsContextKey; UniqueSwsContext swsContext; }; VideoDecoder();