Skip to content

Commit 12c75e7

Browse files
committed
Add h265 support
1 parent 70873bf commit 12c75e7

File tree

7 files changed

+136
-31
lines changed

7 files changed

+136
-31
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 86 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,24 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
138138
return UniqueCUvideodecoder(decoder, CUvideoDecoderDeleter{});
139139
}
140140

141+
cudaVideoCodec validateCodecSupport(AVCodecID codecId) {
142+
switch (codecId) {
143+
case AV_CODEC_ID_H264:
144+
return cudaVideoCodec_H264;
145+
case AV_CODEC_ID_HEVC:
146+
return cudaVideoCodec_HEVC;
147+
// TODONVDEC P0: support more codecs
148+
// case AV_CODEC_ID_AV1: return cudaVideoCodec_AV1;
149+
// case AV_CODEC_ID_MPEG4: return cudaVideoCodec_MPEG4;
150+
// case AV_CODEC_ID_VP8: return cudaVideoCodec_VP8;
151+
// case AV_CODEC_ID_VP9: return cudaVideoCodec_VP9;
152+
// case AV_CODEC_ID_MJPEG: return cudaVideoCodec_JPEG;
153+
default: {
154+
TORCH_CHECK(false, "Unsupported codec type: ", avcodec_get_name(codecId));
155+
}
156+
}
157+
}
158+
141159
} // namespace
142160

143161
BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
@@ -163,29 +181,62 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
163181
}
164182
}
165183

166-
void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
167-
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
168-
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
184+
void BetaCudaDeviceInterface::initializeBSF(
185+
const AVCodecParameters* codecPar,
186+
const UniqueDecodingAVFormatContext& avFormatCtx) {
187+
// Setup bit stream filters (BSF):
188+
// https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html
189+
// This is only needed for some formats, like H264 or HEVC.
169190

170-
TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
171-
timeBase_ = avStream->time_base;
191+
TORCH_CHECK(codecPar != nullptr, "codecPar cannot be null");
192+
TORCH_CHECK(avFormatCtx != nullptr, "AVFormatContext cannot be null");
193+
TORCH_CHECK(
194+
avFormatCtx->iformat != nullptr,
195+
"AVFormatContext->iformat cannot be null");
196+
std::string filterName;
197+
198+
// Matching logic is taken from DALI
199+
switch (codecPar->codec_id) {
200+
case AV_CODEC_ID_H264: {
201+
const std::string formatName = avFormatCtx->iformat->long_name
202+
? avFormatCtx->iformat->long_name
203+
: "";
204+
205+
if (formatName == "QuickTime / MOV" ||
206+
formatName == "FLV (Flash Video)" ||
207+
formatName == "Matroska / WebM" || formatName == "raw H.264 video") {
208+
filterName = "h264_mp4toannexb";
209+
}
210+
break;
211+
}
172212

173-
const AVCodecParameters* codecpar = avStream->codecpar;
174-
TORCH_CHECK(codecpar != nullptr, "CodecParameters cannot be null");
213+
case AV_CODEC_ID_HEVC: {
214+
const std::string formatName = avFormatCtx->iformat->long_name
215+
? avFormatCtx->iformat->long_name
216+
: "";
175217

176-
TORCH_CHECK(
177-
// TODONVDEC P0 support more
178-
avStream->codecpar->codec_id == AV_CODEC_ID_H264,
179-
"Can only do H264 for now");
218+
if (formatName == "QuickTime / MOV" ||
219+
formatName == "FLV (Flash Video)" ||
220+
formatName == "Matroska / WebM" || formatName == "raw HEVC video") {
221+
filterName = "hevc_mp4toannexb";
222+
}
223+
break;
224+
}
180225

181-
// Setup bit stream filters (BSF):
182-
// https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html
183-
// This is only needed for some formats, like H264 or HEVC. TODONVDEC P1: For
184-
// now we apply BSF unconditionally, but it should be optional and dependent
185-
// on codec and container.
186-
const AVBitStreamFilter* avBSF = av_bsf_get_by_name("h264_mp4toannexb");
226+
default:
227+
// No bitstream filter needed for other codecs
228+
// TODONVDEC P1 MPEG4 will need one!
229+
break;
230+
}
231+
232+
if (filterName.empty()) {
233+
// Only initialize BSF if we actually need one
234+
return;
235+
}
236+
237+
const AVBitStreamFilter* avBSF = av_bsf_get_by_name(filterName.c_str());
187238
TORCH_CHECK(
188-
avBSF != nullptr, "Failed to find h264_mp4toannexb bitstream filter");
239+
avBSF != nullptr, "Failed to find bitstream filter: ", filterName);
189240

