Skip to content

Commit

Permalink
Implement and use load_valued_intervals
Browse files Browse the repository at this point in the history
  • Loading branch information
justinsalamon committed Feb 22, 2016
1 parent d4aedfe commit 4e63635
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 61 deletions.
38 changes: 38 additions & 0 deletions mir_eval/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,41 @@ def load_wav(path, mono=True):
if mono and audio_data.ndim != 1:
audio_data = audio_data.mean(axis=1)
return audio_data, fs


def load_valued_intervals(filename, delimiter=r'\s+'):
r"""Import valued intervals from an annotation file. The file should consist
of three columns: Two consisting of numeric values corresponding to start
and end time of each interval and a third, also of numeric values,
corresponding to the value of each interval. This is primarily useful for
processing events which span a duration and have a numeric value, such as
piano-roll notes which have an onset, offset, and a pitch value.
Parameters
----------
filename : str
Path to the annotation file
delimiter : str
Separator regular expression.
By default, lines will be split by any amount of whitespace.
Returns
-------
intervals : np.ndarray, shape=(n_events, 2)
array of event start and end time
values : list of float
list of values
"""
# Use our universal function to load in the events
starts, ends, values = load_delimited(filename, [float, float, float],
delimiter)
# Stack into an interval matrix
intervals = np.array([starts, ends]).T
# Validate them, but throw a warning in place of an error
try:
util.validate_intervals(intervals)
except ValueError as error:
warnings.warn(error.args[0])

return intervals, values
95 changes: 34 additions & 61 deletions mir_eval/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,9 @@
import warnings


