Skip to content

Commit

Permalink
Fix video reading bug (#803)
Browse files Browse the repository at this point in the history
* feature: support temporal models for neural alignment by chaning TemporalIgnore to Temporal Aligned

* add example temporal submission

* complete new framework

* new module: temporal model helpers

* change the arch of temporal; add tutorials

* improve: better naming

* update: wrapper tutorial on brain model

* add feature: inferencer identifier tracked by extractor for result caching

* fix: video fps sampling; need more tests!

* fix bugs: video sampling based on fps was wrong.

* add mmaction2 models; add more features to the inferencers

* PR: temporal model helpers

* PR fix: not including gitmodules for now

* Update brainscore_vision/model_helpers/brain_transformation/temporal.py

Co-authored-by: Martin Schrimpf <mschrimpf@users.noreply.github.com>

* Update brainscore_vision/model_helpers/brain_transformation/temporal.py

Co-authored-by: Martin Schrimpf <mschrimpf@users.noreply.github.com>

* Update brainscore_vision/model_helpers/brain_transformation/temporal.py

Co-authored-by: Martin Schrimpf <mschrimpf@users.noreply.github.com>

* Update brainscore_vision/models/temporal_models/test.py

Co-authored-by: Martin Schrimpf <mschrimpf@users.noreply.github.com>

* add mae_st; add ding2012

* try new arch

* init ding2012

* add tests for temporal model helpers; add block inferencer

* Delete tests/test_model_helpers/temporal/test___init__.py

delete the old test

* add benchmark ding2012

* add mutliple libs for temporal models

* change executor output format; add more inference tests; init load_weight in s3

* add openstl

* update backend for executor

* feat:load_weight_file and corresponding test

* change:resize strategy changed from bilinear to pooling

* change:resize strategy changed from bilinear to pooling

* fix mae_st submission

* minor

* fix:dtype in assembly time align

* minor

* update model submissions

* fix dependency

* refactor: simplify the inferencer methods

* fix:block inferencer, neuroid coord while merging

* fix:inferencer identifier

* fix:weigh download

* change tests to have max_workers=1

* revert screen.py

* not submit region_layer_map

* remove torch dependency

* make fake modules in tests

* add torch to requirements; avoid torch in tests

* minor

* minor

* np.object changed to object

* remove return in tests

* fix insertion position bug

* Apply suggestions from code review

add: more type hints

Co-authored-by: Martin Schrimpf <mschrimpf@users.noreply.github.com>

* add: more type hints and comments

* minor

* pr:only commit temporal model helpers

* pr: add one model for example

* undo whole_brain in Brainodel.RecordingTarget

* use logger and fix newlines

* fix: video fps with copy was wrong

* feat:fractional max_spatial_size

* downsample layers in VideoMAE

* fix:video sampling wrong duration

* add more tests

* fix merge

* fix merge

---------

Co-authored-by: Yingtian Tang <ytang@jst285.jed.cluster>
Co-authored-by: Martin Schrimpf <mschrimpf@users.noreply.github.com>
Co-authored-by: Martin Schrimpf <m4rtinsch@gmail.com>
  • Loading branch information
4 people committed May 1, 2024
1 parent 6051f72 commit fd1ee35
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class Inferencer:
For example, {"temp_conv": "TCHW", "spatial_conv": "CHW", "fc": "C"}.
visual_degrees: float
the visual degrees of the stimuli.
max_spatial_size: int
max_spatial_size: int/float
the maximum spatial size of the activations. If the spatial size of the activations is larger than this value,
the activations will be downsampled to this size. This is used to avoid the large memory consumption by the first layers of some model.
If float, resize the image based on this factor.
dtype: np.dtype
data type of the activations.
batch_size: int
Expand Down Expand Up @@ -77,7 +78,7 @@ def __init__(
layer_activation_format : dict,
stimulus_type : Stimulus,
visual_degrees : float = 8.,
max_spatial_size : int = None,
max_spatial_size : Union[int, float] = None,
dtype : np.dtype = np.float16,
batch_size : int = 64,
batch_grouper : Callable[[Stimulus], Hashable] = None,
Expand All @@ -89,6 +90,8 @@ def __init__(

self.stimulus_type = stimulus_type
self.layer_activation_format = layer_activation_format
if isinstance(max_spatial_size, float):
assert max_spatial_size < 1, "a proporational max_spatial_size should be < 1."
self.max_spatial_size = max_spatial_size
self.visual_degrees = visual_degrees
self.dtype = dtype
Expand Down Expand Up @@ -254,10 +257,18 @@ def _package(activation: np.array, dims):
return ret

def _compute_new_size(w, h, max_spatial_size):
if h > w:
new_h = max_spatial_size
new_w = int(w * new_h / h)
if isinstance(max_spatial_size, int):
if h > w:
new_h = max_spatial_size
new_w = int(w * new_h / h)
else:
new_w = max_spatial_size
new_h = int(h * new_w / w)
else:
new_w = max_spatial_size
new_h = int(h * new_w / w)
new_h = int(h * max_spatial_size)
new_w = int(w * max_spatial_size)

new_h = max(1, new_h)
new_w = max(1, new_w)

return new_h, new_w
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def __init__(
duration : Union[float, Tuple[float, float]] = None,
time_alignment : str = "evenly_spaced",
convert_img_to_video : bool = True,
img_duration : float = 1000.,
img_duration : float = 1000.0,
batch_size : int = 32,
batch_grouper : Callable[[Video], Hashable] = lambda video: (video.duration, video.fps), # not including video.frame_size because most preprocessors will change the frame size to be the same
batch_grouper : Callable[[Video], Hashable] = lambda video: (round(video.duration, 6), video.fps), # not including video.frame_size because most preprocessors will change the frame size to be the same
**kwargs,
):
super().__init__(*args, stimulus_type=Video, batch_size=batch_size,
Expand All @@ -83,9 +83,9 @@ def __init__(

@property
def identifier(self) -> str:
id = f"{super().identifier}.{self.time_aligner.__name__}.fps={self.fps}"
id = f"{super().identifier}.{self.time_aligner.__name__}.fps={float(self.fps)}"
if self.convert_to_video:
id += f".img_dur={self.img_duration}"
id += f".img_dur={float(self.img_duration)}"
return id

def load_stimulus(self, path: Union[str, Path]) -> Video:
Expand Down Expand Up @@ -129,6 +129,6 @@ def _make_range(self, num, type="num_frames"):
def _check_video(self, video: Video):
if self.num_frames is not None:
estimated_num_frames = int(self.fps * video.duration / 1000)
assert self.num_frames[0] <= estimated_num_frames <= self.num_frames[1]
assert self.num_frames[0] <= estimated_num_frames <= self.num_frames[1], f"The number of frames must be within {self.num_frames}, but got {estimated_num_frames}"
if self.duration is not None:
assert self.duration[0] <= video.duration <= self.duration[1]
assert self.duration[0] <= video.duration <= self.duration[1], f"The duration must be within {self.duration}, but got {video.duration}"
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from brainscore_vision.model_helpers.activations.temporal.utils import batch_2d_resize


EPS = 1e-9

def get_video_stats(video_path):
cap = cv2.VideoCapture(video_path)
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
Expand All @@ -31,31 +33,53 @@ def get_image_stats(image_path):
class Video(Stimulus):
"""Video object that represents a video clip."""

def __init__(self, path: Union[str, Path], fps: float, start: float, end: float, size: Tuple[int, int]):
def __init__(
self,
path: Union[str, Path],
fps: float,
start: float,
end: float,
size: Tuple[int, int]
):
self._path = path
self._fps = fps
self._size = size
self._original_fps = self._fps
self._start = start
self._end = end
self._original_fps = None
self._original_duration = None
self._original_size = None

def __getattribute__(self, key):
if key.startswith("_original_"):
if super().__getattribute__(key) is None:
self._original_fps, self._original_duration, self._original_size = get_video_stats(self._path)
return super().__getattribute__(key)

def copy(self):
# return view
video = self.__class__(self._path, self._fps, self._start, self._end, self._size)
video._original_fps = self._original_fps
video._original_duration = self._original_duration
video._original_size = self._original_size
return video

@property
def duration(self):
# in ms
return self._end - self._start

@property
def fps(self):
return self._fps

@property
def num_frames(self):
return int(self.duration * self.fps/1000)
return int(self.duration * self.fps/1000 + EPS)

@property
def original_num_frames(self):
return int(self._original_duration * self._original_fps/1000 + EPS)

@property
def frame_size(self):
Expand Down Expand Up @@ -110,12 +134,13 @@ def to_numpy(self):
# get the time stamps of frame samples
start_frame = self._start * self._original_fps / 1000
end_frame = self._end * self._original_fps / 1000
EPS = 1e-9 # avoid taking the last extra frame
samples = np.arange(start_frame, end_frame - EPS, self._original_fps/self._fps)
# avoid taking the last extra frame
samples = np.arange(start_frame, end_frame - EPS, self._original_fps/self.fps)
sample_indices = samples.astype(int)

# padding: repeat the first/last frame
sample_indices = np.clip(sample_indices, 0, self.num_frames-1)
original_num_frames = int(self._original_duration * self._original_fps/1000 + EPS)
sample_indices = np.clip(sample_indices, 0, original_num_frames-1)

# actual sampling
frames = self.get_frames(sample_indices)
Expand All @@ -137,6 +162,21 @@ def to_path(self):
path = None # make a temporal file
raise NotImplementedError()
return path

def store_to_path(self, path):
# pick format based on path filename
if path.endswith(".avi"):
fourcc = cv2.VideoWriter_fourcc(*'XVID')
elif path.endswith(".mp4"):
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
else:
raise ValueError("Unsupported video format.")

out = cv2.VideoWriter(path, fourcc, self._fps, self._size)
for frame in self.to_frames():
out.write(frame[...,::-1]) # to RGB
out.release()
return path


class VideoFromImage(Video):
Expand Down
4 changes: 3 additions & 1 deletion brainscore_vision/models/temporal_model_VideoMAE/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from torchvision import transforms


LAYER_SELECTION_STEP = 2

class VideoMAEv1Wrapper(PytorchWrapper):
def forward(self, inputs):
tensor = th.stack(inputs)
Expand Down Expand Up @@ -78,7 +80,7 @@ def get_model(identifier, num_frames=16):
"fps": 6.25,
"layer_activation_format": {
"encoder.patch_embed": "THWC",
**{f"encoder.blocks.{i}": "THWC" for i in range(num_blocks)},
**{f"encoder.blocks.{i}": "THWC" for i in range(0, num_blocks, LAYER_SELECTION_STEP)},
},
"num_frames": num_frames,
}
Expand Down
10 changes: 5 additions & 5 deletions tests/test_model_helpers/temporal/activations/test_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,19 @@ def test_compute_temporal_context():

@pytest.mark.memory_intense
@pytest.mark.parametrize("preprocess", ["normal", "downsample"])
def test_causal_inferencer(preprocess):
@pytest.mark.parametrize("fps", [1, 40])
def test_causal_inferencer(preprocess, fps):
if preprocess == "normal":
preprocess = dummy_preprocess
else:
preprocess = time_down_sample_preprocess
fps = 10
inferencer = CausalInferencer(dummy_get_features, dummy_preprocess,
dummy_layer_activation_format,
fps=fps, max_workers=1)
model_assembly = inferencer(video_paths, layers=dummy_layers)
assert model_assembly.sizes["time_bin"] == 6 * fps
assert np.isclose(model_assembly['time_bin_end'].values[0] - model_assembly['time_bin_start'].values[0], 1000/fps)
assert inferencer._compute_temporal_context() == (100, np.inf)
assert inferencer._compute_temporal_context() == (1000/fps, np.inf)

# manual computation check
output_values = model_assembly.sel(stimulus_path=video_paths[1])\
Expand All @@ -159,12 +159,12 @@ def test_causal_inferencer(preprocess):

@pytest.mark.memory_intense
@pytest.mark.parametrize("preprocess", ["normal", "downsample"])
def test_block_inferencer(preprocess):
@pytest.mark.parametrize("fps", [1, 40])
def test_block_inferencer(preprocess, fps):
if preprocess == "normal":
preprocessing = dummy_preprocess
else:
preprocessing = time_down_sample_preprocess
fps = 10
inferencer = BlockInferencer(dummy_get_features, preprocessing, dummy_layer_activation_format, fps=fps,
duration=(200, 4000), temporal_context_strategy="greedy", max_workers=1)
model_assembly = inferencer(video_paths, layers=dummy_layers)
Expand Down
23 changes: 21 additions & 2 deletions tests/test_model_helpers/temporal/activations/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,28 @@ def test_video():
assert video8.duration == 100
assert (video8.to_numpy() == video1.set_window(300, 400).to_numpy()).all()

# test copy
video9 = video1.set_fps(5).copy().set_fps(30).copy()
assert (video9.to_numpy()[1] == video1.to_numpy()[2]).all()
assert (video9.to_numpy()[2] == video1.to_numpy()[4]).all()

for frame in [10, 50, 100]:
time_start = 1000 / video1.fps * frame
video10 = video1.set_window(time_start, time_start+1000/video1.fps)
assert video10.to_numpy().shape[0] == 1
assert (video10.to_numpy()[0] == video1.to_numpy()[frame]).all()

video10 = video1.set_window(0, time_start+1000/video1.fps)
assert video10.to_numpy().shape[0] == frame+1
assert (video10.to_numpy()[frame] == video1.to_numpy()[frame]).all()

video10 = video1.set_window(time_start, video1.duration)
assert video10.to_numpy().shape[0] == video1.to_numpy().shape[0] - frame
assert (video10.to_numpy()[0] == video1.to_numpy()[frame]).all()

for fps in [7.5, 9, 1, 43, 1000/video1.duration, 1001/video1.duration]:
video9 = video1.set_fps(fps)
assert video9.to_numpy().shape[0] == np.ceil(video1.duration * fps / 1000)
video11 = video1.set_fps(fps)
assert video11.to_numpy().shape[0] == np.ceil(video1.duration * fps / 1000)

for v in [video1, video2]:
target_num_frames = 7
Expand Down

0 comments on commit fd1ee35

Please sign in to comment.