Skip to content

Commit

Permalink
adding pandas dataframe writer, also works with multiprocessing if ma…
Browse files Browse the repository at this point in the history
…tplotlib is not loaded
  • Loading branch information
faroit committed Jul 28, 2016
1 parent d338316 commit 0a5283e
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 50 deletions.
21 changes: 14 additions & 7 deletions dsdtools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,16 @@ def _process_function(self, track, user_function, estimates_dir, evaluate):
else:
# call the user provided function
user_results = user_function(track)

if estimates_dir and not evaluate and user_function is not None:
self._save_estimates(user_results, track, estimates_dir)
if evaluate:
self.evaluator.evaluate_track(

if evaluate and user_results:
return self.evaluator.evaluate_track(
track, user_results, estimates_dir
)
else:
return None

def _save_estimates(self, user_estimates, track, estimates_dir):
track_estimate_dir = op.join(
Expand Down Expand Up @@ -343,12 +347,15 @@ def evaluate(
estimates_dirs = [estimates_dirs]

for estimates_dir in estimates_dirs:
self.run(
results = self.run(
user_function=None,
estimates_dir=estimates_dir,
evaluate=True,
*args, **kwargs
)
for result in results:
if result is not None:
self.evaluator.data.append(result)

def run(
self,
Expand Down Expand Up @@ -402,12 +409,12 @@ def run(
pass

# list of tracks to be processed
# TODO: adjust the number of tracks if not all estimates are provided
tracks = self.load_dsd_tracks(subsets=subsets, ids=ids)

success = False
if parallel:
pool = multiprocessing.Pool(cpus, initializer=init_worker)
success = list(
results = list(
tqdm.tqdm(
pool.imap_unordered(
func=functools.partial(
Expand All @@ -428,7 +435,7 @@ def run(
pool.join()

else:
success = list(
results = list(
tqdm.tqdm(
map(
lambda x: self._process_function(
Expand All @@ -442,7 +449,7 @@ def run(
total=len(tracks)
)
)
return success
return results


def process_function_alias(obj, *args, **kwargs):
Expand Down
84 changes: 64 additions & 20 deletions dsdtools/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
from __future__ import print_function
import numpy as np
import mir_eval
from . import utils
import pandas as pd


class DF_writer(object):
def __init__(self, columns):
self.df = pd.DataFrame(columns=columns)
self.columns = columns

def row2series(self, **row_data):
if set(self.columns) == set(row_data):
return pd.Series(row_data)

def append(self, series):
self.df = self.df.append(series, ignore_index=True)

def to_pickle(self, filename):
self.df.to_pickle(filename)


class BSSeval(object):
def __init__(
self,
collect=True,
window=30*44100,
hop=15*44100,
):
self.data = utils.DF_writer([
self.data = DF_writer([
'track_id',
'track_name',
'target_name',
Expand All @@ -22,13 +37,38 @@ def __init__(
'SAR',
'sample'
])

self.window = window
self.hop = hop

def evaluate_track(self, track, user_estimates, estimates_dir=None):
# def plot_results(self, measures=['SDR', 'ISR', 'SIR', 'SAR']):
# figure, ax = plt.subplots(1, len(measures))
# for i, measure in enumerate(measures):
# sns.boxplot(
# "target_name",
# measure,
# hue='estimate_dir',
# data=self.data.df,
# showmeans=True,
# showfliers=False,
# palette=sns.color_palette('muted'),
# ax=ax[i],
# meanline=True,
# )
# return figure

def evaluate_track(
self,
track,
user_estimates,
estimates_dir=None,
verbose=False
):
audio_estimates = []
audio_reference = []

rows = []

# make sure to always build the list in the same order
# therefore track.targets is an OrderedDict
targets = [] # save the list of targets to be evaluated
Expand All @@ -49,15 +89,23 @@ def evaluate_track(self, track, user_estimates, estimates_dir=None):
if audio_estimates and audio_reference:
audio_estimates = np.array(audio_estimates)
audio_reference = np.array(audio_reference)
if audio_estimates.shape == audio_reference.shape:
SDR, ISR, SIR, SAR = self.evaluate(
audio_estimates, audio_reference
)
# iterate over all targets
for i, target in enumerate(targets):
# iterate over all frames
for k in range(len(SDR[i])):
self.data.append(

SDR, ISR, SIR, SAR = self.evaluate(
audio_estimates, audio_reference
)
# iterate over all targets
for i, target in enumerate(targets):
# iterate over all frames
if verbose:
print(target)
print("SDR: ", str(SDR[i]))
print("ISR: ", str(ISR[i]))
print("SIR: ", str(SIR[i]))
print("SAR: ", str(SAR[i]))

for k in range(len(SDR[i])):
rows.append(
self.data.row2series(
track_id=int(track.id),
track_name=track.filename,
target_name=target.name,
Expand All @@ -68,8 +116,10 @@ def evaluate_track(self, track, user_estimates, estimates_dir=None):
SAR=SAR[i, k],
sample=k * self.hop
)
)
return rows

def evaluate(self, estimates, references, verbose=True):
def evaluate(self, estimates, references):
"""BSS_EVAL images evaluation using mir_eval.separation module
Parameters
Expand Down Expand Up @@ -99,10 +149,4 @@ def evaluate(self, estimates, references, verbose=True):
hop=self.hop,
)

if verbose:
print("SDR: ", str(sdr))
print("ISR: ", str(isr))
print("SIR: ", str(sir))
print("SAR: ", str(sar))

return sdr, isr, sir, sar
16 changes: 0 additions & 16 deletions dsdtools/utils.py

This file was deleted.

21 changes: 16 additions & 5 deletions examples/run_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from __future__ import print_function
import dsdtools
import argparse

if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Source Separation based on Common Fate Model')

# initiate dsdtools
dsd = dsdtools.DB(evaluation=True)
parser.add_argument(
'estimate_dirs',
type=str,
nargs='+',
help='Estimate folders'
)

dsd.evaluate(estimates_dirs='./Estimates')
args = parser.parse_args()

print(dsd.evaluator.df.df)
import ipdb; ipdb.set_trace()
# initiate dsdtools
dsd = dsdtools.DB(evaluation=True)

dsd.evaluate(estimates_dirs=args.estimate_dirs, parallel=True)
dsd.evaluator.data.to_pickle("results.pandas")
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ sphinx_rtd_theme
numpydoc
git+https://github.com/faroit/mir_eval.git@bss_eval_images_framewise
pandas
seaborn
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
name='dsdtools',

# Version
version="0.1.3",
version="0.2.0",

# Description
description='Python tools for the Demixing Secrets Dataset (DSD)',
Expand All @@ -32,7 +32,8 @@
'pyaml',
'PySoundFile>=0.8',
'mir_eval',
'pandas'
'pandas',
'seaborn'
],

package_data={
Expand Down

0 comments on commit 0a5283e

Please sign in to comment.