def validate(ref_onsets, ref_offsets, ref_pitches, est_onsets, est_offsets,
est_pitches):
"""Checks that the input annotations to a metric look like note onsets,
offsets and pitch arrays, and throws helpful errors if not.
def validate(ref_intervals, ref_pitches, est_intervals, est_pitches):
"""Checks that the input annotations to a metric look like time intervals
and a pitch list, and throws helpful errors if not.
Parameters
----------
Expand All @@ -79,43 +78,21 @@ def validate(ref_onsets, ref_offsets, ref_pitches, est_onsets, est_offsets,
est_pitches: np.ndarray
estimated note pitch values, in Hertz
"""
# If reference or estimated beats are empty, warn
if ref_onsets.size == 0:
warnings.warn("Reference note onsets are empty.")
if ref_offsets.size == 0:
warnings.warn("Reference note offsets are empty.")
if ref_pitches.size == 0:
# If reference or estimated notes are empty, warn
if ref_intervals.size == 0:
warnings.warn("Reference note intervals are empty.")
if len(ref_pitches) == 0:
warnings.warn("Reference note pitches are empty.")
if est_onsets.size == 0:
warnings.warn("Estimated note onsets are empty.")
if est_offsets.size == 0:
warnings.warn("Estimated note offsets are empty.")
if est_pitches.size == 0:
warnings.warn("Estimated note pitches are empty.")

# Make sure all three arrays of each transcription match in length
if not (len(ref_onsets)==len(ref_offsets) and
len(ref_onsets)==len(ref_pitches)):
warnings.warn("Reference arrays have different lengths.")
if not (len(est_onsets)==len(est_offsets) and
len(est_onsets)==len(est_pitches)):
warnings.warn("Estimate arrays have different lengths.")

# Check for notes with negative duration
negative_duration_note = False
for onset, offset in zip(ref_onsets, ref_offsets):
if offset < onset:
negative_duration_note = True
if negative_duration_note:
warnings.warn("Reference contains at least one note with negative "
"duration")
negative_duration_note = False
for onset, offset in zip(est_onsets, est_offsets):
if offset < onset:
negative_duration_note = True
if negative_duration_note:
warnings.warn("Estimate contains at least one note with negative "
"duration")
if est_intervals.size == 0:
warnings.warn("Estimate note intervals are empty.")
if len(est_pitches) == 0:
warnings.warn("Estimate note pitches are empty.")

# Make sure intervals and pitches match in length
if not len(ref_intervals)==len(ref_pitches):
warnings.warn("Reference intervals and pitches have different lengths.")
if not len(est_intervals)==len(est_pitches):
warnings.warn("Estimate intervals and pitches have different lengths.")

# Make sure all pitch values are positive
if np.min(ref_pitches) <= 0:
Expand All @@ -126,20 +103,19 @@ def validate(ref_onsets, ref_offsets, ref_pitches, est_onsets, est_offsets,
"value")


def prf(ref_onsets, ref_offsets, ref_pitches, est_onsets, est_offsets,
est_pitches):
def prf(ref_intervals, ref_pitches, est_intervals, est_pitches):
"""Compute the Precision, Recall and F-measure of correct vs incorrectly
transcribed notes. "Correctness" is determined based on note onset, offset
and pitch as detailed at the top of this document.
Examples
--------
>>> ref_onsets, ref_offsets, ref_pitches = mir_eval.io.load_delimited(
... 'reference.txt', [float, float, float], '\t')
>>> est_onsets, est_offsets, est_pitches = mir_eval.io.load_delimited(
... 'estimated.txt', [float, float, float], '\t')
>>> precision, recall, f_measure = mir_eval.transcription.prf(ref_onests,
... ref_offsets, ref_pitches, est_onsets, est_offsets, est_pitches)
>>> ref_intervals, ref_pitches = mir_eval.io.load_valued_intervals(
... 'reference.txt')
>>> est_intervals, est_pitches = mir_eval.io.load_valued_intervals(
... 'estimated.txt')
>>> precision, recall, f_measure = mir_eval.transcription.prf(ref_intervals,
... ref_pitches, est_intervals, est_pitches)
Parameters (TODO)
----------
Expand All @@ -157,8 +133,7 @@ def prf(ref_onsets, ref_offsets, ref_pitches, est_onsets, est_offsets,
The computed F-measure score
"""
validate(ref_onsets, ref_offsets, ref_pitches, est_onsets, est_offsets,
est_pitches)
validate(ref_intervals, ref_pitches, est_intervals, est_pitches)
# # When estimated beats are empty, no beats are correct; metric is 0
# if ref_onsets.size == 0 or reference_beats.size == 0:
# return 0.
Expand All @@ -172,18 +147,17 @@ def prf(ref_onsets, ref_offsets, ref_pitches, est_onsets, est_offsets,
# return util.f_measure(precision, recall)


def evaluate(ref_onsets, ref_offsets, ref_pitches, est_onsets, est_offsets,
est_pitches, **kwargs):
def evaluate(ref_intervals, ref_pitches, est_intervals, est_pitches, **kwargs):
"""Compute all metrics for the given reference and estimated annotations.
Examples
--------
>>> ref_onsets, ref_offsets, ref_pitches = mir_eval.io.load_delimited(
... 'reference.txt', [float, float, float], '\t')
>>> est_onsets, est_offsets, est_pitches = mir_eval.io.load_delimited(
... 'estimate.txt', [float, float, float], '\t')
>>> scores = mir_eval.transcription.evaluate(ref_onsets, ref_offsets,
... ref_pitches, est_onsets, est_offsets, est_pitches)
>>> ref_intervals, ref_pitches = mir_eval.io.load_valued_intervals(
... 'reference.txt')
>>> est_intervals, est_pitches = mir_eval.io.load_valued_intervals(
... 'estimate.txt')
>>> scores = mir_eval.transcription.evaluate(ref_intervals, ref_pitches,
... est_intervals, est_pitches)
Parameters (TODO)
----------
Expand All @@ -210,9 +184,8 @@ def evaluate(ref_onsets, ref_offsets, ref_pitches, est_onsets, est_offsets,
# All metrics
(scores['Precision'],
scores['Recall'],
scores['F-measure']) = util.filter_kwargs(prf, ref_onsets, ref_offsets,
ref_pitches, est_onsets,
est_offsets, est_pitches,
scores['F-measure']) = util.filter_kwargs(prf, ref_intervals, ref_pitches,
est_intervals, est_pitches,
**kwargs)

return scores

0 comments on commit 4e63635

Please sign in to comment.