Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/decoders/BenchmarkDecodersMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void runNDecodeIterations(
decoder->addVideoStreamDecoder(-1);
for (double pts : ptsList) {
decoder->setCursorPtsInSeconds(pts);
torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame;
torch::Tensor tensor = decoder->getNextFrameNoDemux().frame;
}
if (i + 1 == warmupIterations) {
start = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -95,7 +95,7 @@ void runNdecodeIterationsGrabbingConsecutiveFrames(
VideoDecoder::createFromFilePath(videoPath);
decoder->addVideoStreamDecoder(-1);
for (int j = 0; j < consecutiveFrameCount; ++j) {
torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame;
torch::Tensor tensor = decoder->getNextFrameNoDemux().frame;
}
if (i + 1 == warmupIterations) {
start = std::chrono::high_resolution_clock::now();
Expand Down
4 changes: 3 additions & 1 deletion packaging/check_glibcxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@
all_symbols.add(match.group(0))

if not all_symbols:
raise ValueError(f"No GLIBCXX symbols found in {symbol_matches}. Something is wrong.")
raise ValueError(
f"No GLIBCXX symbols found in {symbol_matches}. Something is wrong."
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why, my linter started to want to format this. If the linter on our CI job is OK with it, can we let this in? Otherwise I have to manually revert it on all my commits.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Linters gonna lint.


all_versions = (symbol.split("_")[1].split(".") for symbol in all_symbols)
all_versions = (tuple(int(v) for v in version) for version in all_versions)
Expand Down
5 changes: 0 additions & 5 deletions src/torchcodec/decoders/_core/CudaDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,6 @@ void convertAVFrameToDecodedOutputOnCuda(
std::chrono::duration<double, std::micro> duration = end - start;
VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width
<< " took: " << duration.count() << "us" << std::endl;
if (options.dimensionOrder == "NCHW") {
// The docs guaranty this to return a view:
// https://pytorch.org/docs/stable/generated/torch.permute.html
dst = dst.permute({2, 0, 1});
}
}

} // namespace facebook::torchcodec
100 changes: 56 additions & 44 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,6 @@ double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
return ptsToSeconds(pts, timeBase.den);
}

// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
// The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
// or 4D.
// Calling permute() is guaranteed to return a view as per the docs:
// https://pytorch.org/docs/stable/generated/torch.permute.html
torch::Tensor MaybePermuteHWC2CHW(
const VideoDecoder::VideoStreamDecoderOptions& options,
torch::Tensor& hwcTensor) {
if (options.dimensionOrder == "NHWC") {
return hwcTensor;
}
auto numDimensions = hwcTensor.dim();
auto shape = hwcTensor.sizes();
if (numDimensions == 3) {
TORCH_CHECK(shape[2] == 3, "Not a HWC tensor: ", shape);
return hwcTensor.permute({2, 0, 1});
} else if (numDimensions == 4) {
TORCH_CHECK(shape[3] == 3, "Not a NHWC tensor: ", shape);
return hwcTensor.permute({0, 3, 1, 2});
} else {
TORCH_CHECK(
false, "Expected tensor with 3 or 4 dimensions, got ", numDimensions);
}
}

struct AVInput {
UniqueAVFormatContext formatContext;
std::unique_ptr<AVIOBytesContext> ioBytesContext;
Expand Down Expand Up @@ -136,6 +111,31 @@ VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibraryForWidth(

} // namespace

// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
// The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
// or 4D.
// Calling permute() is guaranteed to return a view as per the docs:
// https://pytorch.org/docs/stable/generated/torch.permute.html
torch::Tensor VideoDecoder::MaybePermuteHWC2CHW(
int streamIndex,
torch::Tensor& hwcTensor) {
if (streams_[streamIndex].options.dimensionOrder == "NHWC") {
return hwcTensor;
}
auto numDimensions = hwcTensor.dim();
auto shape = hwcTensor.sizes();
if (numDimensions == 3) {
TORCH_CHECK(shape[2] == 3, "Not a HWC tensor: ", shape);
return hwcTensor.permute({2, 0, 1});
} else if (numDimensions == 4) {
TORCH_CHECK(shape[3] == 3, "Not a NHWC tensor: ", shape);
return hwcTensor.permute({0, 3, 1, 2});
} else {
TORCH_CHECK(
false, "Expected tensor with 3 or 4 dimensions, got ", numDimensions);
}
}

VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions(
const std::string& optionsString) {
std::vector<std::string> tokens =
Expand Down Expand Up @@ -929,14 +929,6 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
"Invalid color conversion library: " +
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
}
if (!preAllocatedOutputTensor.has_value()) {
// We only convert to CHW if a pre-allocated tensor wasn't passed. When a
// pre-allocated tensor is passed, it's up to the caller (typically a
// batch API) to do the conversion. This is more efficient as it allows
// batch NHWC tensors to be permuted only once, instead of permuting HWC
// tensors N times.
output.frame = MaybePermuteHWC2CHW(streamInfo.options, output.frame);
}
Comment on lines -932 to -939
Copy link
Contributor Author

@NicolasHug NicolasHug Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This removal is actually the key change. Everything else is just (sensible) patching until tests work


} else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
// TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
Expand Down Expand Up @@ -978,7 +970,9 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux(
return seconds >= frameStartTime && seconds < frameEndTime;
});
// Convert the frame to tensor.
return convertAVFrameToDecodedOutput(rawOutput);
auto output = convertAVFrameToDecodedOutput(rawOutput);
output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame);
return output;
}

