Skip to content

Commit

Permalink
Merge pull request #13 from bmcfee/tag-vocab
Browse files Browse the repository at this point in the history
added automatic tag vocab inference
  • Loading branch information
bmcfee committed Jul 28, 2016
2 parents b068f12 + 83cf2de commit b3432f5
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
12 changes: 10 additions & 2 deletions pumpp/task/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer

import jams

from .base import BaseTaskTransformer

__all__ = ['DynamicLabelTransformer', 'StaticLabelTransformer']


class DynamicLabelTransformer(BaseTaskTransformer):

def __init__(self, namespace, name, labels, sr=22050, hop_length=512):
def __init__(self, namespace, name, labels=None, sr=22050, hop_length=512):
'''Initialize a time-series label transformer
Parameters
Expand All @@ -32,6 +34,9 @@ def __init__(self, namespace, name, labels, sr=22050, hop_length=512):
sr=sr,
hop_length=hop_length)

if labels is None:
labels = jams.schema.values(namespace)

self.encoder = MultiLabelBinarizer()
self.encoder.fit([labels])
self._classes = set(self.encoder.classes_)
Expand Down Expand Up @@ -63,7 +68,7 @@ def transform_annotation(self, ann, duration):

class StaticLabelTransformer(BaseTaskTransformer):

def __init__(self, namespace, name, labels):
def __init__(self, namespace, name, labels=None):
'''Initialize a global label transformer
Parameters
Expand All @@ -76,6 +81,9 @@ def __init__(self, namespace, name, labels):
namespace=namespace,
sr=1, hop_length=1)

if labels is None:
labels = jams.schema.values(namespace)

self.encoder = MultiLabelBinarizer()
self.encoder.fit([labels])
self._classes = set(self.encoder.classes_)
Expand Down
48 changes: 45 additions & 3 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def test_task_dlabel_present(SR, HOP_LENGTH):

ann.append(time=0, duration=1.0, value='alpha')
ann.append(time=0, duration=1.0, value='beta')
ann.append(time=1, duration=1.0, value='23')
ann.append(time=1, duration=1.0, value='some nonsense')
ann.append(time=3, duration=1.0, value='disco')

jam.annotations.append(ann)
Expand Down Expand Up @@ -259,6 +259,27 @@ def test_task_dlabel_absent(SR, HOP_LENGTH):
assert type_match(output[key].dtype, trans.fields[key].dtype)


def test_task_dlabel_auto(SR, HOP_LENGTH):
jam = jams.JAMS(file_metadata=dict(duration=4.0))
trans = pumpp.task.DynamicLabelTransformer(namespace='tag_gtzan',
name='genre')

output = trans.transform(jam)

# Mask should be false since we have no matching namespace
assert not np.any(output['genre/_valid'])

y = output['genre/tags']

# Check the shape
assert y.shape == (1, 4 * (SR // HOP_LENGTH), 10)

# Make sure it's empty
assert not np.any(y)
for key in trans.fields:
assert shape_match(output[key].shape[1:], trans.fields[key].shape)
assert type_match(output[key].dtype, trans.fields[key].dtype)


def test_task_slabel_absent():
labels = ['alpha', 'beta', 'psycho', 'aqua', 'disco']
Expand Down Expand Up @@ -294,7 +315,7 @@ def test_task_slabel_present():

ann.append(time=0, duration=1.0, value='alpha')
ann.append(time=0, duration=1.0, value='beta')
ann.append(time=1, duration=1.0, value='23')
ann.append(time=1, duration=1.0, value='some nonsense')
ann.append(time=3, duration=1.0, value='disco')

jam.annotations.append(ann)
Expand Down Expand Up @@ -323,6 +344,27 @@ def test_task_slabel_present():
assert type_match(output[key].dtype, trans.fields[key].dtype)


def test_task_slabel_auto():
jam = jams.JAMS(file_metadata=dict(duration=4.0))
trans = pumpp.task.StaticLabelTransformer(namespace='tag_gtzan',
name='genre')

output = trans.transform(jam)

# Mask should be false since we have no matching namespace
assert not np.any(output['genre/_valid'])

# Check the shape
assert output['genre/tags'].ndim == 2
assert output['genre/tags'].shape[1] == 10

# Make sure it's empty
assert not np.any(output['genre/tags'])

for key in trans.fields:
assert shape_match(output[key].shape[1:], trans.fields[key].shape)
assert type_match(output[key].dtype, trans.fields[key].dtype)


@pytest.mark.parametrize('dimension', [1, 2, 4])
@pytest.mark.parametrize('name', ['collab', 'vec'])
Expand Down Expand Up @@ -532,7 +574,7 @@ def test_transform_query():

ann.append(time=0, duration=1.0, value='alpha')
ann.append(time=0, duration=1.0, value='beta')
ann.append(time=1, duration=1.0, value='23')
ann.append(time=1, duration=1.0, value='some nonsense')
ann.append(time=3, duration=1.0, value='disco')

jam.annotations.append(ann)
Expand Down

0 comments on commit b3432f5

Please sign in to comment.