Skip to content

Commit

Permalink
Added more tests for io functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mbsantiago committed Jul 16, 2023
1 parent ee83992 commit c79a52d
Show file tree
Hide file tree
Showing 12 changed files with 1,025 additions and 80 deletions.
1 change: 1 addition & 0 deletions docs/examples/nips4b_plus_aoef.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/soundevent/data/annotation_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class AnnotationProject(BaseModel):
name: str
"""Name of the annotation project."""

description: Optional[str] = Field(None, repr=False)
description: Optional[str] = Field(default=None, repr=False)
"""Description of the annotation collection."""

tasks: List[AnnotationTask] = Field(default_factory=list, repr=False)
Expand Down
59 changes: 36 additions & 23 deletions src/soundevent/data/geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

from pydantic import BaseModel, Field, PrivateAttr, field_validator
from shapely import geometry
from shapely.geometry.base import BaseGeometry
import shapely.geometry.base as shapely

if sys.version_info >= (3, 8):
from typing import Literal
Expand Down Expand Up @@ -103,7 +103,7 @@
"""The absolute maximum frequency that can be used in a geometry."""


class Geometry(BaseModel, ABC):
class BaseGeometry(BaseModel, ABC):
"""Base class for geometries.
Notes
Expand All @@ -124,7 +124,7 @@ class Geometry(BaseModel, ABC):
include=True,
)

_geom: BaseGeometry = PrivateAttr()
_geom: shapely.BaseGeometry = PrivateAttr()
"""The Shapely geometry object representing the geometry."""

@classmethod
Expand All @@ -140,7 +140,7 @@ def geom_type(cls) -> str:
return type_field.default

@property
def geom(self) -> BaseGeometry:
def geom(self) -> shapely.BaseGeometry:
"""Get the Shapely geometry object representing the geometry.
Returns
Expand All @@ -166,7 +166,7 @@ def __init__(self, **data):
self._geom = self._get_shapely_geometry()

@abstractmethod
def _get_shapely_geometry(self) -> BaseGeometry:
def _get_shapely_geometry(self) -> shapely.BaseGeometry:
"""Get the Shapely geometry object representing the geometry.
Returns
Expand All @@ -176,7 +176,7 @@ def _get_shapely_geometry(self) -> BaseGeometry:
raise NotImplementedError


class TimeStamp(Geometry):
class TimeStamp(BaseGeometry):
"""TimeStamp geometry type.
This geometry type is used to locate a sound event with a single time stamp.
Expand All @@ -195,7 +195,7 @@ class TimeStamp(Geometry):
The time stamp is relative to the start of the recording.
"""

def _get_shapely_geometry(self) -> BaseGeometry:
def _get_shapely_geometry(self) -> shapely.BaseGeometry:
"""Get the Shapely geometry object representing the geometry.
Returns
Expand All @@ -210,7 +210,7 @@ def _get_shapely_geometry(self) -> BaseGeometry:
)


class TimeInterval(Geometry):
class TimeInterval(BaseGeometry):
"""TimeInterval geometry type.
This geometry type is used to locate a sound event with a time interval.
Expand Down Expand Up @@ -250,7 +250,7 @@ def validate_time_interval(cls, v: Tuple[Time, Time]) -> Tuple[Time, Time]:
raise ValueError("The start time must be before the end time.")
return v

def _get_shapely_geometry(self) -> BaseGeometry:
def _get_shapely_geometry(self) -> shapely.BaseGeometry:
"""Get the Shapely geometry object representing the sound event.
Returns
Expand All @@ -261,7 +261,7 @@ def _get_shapely_geometry(self) -> BaseGeometry:
return geometry.box(start_time, 0, end_time, MAX_FREQUENCY)


class Point(Geometry):
class Point(BaseGeometry):
"""Point geometry type.
This geometry type is used to locate a sound event with a single point in
Expand All @@ -280,7 +280,7 @@ class Point(Geometry):
"""

