In [127]:
from model.LongTermAccompaniment import LongTermAccompaniment
from model import load_model
import torch
from hvo_sequence.hvo_seq import HVO_Sequence
from hvo_sequence.drum_mappings import ROLAND_REDUCED_MAPPING
from bokeh.plotting import show, output_notebook
import os 

output_notebook()

mdl_ = 'misc/LTA/cool-universe-43_7org9tsf/007.pth'
model = load_model(
    model_path=mdl_,
    model_class=LongTermAccompaniment,
    device='cpu',
    is_evaluating=True
)
model.serialize(save_folder=os.path.dirname(mdl_), filename=mdl_.split('/')[-1].replace('.pth', '.pt'))
model.eval()

GrooveEncoderModel = model.GrooveRhythmEncoder

In [128]:
# load data

from data import PairedLTADataset

test_dataset = PairedLTADataset(
        input_inst_dataset_bz2_filepath="data/lmd/data_bass_groove_test.bz2",
        output_inst_dataset_bz2_filepath="data/lmd/data_drums_full_unsplit.bz2",
        shift_tgt_by_n_steps=1,
        max_input_bars=32,
        continuation_bars=2,
        hop_n_bars=2,
        input_has_velocity=True,
        input_has_offsets=True
    )

INFO:data.Base.dataLoaders:PairedLTADatasetV2 Constructor --> Loading Cached Version from: cached/TorchDatasets/PairedLTADataset_data_bass_groove_test.bz2_data_drums_full_unsplit.bz2_32_2_True_True_1.bz2pickle


In [129]:
def batch_data_extractor(data_, in_step_start, in_n_steps, out_step_start, out_n_steps, device='cpu'):

    inst_1 = data_[0].to(device) if data_[0].device.type != device else data_[0]
    inst_2 = data_[1].to(device) if data_[1].device.type != device else data_[1]
    stacked_inst_12 = data_[2].to(device) if data_[2].device.type != device else data_[2]

    input_solo = inst_1[:, in_step_start:in_step_start + in_n_steps]
    input_stacked = stacked_inst_12[:, in_step_start:in_step_start + in_n_steps]
    next_output = inst_2[:, out_step_start:out_step_start + out_n_steps]
    shifted_output = next_output.clone()
    shifted_output[:, 1:, :] = next_output[:, :-1, :]
    shifted_output[:, 0, :] = 0
    return input_solo, input_stacked, next_output, shifted_output

def create_src_mask(n_bars, max_n_bars):
    # masked items are the ones noted as True

    batch_size = n_bars.shape[0]
    mask = torch.zeros((batch_size, max_n_bars)).bool()
    for i in range(batch_size):
        mask[i, n_bars[i]:] = 1
    return mask

def predict_using_batch_data(batch_data, num_input_bars=None, model_=model, device='cpu'):
    model_.eval()

    in_len = 32 * 16
    out_len = 32
    in_step_start = 0
    in_n_steps = in_len
    out_step_start = in_len
    out_n_steps = out_len

    input_solo, input_stacked, output, shifted_output = batch_data_extractor(
        data_=batch_data,
        in_step_start=in_step_start,
        in_n_steps=in_n_steps,
        out_step_start=out_step_start,
        out_n_steps=out_n_steps,
        device=device
    )

    enc_src = input_stacked 
    dec_src = shifted_output 

    if num_input_bars is None:
        num_input_bars = torch.ones((enc_src.shape[0], 1), dtype=torch.long).to(device) * 32

    with torch.no_grad():
        h, v, o, hvo = model_.sample(
            src=enc_src,
            src_key_padding_and_memory_mask=create_src_mask(num_input_bars, 32).to(device),
            tgt=dec_src
        )
    return hvo

# Auto-regressive prediction
def predict_using_batch_data_auto_Reg(batch_data, num_input_bars=None, model_=model, device='cpu'):
    model_.eval()

    in_len = 32 * 16
    out_len = 32
    in_step_start = 0
    in_n_steps = in_len
    out_step_start = in_len
    out_n_steps = out_len

    input_solo, input_stacked, output, shifted_output = batch_data_extractor(
        data_=batch_data,
        in_step_start=in_step_start,
        in_n_steps=in_n_steps,
        out_step_start=out_step_start,
        out_n_steps=out_n_steps,
        device=device
    )

    enc_src = input_stacked 
    dec_src = shifted_output 
    print(shifted_output.shape)
    if num_input_bars is None:
        num_input_bars = torch.ones((enc_src.shape[0], 1), dtype=torch.long).to(device) * 32

    hvo_in = torch.zeros((1, 32, 27))
    
    start = timeit.default_timer()


    for i in range(32):
        
        with torch.no_grad():
            
            h, v, o, hvo_ = model_.sample(
                src=enc_src,
                src_key_padding_and_memory_mask=create_src_mask(num_input_bars, 32).to(device),
                tgt=hvo_in
            )
            
        if i < 31:
            hvo_in[:, i+1, :] = hvo_[:, i, :]
    
    print('Time: ', (timeit.default_timer() - start) * 1000, 'ms')
        
    return hvo_