190241
AVBSFContext* avBSFContext = nullptr;
191242
int retVal = av_bsf_alloc(avBSF, &avBSFContext);
@@ -196,7 +247,7 @@ void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
196247

197248
bitstreamFilter_.reset(avBSFContext);
198249

199-
retVal = avcodec_parameters_copy(bitstreamFilter_->par_in, codecpar);
250+
retVal = avcodec_parameters_copy(bitstreamFilter_->par_in, codecPar);
200251
TORCH_CHECK(
201252
retVal >= AVSUCCESS,
202253
"Failed to copy codec parameters: ",
@@ -207,10 +258,25 @@ void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
207258
retVal == AVSUCCESS,
208259
"Failed to initialize bitstream filter: ",
209260
getFFMPEGErrorStringFromErrorCode(retVal));
261+
}
262+
263+
void BetaCudaDeviceInterface::initializeInterface(
264+
const AVStream* avStream,
265+
const UniqueDecodingAVFormatContext& avFormatCtx) {
266+
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
267+
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
268+
269+
TORCH_CHECK(avStream != nullptr, "AVStream cannot be null");
270+
timeBase_ = avStream->time_base;
271+
272+
const AVCodecParameters* codecPar = avStream->codecpar;
273+
TORCH_CHECK(codecPar != nullptr, "CodecParameters cannot be null");
274+
275+
initializeBSF(codecPar, avFormatCtx);
210276

211277
// Create parser. Default values that aren't obvious are taken from DALI.
212278
CUVIDPARSERPARAMS parserParams = {};
213-
parserParams.CodecType = cudaVideoCodec_H264;
279+
parserParams.CodecType = validateCodecSupport(codecPar->codec_id);
214280
parserParams.ulMaxNumDecodeSurfaces = 8;
215281
parserParams.ulMaxDisplayDelay = 0;
216282
// Callback setup, all are triggered by the parser within a call

src/torchcodec/_core/BetaCudaDeviceInterface.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ class BetaCudaDeviceInterface : public DeviceInterface {
3737
explicit BetaCudaDeviceInterface(const torch::Device& device);
3838
virtual ~BetaCudaDeviceInterface();
3939

40-
void initializeInterface(AVStream* stream) override;
40+
void initializeInterface(
41+
const AVStream* stream,
42+
const UniqueDecodingAVFormatContext& avFormatCtx) override;
4143

4244
void convertAVFrameToFrameOutput(
4345
const VideoStreamOptions& videoStreamOptions,
@@ -63,6 +65,9 @@ class BetaCudaDeviceInterface : public DeviceInterface {
6365
private:
6466
// Apply bitstream filter, modifies packet in-place
6567
void applyBSF(ReferenceAVPacket& packet);
68+
void initializeBSF(
69+
const AVCodecParameters* codecPar,
70+
const UniqueDecodingAVFormatContext& avFormatCtx);
6671

6772
UniqueAVFrame convertCudaFrameToAVFrame(
6873
CUdeviceptr framePtr,

src/torchcodec/_core/DeviceInterface.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ class DeviceInterface {
5555
virtual void initializeContext(
5656
[[maybe_unused]] AVCodecContext* codecContext) {}
5757

58-
virtual void initializeInterface([[maybe_unused]] AVStream* stream) {}
58+
virtual void initializeInterface(
59+
[[maybe_unused]] const AVStream* stream,
60+
[[maybe_unused]] const UniqueDecodingAVFormatContext& avFormatCtx) {}
5961

6062
virtual void convertAVFrameToFrameOutput(
6163
const VideoStreamOptions& videoStreamOptions,

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ void SingleStreamDecoder::addStream(
462462
if (mediaType == AVMEDIA_TYPE_VIDEO) {
463463
if (deviceInterface_) {
464464
deviceInterface_->initializeContext(codecContext);
465-
deviceInterface_->initializeInterface(streamInfo.stream);
465+
deviceInterface_->initializeInterface(streamInfo.stream, formatContext_);
466466
}
467467
}
468468

test/resources/testsrc2_h265.mp4

890 KB
Binary file not shown.

test/test_decoders.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
SINE_MONO_S32_44100,
4545
SINE_MONO_S32_8000,
4646
TEST_SRC_2_720P,
47+
TEST_SRC_2_720P_H265,
4748
)
4849

4950

@@ -1415,7 +1416,9 @@ def test_get_frames_at_tensor_indices(self):
14151416
# assert_tensor_close_on_at_least or something like that.
14161417

14171418
@needs_cuda
1418-
@pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE))
1419+
@pytest.mark.parametrize(
1420+
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1421+
)
14191422
@pytest.mark.parametrize("contiguous_indices", (True, False))
14201423
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
14211424
def test_beta_cuda_interface_get_frame_at(
@@ -1445,7 +1448,9 @@ def test_beta_cuda_interface_get_frame_at(
14451448
assert beta_frame.duration_seconds == ref_frame.duration_seconds
14461449

14471450
@needs_cuda
1448-
@pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE))
1451+
@pytest.mark.parametrize(
1452+
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1453+
)
14491454
@pytest.mark.parametrize("contiguous_indices", (True, False))
14501455
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
14511456
def test_beta_cuda_interface_get_frames_at(
@@ -1476,7 +1481,9 @@ def test_beta_cuda_interface_get_frames_at(
14761481
)
14771482

14781483
@needs_cuda
1479-
@pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE))
1484+
@pytest.mark.parametrize(
1485+
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1486+
)
14801487
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
14811488
def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode):
14821489
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
@@ -1498,7 +1505,9 @@ def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode):
14981505
assert beta_frame.duration_seconds == ref_frame.duration_seconds
14991506

15001507
@needs_cuda
1501-
@pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE))
1508+
@pytest.mark.parametrize(
1509+
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1510+
)
15021511
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
15031512
def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode):
15041513
ref_decoder = VideoDecoder(asset.path, device="cuda", seek_mode=seek_mode)
@@ -1521,7 +1530,9 @@ def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode):
15211530
)
15221531

