-
Notifications
You must be signed in to change notification settings - Fork 64
Add CUDA decoding support #242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
72bdd25
707cff3
0c916b3
3600eee
45df9dc
05d02a4
9f169c9
461a2ff
41a1ba2
57818c5
1611245
58624d0
e576929
dca3540
8ff05ee
a52ba5c
a98aaa3
ec160a9
f096a16
0c78564
32cdb37
27bb2b2
e65ddc3
0abc173
21d8c1a
e4e02b3
4624f5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
import argparse | ||
import os | ||
import pathlib | ||
import time | ||
from concurrent.futures import ThreadPoolExecutor | ||
|
||
import torch | ||
|
||
import torch.utils.benchmark as benchmark | ||
|
||
import torchcodec | ||
import torchvision.transforms.v2.functional as F | ||
|
||
RESIZED_WIDTH = 256 | ||
RESIZED_HEIGHT = 256 | ||
|
||
|
||
def transfer_and_resize_frame(frame, resize_device_string): | ||
# This should be a no-op if the frame is already on the target device. | ||
frame = frame.to(resize_device_string) | ||
frame = F.resize(frame, (RESIZED_HEIGHT, RESIZED_WIDTH)) | ||
return frame | ||
|
||
|
||
def decode_full_video(video_path, decode_device_string, resize_device_string): | ||
# We use the core API instead of SimpleVideoDecoder because the core API | ||
# allows us to natively resize as part of the decode step. | ||
print(f"{decode_device_string=} {resize_device_string=}") | ||
decoder = torchcodec.decoders._core.create_from_file(video_path) | ||
num_threads = None | ||
if "cuda" in decode_device_string: | ||
num_threads = 1 | ||
width = None | ||
height = None | ||
if "native" in resize_device_string: | ||
width = RESIZED_WIDTH | ||
height = RESIZED_HEIGHT | ||
torchcodec.decoders._core._add_video_stream( | ||
decoder, | ||
stream_index=-1, | ||
device=decode_device_string, | ||
num_threads=num_threads, | ||
width=width, | ||
height=height, | ||
) | ||
|
||
start_time = time.time() | ||
frame_count = 0 | ||
while True: | ||
try: | ||
frame, *_ = torchcodec.decoders._core.get_next_frame(decoder) | ||
if resize_device_string != "none" and "native" not in resize_device_string: | ||
frame = transfer_and_resize_frame(frame, resize_device_string) | ||
|
||
frame_count += 1 | ||
except Exception as e: | ||
print("EXCEPTION", e) | ||
break | ||
|
||
end_time = time.time() | ||
elapsed = end_time - start_time | ||
fps = frame_count / (end_time - start_time) | ||
print( | ||
f"****** DECODED full video {decode_device_string=} {frame_count=} {elapsed=} {fps=}" | ||
) | ||
return frame_count, end_time - start_time | ||
|
||
|
||
def decode_videos_using_threads( | ||
video_path, | ||
decode_device_string, | ||
resize_device_string, | ||
num_videos, | ||
num_threads, | ||
use_multiple_gpus, | ||
): | ||
executor = ThreadPoolExecutor(max_workers=num_threads) | ||
for i in range(num_videos): | ||
actual_decode_device = decode_device_string | ||
if "cuda" in decode_device_string and use_multiple_gpus: | ||
actual_decode_device = f"cuda:{i % torch.cuda.device_count()}" | ||
executor.submit( | ||
decode_full_video, video_path, actual_decode_device, resize_device_string | ||
) | ||
executor.shutdown(wait=True) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--devices", | ||
default="cuda:0,cpu", | ||
type=str, | ||
help="Comma-separated devices to test decoding on.", | ||
) | ||
parser.add_argument( | ||
"--resize_devices", | ||
default="cuda:0,cpu,native,none", | ||
type=str, | ||
help="Comma-separated devices to test preroc (resize) on. Use 'none' to specify no resize.", | ||
) | ||
parser.add_argument( | ||
"--video", | ||
type=str, | ||
default=str( | ||
pathlib.Path(__file__).parent / "../../test/resources/nasa_13013.mp4" | ||
), | ||
) | ||
parser.add_argument( | ||
"--use_torch_benchmark", | ||
action=argparse.BooleanOptionalAction, | ||
default=True, | ||
help=( | ||
"Use pytorch benchmark to measure decode time with warmup and " | ||
"autorange. Without this we just run one iteration without warmup " | ||
"to measure the cold start time." | ||
), | ||
) | ||
parser.add_argument( | ||
"--num_threads", | ||
type=int, | ||
default=1, | ||
help="Number of threads to use for decoding. Only used when --use_torch_benchmark is set.", | ||
) | ||
parser.add_argument( | ||
"--num_videos", | ||
type=int, | ||
default=50, | ||
help="Number of videos to decode in parallel. Only used when --num_threads is set.", | ||
) | ||
parser.add_argument( | ||
"--use_multiple_gpus", | ||
action=argparse.BooleanOptionalAction, | ||
default=True, | ||
help=("Use multiple GPUs to decode multiple videos in multi-threaded mode."), | ||
) | ||
args = parser.parse_args() | ||
video_path = args.video | ||
|
||
if not args.use_torch_benchmark: | ||
for device in args.devices.split(","): | ||
print("Testing on", device) | ||
decode_full_video(video_path, device) | ||
return | ||
|
||
resize_devices = args.resize_devices.split(",") | ||
resize_devices = [d for d in resize_devices if d != ""] | ||
if len(resize_devices) == 0: | ||
resize_devices.append("none") | ||
|
||
label = "Decode+Resize Time" | ||
|
||
results = [] | ||
for decode_device_string in args.devices.split(","): | ||
for resize_device_string in resize_devices: | ||
decode_label = decode_device_string | ||
if "cuda" in decode_label: | ||
# Shorten "cuda:0" to "cuda" | ||
decode_label = "cuda" | ||
resize_label = resize_device_string | ||
if "cuda" in resize_device_string: | ||
# Shorten "cuda:0" to "cuda" | ||
resize_label = "cuda" | ||
print("decode_device", decode_device_string) | ||
print("resize_device", resize_device_string) | ||
if args.num_threads > 1: | ||
t = benchmark.Timer( | ||
stmt="decode_videos_using_threads(video_path, decode_device_string, resize_device_string, num_videos, num_threads, use_multiple_gpus)", | ||
globals={ | ||
"decode_device_string": decode_device_string, | ||
"video_path": video_path, | ||
"decode_full_video": decode_full_video, | ||
"decode_videos_using_threads": decode_videos_using_threads, | ||
"resize_device_string": resize_device_string, | ||
"num_videos": args.num_videos, | ||
"num_threads": args.num_threads, | ||
"use_multiple_gpus": args.use_multiple_gpus, | ||
}, | ||
label=label, | ||
description=f"threads={args.num_threads} work={args.num_videos} video={os.path.basename(video_path)}", | ||
sub_label=f"D={decode_label} R={resize_label} T={args.num_threads} W={args.num_videos}", | ||
).blocked_autorange() | ||
results.append(t) | ||
else: | ||
t = benchmark.Timer( | ||
stmt="decode_full_video(video_path, decode_device_string, resize_device_string)", | ||
globals={ | ||
"decode_device_string": decode_device_string, | ||
"video_path": video_path, | ||
"decode_full_video": decode_full_video, | ||
"resize_device_string": resize_device_string, | ||
}, | ||
label=label, | ||
description=f"video={os.path.basename(video_path)}", | ||
sub_label=f"D={decode_label} R={resize_label}", | ||
).blocked_autorange() | ||
results.append(t) | ||
compare = benchmark.Compare(results) | ||
compare.print() | ||
print("Key: D=Decode, R=Resize T=threads W=work (number of videos to decode)") | ||
print("Native resize is done as part of the decode step") | ||
print("none resize means there is no resize step -- native or otherwise") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,52 @@ | ||
#include <ATen/cuda/CUDAEvent.h> | ||
#include <c10/cuda/CUDAStream.h> | ||
#include <npp.h> | ||
#include <torch/types.h> | ||
#include "src/torchcodec/decoders/_core/DeviceInterface.h" | ||
#include "src/torchcodec/decoders/_core/FFMPEGCommon.h" | ||
#include "src/torchcodec/decoders/_core/VideoDecoder.h" | ||
|
||
extern "C" { | ||
#include <libavcodec/avcodec.h> | ||
#include <libavutil/hwcontext_cuda.h> | ||
#include <libavutil/pixdesc.h> | ||
} | ||
|
||
namespace facebook::torchcodec { | ||
namespace { | ||
AVBufferRef* getCudaContext(const torch::Device& device) { | ||
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); | ||
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); | ||
torch::DeviceIndex deviceIndex = device.index(); | ||
// FFMPEG cannot handle negative device indices. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand now what instigated this code, but I still can't evaluate if it's correct. Looking at the docs, a negative value indicates the "current device": https://pytorch.org/cppdocs/api/structc10_1_1_device.html#_CPPv4N3c106DeviceE Is it safe to map all values of "current device" to 0? Is this a mapping we need to track? What happens when we are on a system with multiple GPUs? I'm assuming we don't fully understand the answers to these questions, and I don't want to block progress. So I think we should have a meatier comment both explaining what we do know, and indicating this may be a problem in the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added a longer comment with a TODO to investigate that it works properly with multi-GPU setup. I am sure once users start using it, we will hit more edge cases. |
||
// For single GPU- machines libtorch returns -1 for the device index. So for | ||
// that case we set the device index to 0. | ||
// TODO: Double check if this works for multi-GPU machines correctly. | ||
deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0); | ||
std::string deviceOrdinal = std::to_string(deviceIndex); | ||
AVBufferRef* hw_device_ctx; | ||
int err = av_hwdevice_ctx_create( | ||
&hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0); | ||
if (err < 0) { | ||
TORCH_CHECK( | ||
false, | ||
"Failed to create specified HW device", | ||
getFFMPEGErrorStringFromErrorCode(err)); | ||
} | ||
return hw_device_ctx; | ||
} | ||
|
||
torch::Tensor allocateDeviceTensor( | ||
at::IntArrayRef shape, | ||
torch::Device device, | ||
const torch::Dtype dtype = torch::kUInt8) { | ||
return torch::empty( | ||
shape, | ||
torch::TensorOptions() | ||
.dtype(dtype) | ||
.layout(torch::kStrided) | ||
.device(device)); | ||
} | ||
|
||
void throwErrorIfNonCudaDevice(const torch::Device& device) { | ||
TORCH_CHECK( | ||
|
@@ -10,13 +56,70 @@ void throwErrorIfNonCudaDevice(const torch::Device& device) { | |
throw std::runtime_error("Unsupported device: " + device.str()); | ||
} | ||
} | ||
} // namespace | ||
|
||
void initializeDeviceContext(const torch::Device& device) { | ||
void initializeContextOnCuda( | ||
const torch::Device& device, | ||
AVCodecContext* codecContext) { | ||
throwErrorIfNonCudaDevice(device); | ||
// TODO: https://github.com/pytorch/torchcodec/issues/238: Implement CUDA | ||
// device. | ||
throw std::runtime_error( | ||
"CUDA device is unimplemented. Follow this issue for tracking progress: https://github.com/pytorch/torchcodec/issues/238"); | ||
// It is important for pytorch itself to create the cuda context. If ffmpeg | ||
// creates the context it may not be compatible with pytorch. | ||
// This is a dummy tensor to initialize the cuda context. | ||
torch::Tensor dummyTensorForCudaInitialization = torch::empty( | ||
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); | ||
ahmadsharif1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
codecContext->hw_device_ctx = getCudaContext(device); | ||
return; | ||
} | ||
|
||
void convertAVFrameToDecodedOutputOnCuda( | ||
const torch::Device& device, | ||
const VideoDecoder::VideoStreamDecoderOptions& options, | ||
AVCodecContext* codecContext, | ||
VideoDecoder::RawDecodedOutput& rawOutput, | ||
VideoDecoder::DecodedOutput& output) { | ||
AVFrame* src = rawOutput.frame.get(); | ||
|
||
TORCH_CHECK( | ||
src->format == AV_PIX_FMT_CUDA, | ||
"Expected format to be AV_PIX_FMT_CUDA, got " + | ||
std::string(av_get_pix_fmt_name((AVPixelFormat)src->format))); | ||
int width = options.width.value_or(codecContext->width); | ||
int height = options.height.value_or(codecContext->height); | ||
NppiSize oSizeROI = {width, height}; | ||
Npp8u* input[2] = {src->data[0], src->data[1]}; | ||
torch::Tensor& dst = output.frame; | ||
dst = allocateDeviceTensor({height, width, 3}, options.device); | ||
|
||
// Use the user-requested GPU for running the NPP kernel. | ||
c10::cuda::CUDAGuard deviceGuard(device); | ||
|
||
auto start = std::chrono::high_resolution_clock::now(); | ||
|
||
NppStatus status = nppiNV12ToRGB_8u_P2C3R( | ||
input, | ||
src->linesize[0], | ||
static_cast<Npp8u*>(dst.data_ptr()), | ||
dst.stride(0), | ||
oSizeROI); | ||
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); | ||
// Make the pytorch stream wait for the npp kernel to finish before using the | ||
// output. | ||
ahmadsharif1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
at::cuda::CUDAEvent nppDoneEvent; | ||
at::cuda::CUDAStream nppStreamWrapper = | ||
c10::cuda::getStreamFromExternal(nppGetStream(), device.index()); | ||
nppDoneEvent.record(nppStreamWrapper); | ||
nppDoneEvent.block(at::cuda::getCurrentCUDAStream()); | ||
|
||
auto end = std::chrono::high_resolution_clock::now(); | ||
|
||
std::chrono::duration<double, std::micro> duration = end - start; | ||
VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width | ||
<< " took: " << duration.count() << "us" << std::endl; | ||
if (options.dimensionOrder == "NCHW") { | ||
// The docs guaranty this to return a view: | ||
// https://pytorch.org/docs/stable/generated/torch.permute.html | ||
dst = dst.permute({2, 0, 1}); | ||
} | ||
} | ||
|
||
} // namespace facebook::torchcodec |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this function and the function below always throw? If yes, then we should just do something like
TORCH_CHECK(false, "Unsupported device.");
. In order avoid the need for a return value, mark the function as[[noreturn]]
: https://en.cppreference.com/w/cpp/language/attributes/noreturn. We should rely on a TORCH macro to do the throwing for us rather than doing the throw ourselves, and we should make it obviously one that will always fail its check.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion. Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we maybe should also annotate
convertAVFrameToDecodedOutputOnDevice()
andinitializeDeviceContext()
with[[noreturn]]
. Let's also avoid twoTORCH_CHECK
calls. Whatever message we want to put on stderr, we can do it in one check.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The two checks are there because one is a programming/logic error on our part -- we should never pass in a CPU device for device functions.
The other is the check for passing in a non-compiled device.