Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions benchmarks/decoders/gpu_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import argparse
import os
import pathlib
import time
from concurrent.futures import ThreadPoolExecutor

import torch

import torch.utils.benchmark as benchmark

import torchcodec
import torchvision.transforms.v2.functional as F

RESIZED_WIDTH = 256
RESIZED_HEIGHT = 256


def transfer_and_resize_frame(frame, resize_device_string):
# This should be a no-op if the frame is already on the target device.
frame = frame.to(resize_device_string)
frame = F.resize(frame, (RESIZED_HEIGHT, RESIZED_WIDTH))
return frame


def decode_full_video(video_path, decode_device_string, resize_device_string):
# We use the core API instead of SimpleVideoDecoder because the core API
# allows us to natively resize as part of the decode step.
print(f"{decode_device_string=} {resize_device_string=}")
decoder = torchcodec.decoders._core.create_from_file(video_path)
num_threads = None
if "cuda" in decode_device_string:
num_threads = 1
width = None
height = None
if "native" in resize_device_string:
width = RESIZED_WIDTH
height = RESIZED_HEIGHT
torchcodec.decoders._core._add_video_stream(
decoder,
stream_index=-1,
device=decode_device_string,
num_threads=num_threads,
width=width,
height=height,
)

start_time = time.time()
frame_count = 0
while True:
try:
frame, *_ = torchcodec.decoders._core.get_next_frame(decoder)
if resize_device_string != "none" and "native" not in resize_device_string:
frame = transfer_and_resize_frame(frame, resize_device_string)

frame_count += 1
except Exception as e:
print("EXCEPTION", e)
break

end_time = time.time()
elapsed = end_time - start_time
fps = frame_count / (end_time - start_time)
print(
f"****** DECODED full video {decode_device_string=} {frame_count=} {elapsed=} {fps=}"
)
return frame_count, end_time - start_time


def decode_videos_using_threads(
video_path,
decode_device_string,
resize_device_string,
num_videos,
num_threads,
use_multiple_gpus,
):
executor = ThreadPoolExecutor(max_workers=num_threads)
for i in range(num_videos):
actual_decode_device = decode_device_string
if "cuda" in decode_device_string and use_multiple_gpus:
actual_decode_device = f"cuda:{i % torch.cuda.device_count()}"
executor.submit(
decode_full_video, video_path, actual_decode_device, resize_device_string
)
executor.shutdown(wait=True)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--devices",
default="cuda:0,cpu",
type=str,
help="Comma-separated devices to test decoding on.",
)
parser.add_argument(
"--resize_devices",
default="cuda:0,cpu,native,none",
type=str,
help="Comma-separated devices to test preroc (resize) on. Use 'none' to specify no resize.",
)
parser.add_argument(
"--video",
type=str,
default=str(
pathlib.Path(__file__).parent / "../../test/resources/nasa_13013.mp4"
),
)
parser.add_argument(
"--use_torch_benchmark",
action=argparse.BooleanOptionalAction,
default=True,
help=(
"Use pytorch benchmark to measure decode time with warmup and "
"autorange. Without this we just run one iteration without warmup "
"to measure the cold start time."
),
)
parser.add_argument(
"--num_threads",
type=int,
default=1,
help="Number of threads to use for decoding. Only used when --use_torch_benchmark is set.",
)
parser.add_argument(
"--num_videos",
type=int,
default=50,
help="Number of videos to decode in parallel. Only used when --num_threads is set.",
)
parser.add_argument(
"--use_multiple_gpus",
action=argparse.BooleanOptionalAction,
default=True,
help=("Use multiple GPUs to decode multiple videos in multi-threaded mode."),
)
args = parser.parse_args()
video_path = args.video

if not args.use_torch_benchmark:
for device in args.devices.split(","):
print("Testing on", device)
decode_full_video(video_path, device)
return

resize_devices = args.resize_devices.split(",")
resize_devices = [d for d in resize_devices if d != ""]
if len(resize_devices) == 0:
resize_devices.append("none")

label = "Decode+Resize Time"