In [165]:
import numpy as np
import timeit
sample_ix = np.random.randint(0, len(test_dataset))
print(sample_ix)

# select a random sample


hvo = predict_using_batch_data_auto_Reg(test_dataset[sample_ix:sample_ix+1], model_=model, num_input_bars=torch.tensor([32]))

hvo_gt = test_dataset.instrument2_hvos[sample_ix][-64:, :]



total = torch.cat([hvo_gt, hvo[0]], dim=0)


hvo_seq = HVO_Sequence(
    beat_division_factors=[4], 
    drum_mapping=ROLAND_REDUCED_MAPPING
)
hvo_seq.add_tempo(0, 120)
hvo_seq.add_time_signature(0, 4, 4)
print(total.cpu().numpy().shape)
hvo_seq.hvo = total.cpu().numpy()
hvo_seq.to_html_plot(show_figure=True, width=800, height=400)


13215
torch.Size([1, 32, 27])
Time:  1761.9392779888585 ms
(96, 27)


In [166]:
hvo_seq2 = HVO_Sequence(
    beat_division_factors=[4], 
    drum_mapping=ROLAND_REDUCED_MAPPING
)
hvo_seq2.add_tempo(0, 120)
hvo_seq2.add_time_signature(0, 4, 4)
hvo_seq2.hvo = torch.cat([hvo_gt[:32, :], hvo[0]], dim=0).cpu().numpy()

# load audio player
from IPython.display import Audio
Audio(hvo_seq2.synthesize(sf_path='hvo_sequence/soundfonts/Standard_Drum_Kit.sf2'), rate=44100)


fluidsynth: error: Unknown integer parameter 'synth.sample-rate'


In [132]:
# mask bass groove
import numpy as np 
import timeit

start = timeit.default_timer()

# select a random sample
sample_ix = sample_ix

sample = test_dataset[sample_ix:sample_ix+1]

hvo = predict_using_batch_data_auto_Reg(sample, model_=model, num_input_bars=torch.tensor([12]))

hvo_gt = test_dataset.instrument2_hvos[sample_ix][-64:, :]

print('Time: ', (timeit.default_timer() - start) * 1000, 'ms')

total = torch.cat([hvo_gt, hvo[0]], dim=0)


hvo_seq = HVO_Sequence(
    beat_division_factors=[4], 
    drum_mapping=ROLAND_REDUCED_MAPPING
)
hvo_seq.add_tempo(0, 120)
hvo_seq.add_time_signature(0, 4, 4)
print(total.cpu().numpy().shape)
hvo_seq.hvo = total.cpu().numpy()
hvo_seq.to_html_plot(show_figure=True, width=800, height=400)

torch.Size([1, 32, 27])
Time:  1638.0083239637315 ms
Time:  1703.8672149647027 ms
(96, 27)


In [133]:
hvo_seq2 = HVO_Sequence(
    beat_division_factors=[4], 
    drum_mapping=ROLAND_REDUCED_MAPPING
)

hvo_seq2.add_tempo(0, 120)
hvo_seq2.add_time_signature(0, 4, 4)
hvo_seq2.hvo = torch.cat([hvo_gt[:32, :], hvo[0]], dim=0).cpu().numpy()

# load audio player
from IPython.display import Audio
Audio(hvo_seq2.synthesize(sf_path='hvo_sequence/soundfonts/Standard_Drum_Kit.sf2'), rate=44100)


fluidsynth: error: Unknown integer parameter 'synth.sample-rate'


In [134]:
sample_ix = np.random.randint(0, len(test_dataset))
inst1 = test_dataset.instrument1and2_hvos[sample_ix]
bar1 = GrooveEncoderModel(inst1[0:16, :].unsqueeze(0).to('cpu'))
bar2 = GrooveEncoderModel(inst1[16:32, :].unsqueeze(0).to('cpu'))
bar3 = GrooveEncoderModel(inst1[32:48, :].unsqueeze(0).to('cpu'))
bar4 = GrooveEncoderModel(inst1[-16:, :].unsqueeze(0).to('cpu'))


bar1_hvo = test_dataset.instrument1and2_hvos[sample_ix][0:16, :]
bar1_hvo_no_bass = bar1_hvo.clone()
bar1_hvo_no_bass[::2, :1] = 1
bar1_hvo_no_bass[:, 10:11] = 0.125
bar1_hvo_no_bass[:, 20:21] = -0.24
bar1_hvo_no_bass_embed = GrooveEncoderModel(bar1_hvo_no_bass.unsqueeze(0).to('cpu'))

In [118]:
torch.abs(bar1 - bar1_hvo_no_bass_embed).sum(), torch.abs(bar1 - bar2).sum()

(tensor(4.6955, grad_fn=<SumBackward0>),
 tensor(169.8127, grad_fn=<SumBackward0>))

In [11]:
bar1_hvo_no_bass[:, :1]

tensor([[1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.],
        [1.],
        [0.]])