In [None]:
import itertools
import torch
import torchvision
import torchaudio
from torchvision.datasets.utils import download_url
from torchaudio.io import StreamWriter

In [None]:
print("FFmpeg library versions")
for k, v in torchaudio.utils.ffmpeg_utils.get_versions().items():
    print(f"  {k}: {v}")

In [None]:
torchvision.set_video_backend("video_reader")
stream = "video"
# Download video
# download_url(
#     "https://github.com/pytorch/vision/blob/main/test/assets/videos/WUzgd7C1pWA.mp4?raw=true",
#     ".",
#     "WUzgd7C1pWA.mp4",
# )
video_path = "./WUzgd7C1pWA.mp4"


In [None]:
def read_video(video_object, start=0, end=None, read_video=True, read_audio=True):
    if end is None:
        end = float("inf")
    if end < start:
        raise ValueError(
            "end time should be larger than start time, got "
            f"start time={start} and end time={end}"
        )

    video_frames = torch.empty(0)
    video_pts = []
    if read_video:
        video_object.set_current_stream("video")
        frames = []
        for frame in itertools.takewhile(
            lambda x: x["pts"] <= end, video_object.seek(start)
        ):
            frames.append(frame["data"])
            video_pts.append(frame["pts"])
        if len(frames) > 0:
            video_frames = torch.stack(frames, 0)

    audio_frames = torch.empty(0)
    audio_pts = []
    if read_audio:
        video_object.set_current_stream("audio")
        frames = []
        for frame in itertools.takewhile(
            lambda x: x["pts"] <= end, video_object.seek(start)
        ):
            frames.append(frame["data"])
            audio_pts.append(frame["pts"])
        if len(frames) > 0:
            audio_frames = torch.cat(frames, 0)

    return (
        video_frames,
        audio_frames,
        (video_pts, audio_pts),
        video_object.get_metadata(),
    )

# Total number of frames should be 327 for video and 523264 datapoints for audio
video = torchvision.io.VideoReader(video_path)
vf, af, info, meta = read_video(video)

In [None]:
SAMPLE_RATE=meta['audio']['framerate']
path = "./test.wav"
s = StreamWriter(path)
s.add_audio_stream(
    sample_rate=SAMPLE_RATE,
    num_channels=1,
)
# 3. Write the data
with s.open():
    s.write_audio_chunk(0, af)