|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | | -from typing import List |
17 | | - |
18 | 16 | import numpy as np |
| 17 | +from typing import List |
19 | 18 |
|
20 | | -from ..base import BaseVideoPreprocessor |
21 | | -from ..helper import compute_descriptor, compare_descriptor, detect_peak_boundary, compare_ecr |
22 | | -from ..io_utils import video as video_util |
23 | | -from ...proto import gnes_pb2, array2blob |
| 19 | +from gnes.preprocessor.base import BaseVideoPreprocessor |
| 20 | +from gnes.proto import gnes_pb2, array2blob, blob2array |
| 21 | +from gnes.preprocessor.io_utils import video |
| 22 | +from gnes.preprocessor.helper import compute_descriptor, compare_descriptor, detect_peak_boundary, compare_ecr |
24 | 23 |
|
25 | 24 |
|
26 | 25 | class ShotDetectPreprocessor(BaseVideoPreprocessor): |
27 | 26 | store_args_kwargs = True |
28 | 27 |
|
29 | 28 | def __init__(self, |
30 | | - frame_size: str = '192:168', |
| 29 | + scale: str = None, |
31 | 30 | descriptor: str = 'block_hsv_histogram', |
32 | 31 | distance_metric: str = 'bhattacharya', |
33 | 32 | detect_method: str = 'threshold', |
34 | 33 | frame_rate: int = 10, |
35 | 34 | frame_num: int = -1, |
| 35 | + drop_raw_data: bool = False, |
36 | 36 | *args, |
37 | 37 | **kwargs): |
38 | 38 | super().__init__(*args, **kwargs) |
39 | | - self.frame_size = frame_size |
| 39 | + self.scale = scale |
40 | 40 | self.descriptor = descriptor |
41 | 41 | self.distance_metric = distance_metric |
42 | 42 | self.detect_method = detect_method |
43 | 43 | self.frame_rate = frame_rate |
44 | 44 | self.frame_num = frame_num |
| 45 | + self.drop_raw_data = drop_raw_data |
45 | 46 | self._detector_kwargs = kwargs |
46 | 47 |
|
47 | 48 | def detect_shots(self, frames: 'np.ndarray') -> List[List['np.ndarray']]: |
@@ -71,23 +72,38 @@ def detect_shots(self, frames: 'np.ndarray') -> List[List['np.ndarray']]: |
71 | 72 | def apply(self, doc: 'gnes_pb2.Document') -> None: |
72 | 73 | super().apply(doc) |
73 | 74 |
|
74 | | - if doc.raw_bytes: |
75 | | - all_frames = video_util.capture_frames( |
76 | | - input_data=doc.raw_bytes, |
77 | | - scale=self.frame_size, |
78 | | - fps=self.frame_rate, |
79 | | - vframes=self.frame_num) |
80 | | - num_frames = len(all_frames) |
81 | | - assert num_frames > 0 |
82 | | - shots = self.detect_shots(all_frames) |
| 75 | + video_frames = [] |
83 | 76 |
|
84 | | - for ci, frames in enumerate(shots): |
85 | | - c = doc.chunks.add() |
86 | | - c.doc_id = doc.doc_id |
87 | | - # chunk_data = np.concatenate(frames, axis=0) |
88 | | - chunk_data = np.array(frames) |
89 | | - c.blob.CopyFrom(array2blob(chunk_data)) |
90 | | - c.offset = ci |
91 | | - c.weight = len(frames) / num_frames |
| 77 | + if doc.WhichOneof('raw_data'): |
| 78 | + raw_type = type(getattr(doc, doc.WhichOneof('raw_data'))) |
| 79 | + if doc.raw_bytes: |
| 80 | + video_frames = video.capture_frames( |
| 81 | + input_data=doc.raw_bytes, |
| 82 | + scale=self.scale, |
| 83 | + fps=self.frame_rate, |
| 84 | + vframes=self.frame_num) |
| 85 | + elif raw_type == gnes_pb2.NdArray: |
| 86 | + video_frames = blob2array(doc.raw_video) |
| 87 | + if self.frame_num > 0: |
| 88 | + stepwise = len(video_frames) / self.frame_num |
| 89 | + video_frames = video_frames[0::stepwise, :] |
| 90 | + |
| 91 | + num_frames = len(video_frames) |
| 92 | + if num_frames > 0: |
| 93 | + shots = self.detect_shots(video_frames) |
| 94 | + for ci, frames in enumerate(shots): |
| 95 | + c = doc.chunks.add() |
| 96 | + c.doc_id = doc.doc_id |
| 97 | + chunk_data = np.array(frames) |
| 98 | + c.blob.CopyFrom(array2blob(chunk_data)) |
| 99 | + c.offset = ci |
| 100 | + c.weight = len(frames) / num_frames |
| 101 | + else: |
| 102 | + self.logger.error( |
| 103 | + 'bad document: "raw_bytes" or "raw_video" is empty!') |
92 | 104 | else: |
93 | | - self.logger.error('bad document: "raw_bytes" is empty!') |
| 105 | + self.logger.error('bad document: "raw_data" is empty!') |
| 106 | + |
| 107 | + if self.drop_raw_data: |
| 108 | + self.logger.info("document raw data will be cleaned!") |
| 109 | + doc.ClearField('raw_data') |
0 commit comments