results = []
for decode_device_string in args.devices.split(","):
for resize_device_string in resize_devices:
decode_label = decode_device_string
if "cuda" in decode_label:
# Shorten "cuda:0" to "cuda"
decode_label = "cuda"
resize_label = resize_device_string
if "cuda" in resize_device_string:
# Shorten "cuda:0" to "cuda"
resize_label = "cuda"
print("decode_device", decode_device_string)
print("resize_device", resize_device_string)
if args.num_threads > 1:
t = benchmark.Timer(
stmt="decode_videos_using_threads(video_path, decode_device_string, resize_device_string, num_videos, num_threads, use_multiple_gpus)",
globals={
"decode_device_string": decode_device_string,
"video_path": video_path,
"decode_full_video": decode_full_video,
"decode_videos_using_threads": decode_videos_using_threads,
"resize_device_string": resize_device_string,
"num_videos": args.num_videos,
"num_threads": args.num_threads,
"use_multiple_gpus": args.use_multiple_gpus,
},
label=label,
description=f"threads={args.num_threads} work={args.num_videos} video={os.path.basename(video_path)}",
sub_label=f"D={decode_label} R={resize_label} T={args.num_threads} W={args.num_videos}",
).blocked_autorange()
results.append(t)
else:
t = benchmark.Timer(
stmt="decode_full_video(video_path, decode_device_string, resize_device_string)",
globals={
"decode_device_string": decode_device_string,
"video_path": video_path,
"decode_full_video": decode_full_video,
"resize_device_string": resize_device_string,
},
label=label,
description=f"video={os.path.basename(video_path)}",
sub_label=f"D={decode_label} R={resize_label}",
).blocked_autorange()
results.append(t)
compare = benchmark.Compare(results)
compare.print()
print("Key: D=Decode, R=Resize T=threads W=work (number of videos to decode)")
print("Native resize is done as part of the decode step")
print("none resize means there is no resize step -- native or otherwise")


if __name__ == "__main__":
main()
10 changes: 7 additions & 3 deletions src/torchcodec/decoders/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,16 @@ function(make_torchcodec_library library_name ffmpeg_target)
${Python3_INCLUDE_DIRS}
)

set(NEEDED_LIBRARIES ${ffmpeg_target} ${TORCH_LIBRARIES}
${Python3_LIBRARIES})
if(ENABLE_CUDA)
list(APPEND NEEDED_LIBRARIES ${CUDA_CUDA_LIBRARY}
${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} )
endif()
target_link_libraries(
${library_name}
PUBLIC
${ffmpeg_target}
${TORCH_LIBRARIES}
${Python3_LIBRARIES}
${NEEDED_LIBRARIES}
)

# We already set the library_name to be libtorchcodecN, so we don't want
Expand Down
18 changes: 15 additions & 3 deletions src/torchcodec/decoders/_core/CPUOnlyDevice.cpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
#include <torch/types.h>
#include "src/torchcodec/decoders/_core/DeviceInterface.h"

