In [66]:
from model.LongTermAccompaniment import LongTermAccompaniment
from model import load_model
import torch
from hvo_sequence.hvo_seq import HVO_Sequence
from hvo_sequence.drum_mappings import ROLAND_REDUCED_MAPPING
from bokeh.plotting import show, output_notebook

output_notebook()

model = load_model(
    model_path='misc/LTA/wandering-night-36_y44qfrf5/099.pth',
    model_class=LongTermAccompaniment,
    device='cpu',
    is_evaluating=True
)

model.eval()

# 

LongTermAccompaniment(
  (GrooveRhythmEncoder): GrooveRhythmEncoder(
    (InputLayerEncoder): InputGrooveRhythmLayer(
      (velocity_dropout): Dropout(p=0.1, inplace=False)
      (offset_dropout): Dropout(p=0.1, inplace=False)
      (HitsLinear): Linear(in_features=10, out_features=128, bias=True)
      (VelocitiesLinear): Linear(in_features=10, out_features=128, bias=True)
      (OffsetsLinear): Linear(in_features=10, out_features=128, bias=True)
      (HitsReLU): ReLU()
      (VelocitiesReLU): ReLU()
      (OffsetsReLU): ReLU()
      (PositionalEncoding): PositionalEncoding(
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (Encoder): TransformerEncoder(
      (Encoder): TransformerEncoder(
        (layers): ModuleList(
          (0-3): 4 x TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
            )
            (linear1): Linear(in_features

In [67]:
# load data

from data import PairedLTADataset

test_dataset = PairedLTADataset(
        input_inst_dataset_bz2_filepath="data/lmd/data_bass_groove_test.bz2",
        output_inst_dataset_bz2_filepath="data/lmd/data_drums_full_unsplit.bz2",
        shift_tgt_by_n_steps=1,
        max_input_bars=32,
        continuation_bars=2,
        hop_n_bars=2,
        input_has_velocity=True,
        input_has_offsets=True
    )

INFO:data.Base.dataLoaders:PairedLTADatasetV2 Constructor --> Loading Cached Version from: cached/TorchDatasets/PairedLTADataset_data_bass_groove_test.bz2_data_drums_full_unsplit.bz2_32_2_True_True_1.bz2pickle


In [68]:
def batch_data_extractor(data_, in_step_start, in_n_steps, out_step_start, out_n_steps, device='cpu'):

    inst_1 = data_[0].to(device) if data_[0].device.type != device else data_[0]
    inst_2 = data_[1].to(device) if data_[1].device.type != device else data_[1]
    stacked_inst_12 = data_[2].to(device) if data_[2].device.type != device else data_[2]

    input_solo = inst_1[:, in_step_start:in_step_start + in_n_steps]
    input_stacked = stacked_inst_12[:, in_step_start:in_step_start + in_n_steps]
    next_output = inst_2[:, out_step_start:out_step_start + out_n_steps]
    previous_output = inst_2[:, out_step_start - out_n_steps:out_step_start]

    return input_solo, input_stacked, next_output, previous_output

def create_src_mask(n_bars, max_n_bars):
    # masked items are the ones noted as True

    batch_size = n_bars.shape[0]
    mask = torch.zeros((batch_size, max_n_bars)).bool()
    for i in range(batch_size):
        mask[i, n_bars[i]:] = 1
    return mask

def predict_using_batch_data(batch_data, num_input_bars=None, model_=model, device='cpu'):
    model_.eval()

    in_len = 32 * 16
    out_len = 32
    in_step_start = 0
    in_n_steps = in_len
    out_step_start = in_len
    out_n_steps = out_len

    input_solo, input_stacked, output, previous_output = batch_data_extractor(
        data_=batch_data,
        in_step_start=in_step_start,
        in_n_steps=in_n_steps,
        out_step_start=out_step_start,
        out_n_steps=out_n_steps,
        device=device
    )

    enc_src = input_stacked 
    dec_src = previous_output 

    if num_input_bars is None:
        num_input_bars = torch.ones((enc_src.shape[0], 1), dtype=torch.long).to(device) * 32

    with torch.no_grad():
        h, v, o, hvo = model_.sample(
            src=enc_src,
            src_key_padding_and_memory_mask=create_src_mask(num_input_bars, 32).to(device),
            tgt=dec_src
        )
    return hvo

In [94]:
import numpy as np
import timeit
sample_ix = np.random.randint(0, len(test_dataset))
print(sample_ix)
start = timeit.default_timer()

# select a random sample


hvo = predict_using_batch_data(test_dataset[sample_ix:sample_ix+1], model_=model, num_input_bars=torch.tensor([32]))

hvo_gt = test_dataset.instrument2_hvos[sample_ix][-64:, :]

print('Time: ', (timeit.default_timer() - start) * 1000, 'ms')


total = torch.cat([hvo_gt, hvo[0]], dim=0)


hvo_seq = HVO_Sequence(
    beat_division_factors=[4], 
    drum_mapping=ROLAND_REDUCED_MAPPING
)
hvo_seq.add_tempo(0, 120)
hvo_seq.add_time_signature(0, 4, 4)
print(total.cpu().numpy().shape)
hvo_seq.hvo = total.cpu().numpy()
hvo_seq.to_html_plot(show_figure=True, width=800, height=400)



6197
Time:  36.86212300090119 ms
(96, 27)


In [72]:
import numpy as np
import timeit

start = timeit.default_timer()

# select a random sample
sample_ix = sample_ix

hvo = predict_using_batch_data(test_dataset[sample_ix:sample_ix+1], model_=model, num_input_bars=torch.tensor([2]))

hvo_gt = test_dataset.instrument2_hvos[sample_ix][-64:, :]

print('Time: ', (timeit.default_timer() - start) * 1000, 'ms')

total = torch.cat([hvo_gt, hvo[0]], dim=0)


hvo_seq = HVO_Sequence(
    beat_division_factors=[4], 
    drum_mapping=ROLAND_REDUCED_MAPPING
)
hvo_seq.add_tempo(0, 120)
hvo_seq.add_time_signature(0, 4, 4)
print(total.cpu().numpy().shape)
hvo_seq.hvo = total.cpu().numpy()
hvo_seq.to_html_plot(show_figure=True, width=800, height=400)

Time:  32.58346900111064 ms
(96, 27)
