In [1]:
import wandb
import os
os.environ["WANDB_SILENT"] = "true"

import numpy as np

%cd ..

from helpers import load_model
from model import BaseVAE, MuteVAE, MuteGenreLatentVAE, GenreClassifier
from data.src.dataLoaders import Groove2Drum2BarDataset

import torch

from helpers.eval_utils import UMapper

In [2]:
down_sampled_ratio=None
# load dataset as torch.utils.data.Dataset
dataset = Groove2Drum2BarDataset(
    dataset_setting_json_path="data/dataset_json_settings/Balanced_6000_performed.json",
    subset_tag="test",
    max_len=32,
    tapped_voice_idx=2,
    collapse_tapped_sequence=True,
    num_voice_density_bins=3,
    num_tempo_bins=6,
    num_global_density_bins=7,
    augment_dataset=False,
    force_regenerate=False
)

In [3]:
from helpers import download_model_from_wandb, predict_using_model, load_model
    
# download_model_from_wandb("45", 3, "driven-frost-24", GenreClassifier, new_path="./trained_models/genre_classifier.pth")
# download_model_from_wandb("155", 1, "lively-pond-9", BaseVAE, new_path="./trained_models/base_vae_beta_0_2.pth")
# download_model_from_wandb("405", 0, "polished-pyramid-1", MuteVAE, new_path="./trained_models/mute_vae_beta_0_2.pth")

genre_classifier = load_model("./trained_models/genre_classifier.pth", GenreClassifier)
model_BaseVAE = load_model("./trained_models/base_vae_beta_0_2.pth", BaseVAE)
model_MuteVAE = load_model("./trained_models/mute_vae_beta_0_2.pth", MuteVAE)

# model_MuteVAE


# model.serialize(save_folder=f"{run_name}", filename=f"Gen_{run_name}_{epoch}_serialized__{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")

In [4]:
def generate_umap(model_, dataset, indices=None):
    """
    Generate the umap for the given model and dataset setting.
    Args:
        dataset: torch.utils.data.Dataset
        model_: torch.nn.Module
        indices: list of int, optional
            indices of the dataset to be used for umap generation

    Returns:
        dictionary ready to be logged by wandb {f"{subset_name}_{umap}": wandb.Html}
    """

    _, latents_z = predict_using_model(model_, dataset, indices=indices)
    tags = [dataset.genre_tags[i] for i in indices] if indices is not None else dataset.genre_tags
    
    umapper = UMapper("-")
    umapper.fit(latents_z.detach().cpu().numpy(), tags_=tags)
    p = umapper.plot(show_plot=False, prepare_for_wandb=False)
    return p

def plot_and_synthesize(hvo_seq_sample):
  """ Plots the piano roll of the sequence stored in the hvo_sequence object
  and also returns the synthesized pattern
  """
  hvo_seq_sample.piano_roll(show_figure=True)
  audio = hvo_seq_sample.synthesize(
      sf_path="hvo_sequence/soundfonts/Standard_Drum_Kit.sf2")
  return audio

def get_sample_with_filename(name):
    indices = [i for i, sample in enumerate(dataset.hvo_sequences) if name in sample.metadata["full_midi_filename"]]
    return indices

# Run Inference

In [5]:
from helpers import synthesize_visualize_using_models

tabs, audios, gt_audio = synthesize_visualize_using_models([model_BaseVAE, model_MuteVAE], dataset, np.random.randint(0, len(dataset)))

import IPython.display as ipd
from bokeh.io import show

audio_players = [ipd.Audio(gt_audio, rate=44100)]
audio_players.extend([ipd.Audio(audio, rate=44100) for audio in audios])
ipd.display(*audio_players)
show(tabs)

    
    



"# Generate Random Styles

In [6]:
latent_dim = 128
z_a_dataset_ix = 0
z_b_dataset_ix = 1
print("Generated latent_z: ")
print("-"*50)
print("re-run this cell to generate a new latent_z")
import ipywidgets as widgets
from ipywidgets import HBox, VBox
import IPython.display as ipd
from bokeh.embed import file_html
from bokeh.resources import CDN


In [7]:
import ipywidgets as widgets
from ipywidgets import interact

# Assuming 'dataset' and 'model' are defined elsewhere, along with other required imports
from bokeh.io import show
from bokeh.layouts import row
global prev_randomize
randomize_a_prev_state = False
randomize_b_prev_state = False

