In [171]:
from model.LongTermAccompaniment import LongTermAccompanimentHierarchical
from model import load_model
import torch
from hvo_sequence.hvo_seq import HVO_Sequence
from hvo_sequence.drum_mappings import ROLAND_REDUCED_MAPPING
from bokeh.plotting import show, output_notebook
import os 

output_notebook()

mdl_ = 'misc/LTA/good-darkness-58_8c9fdnwy/085.pth'
model = load_model(
    model_path=mdl_,
    model_class=LongTermAccompanimentHierarchical,
    device='cpu',
    is_evaluating=True
)
model.serialize(save_folder=os.path.dirname(mdl_), filename=mdl_.split('/')[-1].replace('.pth', '.pt'))
model.eval()

GrooveEncoderModel = model.BeatEncoder

In [172]:
# load data

from data import PairedLTADataset

test_dataset = PairedLTADataset(
        input_inst_dataset_bz2_filepath="data/lmd/data_bass_groove_test.bz2",
        output_inst_dataset_bz2_filepath="data/lmd/data_drums_full_unsplit.bz2",
        shift_tgt_by_n_steps=1,
        max_input_bars=32,
        continuation_bars=2,
        hop_n_bars=1,
        input_has_velocity=True,
        input_has_offsets=True
    )

INFO:data.Base.dataLoaders:PairedLTADatasetV2 Constructor --> Loading Cached Version from: cached/TorchDatasets/PairedLTADataset_data_bass_groove_test.bz2_data_drums_full_unsplit.bz2_32_1_True_True_1.bz2pickle


In [173]:
def batch_data_extractor(data_, in_step_start, in_n_steps, out_step_start, out_n_steps, device='cpu'):

    inst_1 = data_[0].to(device) if data_[0].device.type != device else data_[0]
    inst_2 = data_[1].to(device) if data_[1].device.type != device else data_[1]
    stacked_inst_12 = data_[2].to(device) if data_[2].device.type != device else data_[2]

    input_solo = inst_1[:, in_step_start:in_step_start + in_n_steps]
    input_stacked = stacked_inst_12[:, in_step_start:in_step_start + in_n_steps]
    next_output = inst_2[:, out_step_start:out_step_start + out_n_steps]
    shifted_output = next_output.clone()
    shifted_output[:, 1:, :] = next_output[:, :-1, :]
    shifted_output[:, 0, :] = 0
    return input_solo, input_stacked, next_output, shifted_output

def create_src_mask(n_bars, max_n_bars):
    # masked items are the ones noted as True

    batch_size = n_bars.shape[0]
    mask = torch.zeros((batch_size, max_n_bars)).bool()
    for i in range(batch_size):
        mask[i, n_bars[i]:] = 1
    return mask

def predict_using_batch_data(batch_data, num_input_bars=None, model_=model, device='cpu'):
    model_.eval()

    in_len = 32 * 16
    out_len = 32
    in_step_start = 0
    in_n_steps = in_len
    out_step_start = in_len
    out_n_steps = out_len

    input_solo, input_stacked, output, shifted_output = batch_data_extractor(
        data_=batch_data,
        in_step_start=in_step_start,
        in_n_steps=in_n_steps,
        out_step_start=out_step_start,
        out_n_steps=out_n_steps,
        device=device
    )

    enc_src = input_solo 
    dec_src = shifted_output 

    if num_input_bars is None:
        num_input_bars = torch.ones((enc_src.shape[0], 1), dtype=torch.long).to(device) * 32

    with torch.no_grad():
        h, v, o, hvo = model_.sample(
            src=enc_src,
            src_key_padding_and_memory_mask=create_src_mask(num_input_bars, 32).to(device),
            tgt=dec_src
        )
    return hvo

# Auto-regressive prediction
def predict_using_batch_data_auto_Reg(batch_data, num_input_bars=None, model_=model, device='cpu'):
    model_.eval()

    in_len = 32 * 16
    out_len = 32
    in_step_start = 0
    in_n_steps = in_len
    out_step_start = in_len
    out_n_steps = out_len

    input_solo, input_stacked, output, shifted_output = batch_data_extractor(
        data_=batch_data,
        in_step_start=in_step_start,
        in_n_steps=in_n_steps,
        out_step_start=out_step_start,
        out_n_steps=out_n_steps,
        device=device
    )

    enc_src = input_solo 
    dec_src = shifted_output 
    print(shifted_output.shape)
    if num_input_bars is None:
        num_input_bars = torch.ones((enc_src.shape[0], 1), dtype=torch.long).to(device) * 32

    hvo_in = torch.zeros((1, 32, 27))

    start = timeit.default_timer()


    for i in range(32):
        
        with torch.no_grad():
            
            h, v, o, hvo_ = model_.sample(
                src=enc_src,
                src_key_padding_and_memory_mask=create_src_mask(num_input_bars, 32).to(device),
                tgt=hvo_in,
                scale_vel=1.2
            )
            
        if i < 31:
            hvo_in[:, i+1, :] = hvo_[:, i, :]
    
    print('Time: ', (timeit.default_timer() - start) * 1000, 'ms')
        
    return hvo_