namespace facebook::torchcodec {

// This file is linked with the CPU-only version of torchcodec.
// So all functions will throw an error because they should only be called if
// the device is not CPU.

void throwUnsupportedDeviceError(const torch::Device& device) {
[[noreturn]] void throwUnsupportedDeviceError(const torch::Device& device) {
TORCH_CHECK(
device.type() != torch::kCPU,
"Device functions should only be called if the device is not CPU.")
throw std::runtime_error("Unsupported device: " + device.str());
TORCH_CHECK(false, "Unsupported device: " + device.str());
}

void initializeDeviceContext(const torch::Device& device) {
void convertAVFrameToDecodedOutputOnCuda(
const torch::Device& device,
const VideoDecoder::VideoStreamDecoderOptions& options,
AVCodecContext* codecContext,
VideoDecoder::RawDecodedOutput& rawOutput,
VideoDecoder::DecodedOutput& output) {
throwUnsupportedDeviceError(device);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this function and the function below always throw? If yes, then we should just do something like TORCH_CHECK(false, "Unsupported device.");. In order avoid the need for a return value, mark the function as [[noreturn]]: https://en.cppreference.com/w/cpp/language/attributes/noreturn. We should rely on a TORCH macro to do the throwing for us rather than doing the throw ourselves, and we should make it obviously one that will always fail its check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we maybe should also annotate convertAVFrameToDecodedOutputOnDevice() and initializeDeviceContext() with [[noreturn]]. Let's also avoid two TORCH_CHECK calls. Whatever message we want to put on stderr, we can do it in one check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two checks are there because one is a programming/logic error on our part -- we should never pass in a CPU device for device functions.

The other is the check for passing in a non-compiled device.

}

void initializeContextOnCuda(
const torch::Device& device,
AVCodecContext* codecContext) {
throwUnsupportedDeviceError(device);
}

Expand Down
113 changes: 108 additions & 5 deletions src/torchcodec/decoders/_core/CudaDevice.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,52 @@
#include <ATen/cuda/CUDAEvent.h>
#include <c10/cuda/CUDAStream.h>
#include <npp.h>
#include <torch/types.h>
#include "src/torchcodec/decoders/_core/DeviceInterface.h"
#include "src/torchcodec/decoders/_core/FFMPEGCommon.h"
#include "src/torchcodec/decoders/_core/VideoDecoder.h"

extern "C" {
#include <libavcodec/avcodec.h>
#include <libavutil/hwcontext_cuda.h>
#include <libavutil/pixdesc.h>
}

namespace facebook::torchcodec {
namespace {
AVBufferRef* getCudaContext(const torch::Device& device) {
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
torch::DeviceIndex deviceIndex = device.index();
// FFMPEG cannot handle negative device indices.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand now what instigated this code, but I still can't evaluate if it's correct. Looking at the docs, a negative value indicates the "current device": https://pytorch.org/cppdocs/api/structc10_1_1_device.html#_CPPv4N3c106DeviceE

Is it safe to map all values of "current device" to 0? Is this a mapping we need to track? What happens when we are on a system with multiple GPUs? I'm assuming we don't fully understand the answers to these questions, and I don't want to block progress. So I think we should have a meatier comment both explaining what we do know, and indicating this may be a problem in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a longer comment with a TODO to investigate that it works properly with multi-GPU setup. I am sure once users start using it, we will hit more edge cases.

// For single GPU- machines libtorch returns -1 for the device index. So for
// that case we set the device index to 0.
// TODO: Double check if this works for multi-GPU machines correctly.
deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0);
std::string deviceOrdinal = std::to_string(deviceIndex);
AVBufferRef* hw_device_ctx;
int err = av_hwdevice_ctx_create(
&hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0);
if (err < 0) {
TORCH_CHECK(
false,
"Failed to create specified HW device",
getFFMPEGErrorStringFromErrorCode(err));
}
return hw_device_ctx;
}

torch::Tensor allocateDeviceTensor(
at::IntArrayRef shape,
torch::Device device,
const torch::Dtype dtype = torch::kUInt8) {
return torch::empty(
shape,
torch::TensorOptions()
.dtype(dtype)
.layout(torch::kStrided)
.device(device));
}

void throwErrorIfNonCudaDevice(const torch::Device& device) {
TORCH_CHECK(
Expand All @@ -10,13 +56,70 @@ void throwErrorIfNonCudaDevice(const torch::Device& device) {
throw std::runtime_error("Unsupported device: " + device.str());
}
}
} // namespace

void initializeDeviceContext(const torch::Device& device) {
void initializeContextOnCuda(
const torch::Device& device,
AVCodecContext* codecContext) {
throwErrorIfNonCudaDevice(device);
// TODO: https://github.com/pytorch/torchcodec/issues/238: Implement CUDA
// device.
throw std::runtime_error(
"CUDA device is unimplemented. Follow this issue for tracking progress: https://github.com/pytorch/torchcodec/issues/238");
// It is important for pytorch itself to create the cuda context. If ffmpeg
// creates the context it may not be compatible with pytorch.
// This is a dummy tensor to initialize the cuda context.
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device));
codecContext->hw_device_ctx = getCudaContext(device);
return;
}

void convertAVFrameToDecodedOutputOnCuda(
const torch::Device& device,
const VideoDecoder::VideoStreamDecoderOptions& options,
AVCodecContext* codecContext,
VideoDecoder::RawDecodedOutput& rawOutput,
VideoDecoder::DecodedOutput& output) {
AVFrame* src = rawOutput.frame.get();

TORCH_CHECK(
src->format == AV_PIX_FMT_CUDA,
"Expected format to be AV_PIX_FMT_CUDA, got " +
std::string(av_get_pix_fmt_name((AVPixelFormat)src->format)));
int width = options.width.value_or(codecContext->width);
int height = options.height.value_or(codecContext->height);
NppiSize oSizeROI = {width, height};
Npp8u* input[2] = {src->data[0], src->data[1]};
torch::Tensor& dst = output.frame;
dst = allocateDeviceTensor({height, width, 3}, options.device);

// Use the user-requested GPU for running the NPP kernel.
c10::cuda::CUDAGuard deviceGuard(device);

auto start = std::chrono::high_resolution_clock::now();

NppStatus status = nppiNV12ToRGB_8u_P2C3R(
input,
src->linesize[0],
static_cast<Npp8u*>(dst.data_ptr()),
dst.stride(0),
oSizeROI);
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame.");
// Make the pytorch stream wait for the npp kernel to finish before using the
// output.
at::cuda::CUDAEvent nppDoneEvent;
at::cuda::CUDAStream nppStreamWrapper =
c10::cuda::getStreamFromExternal(nppGetStream(), device.index());
nppDoneEvent.record(nppStreamWrapper);
nppDoneEvent.block(at::cuda::getCurrentCUDAStream());

auto end = std::chrono::high_resolution_clock::now();

std::chrono::duration<double, std::micro> duration = end - start;
VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width
<< " took: " << duration.count() << "us" << std::endl;
if (options.dimensionOrder == "NCHW") {
// The docs guaranty this to return a view:
// https://pytorch.org/docs/stable/generated/torch.permute.html
dst = dst.permute({2, 0, 1});
}
}

} // namespace facebook::torchcodec
Loading
Loading