In [1]:
# in this one

In [None]:
from model.LTA_Stacked import LTA_Stacked
from model import load_model
import torch
from hvo_sequence.hvo_seq import HVO_Sequence
from hvo_sequence.drum_mappings import ROLAND_REDUCED_MAPPING
from bokeh.plotting import show, output_notebook
import os
import timeit
from hvo_sequence.hvo_seq import HVO_Sequence

output_notebook()

# mdl_ = 'misc/LTA_Stacked/Bass (LTA_Stacked) Predict 1 bar ahead, no velocity at inputs_ojodqmhh/060.pth'
mdl_ = 'misc/LTA_Stacked/[bassGuiSynLeadPian] (LTA_Stacked) 1bar no in vel_m62skfv9/015.pth'

model = load_model(
    model_path=mdl_,
    model_class=LTA_Stacked,
    device='cpu',
    is_evaluating=True
)

model.serialize(save_folder=os.path.dirname(mdl_), filename=mdl_.split('/')[-1].replace('.pth', '.pt'))
model.eval()

In [None]:
# load data
from data import StackedLTADatasetV2
max_n_bars = 32

# test_datasets:
#   - "data/lmd/data_bass_groove_test.bz2"
#   - "data/lmd/data_guitar_groove_test.bz2"
#   - "data/lmd/data_synth_groove_test.bz2"
#   - "data/lmd/data_lead_groove_test.bz2"
#   - "data/lmd/data_piano_groove_test.bz2"

dataset_dict = {
    'bass': "data/lmd/data_bass_groove_test.bz2",
    'guitar': "data/lmd/data_guitar_groove_test.bz2",
    'synth': "data/lmd/data_synth_groove_test.bz2",
    'lead': "data/lmd/data_lead_groove_test.bz2",
    'piano': "data/lmd/data_piano_groove_test.bz2"
}

test_dataset = StackedLTADatasetV2(
        input_inst_dataset_bz2_filepath=dataset_dict['guitar'],
        output_inst_dataset_bz2_filepath="data/lmd/data_drums_full_unsplit.bz2",
        shift_tgt_by_n_steps=1,
        max_input_bars=max_n_bars,
        hop_n_bars=4,
       input_has_velocity=False
    )

In [None]:
device = 'cpu'

def patch_attention(m):
    forward_orig = m.forward

    def wrap(*args, **kwargs):
        kwargs['need_weights'] = True
        kwargs['average_attn_weights'] = False

        return forward_orig(*args, **kwargs)

    m.forward = wrap
class SaveOutput:
    def __init__(self):
        self.outputs = []

    def __call__(self, module, module_in, module_out):
        self.outputs.append(module_out[1])

    def clear(self):
        self.outputs = []



# plot heatmap
import seaborn as sns
import matplotlib.pyplot as plt




def get_hvo_seqs(groove_hvo__, drum_hvo__):
    if len(groove_hvo__.shape) == 3:
        groove_hvo_ = groove_hvo__[0]
    else:
        groove_hvo_ = groove_hvo__

    if len(drum_hvo__.shape) == 3:
        drum_hvo = drum_hvo__[0]
    else:
        drum_hvo = drum_hvo__

    hvo_seq_bass = HVO_Sequence(
        beat_division_factors=[4],
        drum_mapping={'BASS Rhythm': [80]}
    )
    hvo_seq_bass.add_tempo(0, 120)
    hvo_seq_bass.add_time_signature(0, 4, 4)
    hvo_seq_bass.hvo = groove_hvo_.detach().numpy()

    hvo_seq = HVO_Sequence(
        beat_division_factors=[4],
        drum_mapping=ROLAND_REDUCED_MAPPING
    )
    hvo_seq.add_tempo(0, 120)
    hvo_seq.add_time_signature(0, 4, 4)
    hvo_seq.hvo = drum_hvo.detach().numpy()

    audio_bass = hvo_seq_bass.synthesize(sf_path='hvo_sequence/soundfonts/Standard_Drum_Kit.sf2') * 0.5
    audio_drums = hvo_seq.synthesize(sf_path='hvo_sequence/soundfonts/Standard_Drum_Kit.sf2')
    audio_mixed = audio_bass[:min(audio_bass.shape[0], audio_drums.shape[0])] + audio_drums[:min(audio_bass.shape[0], audio_drums.shape[0])]

    return hvo_seq_bass, hvo_seq, audio_bass, audio_drums, audio_mixed


def batch_data_extractor(data_, device=device):
    stacked_target_shifted = data_[0].to(device)
    stacked_target = data_[1].to(device)
    return stacked_target_shifted, stacked_target