void VideoDecoder::validateUserProvidedStreamIndex(uint64_t streamIndex) {
Expand Down Expand Up @@ -1015,6 +1009,16 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
int streamIndex,
int64_t frameIndex,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
auto output = getFrameAtIndexInternal(
streamIndex, frameIndex, preAllocatedOutputTensor);
output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame);
return output;
}

VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal(
int streamIndex,
int64_t frameIndex,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
validateUserProvidedStreamIndex(streamIndex);
validateScannedAllStreams("getFrameAtIndex");

Expand All @@ -1023,7 +1027,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(

int64_t pts = stream.allFrames[frameIndex].pts;
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase));
return getNextDecodedOutputNoDemux(preAllocatedOutputTensor);
return getNextFrameOutputNoDemuxInternal(preAllocatedOutputTensor);
}

VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
Expand Down Expand Up @@ -1073,14 +1077,14 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
output.durationSeconds[indexInOutput] =
output.durationSeconds[previousIndexInOutput];
} else {
DecodedOutput singleOut = getFrameAtIndex(
DecodedOutput singleOut = getFrameAtIndexInternal(
streamIndex, indexInVideo, output.frames[indexInOutput]);
output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds;
output.durationSeconds[indexInOutput] = singleOut.durationSeconds;
}
previousIndexInVideo = indexInVideo;
}
output.frames = MaybePermuteHWC2CHW(options, output.frames);
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
return output;
}

Expand Down Expand Up @@ -1150,11 +1154,12 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
BatchDecodedOutput output(numOutputFrames, options, streamMetadata);

for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
DecodedOutput singleOut =
getFrameAtIndexInternal(streamIndex, i, output.frames[f]);
output.ptsSeconds[f] = singleOut.ptsSeconds;
output.durationSeconds[f] = singleOut.durationSeconds;
}
output.frames = MaybePermuteHWC2CHW(options, output.frames);
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
return output;
}

Expand Down Expand Up @@ -1207,7 +1212,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
// need this special case below.
if (startSeconds == stopSeconds) {
BatchDecodedOutput output(0, options, streamMetadata);
output.frames = MaybePermuteHWC2CHW(options, output.frames);
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
return output;
}

Expand Down Expand Up @@ -1243,11 +1248,12 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
int64_t numFrames = stopFrameIndex - startFrameIndex;
BatchDecodedOutput output(numFrames, options, streamMetadata);
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
DecodedOutput singleOut =
getFrameAtIndexInternal(streamIndex, i, output.frames[f]);
output.ptsSeconds[f] = singleOut.ptsSeconds;
output.durationSeconds[f] = singleOut.durationSeconds;
}
output.frames = MaybePermuteHWC2CHW(options, output.frames);
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);

