-
Notifications
You must be signed in to change notification settings - Fork 75
Refac: Straightforward output shape permutation #317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7c37e9a
f273b52
639952a
3ae34c7
8019219
dc40ad6
6bd363d
0b6590d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,31 +34,6 @@ double ptsToSeconds(int64_t pts, const AVRational& timeBase) { | |
| return ptsToSeconds(pts, timeBase.den); | ||
| } | ||
|
|
||
| // Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so. | ||
| // The [N] leading batch-dimension is optional i.e. the input tensor can be 3D | ||
| // or 4D. | ||
| // Calling permute() is guaranteed to return a view as per the docs: | ||
| // https://pytorch.org/docs/stable/generated/torch.permute.html | ||
| torch::Tensor MaybePermuteHWC2CHW( | ||
| const VideoDecoder::VideoStreamDecoderOptions& options, | ||
| torch::Tensor& hwcTensor) { | ||
| if (options.dimensionOrder == "NHWC") { | ||
| return hwcTensor; | ||
| } | ||
| auto numDimensions = hwcTensor.dim(); | ||
| auto shape = hwcTensor.sizes(); | ||
| if (numDimensions == 3) { | ||
| TORCH_CHECK(shape[2] == 3, "Not a HWC tensor: ", shape); | ||
| return hwcTensor.permute({2, 0, 1}); | ||
| } else if (numDimensions == 4) { | ||
| TORCH_CHECK(shape[3] == 3, "Not a NHWC tensor: ", shape); | ||
| return hwcTensor.permute({0, 3, 1, 2}); | ||
| } else { | ||
| TORCH_CHECK( | ||
| false, "Expected tensor with 3 or 4 dimensions, got ", numDimensions); | ||
| } | ||
| } | ||
|
|
||
| struct AVInput { | ||
| UniqueAVFormatContext formatContext; | ||
| std::unique_ptr<AVIOBytesContext> ioBytesContext; | ||
|
|
@@ -136,6 +111,31 @@ VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibraryForWidth( | |
|
|
||
| } // namespace | ||
|
|
||
| // Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so. | ||
| // The [N] leading batch-dimension is optional i.e. the input tensor can be 3D | ||
| // or 4D. | ||
| // Calling permute() is guaranteed to return a view as per the docs: | ||
| // https://pytorch.org/docs/stable/generated/torch.permute.html | ||
| torch::Tensor VideoDecoder::MaybePermuteHWC2CHW( | ||
| int streamIndex, | ||
| torch::Tensor& hwcTensor) { | ||
| if (streams_[streamIndex].options.dimensionOrder == "NHWC") { | ||
| return hwcTensor; | ||
| } | ||
| auto numDimensions = hwcTensor.dim(); | ||
| auto shape = hwcTensor.sizes(); | ||
| if (numDimensions == 3) { | ||
| TORCH_CHECK(shape[2] == 3, "Not a HWC tensor: ", shape); | ||
| return hwcTensor.permute({2, 0, 1}); | ||
| } else if (numDimensions == 4) { | ||
| TORCH_CHECK(shape[3] == 3, "Not a NHWC tensor: ", shape); | ||
| return hwcTensor.permute({0, 3, 1, 2}); | ||
| } else { | ||
| TORCH_CHECK( | ||
| false, "Expected tensor with 3 or 4 dimensions, got ", numDimensions); | ||
| } | ||
| } | ||
|
|
||
| VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions( | ||
| const std::string& optionsString) { | ||
| std::vector<std::string> tokens = | ||
|
|
@@ -929,14 +929,6 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( | |
| "Invalid color conversion library: " + | ||
| std::to_string(static_cast<int>(streamInfo.colorConversionLibrary))); | ||
| } | ||
| if (!preAllocatedOutputTensor.has_value()) { | ||
| // We only convert to CHW if a pre-allocated tensor wasn't passed. When a | ||
| // pre-allocated tensor is passed, it's up to the caller (typically a | ||
| // batch API) to do the conversion. This is more efficient as it allows | ||
| // batch NHWC tensors to be permuted only once, instead of permuting HWC | ||
| // tensors N times. | ||
| output.frame = MaybePermuteHWC2CHW(streamInfo.options, output.frame); | ||
| } | ||
|
Comment on lines
-932
to
-939
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This removal is actually the key change. Everything else is just (sensible) patching until tests work |
||
|
|
||
| } else if (output.streamType == AVMEDIA_TYPE_AUDIO) { | ||
| // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement | ||
|
|
@@ -978,7 +970,9 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( | |
| return seconds >= frameStartTime && seconds < frameEndTime; | ||
| }); | ||
| // Convert the frame to tensor. | ||
| return convertAVFrameToDecodedOutput(rawOutput); | ||
| auto output = convertAVFrameToDecodedOutput(rawOutput); | ||
| output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame); | ||
| return output; | ||
| } | ||
|
|
||
| void VideoDecoder::validateUserProvidedStreamIndex(uint64_t streamIndex) { | ||
|
|
@@ -1015,6 +1009,16 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( | |
| int streamIndex, | ||
| int64_t frameIndex, | ||
| std::optional<torch::Tensor> preAllocatedOutputTensor) { | ||
| auto output = getFrameAtIndexInternal( | ||
| streamIndex, frameIndex, preAllocatedOutputTensor); | ||
| output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame); | ||
| return output; | ||
| } | ||
|
|
||
| VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( | ||
| int streamIndex, | ||
| int64_t frameIndex, | ||
| std::optional<torch::Tensor> preAllocatedOutputTensor) { | ||
| validateUserProvidedStreamIndex(streamIndex); | ||
| validateScannedAllStreams("getFrameAtIndex"); | ||
|
|
||
|
|
@@ -1023,7 +1027,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( | |
|
|
||
| int64_t pts = stream.allFrames[frameIndex].pts; | ||
| setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase)); | ||
| return getNextDecodedOutputNoDemux(preAllocatedOutputTensor); | ||
| return getNextFrameOutputNoDemuxInternal(preAllocatedOutputTensor); | ||
| } | ||
|
|
||
| VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( | ||
|
|
@@ -1073,14 +1077,14 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( | |
| output.durationSeconds[indexInOutput] = | ||
| output.durationSeconds[previousIndexInOutput]; | ||
| } else { | ||
| DecodedOutput singleOut = getFrameAtIndex( | ||
| DecodedOutput singleOut = getFrameAtIndexInternal( | ||
| streamIndex, indexInVideo, output.frames[indexInOutput]); | ||
| output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds; | ||
| output.durationSeconds[indexInOutput] = singleOut.durationSeconds; | ||
| } | ||
| previousIndexInVideo = indexInVideo; | ||
| } | ||
| output.frames = MaybePermuteHWC2CHW(options, output.frames); | ||
| output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); | ||
| return output; | ||
| } | ||
|
|
||
|
|
@@ -1150,11 +1154,12 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( | |
| BatchDecodedOutput output(numOutputFrames, options, streamMetadata); | ||
|
|
||
| for (int64_t i = start, f = 0; i < stop; i += step, ++f) { | ||
| DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]); | ||
| DecodedOutput singleOut = | ||
| getFrameAtIndexInternal(streamIndex, i, output.frames[f]); | ||
| output.ptsSeconds[f] = singleOut.ptsSeconds; | ||
| output.durationSeconds[f] = singleOut.durationSeconds; | ||
| } | ||
| output.frames = MaybePermuteHWC2CHW(options, output.frames); | ||
| output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); | ||
| return output; | ||
| } | ||
|
|
||
|
|
@@ -1207,7 +1212,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( | |
| // need this special case below. | ||
| if (startSeconds == stopSeconds) { | ||
| BatchDecodedOutput output(0, options, streamMetadata); | ||
| output.frames = MaybePermuteHWC2CHW(options, output.frames); | ||
| output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); | ||
| return output; | ||
| } | ||
|
|
||
|
|
@@ -1243,11 +1248,12 @@ VideoDecoder::getFramesPlayedByTimestampInRange( | |
| int64_t numFrames = stopFrameIndex - startFrameIndex; | ||
| BatchDecodedOutput output(numFrames, options, streamMetadata); | ||
| for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { | ||
| DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]); | ||
| DecodedOutput singleOut = | ||
| getFrameAtIndexInternal(streamIndex, i, output.frames[f]); | ||
| output.ptsSeconds[f] = singleOut.ptsSeconds; | ||
| output.durationSeconds[f] = singleOut.durationSeconds; | ||
| } | ||
| output.frames = MaybePermuteHWC2CHW(options, output.frames); | ||
| output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); | ||
|
|
||
| return output; | ||
| } | ||
|
|
@@ -1261,7 +1267,13 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { | |
| return rawOutput; | ||
| } | ||
|
|
||
| VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux( | ||
| VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() { | ||
| auto output = getNextFrameOutputNoDemuxInternal(); | ||
| output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame); | ||
| return output; | ||
| } | ||
|
|
||
| VideoDecoder::DecodedOutput VideoDecoder::getNextFrameOutputNoDemuxInternal( | ||
| std::optional<torch::Tensor> preAllocatedOutputTensor) { | ||
| auto rawOutput = getNextRawDecodedOutputNoDemux(); | ||
| return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); | ||
|
|
@@ -1283,7 +1295,7 @@ double VideoDecoder::getPtsSecondsForFrame( | |
| int streamIndex, | ||
| int64_t frameIndex) { | ||
| validateUserProvidedStreamIndex(streamIndex); | ||
| validateScannedAllStreams("getFrameAtIndex"); | ||
| validateScannedAllStreams("getPtsSecondsForFrame"); | ||
|
|
||
| const auto& stream = streams_[streamIndex]; | ||
| validateFrameIndex(stream, frameIndex); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -157,10 +157,12 @@ class VideoDecoder { | |
| int streamIndex, | ||
| const AudioStreamDecoderOptions& options = AudioStreamDecoderOptions()); | ||
|
|
||
| torch::Tensor MaybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); | ||
|
|
||
| // ---- SINGLE FRAME SEEK AND DECODING API ---- | ||
| // Places the cursor at the first frame on or after the position in seconds. | ||
| // Calling getNextDecodedOutputNoDemux() will return the first frame at or | ||
| // after this position. | ||
| // Calling getNextFrameOutputNoDemuxInternal() will return the first frame at | ||
| // or after this position. | ||
| void setCursorPtsInSeconds(double seconds); | ||
| // This is an internal structure that is used to store the decoded output | ||
| // from decoding a frame through color conversion. Example usage is: | ||
|
|
@@ -214,8 +216,7 @@ class VideoDecoder { | |
| }; | ||
| // Decodes the frame where the current cursor position is. It also advances | ||
| // the cursor to the next frame. | ||
| DecodedOutput getNextDecodedOutputNoDemux( | ||
| std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt); | ||
| DecodedOutput getNextFrameNoDemux(); | ||
| // Decodes the first frame in any added stream that is visible at a given | ||
| // timestamp. Frames in the video have a presentation timestamp and a | ||
| // duration. For example, if a frame has presentation timestamp of 5.0s and a | ||
|
|
@@ -386,6 +387,13 @@ class VideoDecoder { | |
| DecodedOutput& output, | ||
| std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt); | ||
|
|
||
| DecodedOutput getFrameAtIndexInternal( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should you have a comment somewhere saying these are always returned in HWC? You could have a convention that Internal suffix'd functions always return in HWC
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes definitely, ultimately I want to label all functions in the decoding stack with expected input and output shapes. I'll follow-up with that, I think it'll be part of a sequence of PRs, after the one about up-leveling the tensor allocation |
||
| int streamIndex, | ||
| int64_t frameIndex, | ||
| std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt); | ||
| DecodedOutput getNextFrameOutputNoDemuxInternal( | ||
| std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt); | ||
|
|
||
| DecoderOptions options_; | ||
| ContainerMetadata containerMetadata_; | ||
| UniqueAVFormatContext formatContext_; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why, my linter started to want to format this. If the linter on our CI job is OK with it, can we let this in? Otherwise I have to manually revert it on all my commits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Linters gonna lint.