def forward_using_batch_data(batch_data, scope_end_step=None, model_=model, device=device):
    model_.train()

    stacked_target_shifted, stacked_target = batch_data_extractor(
        data_=batch_data,
        device=device
    )

    if scope_end_step is not None:
        scope_end_step = min(scope_end_step, stacked_target_shifted.shape[1])
        stacked_target_shifted = stacked_target_shifted[:, :scope_end_step, :]
        stacked_target = stacked_target[:, :scope_end_step, :]

    h_logits, v_logits, o_logits = model_.forward(shifted_tgt=stacked_target_shifted)

    return h_logits, v_logits, o_logits, stacked_target.to(device)


# Auto-regressive prediction
def predict_using_batch_auto_reg(batch_data, model_=model, device='cpu'):
    model_.eval()


    stacked_target_shifted, stacked_target = batch_data_extractor(
        data_=batch_data,
        device=device
    )

    generated = torch.zeros_like(stacked_target)
    generated_shifted = torch.zeros_like(stacked_target_shifted)

    generated_shifted[:, :16, ::10] = stacked_target_shifted[:, :16, ::10] # copy the first bar of bass only

    for i in range(16, stacked_target_shifted.shape[1]):
        h_logits, v_logits, o_logits = model_.forward(shifted_tgt=generated_shifted)

        h = torch.sigmoid(h_logits)
        h[:, :, 0][h[:, :, 0] < 0.2] = 0.0 # bass must be confident
        h[:, :, 1:][h[:, :, 1:] < 0.2] = 0.0    # more improve allowed for drums
        # h[h >= 0.4] = 1.0
        # sample
        h = torch.bernoulli(h)
        v = torch.clamp(((torch.tanh(v_logits) + 1.0) / 2), 0.0, 1.0) * h
        o = torch.tanh(o_logits) * h

        generated[:, i, :10] = h[:, i, :]
        generated[:, i, 10:20] = v[:, i, :]
        generated[:, i, 20:] = o[:, i, :]

        if i < stacked_target_shifted.shape[1] - 1:
            generated_shifted[:, i+1, 0] = stacked_target_shifted[:, i+1, 0] # use the gt bass
            generated_shifted[:, i+1, 10] = stacked_target_shifted[:, i+1, 10] # use the gt bass
            generated_shifted[:, i+1, 1:10] = h[:, i, 1:]
            generated_shifted[:, i+1, 11:20] = o[:, i, 1:]


    drum_hvo = torch.zeros((generated.shape[0], generated.shape[1], 27))

    groove_hvo = generated[:, :, ::10]
    grove_hvo_tgt = stacked_target[:, :, ::10]

    drum_hvo[:, :, :9] = generated[:, :, 1:10]
    drum_hvo[:, :, 9:18] = generated[:, :, 11:20]
    drum_hvo[:, :, 18:] = generated[:, :, 21:]

    drum_hvo_tgt = torch.zeros_like(drum_hvo)
    drum_hvo_tgt[:, :, :9] = stacked_target[:, :, 1:10]
    drum_hvo_tgt[:, :, 9:18] = stacked_target[:, :, 11:20]
    drum_hvo_tgt[:, :, 18:] = stacked_target[:, :, 21:]

    return groove_hvo, drum_hvo, grove_hvo_tgt, drum_hvo_tgt


sample_ix = torch.randint(0, len(test_dataset), (1,)).item()
groove_hvo, drum_hvo, grove_hvo_tgt, drum_hvo_tgt = predict_using_batch_auto_reg(test_dataset[sample_ix:sample_ix+1], model_=model, device='cpu')
hvo_seq_bass, hvo_seq_drum, audio_bass, audio_drums, audio_mixed = get_hvo_seqs(groove_hvo, drum_hvo)
hvo_seq_bass_tgt, hvo_seq_drum_tgt, _, _, audio_mixed_tgt = get_hvo_seqs(grove_hvo_tgt, drum_hvo_tgt)

In [None]:
from IPython.display import Audio, display
print("---------------------------------")
print("Generated")
print("---------------------------------")
display(Audio(audio_mixed, rate=44100))
hvo_seq_bass.to_html_plot(show_figure=True, width=1400, height=200)

In [None]:
hvo_seq_drum.to_html_plot(show_figure=True, width=1400, height=200)

In [None]:
print("---------------------------------")
print("Target")
print("---------------------------------")
display(Audio(audio_mixed_tgt, rate=44100))
hvo_seq_bass_tgt.to_html_plot(show_figure=True, width=1400, height=200)

