-
Notifications
You must be signed in to change notification settings - Fork 31
/
predict_new_vids.py
executable file
·161 lines (128 loc) · 5.87 KB
/
predict_new_vids.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""Run inference on a list of models and videos."""
import os
import hydra
import lightning.pytorch as pl
import numpy as np
from moviepy.editor import VideoFileClip
from omegaconf import DictConfig, OmegaConf
from typeguard import typechecked
from lightning_pose.utils.io import (
check_if_semi_supervised,
ckpt_path_from_base_path,
get_videos_in_dir,
return_absolute_data_paths,
return_absolute_path,
)
from lightning_pose.utils.predictions import load_model_from_checkpoint
from lightning_pose.utils.scripts import (
compute_metrics,
export_predictions_and_labeled_video,
get_data_module,
get_dataset,
get_imgaug_transform,
)
""" this script will get two imporant args. model to use and video folder to process.
hydra will orchestrate both. advanatages -- in the future we could parallelize to new machines.
no need to loop over models or folders. we do need to loop over videos within a folder.
however, keeping cfg.eval.hydra_paths is useful for the fiftyone image plotting. so keep"""
@typechecked
class VideoPredPathHandler:
"""class that defines filename for a predictions .csv file, given video file and
model specs.
"""
def __init__(
self, save_preds_dir: str, video_file: str, model_cfg: DictConfig
) -> None:
self.video_file = video_file
self.save_preds_dir = save_preds_dir
self.model_cfg = model_cfg
self.check_input_paths()
@property
def video_basename(self) -> str:
return os.path.basename(self.video_file).split(".")[0]
def check_input_paths(self) -> None:
assert os.path.isfile(self.video_file)
assert os.path.isdir(self.save_preds_dir)
def build_pred_file_basename(self, extra_str="") -> str:
return f"{self.video_basename}.csv"
def __call__(self, extra_str="") -> str:
pred_file_basename = self.build_pred_file_basename(extra_str=extra_str)
return os.path.join(self.save_preds_dir, pred_file_basename)
@hydra.main(config_path="configs", config_name="config_mirror-mouse-example")
def predict_videos_in_dir(cfg: DictConfig):
"""
This script will work with a path to a trained model's hydra folder
From that folder it'll read the info about the model, get the checkpoint, and
predict on a new vid
If you need to predict multiple folders (each with one or more videos), define a
--multirun and pass these directories as
cfg.eval.test_videos_directory='dir/1','dir/2'...
NOTE: by decorating with hydra, the current working directory will be become the new
folder os.path.join(os.getcwd(), "/outputs/YYYY-MM-DD/hour-info")
"""
# get pl trainer for prediction
trainer = pl.Trainer(accelerator="gpu", devices=1)
# loop over models
for i, hydra_relative_path in enumerate(cfg.eval.hydra_paths):
# cfg.eval.hydra_paths defines a list of relative paths to hydra folders
# "YYYY-MM-DD/HH-MM-SS", and we extract an absolute path below
# absolute_cfg_path will be the path of the trained model we're using for predictions
absolute_cfg_path = return_absolute_path(hydra_relative_path, n_dirs_back=2)
# load model
model_cfg = OmegaConf.load(os.path.join(absolute_cfg_path, ".hydra/config.yaml"))
ckpt_file = ckpt_path_from_base_path(
base_path=absolute_cfg_path, model_name=model_cfg.model.model_name
)
model = load_model_from_checkpoint(cfg=cfg, ckpt_file=ckpt_file, eval=True)
# load data module, which contains info about keypoint names, etc.
data_dir, video_dir = return_absolute_data_paths(data_cfg=cfg.data)
print("getting imgaug transform...")
imgaug_transform = get_imgaug_transform(cfg=cfg)
print("getting dataset...")
dataset = get_dataset(cfg=cfg, data_dir=data_dir, imgaug_transform=imgaug_transform)
print("getting data module...")
data_module = get_data_module(cfg=cfg, dataset=dataset, video_dir=video_dir)
# determine a directory in which to save video prediction csv files
if cfg.eval.saved_vid_preds_dir is None:
# save to where the videos are. may get an exception
save_preds_dir = cfg.eval.test_videos_directory
else:
save_preds_dir = return_absolute_path(cfg.eval.saved_vid_preds_dir, n_dirs_back=3)
# loop over videos in a provided directory
video_files = get_videos_in_dir(return_absolute_path(cfg.eval.test_videos_directory))
for video_file in video_files:
video_pred_path_handler = VideoPredPathHandler(
save_preds_dir=save_preds_dir,
video_file=video_file,
model_cfg=model_cfg,
)
prediction_csv_file = video_pred_path_handler()
if cfg.eval.get("save_vids_after_training", False):
labeled_mp4_file = prediction_csv_file.replace(".csv", "_labeled.mp4")
else:
labeled_mp4_file = None
# debug
print(f"\n\n{prediction_csv_file = }\n\n")
export_predictions_and_labeled_video(
video_file=video_file,
cfg=cfg,
ckpt_file=ckpt_file,
prediction_csv_file=prediction_csv_file,
labeled_mp4_file=labeled_mp4_file,
trainer=trainer,
model=model,
data_module=data_module,
save_heatmaps=cfg.eval.get("predict_vids_after_training_save_heatmaps", False),
)
# compute and save various metrics
try:
compute_metrics(
cfg=cfg,
preds_file=prediction_csv_file,
data_module=data_module,
)
except Exception as e:
print(f"Error predicting on video {video_file}:\n{e}")
continue
if __name__ == "__main__":
predict_videos_in_dir()