15231532
@needs_cuda
1524-
@pytest.mark.parametrize("asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE))
1533+
@pytest.mark.parametrize(
1534+
"asset", (NASA_VIDEO, TEST_SRC_2_720P, BT709_FULL_RANGE, TEST_SRC_2_720P_H265)
1535+
)
15251536
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
15261537
def test_beta_cuda_interface_backwards(self, asset, seek_mode):
15271538

@@ -1541,12 +1552,24 @@ def test_beta_cuda_interface_backwards(self, asset, seek_mode):
15411552
assert beta_frame.pts_seconds == ref_frame.pts_seconds
15421553
assert beta_frame.duration_seconds == ref_frame.duration_seconds
15431554

1555+
@needs_cuda
1556+
def test_beta_cuda_interface_small_h265(self):
1557+
# TODONVDEC P2 investigate why/how the default interface can decode this
1558+
# video.
1559+
1560+
# This is fine on the default interface - why?
1561+
VideoDecoder(H265_VIDEO.path, device="cuda").get_frame_at(0)
1562+
# But it fails on the beta interface due to input validation checks, which we took from DALI!
1563+
with pytest.raises(
1564+
RuntimeError,
1565+
match="Video is too small in at least one dimension. Provided: 128x128 vs supported:144x144",
1566+
):
1567+
VideoDecoder(H265_VIDEO.path, device="cuda:0:beta").get_frame_at(0)
1568+
15441569
@needs_cuda
15451570
def test_beta_cuda_interface_error(self):
1546-
with pytest.raises(RuntimeError, match="Can only do H264 for now"):
1571+
with pytest.raises(RuntimeError, match="Unsupported codec type: av1"):
15471572
VideoDecoder(AV1_VIDEO.path, device="cuda:0:beta")
1548-
with pytest.raises(RuntimeError, match="Can only do H264 for now"):
1549-
VideoDecoder(H265_VIDEO.path, device="cuda:0:beta")
15501573
with pytest.raises(RuntimeError, match="Unsupported device"):
15511574
VideoDecoder(NASA_VIDEO.path, device="cuda:0:bad_variant")
15521575

test/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,3 +688,12 @@ def sample_format(self) -> str:
688688
},
689689
frames={0: {}}, # Not needed for now
690690
)
691+
# ffmpeg -f lavfi -i testsrc2=duration=10:size=1280x720:rate=30 -c:v libx265 -crf 23 -preset medium output.mp4
692+
TEST_SRC_2_720P_H265 = TestVideo(
693+
filename="testsrc2_h265.mp4",
694+
default_stream_index=0,
695+
stream_infos={
696+
0: TestVideoStreamInfo(width=1280, height=720, num_color_channels=3),
697+
},
698+
frames={0: {}}, # Not needed for now
699+
)

0 commit comments

Comments
 (0)