Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
refactor(ffmpeg): refactor ffmpeg to read frames, vides and gif
Browse files Browse the repository at this point in the history
  • Loading branch information
felix committed Aug 23, 2019
1 parent 2344a6a commit dbc06a8
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 0 deletions.
42 changes: 42 additions & 0 deletions gnes/preprocessor/io_utils/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import ffmpeg
import numpy as np
from scipy.io import wavfile


def capture_audio(filename: str = 'pipe:',
video_data: bytes = None,
bits_per_raw_sample: int = 16,
sample_rate: int = 16000,
**kwargs) -> List['np.ndarray']:

stdout, err = ffmpeg.input(filename).output(
'pipe:',
format='wav',
bits_per_raw_sample=bits_per_raw_sample,
ac=1,
ar=16000).run(
input=video_data, capture_stdout=True, capture_stderr=True)

audio_stream = io.BytesIO(stdout)
audio_data, sample_rate = sf.read(audio_stream)
# has multiple channels, do average
if len(audio_data.shape) == 2:
audio_data = np.mean(audio_data, axis=1)

return audio_data
61 changes: 61 additions & 0 deletions gnes/preprocessor/io_utils/gif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
import numpy as np
import subprocess as sp
import tempfile
from scipy import ndimage


def decode_gif(data: bytes) -> 'np.ndarray':
with tempfile.NamedTemporaryFile(suffix=".gif") as f:
f.write(data)
f.flush()
im_array = ndimage.imread(f.name)
return im_array


def encode_gif(images: List[np.ndarray],
scale: str,
# width: int,
# height: int,
fps: int,
pix_fmt: str = 'rgb24'):
"""
https://superuser.com/questions/556029/how-do-i-convert-a-video-to-gif-using-ffmpeg-with-reasonable-quality
https://gist.github.com/alexlee-gk/38916bf524dc75ca1b988d113aa30710
"""

cmd = [
'ffmpeg', '-y', '-f', 'rawvideo', '-vcodec', 'rawvideo',
'-r', '%.02f' % fps,
'-s', '%dx%d' % (images[0].shape[1], images[0].shape[0]),
'-pix_fmt', 'rgb24',
'-i', '-',
'-filter_complex', '[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse',
'-r', '%.02f' % fps,
'-s', scale,
'-f', 'gif',
'-']
proc = sp.Popen(cmd, stdin=sp.PIPE, stdout=sp.PIPE, stderr=sp.PIPE)
for frame in frames:
proc.stdin.write(frame.tostring())
out, err = proc.communicate()
if proc.returncode:
err = '\n'.join([' '.join(cmd), err.decode('utf8')])
raise IOError(err)
del proc
return out
70 changes: 70 additions & 0 deletions gnes/preprocessor/io_utils/video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import ffmpeg
import numpy as np


def _extract_frame_size(infos: str):
"""
The sollution is borrowed from:
http://concisionandconcinnity.blogspot.com/2008/04/getting-dimensions-of-video-file-in.html
"""
possible_patterns = [re.compile(r'Stream.*Video.*([0-9]{4,})x([0-9]{4,})'), \
re.compile(r'Stream.*Video.*([0-9]{4,})x([0-9]{3,})'), \
re.compile(r'Stream.*Video.*([0-9]{3,})x([0-9]{3,})')]

for pattern in possible_patterns:
match = pattern.search(err.decode())
if match is not None:
x, y = map(int, match.groups()[0:2])
break

if match is None:
raise ValueError("could not get video frame size")

return (x, y)


def capture_frames(filename: str = 'pipe:',
video_data: bytes = None,
pix_fmt: str = 'rgb24',
fps: int = -1,
scale: str = None,
**kwargs) -> List['np.ndarray']:
capture_stdin = (filename == 'pipe:')
if capture_stdin and video_data is None:
raise ValueError(
"the video data buffered from stdin should not be empty")

stream = ffmpeg.input(filename)
if self.fps > 0:
stream = stream.filter('fps', fps=self.fps, round='up')

if frame_size:
width, height = self.scale.split('*')
stream = stream.filter('scale', width, height)

stream = stream.output('pipe:', format='rawvideo', pix_fmt=self.pix_fmt)

out, err = stream.run(
input=video_data, capture_stdout=True, capture_stderr=True)

width, height = _extract_frame_size(err.decode())

frames = np.frombuffer(out,
np.uint8).reshape([-1, height, width, self.depth])
return list(frames)

0 comments on commit dbc06a8

Please sign in to comment.