@@ -138,6 +138,24 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
138
138
return UniqueCUvideodecoder (decoder, CUvideoDecoderDeleter{});
139
139
}
140
140
141
+ cudaVideoCodec validateCodecSupport (AVCodecID codecId) {
142
+ switch (codecId) {
143
+ case AV_CODEC_ID_H264:
144
+ return cudaVideoCodec_H264;
145
+ case AV_CODEC_ID_HEVC:
146
+ return cudaVideoCodec_HEVC;
147
+ // TODONVDEC P0: support more codecs
148
+ // case AV_CODEC_ID_AV1: return cudaVideoCodec_AV1;
149
+ // case AV_CODEC_ID_MPEG4: return cudaVideoCodec_MPEG4;
150
+ // case AV_CODEC_ID_VP8: return cudaVideoCodec_VP8;
151
+ // case AV_CODEC_ID_VP9: return cudaVideoCodec_VP9;
152
+ // case AV_CODEC_ID_MJPEG: return cudaVideoCodec_JPEG;
153
+ default : {
154
+ TORCH_CHECK (false , " Unsupported codec type: " , avcodec_get_name (codecId));
155
+ }
156
+ }
157
+ }
158
+
141
159
} // namespace
142
160
143
161
BetaCudaDeviceInterface::BetaCudaDeviceInterface (const torch::Device& device)
@@ -163,29 +181,62 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
163
181
}
164
182
}
165
183
166
- void BetaCudaDeviceInterface::initializeInterface (AVStream* avStream) {
167
- torch::Tensor dummyTensorForCudaInitialization = torch::empty (
168
- {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
184
+ void BetaCudaDeviceInterface::initializeBSF (
185
+ const AVCodecParameters* codecPar,
186
+ const UniqueDecodingAVFormatContext& avFormatCtx) {
187
+ // Setup bit stream filters (BSF):
188
+ // https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html
189
+ // This is only needed for some formats, like H264 or HEVC.
169
190
170
- TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
171
- timeBase_ = avStream->time_base ;
191
+ TORCH_CHECK (codecPar != nullptr , " codecPar cannot be null" );
192
+ TORCH_CHECK (avFormatCtx != nullptr , " AVFormatContext cannot be null" );
193
+ TORCH_CHECK (
194
+ avFormatCtx->iformat != nullptr ,
195
+ " AVFormatContext->iformat cannot be null" );
196
+ std::string filterName;
197
+
198
+ // Matching logic is taken from DALI
199
+ switch (codecPar->codec_id ) {
200
+ case AV_CODEC_ID_H264: {
201
+ const std::string formatName = avFormatCtx->iformat ->long_name
202
+ ? avFormatCtx->iformat ->long_name
203
+ : " " ;
204
+
205
+ if (formatName == " QuickTime / MOV" ||
206
+ formatName == " FLV (Flash Video)" ||
207
+ formatName == " Matroska / WebM" || formatName == " raw H.264 video" ) {
208
+ filterName = " h264_mp4toannexb" ;
209
+ }
210
+ break ;
211
+ }
172
212
173
- const AVCodecParameters* codecpar = avStream->codecpar ;
174
- TORCH_CHECK (codecpar != nullptr , " CodecParameters cannot be null" );
213
+ case AV_CODEC_ID_HEVC: {
214
+ const std::string formatName = avFormatCtx->iformat ->long_name
215
+ ? avFormatCtx->iformat ->long_name
216
+ : " " ;
175
217
176
- TORCH_CHECK (
177
- // TODONVDEC P0 support more
178
- avStream->codecpar ->codec_id == AV_CODEC_ID_H264,
179
- " Can only do H264 for now" );
218
+ if (formatName == " QuickTime / MOV" ||
219
+ formatName == " FLV (Flash Video)" ||
220
+ formatName == " Matroska / WebM" || formatName == " raw HEVC video" ) {
221
+ filterName = " hevc_mp4toannexb" ;
222
+ }
223
+ break ;
224
+ }
180
225
181
- // Setup bit stream filters (BSF):
182
- // https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html
183
- // This is only needed for some formats, like H264 or HEVC. TODONVDEC P1: For
184
- // now we apply BSF unconditionally, but it should be optional and dependent
185
- // on codec and container.
186
- const AVBitStreamFilter* avBSF = av_bsf_get_by_name (" h264_mp4toannexb" );
226
+ default :
227
+ // No bitstream filter needed for other codecs
228
+ // TODONVDEC P1 MPEG4 will need one!
229
+ break ;
230
+ }
231
+
232
+ if (filterName.empty ()) {
233
+ // Only initialize BSF if we actually need one
234
+ return ;
235
+ }
236
+
237
+ const AVBitStreamFilter* avBSF = av_bsf_get_by_name (filterName.c_str ());
187
238
TORCH_CHECK (
188
- avBSF != nullptr , " Failed to find h264_mp4toannexb bitstream filter" );
239
+ avBSF != nullptr , " Failed to find bitstream filter: " , filterName );
189
240
190
241
AVBSFContext* avBSFContext = nullptr ;
191
242
int retVal = av_bsf_alloc (avBSF, &avBSFContext);
@@ -196,7 +247,7 @@ void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
196
247
197
248
bitstreamFilter_.reset (avBSFContext);
198
249
199
- retVal = avcodec_parameters_copy (bitstreamFilter_->par_in , codecpar );
250
+ retVal = avcodec_parameters_copy (bitstreamFilter_->par_in , codecPar );
200
251
TORCH_CHECK (
201
252
retVal >= AVSUCCESS,
202
253
" Failed to copy codec parameters: " ,
@@ -207,10 +258,25 @@ void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
207
258
retVal == AVSUCCESS,
208
259
" Failed to initialize bitstream filter: " ,
209
260
getFFMPEGErrorStringFromErrorCode (retVal));
261
+ }
262
+
263
+ void BetaCudaDeviceInterface::initializeInterface (
264
+ const AVStream* avStream,
265
+ const UniqueDecodingAVFormatContext& avFormatCtx) {
266
+ torch::Tensor dummyTensorForCudaInitialization = torch::empty (
267
+ {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
268
+
269
+ TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
270
+ timeBase_ = avStream->time_base ;
271
+
272
+ const AVCodecParameters* codecPar = avStream->codecpar ;
273
+ TORCH_CHECK (codecPar != nullptr , " CodecParameters cannot be null" );
274
+
275
+ initializeBSF (codecPar, avFormatCtx);
210
276
211
277
// Create parser. Default values that aren't obvious are taken from DALI.
212
278
CUVIDPARSERPARAMS parserParams = {};
213
- parserParams.CodecType = cudaVideoCodec_H264 ;
279
+ parserParams.CodecType = validateCodecSupport (codecPar-> codec_id ) ;
214
280
parserParams.ulMaxNumDecodeSurfaces = 8 ;
215
281
parserParams.ulMaxDisplayDelay = 0 ;
216
282
// Callback setup, all are triggered by the parser within a call
0 commit comments