Skip to content

Commit

Permalink
Created load and save model run functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mbsantiago committed Jul 15, 2023
1 parent 0508765 commit a4cd550
Show file tree
Hide file tree
Showing 8 changed files with 690 additions and 149 deletions.
8 changes: 8 additions & 0 deletions src/soundevent/data/processed_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
into the predicted events and aid in subsequent analysis and interpretation.
"""
from typing import List
from uuid import UUID, uuid4

from pydantic import BaseModel, Field

Expand All @@ -47,6 +48,9 @@
class ProcessedClip(BaseModel):
"""Processed clip."""

uuid: UUID = Field(default_factory=uuid4, repr=False)
"""Unique identifier for the processed clip."""

clip: Clip
"""The clip that was processed."""

Expand All @@ -58,3 +62,7 @@ class ProcessedClip(BaseModel):

features: List[Feature] = Field(default_factory=list)
"""List of features associated with the clip."""

def __hash__(self):
"""Hash function for the processed clip."""
return hash(self.uuid)
9 changes: 9 additions & 0 deletions src/soundevent/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,18 @@
sound event data.
"""

from soundevent.io.annotation_projects import (
load_annotation_project,
save_annotation_project,
)
from soundevent.io.datasets import load_dataset, save_dataset
from soundevent.io.model_runs import load_model_run, save_model_run

__all__ = [
"load_annotation_project",
"load_dataset",
"load_model_run",
"save_annotation_project",
"save_dataset",
"save_model_run",
]
123 changes: 61 additions & 62 deletions src/soundevent/io/annotation_projects.py
Original file line number Diff line number Diff line change
@@ -1,109 +1,109 @@
"""Save and loading functions for annotation projects."""

import os
import sys
from pathlib import Path
from typing import Callable, Dict, Union
from typing import Dict

from soundevent import data
from soundevent.io.format import AnnotationProjectObject, is_json

if sys.version_info < (3, 8):
from typing_extensions import Protocol
else:
from typing import Protocol


PathLike = Union[str, os.PathLike]


class Saver(Protocol):
"""Protocol for saving annotation projects."""
from soundevent.io.formats import aoef, infer_format
from soundevent.io.types import Loader, PathLike, Saver

def __call__(
self,
project: data.AnnotationProject,
path: PathLike,
audio_dir: PathLike = ".",
) -> None:
"""Save annotation project to path."""
...
SAVE_FORMATS: Dict[str, Saver[data.AnnotationProject]] = {}
LOAD_FORMATS: Dict[str, Loader[data.AnnotationProject]] = {}


class Loader(Protocol):
"""Protocol for loading annotation projects."""
def load_annotation_project(
path: PathLike,
audio_dir: PathLike = ".",
) -> data.AnnotationProject:
"""Load annotation project from path.
def __call__(
self, path: PathLike, audio_dir: PathLike = "."
) -> data.AnnotationProject:
"""Load annotation project from path."""
...
Parameters
----------
path: PathLike
Path to the file with the annotation project.
audio_dir: PathLike, optional
Path to the directory containing the audio files, by default ".". The
audio file paths in the annotation project will be relative to this
directory.
SAVE_FORMATS: Dict[str, Saver] = {}
LOAD_FORMATS: Dict[str, Loader] = {}
FORMATS: Dict[str, Callable[[PathLike], bool]] = {}
Returns
-------
annotation_project: data.AnnotationProject
The loaded annotation project.
Raises
------
FileNotFoundError
If the path does not exist.
def load_annotation_project(path: PathLike) -> data.AnnotationProject:
"""Load annotation project from path."""
NotImplementedError
If the format of the file is not supported.
"""
path = Path(path)

if not path.exists():
raise FileNotFoundError(f"Path {path} does not exist.")

for format_name, is_format in FORMATS.items():
if not is_format(path):
continue

return LOAD_FORMATS[format_name](path)
try:
format_ = infer_format(path)
except ValueError as e:
raise NotImplementedError(f"File {path} format not supported.") from e

raise NotImplementedError(
f"Could not find a loader for {path}. "
f"Supported formats are: {list(FORMATS.keys())}"
)
loader = LOAD_FORMATS[format_]
return loader(path, audio_dir=audio_dir)


def save_annotation_project(
project: data.AnnotationProject,
path: PathLike,
audio_dir: PathLike = ".",
format: str = "aoef",
) -> None:
"""Save annotation project to path.
Parameters
----------
project: data.AnnotationProject
Annotation project to save.
path: PathLike
Path to save annotation project to.
audio_dir: PathLike, optional
Path to the directory containing the audio files, by default ".". The
audio file paths in the annotation project will be relative to this
directory.
format: str, optional
Format to save the annotation project in, by default "aoef".
Raises
------
NotImplementedError
If the format is not supported.
"""
path = Path(path)

for format_name, is_format in FORMATS.items():
if not is_format(path):
continue

SAVE_FORMATS[format_name](project, path, audio_dir=audio_dir)
return
try:
saver = SAVE_FORMATS[format]
except KeyError as e:
raise NotImplementedError(f"Format {format} not supported.") from e

raise NotImplementedError(
f"Could not find a saver for {path}. "
f"Supported formats are: {list(FORMATS.keys())}"
)
saver(project, path, audio_dir=audio_dir)


def save_annotation_project_in_aoef_format(
project: data.AnnotationProject,
obj: data.AnnotationProject,
path: PathLike,
audio_dir: PathLike = ".",
) -> None:
"""Save annotation project to path in AOEF format."""
path = Path(path)
audio_dir = Path(audio_dir).resolve()
annotation_project_object = (
AnnotationProjectObject.from_annotation_project(
project,
aoef.AnnotationProjectObject.from_annotation_project(
obj,
audio_dir=audio_dir,
)
)
Expand All @@ -122,12 +122,11 @@ def load_annotation_project_in_aoef_format(
"""Load annotation project from path in AOEF format."""
path = Path(path)
audio_dir = Path(audio_dir).resolve()
annotation_project_object = AnnotationProjectObject.model_validate_json(
path.read_text()
annotation_project_object = (
aoef.AnnotationProjectObject.model_validate_json(path.read_text())
)
return annotation_project_object.to_annotation_project(audio_dir=audio_dir)


SAVE_FORMATS["aoef"] = save_annotation_project_in_aoef_format
LOAD_FORMATS["aoef"] = load_annotation_project_in_aoef_format
FORMATS["aoef"] = is_json
Loading

0 comments on commit a4cd550

Please sign in to comment.