In [None]:
hvo_seq_drum_tgt.to_html_plot(show_figure=True, width=1400, height=200)

In [None]:
def plot_head_n_attention(head_n, attn_weights_list,  binary_vis=False):
    attn_weights = attn_weights_list[0] if isinstance(attn_weights_list, list) else attn_weights_list
    
    print(attn_weights.shape)
    if head_n is None:
        # sum across all heads

        attn_weights_0 = attn_weights.sum(dim=1, keepdim=True).detach().cpu().numpy()
    else:

        attn_weights_0 = attn_weights.detach().cpu().numpy()
    
        
    if binary_vis:
        attn_weights_0 = attn_weights_0 > 0.01
        

    # two side by side plots
    fig, ax = plt.subplots(1, 2, figsize=(20, 5))

    # use a greyscale color map
    sns.heatmap(attn_weights_0[0, head_n if head_n is not None else 0, :, :].transpose(), cmap='Greys', ax=ax[0])
    ax[0].set_title('Cross-Attention Weights - With instrumental')

    # put ticks on 0, 4, ...
    # rotate y ticks

    for a in ax:
        a.set_xticks([i for i in range(0, max_n_bars*16, 16)])
        a.set_xticklabels([i for i in range(0, max_n_bars*16, 16)])
        a.set_yticks([i for i in range(0, max_n_bars*16, 16)])
        a.set_yticklabels([i for i in range(0, max_n_bars*16, 16)])
        a.set_yticklabels(a.get_yticklabels(), rotation=0)

    # flip y axis
    for a in ax:
        a.invert_yaxis()
    
    # ensure same aspect ratio
    for a in ax:
        a.set_aspect('equal')
        
    plt.show()
    
save_output = SaveOutput()

layer_ix = 0

patch_attention(model.TransformerEncoder.layers[layer_ix].self_attn)
hook_handle = model.TransformerEncoder.layers[layer_ix].self_attn.register_forward_hook(save_output)

groove_hvo, drum_hvo, grove_hvo_tgt, drum_hvo_tgt = predict_using_batch_auto_reg(test_dataset[sample_ix:sample_ix+1], model_=model, device='cpu')


plot_head_n_attention(0, save_output.outputs[-1], binary_vis=True)


In [None]:
len(save_output.outputs)

In [None]:
save_output.outputs[-1].shape

In [None]:
# # alternative positional embedding
#
# import numpy as np
#
# def modified_positional_encoding(period, d_model, max_len=512):
#     pe = np.zeros((max_len, d_model))
#     position = np.arange(0, max_len).reshape(-1, 1)
#
#     # Sinusoidal part
#     div_term = 2 * np.pi / period
#     pe[:, 0::2] = np.sin(position * div_term)
#     pe[:, 1::2] = np.cos(position * div_term)
#
#     # Linear term for distinguishability
#     div_term_linear = 10000 ** (np.arange(0, d_model) / d_model)
#     pe += position / div_term_linear
#
#     return pe
#
# def normal_positional_encoding(d_model, max_len=512):
#     pe = np.zeros((max_len, d_model))
#     position = np.arange(0, max_len).reshape(-1, 1)
#
#     # Sinusoidal part
#     div_term = 10000 ** (np.arange(0, d_model, 2) / d_model)
#     pe[:, 0::2] = np.sin(position / div_term)
#     pe[:, 1::2] = np.cos(position / div_term)
#
#     return pe
#
# # Example usage:
# period = 16  # For example, if the period is 24
# d_model = 16
# pos_enc_periodic = modified_positional_encoding(period, d_model)
# pos_enc_normal = normal_positional_encoding(d_model)
#
# # To use this positional encoding in a transformer model
# import torch
# pos_enc_tensor_periodic = torch.tensor(pos_enc_periodic, dtype=torch.float32)
# pos_enc_tensor_normal = torch.tensor(pos_enc_normal, dtype=torch.float32)
#
#
#
# # visualize similar to attention is all you need paper
# # show both
# import seaborn as sns
# import matplotlib.pyplot as plt
#
# fig, ax = plt.subplots(1, 2, figsize=(20, 5))
#
# # use a greyscale color map
# sns.heatmap(pos_enc_tensor_periodic, cmap='Greys', ax=ax[0])
# ax[0].set_title('Modified Positional Encoding')
#
# sns.heatmap(pos_enc_tensor_normal, cmap='Greys', ax=ax[1])
# ax[1].set_title('Normal Positional Encoding')
#
# plt.show()