@interact(
    genre=dataset.genre_tags,
    kick_is_muted=widgets.Checkbox(value=False, description='Mute Kick'),
    snare_is_muted=widgets.Checkbox(value=False, description='Mute Snare'),
    hat_is_muted=widgets.Checkbox(value=False, description='Mute Hat'),
    tom_is_muted=widgets.Checkbox(value=False, description='Mute Tom'),
    cymbal_is_muted=widgets.Checkbox(value=False, description='Mute Cymbal'),
    hsliderInterpolation=widgets.FloatSlider(value=0, min=0, max=1, step=0.01, description='Interpolation'),
    randomizeA=widgets.Checkbox(value=False, description='Press to randomize'),
    randomizeB=widgets.Checkbox(value=False, description='Press to randomize'))
def generate(genre, kick_is_muted, snare_is_muted, hat_is_muted, tom_is_muted, cymbal_is_muted, hsliderInterpolation, randomizeA, randomizeB):
    global prev_randomize, latent_z_A, latent_z_B, randomize_a_prev_state, randomize_b_prev_state
    global genre_classifier, model_BaseVAE, model_MuteVAE, dataset, z_a_dataset_ix, z_b_dataset_ix
    
    if randomizeA != randomize_a_prev_state:
            # grab a random sample from the dataset
        z_a_dataset_ix = np.random.randint(0, len(dataset))
        randomize_a_prev_state = randomizeA
    elif randomizeB != randomize_b_prev_state:
        z_b_dataset_ix = np.random.randint(0, len(dataset))
        randomize_b_prev_state = randomizeB
        
    audios = []
    plots = []
    genres_classified = []
    
    # Convert binary mute selections to tensors
    mutes = torch.tensor([
        kick_is_muted,
        snare_is_muted,
        hat_is_muted,
        tom_is_muted,
        cymbal_is_muted
    ]).long()
    
    
    for m_ in [model_BaseVAE, model_MuteVAE]:
        
        _, latent_z_A = predict_using_model(m_, dataset, indices=[z_a_dataset_ix])
        _, latent_z_B = predict_using_model(m_, dataset, indices=[z_b_dataset_ix])
        
        # Interpolate between latent_z_A and latent_z_B
        latent_z = latent_z_A * (1 - hsliderInterpolation) + latent_z_B * hsliderInterpolation
            
        genre_ix = dataset.genre_tags.index(genre)
        genre_ix = torch.tensor([genre_ix], dtype=torch.long)
        
        if m_ is model_BaseVAE:
            
            # Sample from the model using the updated latent_z
            # Note: Implement the model's sampling logic here
            # decode
            h, v, o = m_.sample(
                latent_z = latent_z,
                voice_thresholds=torch.tensor([0.5] * 9),
                voice_max_count_allowed=torch.tensor([32] * 9),
                sampling_mode=0
            )
        elif m_ is model_MuteVAE:
            h, v, o = m_.sample(
                latent_z = latent_z,
                kick_is_muted=torch.tensor([mutes[0]]),
                snare_is_muted=torch.tensor([mutes[1]]),
                hat_is_muted=torch.tensor([mutes[2]]),
                tom_is_muted=torch.tensor([mutes[3]]),
                cymbal_is_muted=torch.tensor([mutes[4]]),
                voice_thresholds=torch.tensor([0.5] * 9),
                voice_max_count_allowed=torch.tensor([32] * 9),
                sampling_mode=0
            )
            
        hvo = torch.cat([h, v, o], dim=2)
        
        gen_, _ = genre_classifier.predict(hvo)
        genres_classified.append(gen_)
        
        hvo_seq = dataset.hvo_sequences[0].copy_empty()
        hvo_seq.hvo = hvo[0, :, :].squeeze().detach().cpu().numpy()
        
        pr = hvo_seq.piano_roll(width=700, height=300, filename=f"{m_.__class__.__name__}")
        # convert bokeh figure to html
        
        plots.append(pr)
        
        audio = hvo_seq.synthesize(sf_path="hvo_sequence/soundfonts/Standard_Drum_Kit.sf2")
        audios.append(audio)
    

    print("Pattern A (gt) genre: ", dataset.genre_targets[z_a_dataset_ix])
    print("Pattern B (gt) genre: ", dataset.genre_targets[z_b_dataset_ix])  
    print(f"Genre classified by BaseVAE: {genres_classified[0]}")
    print(f"Genre classified by MuteVAE: {genres_classified[1]}")
    # show html
    html = file_html(row(*plots), CDN, "my plot")
    for audio in audios:
        ipd.display(ipd.Audio(audio, rate=44100))
    
    ipd.display(ipd.HTML(html))
    