return output;
}
Expand All @@ -1261,7 +1267,13 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
return rawOutput;
}

VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux(
VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() {
auto output = getNextFrameOutputNoDemuxInternal();
output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame);
return output;
}

VideoDecoder::DecodedOutput VideoDecoder::getNextFrameOutputNoDemuxInternal(
std::optional<torch::Tensor> preAllocatedOutputTensor) {
auto rawOutput = getNextRawDecodedOutputNoDemux();
return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor);
Expand All @@ -1283,7 +1295,7 @@ double VideoDecoder::getPtsSecondsForFrame(
int streamIndex,
int64_t frameIndex) {
validateUserProvidedStreamIndex(streamIndex);
validateScannedAllStreams("getFrameAtIndex");
validateScannedAllStreams("getPtsSecondsForFrame");

const auto& stream = streams_[streamIndex];
validateFrameIndex(stream, frameIndex);
Expand Down
16 changes: 12 additions & 4 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,12 @@ class VideoDecoder {
int streamIndex,
const AudioStreamDecoderOptions& options = AudioStreamDecoderOptions());

torch::Tensor MaybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);

// ---- SINGLE FRAME SEEK AND DECODING API ----
// Places the cursor at the first frame on or after the position in seconds.
// Calling getNextDecodedOutputNoDemux() will return the first frame at or
// after this position.
// Calling getNextFrameOutputNoDemuxInternal() will return the first frame at
// or after this position.
void setCursorPtsInSeconds(double seconds);
// This is an internal structure that is used to store the decoded output
// from decoding a frame through color conversion. Example usage is:
Expand Down Expand Up @@ -214,8 +216,7 @@ class VideoDecoder {
};
// Decodes the frame where the current cursor position is. It also advances
// the cursor to the next frame.
DecodedOutput getNextDecodedOutputNoDemux(
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
DecodedOutput getNextFrameNoDemux();
// Decodes the first frame in any added stream that is visible at a given
// timestamp. Frames in the video have a presentation timestamp and a
// duration. For example, if a frame has presentation timestamp of 5.0s and a
Expand Down Expand Up @@ -386,6 +387,13 @@ class VideoDecoder {
DecodedOutput& output,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

DecodedOutput getFrameAtIndexInternal(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you have a comment somewhere saying these are always returned in HWC?

You could have a convention that Internal suffix'd functions always return in HWC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes definitely, ultimately I want to label all functions in the decoding stack with expected input and output shapes. I'll follow-up with that, I think it'll be part of a sequence of PRs, after the one about up-leveling the tensor allocation

int streamIndex,
int64_t frameIndex,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
DecodedOutput getNextFrameOutputNoDemuxInternal(
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

DecoderOptions options_;
ContainerMetadata containerMetadata_;
UniqueAVFormatContext formatContext_;
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
VideoDecoder::DecodedOutput result;
try {
result = videoDecoder->getNextDecodedOutputNoDemux();
result = videoDecoder->getNextFrameNoDemux();
} catch (const VideoDecoder::EndOfFileException& e) {
C10_THROW_ERROR(IndexError, e.what());
}
Expand Down
24 changes: 12 additions & 12 deletions test/decoders/VideoDecoderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ TEST(VideoDecoderTest, RespectsWidthAndHeightFromOptions) {
streamOptions.width = 100;
streamOptions.height = 120;
decoder->addVideoStreamDecoder(-1, streamOptions);
torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame;
torch::Tensor tensor = decoder->getNextFrameNoDemux().frame;
EXPECT_EQ(tensor.sizes(), std::vector<long>({3, 120, 100}));
}

Expand All @@ -159,7 +159,7 @@ TEST(VideoDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) {
VideoDecoder::VideoStreamDecoderOptions streamOptions;
streamOptions.dimensionOrder = "NHWC";
decoder->addVideoStreamDecoder(-1, streamOptions);
torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame;
torch::Tensor tensor = decoder->getNextFrameNoDemux().frame;
EXPECT_EQ(tensor.sizes(), std::vector<long>({270, 480, 3}));
}

Expand All @@ -168,12 +168,12 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) {
std::unique_ptr<VideoDecoder> ourDecoder =
createDecoderFromPath(path, GetParam());
ourDecoder->addVideoStreamDecoder(-1);
auto output = ourDecoder->getNextDecodedOutputNoDemux();
auto output = ourDecoder->getNextFrameNoDemux();
torch::Tensor tensor0FromOurDecoder = output.frame;
EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector<long>({3, 270, 480}));
EXPECT_EQ(output.ptsSeconds, 0.0);
EXPECT_EQ(output.pts, 0);
output = ourDecoder->getNextDecodedOutputNoDemux();
output = ourDecoder->getNextFrameNoDemux();
torch::Tensor tensor1FromOurDecoder = output.frame;
EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector<long>({3, 270, 480}));
EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000);
Expand Down Expand Up @@ -254,11 +254,11 @@ TEST_P(VideoDecoderTest, SeeksCloseToEof) {
createDecoderFromPath(path, GetParam());
ourDecoder->addVideoStreamDecoder(-1);
ourDecoder->setCursorPtsInSeconds(388388. / 30'000);
auto output = ourDecoder->getNextDecodedOutputNoDemux();
auto output = ourDecoder->getNextFrameNoDemux();
EXPECT_EQ(output.ptsSeconds, 388'388. / 30'000);
output = ourDecoder->getNextDecodedOutputNoDemux();
output = ourDecoder->getNextFrameNoDemux();
EXPECT_EQ(output.ptsSeconds, 389'389. / 30'000);
EXPECT_THROW(ourDecoder->getNextDecodedOutputNoDemux(), std::exception);
EXPECT_THROW(ourDecoder->getNextFrameNoDemux(), std::exception);
}

TEST_P(VideoDecoderTest, GetsFramePlayedAtTimestamp) {
Expand Down Expand Up @@ -298,7 +298,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) {
createDecoderFromPath(path, GetParam());
ourDecoder->addVideoStreamDecoder(-1);
ourDecoder->setCursorPtsInSeconds(6.0);
auto output = ourDecoder->getNextDecodedOutputNoDemux();
auto output = ourDecoder->getNextFrameNoDemux();
torch::Tensor tensor6FromOurDecoder = output.frame;
EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000);
torch::Tensor tensor6FromFFMPEG =
Expand All @@ -314,7 +314,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) {
EXPECT_GT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 180);

ourDecoder->setCursorPtsInSeconds(6.1);
output = ourDecoder->getNextDecodedOutputNoDemux();
output = ourDecoder->getNextFrameNoDemux();
torch::Tensor tensor61FromOurDecoder = output.frame;
EXPECT_EQ(output.ptsSeconds, 183'183. / 30'000);
torch::Tensor tensor61FromFFMPEG =
Expand All @@ -334,7 +334,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) {
EXPECT_LT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 10);

