Skip to content

Commit

Permalink
Merge pull request #30 from bmcfee/chordtagunk
Browse files Browse the repository at this point in the history
added X for unknown chord symbols in chordtag
  • Loading branch information
bmcfee committed Feb 22, 2017
2 parents 2869c00 + b73989e commit 2dbceb2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 27 deletions.
15 changes: 6 additions & 9 deletions pumpp/task/chord.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,6 @@ class ChordTagTransformer(BaseTaskTransformer):
Note: 5 requires 3, 6 requires 5, 7 requires 6.
nochord : str
String to use for no-chord symbols
sr : number > 0
Sampling rate of audio
Expand All @@ -284,7 +281,7 @@ class ChordTagTransformer(BaseTaskTransformer):
ChordTransformer
SimpleChordTransformer
'''
def __init__(self, name='chord', vocab='3567s', nochord='N',
def __init__(self, name='chord', vocab='3567s',
sr=22050, hop_length=512):

super(ChordTagTransformer, self).__init__(name=name,
Expand All @@ -306,7 +303,6 @@ def __init__(self, name='chord', vocab='3567s', nochord='N',
raise ParameterError('Invalid vocabulary string: {}'.format(vocab))

self.vocab = vocab.lower()
self.nochord = nochord
labels = self.vocabulary()

self.encoder = LabelBinarizer()
Expand Down Expand Up @@ -345,7 +341,7 @@ def empty(self, duration):

ann.append(time=0,
duration=duration,
value='N', confidence=0)
value='X', confidence=0)

return ann

Expand All @@ -367,7 +363,7 @@ def vocabulary(self):
if 's' in self.vocab:
qualities.extend(['sus2', 'sus4'])

labels = [self.nochord]
labels = ['N', 'X']

for chord in product(PITCHES, qualities):
labels.append('{}:{}'.format(*chord))
Expand All @@ -392,9 +388,10 @@ def simplify(self, chord):
P = 2**np.arange(12, dtype=int)
query = self.mask_ & pitches[::-1].dot(P)

if root < 0 and chord[0].upper() == 'N':
return 'N'
if query not in QUALITIES:
# TODO: check for non-zero pitches here
return self.nochord
return 'X'

return '{}:{}'.format(PITCHES[root], QUALITIES[query])

Expand Down
31 changes: 13 additions & 18 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@ def VOCAB(request):
yield request.param


@pytest.fixture(params=['N', 'X', 'no chord'])
def NOCHORD(request):
yield request.param


def shape_match(sh1, sh2):

for i, j in zip(sh1, sh2):
Expand Down Expand Up @@ -644,14 +639,14 @@ def test_transform_coerce():


@pytest.mark.parametrize('vocab, vocab_size',
[('3', 25),
('3s', 49),
('35', 49),
('35s', 73),
('356', 73),
('356s', 97),
('3567', 145),
('3567s', 169),
[('3', 26),
('3s', 50),
('35', 50),
('35s', 74),
('356', 74),
('356s', 98),
('3567', 146),
('3567s', 170),
pytest.mark.xfail(('bad vocab', 1),
raises=pumpp.ParameterError),
pytest.mark.xfail(('5', 1),
Expand Down Expand Up @@ -706,8 +701,8 @@ def test_task_chord_tag_present(SR, HOP_LENGTH, VOCAB):
'G:sus4']

if 's' not in VOCAB:
Y_true_out[11] = 'N' # sus2 -> N
Y_true_out[12] = 'N' # sus4 -> N
Y_true_out[11] = 'X' # sus2 -> X
Y_true_out[12] = 'X' # sus4 -> X
if '6' not in VOCAB:
Y_true_out[1] = 'C:min' # min6 -> maj
Y_true_out[2] = 'C:maj' # maj6 -> maj
Expand Down Expand Up @@ -747,11 +742,11 @@ def test_task_chord_tag_present(SR, HOP_LENGTH, VOCAB):
assert np.all(Y_pred == Y_expected)


def test_task_chord_tag_absent(SR, HOP_LENGTH, VOCAB, NOCHORD):
def test_task_chord_tag_absent(SR, HOP_LENGTH, VOCAB):

jam = jams.JAMS(file_metadata=dict(duration=4.0))
trans = pumpp.task.ChordTagTransformer(name='chord',
vocab=VOCAB, nochord=NOCHORD,
vocab=VOCAB,
sr=SR, hop_length=HOP_LENGTH)

output = trans.transform(jam)
Expand All @@ -762,7 +757,7 @@ def test_task_chord_tag_absent(SR, HOP_LENGTH, VOCAB, NOCHORD):
# Make sure it's all no-chord
Y_pred = trans.encoder.inverse_transform(output['chord/chord'][0])

assert all([_ == NOCHORD for _ in Y_pred])
assert all([_ == 'X' for _ in Y_pred])

# Check the shape
for key in trans.fields:
Expand Down

0 comments on commit 2dbceb2

Please sign in to comment.