Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarify caching #131

Merged
merged 28 commits into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
61592b7
set ability to turn off caching for prediction
ejm714 Oct 13, 2021
d623912
reomve cache dir from cli
ejm714 Oct 13, 2021
4e61adf
test cache dir is set but not used
ejm714 Oct 13, 2021
bc96a27
rename to MODEL_CACHE_DIR for consistency
ejm714 Oct 13, 2021
240ba24
rename cache_dir to model_cache_dir
ejm714 Oct 13, 2021
e9322b5
move video cache dir into configs
ejm714 Oct 13, 2021
59fe563
remove unneeded code as caching is off by default
ejm714 Oct 13, 2021
10d27ae
rename to video_cache_dir for clarity
ejm714 Oct 13, 2021
838d874
remove
ejm714 Oct 13, 2021
531810f
do not support setting in configs
ejm714 Oct 13, 2021
6ded38e
put back in settings
ejm714 Oct 13, 2021
980f98d
lint and such
ejm714 Oct 13, 2021
529cfd2
rebase fix
ejm714 Oct 14, 2021
ce74a90
put cache_dir and cleanup option on video laoder config
ejm714 Oct 14, 2021
530e27b
add tests for caching
ejm714 Oct 14, 2021
ebf42ef
get empty vlc if none is passed
ejm714 Oct 14, 2021
98f9eda
put within func to avoid writing to real path
ejm714 Oct 14, 2021
252da67
add logging
ejm714 Oct 14, 2021
b49d374
bug fix
ejm714 Oct 14, 2021
5fecc96
reomve old change
ejm714 Oct 16, 2021
4be02c5
loguru uses warning not warn
ejm714 Oct 16, 2021
6fa7378
load dotenv in init
ejm714 Oct 17, 2021
f6174f2
setup logger in init and rename to log_level for simplicity
ejm714 Oct 17, 2021
7093b44
cleanup does not change hash
ejm714 Oct 17, 2021
7f43799
fix test
ejm714 Oct 17, 2021
50c8806
lint
ejm714 Oct 18, 2021
8bad7d0
add regression test
ejm714 Oct 18, 2021
e1b5729
do not use load_dotenv
ejm714 Oct 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def test_actual_prediction_on_single_video(tmp_path): # noqa: F811

save_path = tmp_path / "zamba" / "my_preds.csv"

# Prior to mocking, run one real prediction using config
result = runner.invoke(
app,
[
Expand Down
12 changes: 6 additions & 6 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,16 @@ def test_predict_filepaths_with_duplicates(labels_absolute_path, tmp_path, caplo
assert "Found 1 duplicate row(s) in filepaths csv. Dropping duplicates" in caplog.text


def test_cache_dir(labels_absolute_path, tmp_path):
def test_model_cache_dir(labels_absolute_path, tmp_path):
config = TrainConfig(labels=labels_absolute_path)
assert config.cache_dir == Path(appdirs.user_cache_dir()) / "zamba"
assert config.model_cache_dir == Path(appdirs.user_cache_dir()) / "zamba"

os.environ["ZAMBA_CACHE_DIR"] = str(tmp_path)
os.environ["MODEL_CACHE_DIR"] = str(tmp_path)
config = TrainConfig(labels=labels_absolute_path)
assert config.cache_dir == tmp_path
assert config.model_cache_dir == tmp_path

config = PredictConfig(filepaths=labels_absolute_path, cache_dir=tmp_path / "my_cache")
assert config.cache_dir == tmp_path / "my_cache"
config = PredictConfig(filepaths=labels_absolute_path, model_cache_dir=tmp_path / "my_cache")
assert config.model_cache_dir == tmp_path / "my_cache"


def test_predict_save(labels_absolute_path, tmp_path, dummy_trained_model_checkpoint):
Expand Down
20 changes: 10 additions & 10 deletions tests/test_instantiate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_scheduler_ignored_for_prediction(dummy_checkpoint, tmp_path):
weight_download_region="us",
scheduler_config=SchedulerConfig(scheduler="StepLR", scheduler_params=None),
labels=None,
cache_dir=tmp_path,
model_cache_dir=tmp_path,
)
# since labels is None, we are predicting. as a result, hparams are not updated
assert model.hparams["scheduler"] is None
Expand All @@ -33,7 +33,7 @@ def test_default_scheduler_used(time_distributed_checkpoint, tmp_path):
weight_download_region="us",
scheduler_config="default",
labels=pd.DataFrame([{"filepath": "gorilla.mp4", "species_gorilla": 1}]),
cache_dir=tmp_path,
model_cache_dir=tmp_path,
)

# with "default" scheduler_config, hparams from training are used
Expand All @@ -50,7 +50,7 @@ def test_scheduler_used_if_passed(time_distributed_checkpoint, tmp_path):
weight_download_region="us",
scheduler_config=SchedulerConfig(scheduler="StepLR"),
labels=pd.DataFrame([{"filepath": "gorilla.mp4", "species_gorilla": 1}]),
cache_dir=tmp_path,
model_cache_dir=tmp_path,
)