ourDecoder->setCursorPtsInSeconds(10.0);
output = ourDecoder->getNextDecodedOutputNoDemux();
output = ourDecoder->getNextFrameNoDemux();
torch::Tensor tensor10FromOurDecoder = output.frame;
EXPECT_EQ(output.ptsSeconds, 300'300. / 30'000);
torch::Tensor tensor10FromFFMPEG =
Expand All @@ -351,7 +351,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) {
EXPECT_GT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 60);

ourDecoder->setCursorPtsInSeconds(6.0);
output = ourDecoder->getNextDecodedOutputNoDemux();
output = ourDecoder->getNextFrameNoDemux();
tensor6FromOurDecoder = output.frame;
EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000);
EXPECT_TRUE(torch::equal(tensor6FromOurDecoder, tensor6FromFFMPEG));
Expand All @@ -366,7 +366,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) {

constexpr double kPtsOfLastFrameInVideoStream = 389'389. / 30'000; // ~12.9
ourDecoder->setCursorPtsInSeconds(kPtsOfLastFrameInVideoStream);
output = ourDecoder->getNextDecodedOutputNoDemux();
output = ourDecoder->getNextFrameNoDemux();
torch::Tensor tensor7FromOurDecoder = output.frame;
EXPECT_EQ(output.ptsSeconds, 389'389. / 30'000);
torch::Tensor tensor7FromFFMPEG =
Expand Down
Loading