Skip to content

Commit

Permalink
Merge pull request #85 from bmcfee/beatposition-backfill
Browse files Browse the repository at this point in the history
Beatposition backfill
  • Loading branch information
bmcfee committed Jul 19, 2017
2 parents 6445351 + 91600a5 commit 47ff307
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
4 changes: 2 additions & 2 deletions pumpp/task/beat.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ def transform_annotation(self, ann, duration):
# In this case, the subdivision is well-defined
subdivision = downbeats[next_idx] - downbeats[prev_idx]
elif prev_idx < 0 and next_idx < len(downbeats):
subdivision = downbeats[0]
subdivision = np.max(values[:downbeats[0]+1])
elif next_idx >= len(downbeats):
subdivision = len(values) - downbeats[prev_idx] - 1
subdivision = len(values) - downbeats[prev_idx]

if subdivision > self.max_divisions or subdivision < 1:
position.extend(self.encoder.transform(['X']))
Expand Down
45 changes: 44 additions & 1 deletion tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ def test_task_beatpos_present(SR, HOP_LENGTH, MAX_DIVISIONS, SPARSE):

ann = jams.Annotation(namespace='beat')

Y_true = [0, 0, 0, # 0:3 = X
Y_true = [0, 2, 3, # 0 = X, 1:3 = 2/3
1, 2, 3, # 3:6 = 1/2/3
1, 2, 3, 4, # 6:10 = 1/2/3/4
1, 2, 3, 4, 5, # 10:15 = X or 1/2/3/4/5
Expand All @@ -933,6 +933,7 @@ def test_task_beatpos_present(SR, HOP_LENGTH, MAX_DIVISIONS, SPARSE):

Y_true_out = ['X'] * len(Y_true)
if MAX_DIVISIONS >= 3:
Y_true_out[1:3] = ['{:02d}/{:02d}'.format(3, i+1) for i in range(1, 3)]
Y_true_out[3:6] = ['{:02d}/{:02d}'.format(3, i+1) for i in range(3)]
if MAX_DIVISIONS >= 4:
Y_true_out[6:10] = ['{:02d}/{:02d}'.format(4, i+1) for i in range(4)]
Expand Down Expand Up @@ -966,3 +967,45 @@ def test_task_beatpos_present(SR, HOP_LENGTH, MAX_DIVISIONS, SPARSE):
axis=0).astype(Y_pred.dtype)
for i, (y1, y2) in enumerate(zip(Y_pred, Y_expected)):
assert y1 == y2

def test_task_beatpos_tail(SR, HOP_LENGTH, SPARSE):
# This test checks for implicit end-of-bar encodings
jam = jams.JAMS(file_metadata=dict(duration=10.0))

ann = jams.Annotation(namespace='beat')

Y_true = [0, 2, 3,
1, 2, 3,
1, 2, 3]

Y_true_out = ['X', '03/02', '03/03',
'03/01', '03/02', '03/03',
'03/01', '03/02', '03/03']

for i, y in enumerate(Y_true):
ann.append(time=i, duration=0, value=y)

jam.annotations.append(ann)

trans = pumpp.task.BeatPositionTransformer(name='beat',
max_divisions=4,
sr=SR, hop_length=HOP_LENGTH,
sparse=SPARSE)

output = trans.transform(jam)

assert np.all(output['beat/_valid'] == [0, 10 *
trans.sr // trans.hop_length])

Y_pred = trans.encoder.inverse_transform(output['beat/position'][0])

if SPARSE:
Y_pred = Y_pred[:, 0]

# This trimming is here because duration is inferred from the track,
# not the ytrue_out
Y_expected = np.repeat(Y_true_out,
(SR // HOP_LENGTH),
axis=0).astype(Y_pred.dtype)
for i, (y1, y2) in enumerate(zip(Y_pred, Y_expected)):
assert y1 == y2

0 comments on commit 47ff307

Please sign in to comment.