# hparams reflect user specified scheduler config
Expand All @@ -64,7 +64,7 @@ def test_scheduler_used_if_passed(time_distributed_checkpoint, tmp_path):
weight_download_region="us",
scheduler_config=SchedulerConfig(scheduler="StepLR", scheduler_params={"gamma": 0.3}),
labels=pd.DataFrame([{"filepath": "gorilla.mp4", "species_gorilla": 1}]),
cache_dir=tmp_path,
model_cache_dir=tmp_path,
)
assert scheduler_params_passed_model.hparams["scheduler_params"] == {"gamma": 0.3}

Expand All @@ -76,7 +76,7 @@ def test_remove_scheduler(time_distributed_checkpoint, tmp_path):
weight_download_region="us",
scheduler_config=SchedulerConfig(scheduler=None),
labels=pd.DataFrame([{"filepath": "gorilla.mp4", "species_gorilla": 1}]),
cache_dir=tmp_path,
model_cache_dir=tmp_path,
)
# pretrained model has scheduler but this is overridden with SchedulerConfig
assert remove_scheduler_model.hparams["scheduler"] is None
Expand All @@ -94,7 +94,7 @@ def test_head_not_replaced_for_species_subset(dummy_trained_model_checkpoint, tm
weight_download_region="us",
scheduler_config="default",
labels=pd.DataFrame([{"filepath": "gorilla.mp4", "species_gorilla": 1}]),
cache_dir=tmp_path,
model_cache_dir=tmp_path,
)

assert (model.head.weight == original_model.head.weight).all()
Expand All @@ -118,7 +118,7 @@ def test_not_predict_all_zamba_species(dummy_trained_model_checkpoint, tmp_path)
weight_download_region="us",
scheduler_config="default",
labels=pd.DataFrame([{"filepath": "gorilla.mp4", "species_gorilla": 1}]),
cache_dir=tmp_path,
model_cache_dir=tmp_path,
predict_all_zamba_species=False,
)

Expand All @@ -141,7 +141,7 @@ def test_head_replaced_for_new_species(dummy_trained_model_checkpoint, tmp_path)
weight_download_region="us",
scheduler_config="default",
labels=pd.DataFrame([{"filepath": "alien.mp4", "species_alien": 1}]),
cache_dir=tmp_path,
model_cache_dir=tmp_path,
)

assert (model.head.weight != original_model.head.weight).all()
Expand All @@ -157,7 +157,7 @@ def test_finetune_new_labels(labels_absolute_path, model, tmp_path):
weight_download_region=config.weight_download_region,
scheduler_config="default",
labels=pd.DataFrame([{"filepath": "kangaroo.mp4", "species_kangaroo": 1}]),
cache_dir=tmp_path,
model_cache_dir=tmp_path,
)
assert model.species == ["kangaroo"]

