Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ VideoEncoder::~VideoEncoder() {

VideoEncoder::VideoEncoder(
const torch::Tensor& frames,
int frameRate,
double frameRate,
std::string_view fileName,
const VideoStreamOptions& videoStreamOptions)
: frames_(validateFrames(frames)), inFrameRate_(frameRate) {
Expand Down Expand Up @@ -694,7 +694,7 @@ VideoEncoder::VideoEncoder(

VideoEncoder::VideoEncoder(
const torch::Tensor& frames,
int frameRate,
double frameRate,
std::string_view formatName,
std::unique_ptr<AVIOContextHolder> avioContextHolder,
const VideoStreamOptions& videoStreamOptions)
Expand Down Expand Up @@ -787,9 +787,9 @@ void VideoEncoder::initializeEncoder(
avCodecContext_->width = outWidth_;
avCodecContext_->height = outHeight_;
avCodecContext_->pix_fmt = outPixelFormat_;
// TODO-VideoEncoder: Verify that frame_rate and time_base are correct
avCodecContext_->time_base = {1, inFrameRate_};
avCodecContext_->framerate = {inFrameRate_, 1};
// TODO-VideoEncoder: Add and utilize output frame_rate option
avCodecContext_->framerate = av_d2q(inFrameRate_, INT_MAX);
avCodecContext_->time_base = av_inv_q(avCodecContext_->framerate);

// Set flag for containers that require extradata to be in the codec context
if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) {
Expand Down Expand Up @@ -833,6 +833,10 @@ void VideoEncoder::initializeEncoder(

// Set the stream time base to encode correct frame timestamps
avStream_->time_base = avCodecContext_->time_base;
// Set the stream frame rate to store correct frame durations for some
// containers (webm, mkv)
avStream_->r_frame_rate = avCodecContext_->framerate;

status = avcodec_parameters_from_context(
avStream_->codecpar, avCodecContext_.get());
TORCH_CHECK(
Expand Down
6 changes: 3 additions & 3 deletions src/torchcodec/_core/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,13 @@ class VideoEncoder {

VideoEncoder(
const torch::Tensor& frames,
int frameRate,
double frameRate,
std::string_view fileName,
const VideoStreamOptions& videoStreamOptions);

VideoEncoder(
const torch::Tensor& frames,
int frameRate,
double frameRate,
std::string_view formatName,
std::unique_ptr<AVIOContextHolder> avioContextHolder,
const VideoStreamOptions& videoStreamOptions);
Expand All @@ -172,7 +172,7 @@ class VideoEncoder {
UniqueSwsContext swsContext_;

const torch::Tensor frames_;
int inFrameRate_;
double inFrameRate_;

int inWidth_ = -1;
int inHeight_ = -1;
Expand Down
23 changes: 9 additions & 14 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
m.def(
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
"encode_video_to_file(Tensor frames, float frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
m.def(
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor");
"encode_video_to_tensor(Tensor frames, float frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor");
m.def(
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
"_encode_video_to_file_like(Tensor frames, float frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
m.def(
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
m.def(
Expand Down Expand Up @@ -611,7 +611,7 @@ void _encode_audio_to_file_like(

void encode_video_to_file(
const at::Tensor& frames,
int64_t frame_rate,
double frame_rate,
std::string_view file_name,
std::optional<std::string_view> codec = std::nullopt,
std::optional<std::string_view> pixel_format = std::nullopt,
Expand All @@ -629,17 +629,12 @@ void encode_video_to_file(
unflattenExtraOptions(extra_options.value());
}

VideoEncoder(
frames,
validateInt64ToInt(frame_rate, "frame_rate"),
file_name,
videoStreamOptions)
.encode();
VideoEncoder(frames, frame_rate, file_name, videoStreamOptions).encode();
}

at::Tensor encode_video_to_tensor(
const at::Tensor& frames,
int64_t frame_rate,
double frame_rate,
std::string_view format,
std::optional<std::string_view> codec = std::nullopt,
std::optional<std::string_view> pixel_format = std::nullopt,
Expand All @@ -660,7 +655,7 @@ at::Tensor encode_video_to_tensor(

return VideoEncoder(
frames,
validateInt64ToInt(frame_rate, "frame_rate"),
frame_rate,
format,
std::move(avioContextHolder),
videoStreamOptions)
Expand All @@ -669,7 +664,7 @@ at::Tensor encode_video_to_tensor(

void _encode_video_to_file_like(
const at::Tensor& frames,
int64_t frame_rate,
double frame_rate,
std::string_view format,
int64_t file_like_context,
std::optional<std::string_view> codec = std::nullopt,
Expand All @@ -696,7 +691,7 @@ void _encode_video_to_file_like(

VideoEncoder encoder(
frames,
validateInt64ToInt(frame_rate, "frame_rate"),
frame_rate,
format,
std::move(avioContextHolder),
videoStreamOptions);
Expand Down
8 changes: 4 additions & 4 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def encode_audio_to_file_like(

def encode_video_to_file_like(
frames: torch.Tensor,
frame_rate: int,
frame_rate: float,
format: str,
file_like: Union[io.RawIOBase, io.BufferedIOBase],
codec: Optional[str] = None,
Expand Down Expand Up @@ -329,7 +329,7 @@ def _encode_audio_to_file_like_abstract(
@register_fake("torchcodec_ns::encode_video_to_file")
def encode_video_to_file_abstract(
frames: torch.Tensor,
frame_rate: int,
frame_rate: float,
filename: str,
codec: Optional[str] = None,
pixel_format: Optional[str] = None,
Expand All @@ -343,7 +343,7 @@ def encode_video_to_file_abstract(
@register_fake("torchcodec_ns::encode_video_to_tensor")
def encode_video_to_tensor_abstract(
frames: torch.Tensor,
frame_rate: int,
frame_rate: float,
format: str,
codec: Optional[str] = None,
pixel_format: Optional[str] = None,
Expand All @@ -357,7 +357,7 @@ def encode_video_to_tensor_abstract(
@register_fake("torchcodec_ns::_encode_video_to_file_like")
def _encode_video_to_file_like_abstract(
frames: torch.Tensor,
frame_rate: int,
frame_rate: float,
format: str,
file_like_context: int,
codec: Optional[str] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/torchcodec/encoders/_video_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ class VideoEncoder:
tensor of shape ``(N, C, H, W)`` where N is the number of frames,
C is 3 channels (RGB), H is height, and W is width.
Values must be uint8 in the range ``[0, 255]``.
frame_rate (int): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate.
frame_rate (float): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate.
"""

def __init__(self, frames: Tensor, *, frame_rate: int):
def __init__(self, frames: Tensor, *, frame_rate: float):
torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder")
if not isinstance(frames, Tensor):
raise ValueError(f"Expected frames to be a Tensor, got {type(frames) = }.")
Expand Down
Loading
Loading