def _get_shapely_geometry(self) -> BaseGeometry:
def _get_shapely_geometry(self) -> shapely.BaseGeometry:
"""Get the Shapely geometry object representing the geometry.
Returns
Expand All @@ -290,7 +290,7 @@ def _get_shapely_geometry(self) -> BaseGeometry:
return geometry.Point(self.coordinates)


class LineString(Geometry):
class LineString(BaseGeometry):
"""LineString geometry type.
This geometry type is used to locate a sound event with a line in time and
Expand All @@ -310,7 +310,7 @@ class LineString(Geometry):
All times are relative to the start of the recording."""

def _get_shapely_geometry(self) -> BaseGeometry:
def _get_shapely_geometry(self) -> shapely.BaseGeometry:
"""Get the Shapely geometry object representing the geometry.
Returns
Expand Down Expand Up @@ -338,7 +338,7 @@ def is_ordered_by_time(
return v


class Polygon(Geometry):
class Polygon(BaseGeometry):
"""Polygon geometry type.
This geometry type is used to locate a sound event with a polygon in time
Expand All @@ -365,7 +365,7 @@ def has_at_least_one_ring(
raise ValueError("The polygon must have at least one ring.")
return v

def _get_shapely_geometry(self) -> BaseGeometry:
def _get_shapely_geometry(self) -> shapely.BaseGeometry:
"""Get the Shapely geometry object representing the geometry.
Returns
Expand All @@ -377,7 +377,7 @@ def _get_shapely_geometry(self) -> BaseGeometry:
return geometry.Polygon(shell, holes)


class BoundingBox(Geometry):
class BoundingBox(BaseGeometry):
"""BoundingBox geometry type.
This geometry type is used to locate a sound event with a bounding box in
Expand Down Expand Up @@ -413,7 +413,7 @@ def validate_bounding_box(

return v

def _get_shapely_geometry(self) -> BaseGeometry:
def _get_shapely_geometry(self) -> shapely.BaseGeometry:
"""Get the Shapely geometry object representing the geometry."""
start_time, start_frequency, end_time, end_frequency = self.coordinates
return geometry.box(
Expand All @@ -424,7 +424,7 @@ def _get_shapely_geometry(self) -> BaseGeometry:
)


class MultiPoint(Geometry):
class MultiPoint(BaseGeometry):
"""MultiPoint geometry type.
This geometry type is used to locate a sound event with multiple points in
Expand All @@ -442,7 +442,7 @@ class MultiPoint(Geometry):
All times are relative to the start of the recording."""

def _get_shapely_geometry(self) -> BaseGeometry:
def _get_shapely_geometry(self) -> shapely.BaseGeometry:
"""Get the Shapely geometry object representing the geometry.
Returns
Expand All @@ -452,7 +452,7 @@ def _get_shapely_geometry(self) -> BaseGeometry:
return geometry.MultiPoint(self.coordinates)


class MultiLineString(Geometry):
class MultiLineString(BaseGeometry):
"""MultiLineString geometry type.
This geometry type is used to locate a sound event with multiple lines in
Expand All @@ -471,7 +471,7 @@ class MultiLineString(Geometry):
All times are relative to the start of the recording."""

def _get_shapely_geometry(self) -> BaseGeometry:
def _get_shapely_geometry(self) -> shapely.BaseGeometry:
"""Get the Shapely geometry object representing the geometry.
Returns
Expand Down Expand Up @@ -511,7 +511,7 @@ def each_line_is_ordered_by_time(
return v


class MultiPolygon(Geometry):
class MultiPolygon(BaseGeometry):
"""MultiPolygon geometry type.
This geometry type is used to locate a sound event with multiple polygons in
Expand Down Expand Up @@ -552,7 +552,7 @@ def each_polygon_has_at_least_one_ring(
raise ValueError("Each polygon must have at least one ring.")
return v

def _get_shapely_geometry(self) -> BaseGeometry:
def _get_shapely_geometry(self) -> shapely.BaseGeometry:
"""Get the Shapely geometry object representing the geometry.
Returns
Expand All @@ -566,3 +566,16 @@ def _get_shapely_geometry(self) -> BaseGeometry:
polygon = geometry.Polygon(shell, holes)
polgons.append(polygon)
return geometry.MultiPolygon(polgons)


Geometry = Union[
TimeStamp,
TimeInterval,
Point,
LineString,
Polygon,
BoundingBox,
MultiPoint,
MultiLineString,
MultiPolygon,
]
4 changes: 4 additions & 0 deletions src/soundevent/data/predicted_sound_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,7 @@ class PredictedSoundEvent(BaseModel):

features: List[Feature] = Field(default_factory=list)
"""List of features associated with the prediction."""

def __hash__(self) -> int:
"""Return hash value of the predicted sound event."""
return hash(self.id)
28 changes: 17 additions & 11 deletions src/soundevent/io/annotation_projects.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Save and loading functions for annotation projects."""
from pathlib import Path
from typing import Dict
from typing import Dict, Optional

from soundevent import data
from soundevent.io.formats import aoef, infer_format
Expand All @@ -12,7 +12,7 @@

def load_annotation_project(
path: PathLike,
audio_dir: PathLike = ".",
audio_dir: Optional[PathLike] = None,
) -> data.AnnotationProject:
"""Load annotation project from path.
Expand All @@ -22,9 +22,9 @@ def load_annotation_project(
Path to the file with the annotation project.
audio_dir: PathLike, optional
Path to the directory containing the audio files, by default ".". The
Path to the directory containing the audio files. If provided, the
audio file paths in the annotation project will be relative to this
directory.
directory. By default None.
Returns
-------
Expand Down Expand Up @@ -56,7 +56,7 @@ def load_annotation_project(
def save_annotation_project(
project: data.AnnotationProject,
path: PathLike,
audio_dir: PathLike = ".",
audio_dir: Optional[PathLike] = None,
format: str = "aoef",
) -> None:
"""Save annotation project to path.
Expand All @@ -70,9 +70,9 @@ def save_annotation_project(
Path to save annotation project to.
audio_dir: PathLike, optional
Path to the directory containing the audio files, by default ".". The
Path to the directory containing the audio files. If provided, the
audio file paths in the annotation project will be relative to this
directory.
directory. By default None.
format: str, optional
Format to save the annotation project in, by default "aoef".
Expand All @@ -96,11 +96,14 @@ def save_annotation_project(
def save_annotation_project_in_aoef_format(
obj: data.AnnotationProject,
path: PathLike,
audio_dir: PathLike = ".",
audio_dir: Optional[PathLike] = None,
) -> None:
"""Save annotation project to path in AOEF format."""
path = Path(path)
audio_dir = Path(audio_dir).resolve()

if audio_dir is not None:
audio_dir = Path(audio_dir).resolve()

annotation_project_object = (
aoef.AnnotationProjectObject.from_annotation_project(
obj,
Expand All @@ -117,11 +120,14 @@ def save_annotation_project_in_aoef_format(

def load_annotation_project_in_aoef_format(
path: PathLike,
audio_dir: PathLike = ".",
audio_dir: Optional[PathLike] = None,
) -> data.AnnotationProject:
"""Load annotation project from path in AOEF format."""
path = Path(path)
audio_dir = Path(audio_dir).resolve()

if audio_dir is not None:
audio_dir = Path(audio_dir).resolve()

annotation_project_object = (
aoef.AnnotationProjectObject.model_validate_json(path.read_text())
)
Expand Down
Loading

0 comments on commit c79a52d

Please sign in to comment.