Expand All @@ -171,7 +171,7 @@ def test_resume_subset_labels(labels_absolute_path, model, tmp_path):
scheduler_config=SchedulerConfig(scheduler="StepLR", scheduler_params=None),
# pick species that is present in all models
labels=pd.DataFrame([{"filepath": "bird.mp4", "species_bird": 1}]),
cache_dir=tmp_path,
model_cache_dir=tmp_path,
)
assert model.hparams["scheduler"] == "StepLR"

Expand Down
90 changes: 77 additions & 13 deletions tests/test_load_video_frames.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
import pytest
import shutil
import subprocess
from typing import Any, Callable, Dict, Optional, Union
from unittest import mock

import numpy as np
from PIL import Image
from pydantic import BaseModel, ValidationError

from zamba.data.video import (
cached_load_video_frames,
load_video_frames,
VideoLoaderConfig,
MegadetectorLiteYoloXConfig,
Expand Down Expand Up @@ -380,26 +384,33 @@ def test_load_video_frames(case: Case, video_metadata: Dict[str, Any]):
assert video_shape[field] == value


def test_same_filename_new_kwargs():
def test_same_filename_new_kwargs(tmp_path):
"""Test that load_video_frames does not load the npz file if the params change."""
# use first test video
test_vid = [f for f in TEST_VIDEOS_DIR.rglob("*") if f.is_file()][0]
cache = tmp_path / "test_cache"

first_load = load_video_frames(test_vid, fps=1)
new_params_same_name = load_video_frames(test_vid, fps=2)
assert first_load != new_params_same_name
with mock.patch.dict(os.environ, {"VIDEO_CACHE_DIR": str(cache)}):
# confirm cache is set in environment variable
assert os.environ["VIDEO_CACHE_DIR"] == str(cache)

# check no params
first_load = load_video_frames(test_vid)
assert first_load != new_params_same_name
first_load = cached_load_video_frames(filepath=test_vid, config=VideoLoaderConfig(fps=1))
new_params_same_name = cached_load_video_frames(
filepath=test_vid, config=VideoLoaderConfig(fps=2)
)
assert first_load != new_params_same_name

# check no params
first_load = cached_load_video_frames(filepath=test_vid)
assert first_load != new_params_same_name

# multiple params in config
c1 = VideoLoaderConfig(scene_threshold=0.2)
c2 = VideoLoaderConfig(scene_threshold=0.2, crop_bottom_pixels=2)
# multiple params in config
c1 = VideoLoaderConfig(scene_threshold=0.2)
c2 = VideoLoaderConfig(scene_threshold=0.2, crop_bottom_pixels=2)

first_load = load_video_frames(filepath=test_vid, config=c1)
new_params_same_name = load_video_frames(filepath=test_vid, config=c2)
assert first_load != new_params_same_name
first_load = cached_load_video_frames(filepath=test_vid, config=c1)
new_params_same_name = cached_load_video_frames(filepath=test_vid, config=c2)
assert first_load != new_params_same_name


def test_megadetector_lite_yolox_dog(tmp_path):
Expand Down Expand Up @@ -493,3 +504,56 @@ def test_validate_total_frames():
megadetector_lite_config=MegadetectorLiteYoloXConfig(confidence=0.01, n_frames=8),
)
assert config.total_frames == 8


def test_caching(tmp_path, caplog):
cache = tmp_path / "video_cache"
test_vid = [f for f in TEST_VIDEOS_DIR.rglob("*") if f.is_file()][0]

# no caching by default
_ = cached_load_video_frames(filepath=test_vid, config=VideoLoaderConfig(fps=1))
assert not cache.exists()

# caching can be specifed in config
_ = cached_load_video_frames(
filepath=test_vid, config=VideoLoaderConfig(fps=1, cache_dir=cache)
)
# one file in cache
assert len([f for f in cache.rglob("*") if f.is_file()]) == 1
shutil.rmtree(cache)

