Skip to content

Commit

Permalink
Refactor AudioTfl class to accept the number of detection threads as …
Browse files Browse the repository at this point in the history
…a parameter in the constructor, and update the usage of the num_threads attribute accordingly
  • Loading branch information
skrashevich committed Aug 27, 2023
1 parent 7c629c1 commit de62112
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
1 change: 1 addition & 0 deletions frigate/config.py
Expand Up @@ -444,6 +444,7 @@ class AudioConfig(FrigateBaseModel):
enabled_in_config: Optional[bool] = Field(
title="Keep track of original state of audio detection."
)
num_threads: int = Field(default=2, title="Number of detection threads", ge=1)


class BirdseyeModeEnum(str, Enum):
Expand Down
11 changes: 6 additions & 5 deletions frigate/events/audio.py
Expand Up @@ -89,12 +89,13 @@ def receiveSignal(signalNumber: int, frame: Optional[FrameType]) -> None:


class AudioTfl:
def __init__(self, stop_event: mp.Event):
def __init__(self, stop_event: mp.Event, num_threads=2):
self.stop_event = stop_event
self.labels = load_labels("/audio-labelmap.txt")
self.num_threads = num_threads
self.labels = load_labels("/audio-labelmap.txt", prefill=521)
self.interpreter = Interpreter(
model_path="/cpu_audio_model.tflite",
num_threads=2,
num_threads=self.num_threads,
)

self.interpreter.allocate_tensors()
Expand All @@ -117,7 +118,7 @@ def _detect_raw(self, tensor_input):
count = len(scores)

for i in range(count):
if scores[i] < 0.4 or i == 20:
if scores[i] < AUDIO_MIN_CONFIDENCE or i == 20:
break
detections[i] = [
class_ids[i],
Expand Down Expand Up @@ -164,7 +165,7 @@ def __init__(
self.inter_process_communicator = inter_process_communicator
self.detections: dict[dict[str, any]] = feature_metrics
self.stop_event = stop_event
self.detector = AudioTfl(stop_event)
self.detector = AudioTfl(stop_event, self.config.audio.num_threads)
self.shape = (int(round(AUDIO_DURATION * AUDIO_SAMPLE_RATE)),)
self.chunk_size = int(round(AUDIO_DURATION * AUDIO_SAMPLE_RATE * 2))
self.logger = logging.getLogger(f"audio.{self.config.name}")
Expand Down
4 changes: 2 additions & 2 deletions frigate/util/builtin.py
Expand Up @@ -134,7 +134,7 @@ def get_ffmpeg_arg_list(arg: Any) -> list:
return arg if isinstance(arg, list) else shlex.split(arg)


def load_labels(path, encoding="utf-8"):
def load_labels(path, encoding="utf-8", prefill=91):
"""Loads labels from file (with or without index numbers).
Args:
path: path to label file.
Expand All @@ -143,7 +143,7 @@ def load_labels(path, encoding="utf-8"):
Dictionary mapping indices to labels.
"""
with open(path, "r", encoding=encoding) as f:
labels = {index: "unknown" for index in range(91)}
labels = {index: "unknown" for index in range(prefill)}
lines = f.readlines()
if not lines:
return {}
Expand Down

0 comments on commit de62112

Please sign in to comment.