In [8]:
# import torch
# import ipywidgets as widgets
# from ipywidgets import interact_manual, VBox, HBox
# import numpy as np
# import IPython.display as ipd
# from IPython.display import display
# 
# # Assuming 'model', 'dataset', and other necessary components are defined elsewhere
# 
# latent_dim = 128  # Set this to match your model's latent dimension
# latent_z = torch.randn(1, latent_dim)  # Initialize latent_z
# 
# # Create sliders
# sliders = [widgets.FloatSlider(min=-3, max=3, value=0, step=0.01, orientation='vertical', readout=False, layout={'width': '20px', 'height': '75px'}) for _ in range(latent_dim)]
# 
# # initialize sliders with latent_z values
# for i, slider in enumerate(sliders):
#     slider.value = latent_z[0, i].item()
#     
# # Function to update latent_z based on sliders' values
# def update_latent_z(change):
#     global latent_z
#     latent_z[:] = torch.tensor([[slider.value for slider in sliders]], dtype=torch.float)
# 
# # Attach update function to sliders
# for slider in sliders:
#     slider.observe(update_latent_z, names='value')
# 
# # Group sliders into horizontal boxes
# num_sliders_per_row = 64  # Adjust this number based on your display preferences
# hboxes = [HBox(sliders[i:i+num_sliders_per_row]) for i in range(0, latent_dim, num_sliders_per_row)]
# 
# # Display the grouped sliders
# vbox = VBox(hboxes)
# display(vbox)
# 
# # Define the generate function to use the updated latent_z
# def generate(genre, global_density, kick_is_muted, snare_is_muted, hat_is_muted, tom_is_muted, cymbal_is_muted):
#     global latent_z
#     
#     genre_ix = dataset.genre_tags.index(genre)
#     genre_ix = torch.tensor([genre_ix], dtype=torch.long)
#     
#     # Convert binary mute selections to tensors
#     mutes = torch.tensor([
#         kick_is_muted,
#         snare_is_muted,
#         hat_is_muted,
#         tom_is_muted,
#         cymbal_is_muted
#     ]).long()
#     
#     # Sample from the model using the updated latent_z
#     # Note: Implement the model's sampling logic here
#     # decode
#     h, v, o = model.sample(
#         latent_z = latent_z,
#         genre= genre_ix,
#         kick_is_muted=torch.tensor([mutes[0]]),
#         snare_is_muted=torch.tensor([mutes[1]]),
#         hat_is_muted=torch.tensor([mutes[2]]),
#         tom_is_muted=torch.tensor([mutes[3]]),
#         cymbal_is_muted=torch.tensor([mutes[4]]),
#         voice_thresholds=torch.tensor([0.5] * 9),
#         voice_max_count_allowed=torch.tensor([32] * 9),
#         sampling_mode=0
#     )
#     
#     hvo = torch.cat([h, v, o], dim=2)
#     
#     hvo_seq = dataset.hvo_sequences[0].copy_empty()
#     hvo_seq.hvo = hvo[0, :, :].squeeze().detach().cpu().numpy()
#     
#     pr = hvo_seq.piano_roll(width=700, height=300)
#     # convert bokeh figure to html
#     
#     html = file_html(pr, CDN, "my plot")
#     
#     audio = hvo_seq.synthesize(sf_path="hvo_sequence/soundfonts/Standard_Drum_Kit.sf2")
#     IPython.display.display(ipd.Audio(audio, rate=44100))
#     
#     # show html
#     IPython.display.display(IPython.display.HTML(html))
#     
# # Setup interactive widgets for the generate function
# interact_manual(generate,
#     genre=dataset.genre_tags,
#     global_density=(0, 9),
#     kick_is_muted=widgets.Checkbox(value=False),
#     snare_is_muted=widgets.Checkbox(value=False),
#     hat_is_muted=widgets.Checkbox(value=False),
#     tom_is_muted=widgets.Checkbox(value=False),
#     cymbal_is_muted=widgets.Checkbox(value=False)
# )