In [174]:
import numpy as np
import timeit
sample_ix = np.random.randint(0, len(test_dataset))
print(sample_ix)

# select a random sample


hvo = predict_using_batch_data(test_dataset[sample_ix:sample_ix+1], model_=model, num_input_bars=torch.tensor([32]))

hvo_gt = test_dataset.instrument2_hvos[sample_ix][-64:, :]



total = torch.cat([hvo_gt, hvo[0]], dim=0)


hvo_seq = HVO_Sequence(
    beat_division_factors=[4],
    drum_mapping=ROLAND_REDUCED_MAPPING
)
hvo_seq.add_tempo(0, 120)
hvo_seq.add_time_signature(0, 4, 4)
print(total.cpu().numpy().shape)
hvo_seq.hvo = total.cpu().numpy()
hvo_seq.to_html_plot(show_figure=True, width=800, height=400)



8963
(96, 27)


In [175]:
hvo_seq2 = HVO_Sequence(
    beat_division_factors=[4], 
    drum_mapping=ROLAND_REDUCED_MAPPING
)
hvo_seq2.add_tempo(0, 120)
hvo_seq2.add_time_signature(0, 4, 4)
hvo_seq2.hvo = torch.cat([hvo_gt[:32, :], hvo[0]], dim=0).cpu().numpy()

# load audio player
from IPython.display import Audio
Audio(hvo_seq2.synthesize(sf_path='hvo_sequence/soundfonts/Standard_Drum_Kit.sf2'), rate=44100)


fluidsynth: error: Unknown integer parameter 'synth.sample-rate'


In [178]:
import numpy as np
import timeit



# select a random sample
sample_ix = np.random.randint(0, len(test_dataset))
print(sample_ix)
hvo = predict_using_batch_data_auto_Reg(test_dataset[sample_ix:sample_ix+1], model_=model, num_input_bars=torch.tensor([32]))

hvo_gt = test_dataset.instrument2_hvos[sample_ix][-64:, :]



total = torch.cat([hvo_gt, hvo[0]], dim=0)


hvo_seq = HVO_Sequence(
    beat_division_factors=[4],
    drum_mapping=ROLAND_REDUCED_MAPPING
)
hvo_seq.add_tempo(0, 120)
hvo_seq.add_time_signature(0, 4, 4)
print(total.cpu().numpy().shape)
hvo_seq.hvo = total.cpu().numpy()
hvo_seq.to_html_plot(show_figure=True, width=800, height=400)



8963
torch.Size([1, 32, 27])
Time:  964.2066690139472 ms
(96, 27)


In [179]:
hvo_seq2 = HVO_Sequence(
    beat_division_factors=[4], 
    drum_mapping=ROLAND_REDUCED_MAPPING
)

hvo_seq2.add_tempo(0, 120)
hvo_seq2.add_time_signature(0, 4, 4)
hvo_seq2.hvo = torch.cat([hvo_gt[:32, :], hvo[0]], dim=0).cpu().numpy()

# load audio player
from IPython.display import Audio
Audio(hvo_seq2.synthesize(sf_path='hvo_sequence/soundfonts/Standard_Drum_Kit.sf2'), rate=44100)


fluidsynth: error: Unknown integer parameter 'synth.sample-rate'


In [158]:
# mdl_ = 'misc/LTA/charmed-disco-56_jwwbslj7/160.pth'
# model = load_model(
#     model_path=mdl_,
#     model_class=LongTermAccompanimentHierarchical,
#     device='cpu',
#     is_evaluating=True
# )
# model.eval()

model.reset()
sample_ix = np.random.randint(0, len(test_dataset))
print(sample_ix)
i1, i2, i12 = test_dataset[sample_ix:sample_ix+1]

i1_rand = i1[:, :32*32, :].clone() * 0
i2_rand = i2[:, :32*32, :]
hvo_gt = test_dataset.instrument2_hvos[sample_ix][-64:, :]

model.encode_varying_length_performance(i1_rand, i2_rand)
print(model.num_bars_encoded_so_far)

h, v, o, hvo = model.get_next_2_bars(0.5)
total = torch.cat([hvo_gt[:32, :], hvo[0]], dim=0)

from hvo_sequence.hvo_seq import HVO_Sequence
from hvo_sequence.drum_mappings import ROLAND_REDUCED_MAPPING
hvo_seq2 = HVO_Sequence(
    beat_division_factors=[4],
    drum_mapping=ROLAND_REDUCED_MAPPING
)
hvo_seq2.add_tempo(0, 120)
hvo_seq2.add_time_signature(0, 4, 4)
hvo_seq2.hvo = total.cpu().numpy()

hvo_seq2.to_html_plot(show_figure=True, width=800, height=400)

15460
Error: The input instrumental groove and the drums must have a total of 3 features. The input has 30 features


RuntimeError: mat1 and mat2 shapes cannot be multiplied (16x10 and 1x128)

In [None]:
2aq # load audio player
from IPython.display import Audio

Audio(hvo_seq2.synthesize(sf_path='hvo_sequence/soundfonts/Standard_Drum_Kit.sf2'), rate=44100)