Skip to content
This repository was archived by the owner on Feb 22, 2020. It is now read-only.

Commit e255bd4

Browse files
author
felix
committed
feat(shot-detector): limit number of frames in shots
1 parent dfe78c4 commit e255bd4

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

gnes/preprocessor/video/shotdetect.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self,
3131
detect_method: str = 'threshold',
3232
frame_size: str = None,
3333
frame_rate: int = 10,
34-
frame_num: int = -1,
34+
vframes: int = -1,
3535
sframes: int = -1,
3636
drop_raw_data: bool = False,
3737
*args,
@@ -42,7 +42,7 @@ def __init__(self,
4242
self.distance_metric = distance_metric
4343
self.detect_method = detect_method
4444
self.frame_rate = frame_rate
45-
self.frame_num = frame_num
45+
self.vframes = vframes
4646
self.sframes = sframes
4747
self.drop_raw_data = drop_raw_data
4848
self._detector_kwargs = kwargs
@@ -83,11 +83,11 @@ def apply(self, doc: 'gnes_pb2.Document') -> None:
8383
input_data=doc.raw_bytes,
8484
scale=self.frame_size,
8585
fps=self.frame_rate,
86-
vframes=self.frame_num)
86+
vframes=self.vframes)
8787
elif raw_type == gnes_pb2.NdArray:
8888
video_frames = blob2array(doc.raw_video)
89-
if self.frame_num > 0:
90-
video_frames = video_frames[0:self.frame_num, :]
89+
if self.vframes > 0:
90+
video_frames = video_frames[0:self.vframes, :]
9191

9292
num_frames = len(video_frames)
9393
if num_frames > 0:
@@ -99,9 +99,12 @@ def apply(self, doc: 'gnes_pb2.Document') -> None:
9999
shot_len = len(frames)
100100
c.weight = shot_len / num_frames
101101
if self.sframes > 0 and shot_len > self.sframes:
102-
start_id = int((shot_len - self.sframes) / 2)
103-
end_id = start_id + self.sframes
104-
frames = frames[start_id:end_id]
102+
begin = 0
103+
if self.sframes < 3:
104+
begin = (shot_len - self.sframes) // 2
105+
step = (shot_len) // self.sframes
106+
frames = [frames[_] for _ in range(begin, shot_len, step)]
107+
105108
chunk_data = np.array(frames)
106109
c.blob.CopyFrom(array2blob(chunk_data))
107110
else:

0 commit comments

Comments
 (0)