Skip to content
This repository was archived by the owner on Feb 22, 2020. It is now read-only.

Commit de5b336

Browse files
author
felix
committed
refactor(shot-detector): merge code from hub
1 parent 981085a commit de5b336

File tree

1 file changed

+42
-26
lines changed

1 file changed

+42
-26
lines changed

gnes/preprocessor/video/shotdetect.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,35 +13,36 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import List
17-
1816
import numpy as np
17+
from typing import List
1918

20-
from ..base import BaseVideoPreprocessor
21-
from ..helper import compute_descriptor, compare_descriptor, detect_peak_boundary, compare_ecr
22-
from ..io_utils import video as video_util
23-
from ...proto import gnes_pb2, array2blob
19+
from gnes.preprocessor.base import BaseVideoPreprocessor
20+
from gnes.proto import gnes_pb2, array2blob, blob2array
21+
from gnes.preprocessor.io_utils import video
22+
from gnes.preprocessor.helper import compute_descriptor, compare_descriptor, detect_peak_boundary, compare_ecr
2423

2524

2625
class ShotDetectPreprocessor(BaseVideoPreprocessor):
2726
store_args_kwargs = True
2827

2928
def __init__(self,
30-
frame_size: str = '192:168',
29+
scale: str = None,
3130
descriptor: str = 'block_hsv_histogram',
3231
distance_metric: str = 'bhattacharya',
3332
detect_method: str = 'threshold',
3433
frame_rate: int = 10,
3534
frame_num: int = -1,
35+
drop_raw_data: bool = False,
3636
*args,
3737
**kwargs):
3838
super().__init__(*args, **kwargs)
39-
self.frame_size = frame_size
39+
self.scale = scale
4040
self.descriptor = descriptor
4141
self.distance_metric = distance_metric
4242
self.detect_method = detect_method
4343
self.frame_rate = frame_rate
4444
self.frame_num = frame_num
45+
self.drop_raw_data = drop_raw_data
4546
self._detector_kwargs = kwargs
4647

4748
def detect_shots(self, frames: 'np.ndarray') -> List[List['np.ndarray']]:
@@ -71,23 +72,38 @@ def detect_shots(self, frames: 'np.ndarray') -> List[List['np.ndarray']]:
7172
def apply(self, doc: 'gnes_pb2.Document') -> None:
7273
super().apply(doc)
7374

74-
if doc.raw_bytes:
75-
all_frames = video_util.capture_frames(
76-
input_data=doc.raw_bytes,
77-
scale=self.frame_size,
78-
fps=self.frame_rate,
79-
vframes=self.frame_num)
80-
num_frames = len(all_frames)
81-
assert num_frames > 0
82-
shots = self.detect_shots(all_frames)
75+
video_frames = []
8376

84-
for ci, frames in enumerate(shots):
85-
c = doc.chunks.add()
86-
c.doc_id = doc.doc_id
87-
# chunk_data = np.concatenate(frames, axis=0)
88-
chunk_data = np.array(frames)
89-
c.blob.CopyFrom(array2blob(chunk_data))
90-
c.offset = ci
91-
c.weight = len(frames) / num_frames
77+
if doc.WhichOneof('raw_data'):
78+
raw_type = type(getattr(doc, doc.WhichOneof('raw_data')))
79+
if doc.raw_bytes:
80+
video_frames = video.capture_frames(
81+
input_data=doc.raw_bytes,
82+
scale=self.scale,
83+
fps=self.frame_rate,
84+
vframes=self.frame_num)
85+
elif raw_type == gnes_pb2.NdArray:
86+
video_frames = blob2array(doc.raw_video)
87+
if self.frame_num > 0:
88+
stepwise = len(video_frames) / self.frame_num
89+
video_frames = video_frames[0::stepwise, :]
90+
91+
num_frames = len(video_frames)
92+
if num_frames > 0:
93+
shots = self.detect_shots(video_frames)
94+
for ci, frames in enumerate(shots):
95+
c = doc.chunks.add()
96+
c.doc_id = doc.doc_id
97+
chunk_data = np.array(frames)
98+
c.blob.CopyFrom(array2blob(chunk_data))
99+
c.offset = ci
100+
c.weight = len(frames) / num_frames
101+
else:
102+
self.logger.error(
103+
'bad document: "raw_bytes" or "raw_video" is empty!')
92104
else:
93-
self.logger.error('bad document: "raw_bytes" is empty!')
105+
self.logger.error('bad document: "raw_data" is empty!')
106+
107+
if self.drop_raw_data:
108+
self.logger.info("document raw data will be cleaned!")
109+
doc.ClearField('raw_data')

0 commit comments

Comments
 (0)