# or caching can be specified in environment variable
with mock.patch.dict(os.environ, {"VIDEO_CACHE_DIR": str(cache)}):
_ = cached_load_video_frames(filepath=test_vid)
assert len([f for f in cache.rglob("*") if f.is_file()]) == 1

# changing cleanup in config does not prompt new hashing of videos
with mock.patch.dict(os.environ, {"LOG_LEVEL": "DEBUG"}):
_ = cached_load_video_frames(
filepath=test_vid, config=VideoLoaderConfig(cleanup_cache=True)
)
assert "Loading from cache" in caplog.text

# if no config is passed, this is equivalent to specifying None/False in all non-cache related VLC params
no_config = cached_load_video_frames(filepath=test_vid, config=None)
config_with_nones = cached_load_video_frames(
filepath=test_vid,
config=VideoLoaderConfig(
crop_bottom_pixels=None,
i_frames=False,
scene_threshold=None,
megadetector_lite_config=None,
frame_selection_height=None,
frame_selection_width=None,
total_frames=None,
ensure_total_frames=False,
fps=None,
early_bias=False,
frame_indices=None,
evenly_sample_total_frames=False,
pix_fmt="rgb24",
model_input_height=None,
model_input_width=None,
),
)
assert np.all(no_config == config_with_nones)
9 changes: 9 additions & 0 deletions zamba/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
import os
import sys

from loguru import logger

from zamba.version import __version__

__version__

logger.remove()
log_level = os.getenv("LOG_LEVEL", "INFO")
logger.add(sys.stderr, level=log_level)
18 changes: 0 additions & 18 deletions zamba/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ def train(
weight_download_region: RegionEnum = typer.Option(
None, help="Server region for downloading weights."
),
cache_dir: Path = typer.Option(
None,
exists=False,
help="Path to directory for model weights. Alternatively, specify with environment variable `ZAMBA_CACHE_DIR`. If not specified, user's cache directory is used.",
),
skip_load_validation: bool = typer.Option(
None,
help="Skip check that verifies all videos can be loaded prior to training. Only use if you're very confident all your videos can be loaded.",
Expand Down Expand Up @@ -127,9 +122,6 @@ def train(
if weight_download_region is not None:
train_dict["weight_download_region"] = weight_download_region

if cache_dir is not None:
train_dict["cache_dir"] = cache_dir

if skip_load_validation is not None:
train_dict["skip_load_validation"] = skip_load_validation

Expand Down Expand Up @@ -171,7 +163,6 @@ def train(
GPUs: {config.train_config.gpus}
Dry run: {config.train_config.dry_run}
Save directory: {config.train_config.save_directory}
Cache directory: {config.train_config.cache_dir}
"""

if yes:
Expand Down Expand Up @@ -241,11 +232,6 @@ def predict(
weight_download_region: RegionEnum = typer.Option(
None, help="Server region for downloading weights."
),
cache_dir: Path = typer.Option(
None,
exists=False,
help="Path to directory for model weights. Alternatively, specify with environment variable `ZAMBA_CACHE_DIR`. If not specified, user's cache directory is used.",
),
skip_load_validation: bool = typer.Option(
None,
help="Skip check that verifies all videos can be loaded prior to inference. Only use if you're very confident all your videos can be loaded.",
Expand Down Expand Up @@ -322,9 +308,6 @@ def predict(
if weight_download_region is not None:
predict_dict["weight_download_region"] = weight_download_region

if cache_dir is not None:
predict_dict["cache_dir"] = cache_dir

if skip_load_validation is not None:
predict_dict["skip_load_validation"] = skip_load_validation

Expand Down Expand Up @@ -356,7 +339,6 @@ def predict(
Proba threshold: {config.predict_config.proba_threshold}
Output class names: {config.predict_config.output_class_names}
Weight download region: {config.predict_config.weight_download_region}
Cache directory: {config.predict_config.cache_dir}
"""

if yes:
Expand Down