In [1]:
import wandb
import os
os.environ["WANDB_SILENT"] = "true"

import numpy as np
import pandas as pd

# import sys
# sys.path.insert(0, "../..")

%cd ../..

from helpers import load_GenreGlobalDensityWithVoiceMutesVAE_model
from model import GenreDensityTempoVAE
from data.src.dataLoaders import Groove2Drum2BarDataset

import torch

from umap import UMAP

from bokeh.palettes import inferno, Category20b
from bokeh.core.enums import MarkerType
from bokeh.plotting import figure, show, save
from bokeh.io import output_notebook, reset_output
# output_notebook()

from helpers.VAE.eval_utils import UMapper

/home/dtic/Github/GrooveTransformerV2


In [2]:

dataset_json_dir="data/dataset_json_settings" 
dataset_json_fname="Balanced_5000_per_genre_performed_4_4.json"
down_sampled_ratio=None
# load dataset as torch.utils.data.Dataset
dataset = Groove2Drum2BarDataset(
    dataset_setting_json_path="data/dataset_json_settings/Balanced_5000_per_genre_performed_4_4.json",
    subset_tag="test",
    max_len=32,
    tapped_voice_idx=2,
    collapse_tapped_sequence=True,
    num_voice_density_bins=3,
    num_tempo_bins=6,
    num_global_density_bins=7,
    augment_dataset=False,
    force_regenerate=False
)

dataset.get_complexities_per_genre()

INFO:data.Base.dataLoaders:Groove2Drum2BarDataset Constructor --> Loading Cached Version from: cached/TorchDatasets/Balanced_5000_per_genre_performed_4_4.json_test_32_2_True_None_False_3_6_7.bz2pickle
INFO:data.Base.dataLoaders:Loaded 4957 sequences


{'Afro': [0.7742831707000732,
  0.49747270345687866,
  0.6026937961578369,
  0.7698428630828857,
  0.7045333981513977,
  0.5224411487579346,
  0.721150815486908,
  0.4304780662059784,
  0.5458945035934448,
  0.7262718081474304,
  0.7082489728927612,
  0.6574642062187195,
  0.5649880170822144,
  0.6465933918952942,
  0.8562602400779724,
  0.6798524856567383,
  0.5842910408973694,
  0.6669055223464966,
  0.5433181524276733,
  0.5854308009147644,
  0.6070335507392883,
  0.7116632461547852,
  0.8326142430305481,
  0.7356992363929749,
  0.838104784488678,
  0.8095126152038574,
  0.6658778786659241,
  0.7793093919754028,
  0.2940894365310669,
  0.6569135785102844,
  0.7351283431053162,
  0.5917717218399048,
  0.6185813546180725,
  0.45829999446868896,
  0.5963935256004333,
  0.4114059805870056,
  0.6663544178009033,
  0.7661973834037781,
  0.5471656918525696,
  0.6320554614067078,
  0.6748027205467224,
  0.6818219423294067,
  0.8323485255241394,
  0.5498719215393066,
  0.4599396884441376,
  

In [3]:
from collections import Counter
dataset.genre_tags, len(dataset.genre_tags), Counter(dataset.global_density_bins.numpy())

(['Afro',
  'Blues',
  'Disco',
  'Funk',
  'Hip-Hop/R&B/Soul',
  'Jazz',
  'Latin',
  'Pop',
  'Reggae',
  'Rock',
  'unknown'],
 11,
 Counter({8: 535,
          4: 444,
          7: 489,
          3: 650,
          2: 326,
          0: 492,
          6: 491,
          5: 506,
          1: 527,
          9: 497}))

In [4]:
# voice counts per sample
hits = dataset.output_grooves[:, :, :9]

is_active = hits.sum(dim=1) > 0
is_active.shape

voice_labels = ["K", "S", "H", "T", "C"] # --> Kick:0, snare:1, hihat:2,3 , tom:4,5,6, crashandRide:7,8

voice_map = {
    0: "K",
    1: "S",
    2: "H",
    3: "H",
    4: "T",
    5: "T",
    6: "T",
    7: "C",
    8: "C"
}

labeled_samples = []
for activity in is_active:
    active_voices = "".join([voice_map[ix] for ix, act in enumerate(activity) if act])
    # remove duplicates
    active_voices = "".join(sorted(set(active_voices)))
    labeled_samples.append(active_voices)
    
unique_groups = set(labeled_samples)
print(unique_groups)

# do histogram using pandas
df = pd.DataFrame(labeled_samples, columns=["active_voices"])
df["count"] = 1
df = df.groupby("active_voices").count().reset_index()
df = df.sort_values("count", ascending=False)
df


{'KS', 'KST', 'CHK', 'CKT', 'CHS', 'CKS', 'K', 'H', 'CHKS', 'T', 'HKST', 'CHST', 'HKT', 'KT', 'S', 'HKS', 'C', 'HT', 'CK', 'CH', 'CT', 'CHKT', 'CST', 'HS', 'CHT', 'HK', 'HST', 'ST', 'CHKST', 'CS', 'CKST'}


Unnamed: 0,active_voices,count
18,HKS,1929
4,CHKST,805
3,CHKS,679
19,HKST,668
10,CKS,193
26,KST,86
11,CKST,79
25,KS,62
21,HS,50
17,HK,46


In [5]:
def generate_umap(GenTempoDensityModel):
    """
    Generate the umap for the given model and dataset setting.
    Args:
        :param GenTempoDensityModel: The model to be used for evaluation
        :param test_dataset: The dataset to be used for evaluation
        :param subset_name: The name of the subset to be used for evaluation
        :param collapse_tapped_sequence: Whether to collapse the tapped voice or not

    Returns:
        dictionary ready to be logged by wandb {f"{subset_name}_{umap}": wandb.Html}
    """

    # and model is correct type
    assert isinstance(GenTempoDensityModel, GenreDensityTempoVAE)
    
    in_groove = dataset.input_grooves
    tags = dataset.get_genre_labels_for_all()
    tempo_bins = dataset.tempo_bins
    kick_density_bins = dataset.kick_density_bins
    snare_density_bins = dataset.snare_density_bins
    hats_density_bins = dataset.hat_density_bins
    toms_density_bins = dataset.tom_density_bins
    cymbals_density_bins = dataset.cymbal_density_bins
        
    _, latents_z = GenTempoDensityModel.predict(
        flat_hvo_groove=in_groove,
        genre_tags=tags,
        tempo_bins=tempo_bins,
        kick_density_bins=kick_density_bins,
        snare_density_bins=snare_density_bins,
        hats_density_bins=hats_density_bins,
        toms_density_bins=toms_density_bins,
        cymbals_density_bins=cymbals_density_bins)

    umapper = UMapper("-")
    umapper.fit(latents_z.detach().cpu().numpy(), tags_=tags)
    p = umapper.plot(show_plot=False, prepare_for_wandb=False)
    return p


# Download model, load and Serialize

In [6]:
epoch = "380" #"605"
version = 1
run_name = "stilted-cherry-8"   #"estive-fish-12 "     # "glittering-dragon-9"

artifact_path = f"behzadhaki/GenreGlobalDensityWithVoiceMutesVAE_Balanced5000PerGen/model_epoch_{epoch}:v{version}"
epoch = artifact_path.split("model_epoch_")[-1].split(":")[0]

local_path = f"artifacts/model_epoch_{epoch}:v{version}/{run_name}.pth"
if not os.path.exists(local_path):
    print("Downloading artifact")
    run = wandb.init()
    artifact = run.use_artifact(artifact_path, type='model')
    artifact_dir = artifact.download()
    # rename {epoch}.pth to {run_name}.pth
    os.rename(os.path.join(artifact_dir, f"{epoch}.pth"), os.path.join(artifact_dir, f"{run_name}.pth"))
    print("Artifact downloaded to: ", artifact_dir)
else:
    print("Artifact already downloaded")
    artifact_dir = os.path.dirname(local_path)
    
model = load_GenreGlobalDensityWithVoiceMutesVAE_model(os.path.join(artifact_dir, f"{run_name}.pth"))

import datetime

#model.predict(torch.randn(1, 32, 3))
model.serialize(save_folder=f"{run_name}", filename=f"Gen_{run_name}_{epoch}_serialized__{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")


Downloading artifact
Artifact downloaded to:  ./artifacts/model_epoch_380:v1


# Run Inference

In [7]:
def plot_and_synthesize(hvo_seq_sample):
  """ Plots the piano roll of the sequence stored in the hvo_sequence object
  and also returns the synthesized pattern
  """
  hvo_seq_sample.piano_roll(show_figure=True)
  audio = hvo_seq_sample.synthesize(
      sf_path="hvo_sequence/soundfonts/Standard_Drum_Kit.sf2")
  return audio

In [8]:
sample = dataset[0]

model.predict(
    flat_hvo_groove=sample[0].unsqueeze(0),
    genre_tags=sample[4].unsqueeze(0),
    global_density_bins=torch.tensor([3]),#sample[6].unsqueeze(0),
    kick_is_muted=sample[13].unsqueeze(0),
    snare_is_muted=sample[14].unsqueeze(0),
    hat_is_muted=sample[15].unsqueeze(0),
    tom_is_muted=sample[16].unsqueeze(0),
    cymbal_is_muted=sample[17].unsqueeze(0),
)

(tensor([[[ 1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.2379e-01,
            1.0229e-01,  1.3443e-01,  6.2065e-03,  1.5713e-04,  1.7626e-03,
            9.0005e-04,  2.9777e-03,  7.1008e-02,  5.3233e-02,  7.0303e-04,
            4.0279e-02, -3.9590e-03,  9.9714e-03,  4.0711e-03,  4.6029e-03,
            1.9124e-04,  1.2089e-02],
          [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  2.6986e-04,
            3.5132e-01,  1.4147e-04,  2.4327e-06,  8.9283e-07,  2.3772e-05,
            2.1324e-06,  7.1794e-08,  9.4662e-05, -5.9525e-03,  5.9917e-02,
           -1.2101e-02, -1.6723e-03,  3.2673e-03,  1.0540e-03,  3.8274e-03,
            2.8171e-03, -3.5385e-03],
          [ 0.0000e+00,  1.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  6.6277e-04,
            

In [9]:
dataset.global_density_bins
from collections import Counter
Counter(dataset.global_density_bins.numpy())

Counter({8: 535,
         4: 444,
         7: 489,
         3: 650,
         2: 326,
         0: 492,
         6: 491,
         5: 506,
         1: 527,
         9: 497})

In [17]:
import IPython.display as ipd

# place in bokeh tabs
from bokeh.models import Panel, Tabs
from bokeh.io import output_file, show

idx = np.random.randint(0, len(dataset))

sample = dataset[idx]

model.eval()

print("density =", sample[6], "genre =", dataset.genre_tags[sample[4]])
# TEACHER FORCING
hvo_pred, _ = model.predict(
    flat_hvo_groove=sample[0].unsqueeze(0),
    genre_tags=sample[4].unsqueeze(0),
    global_density_bins=sample[6].unsqueeze(0),
    kick_is_muted=sample[13].unsqueeze(0),
    snare_is_muted=sample[14].unsqueeze(0),
    hat_is_muted=sample[15].unsqueeze(0),
    tom_is_muted=sample[16].unsqueeze(0),
    cymbal_is_muted=sample[17].unsqueeze(0))


target_hvo_seq = dataset.hvo_sequences[idx]
predicted_hvo_seq = dataset.hvo_sequences[idx].copy_empty()

predicted_hvo_seq.hvo = hvo_pred[0, :, :].squeeze().detach().cpu().numpy()

# get_plots
pr_true = target_hvo_seq.piano_roll(width=600, height=300)
pr_pred = predicted_hvo_seq.piano_roll(width=600, height=300)

# HStack the plots 
from bokeh.layouts import row

fig = row(pr_true, pr_pred)


import IPython
IPython.display.display(ipd.Audio(target_hvo_seq.synthesize(sf_path="hvo_sequence/soundfonts/Standard_Drum_Kit.sf2"), rate=44100), ipd.Audio(predicted_hvo_seq.synthesize(sf_path="hvo_sequence/soundfonts/Standard_Drum_Kit.sf2"), rate=44100))
show(fig)



density = tensor(9) genre = Hip-Hop/R&B/Soul


fluidsynth: error: Unknown integer parameter 'synth.sample-rate'
fluidsynth: error: Unknown integer parameter 'synth.sample-rate'


"# Generate Random Styles

In [11]:
latent_dim = model.config["latent_dim"]
latent_z = torch.randn(1, latent_dim)
print("Generated latent_z: ")
print("-"*50)
print("re-run this cell to generate a new latent_z")
import ipywidgets as widgets
from ipywidgets import HBox, VBox
import IPython.display as ipd
from bokeh.embed import file_html
from bokeh.resources import CDN


Generated latent_z: 
--------------------------------------------------
re-run this cell to generate a new latent_z


In [20]:
import ipywidgets as widgets
from ipywidgets import interact

# Assuming 'dataset' and 'model' are defined elsewhere, along with other required imports

global prev_randomize
prev_randomize = False

@interact(
    genre=dataset.genre_tags,
    global_density=(0, 9),
    kick_is_muted=widgets.Checkbox(value=False, description='Mute Kick'),
    snare_is_muted=widgets.Checkbox(value=False, description='Mute Snare'),
    hat_is_muted=widgets.Checkbox(value=False, description='Mute Hat'),
    tom_is_muted=widgets.Checkbox(value=False, description='Mute Tom'),
    cymbal_is_muted=widgets.Checkbox(value=False, description='Mute Cymbal'),
    randomize=widgets.Checkbox(value=False, description='Press to randomize'))
def generate(genre, global_density, kick_is_muted, snare_is_muted, hat_is_muted, tom_is_muted, cymbal_is_muted, randomize):
    global prev_randomize, latent_z
    if randomize != prev_randomize and randomize==True:
        latent_z = torch.randn(1, latent_dim)
    else:
        prev_randomize = randomize
        
    genre_ix = dataset.genre_tags.index(genre)
    genre_ix = torch.tensor([genre_ix], dtype=torch.long)
    
    # Convert binary mute selections to tensors
    mutes = torch.tensor([
        kick_is_muted,
        snare_is_muted,
        hat_is_muted,
        tom_is_muted,
        cymbal_is_muted
    ]).long()
    
    # decode
    h, v, o = model.sample(
        latent_z = latent_z,
        genre= genre_ix,
        global_density_bins=torch.tensor([global_density]),
        kick_is_muted=torch.tensor([mutes[0]]),
        snare_is_muted=torch.tensor([mutes[1]]),
        hat_is_muted=torch.tensor([mutes[2]]),
        tom_is_muted=torch.tensor([mutes[3]]),
        cymbal_is_muted=torch.tensor([mutes[4]]),
        voice_thresholds=torch.tensor([0.5] * 9),
        voice_max_count_allowed=torch.tensor([32] * 9),
        sampling_mode=0
    )
    
    hvo = torch.cat([h, v, o], dim=2)
    
    hvo_seq = dataset.hvo_sequences[0].copy_empty()
    hvo_seq.hvo = hvo[0, :, :].squeeze().detach().cpu().numpy()
    
    pr = hvo_seq.piano_roll(width=700, height=300)
    # convert bokeh figure to html
    
    html = file_html(pr, CDN, "my plot")
    
    audio = hvo_seq.synthesize(sf_path="hvo_sequence/soundfonts/Standard_Drum_Kit.sf2")
    IPython.display.display(ipd.Audio(audio, rate=44100))
    
    # show html
    IPython.display.display(IPython.display.HTML(html))


interactive(children=(Dropdown(description='genre', options=('Afro', 'Blues', 'Disco', 'Funk', 'Hip-Hop/R&B/So…

In [None]:
import ipywidgets as widgets
from IPython.display import display

# Create 128 vertical sliders without labels, with adjusted height
sliders = [widgets.IntSlider(
    orientation='vertical', 
    readout=False, 
    layout=widgets.Layout(height='100px', width='20px')  # Adjust height and width as necessary
) for _ in range(128)]

# Use a Box to arrange the sliders in a row, allowing wrapping for tight packing
slider_box = widgets.Box(sliders, layout=widgets.Layout(
    flex_flow='row wrap',
    align_items='stretch',
    width='100%',  # Adjust the width to fit the container or desired layout
))

# Display the slider layout
display(slider_box)

In [63]:
import torch
import ipywidgets as widgets
from ipywidgets import interact_manual, VBox, HBox
import numpy as np
import IPython.display as ipd
from IPython.display import display

# Assuming 'model', 'dataset', and other necessary components are defined elsewhere

latent_dim = 128  # Set this to match your model's latent dimension
latent_z = torch.randn(1, latent_dim)  # Initialize latent_z

# Create sliders
sliders = [widgets.FloatSlider(min=-3, max=3, value=0, step=0.01, orientation='vertical', readout=False, layout={'width': '20px', 'height': '75px'}) for _ in range(latent_dim)]

# initialize sliders with latent_z values
for i, slider in enumerate(sliders):
    slider.value = latent_z[0, i].item()
    
# Function to update latent_z based on sliders' values
def update_latent_z(change):
    global latent_z
    latent_z[:] = torch.tensor([[slider.value for slider in sliders]], dtype=torch.float)

# Attach update function to sliders
for slider in sliders:
    slider.observe(update_latent_z, names='value')

# Group sliders into horizontal boxes
num_sliders_per_row = 64  # Adjust this number based on your display preferences
hboxes = [HBox(sliders[i:i+num_sliders_per_row]) for i in range(0, latent_dim, num_sliders_per_row)]

# Display the grouped sliders
vbox = VBox(hboxes)
display(vbox)

# Define the generate function to use the updated latent_z
def generate(genre, global_density, kick_is_muted, snare_is_muted, hat_is_muted, tom_is_muted, cymbal_is_muted):
    global latent_z
    
    genre_ix = dataset.genre_tags.index(genre)
    genre_ix = torch.tensor([genre_ix], dtype=torch.long)
    
    # Convert binary mute selections to tensors
    mutes = torch.tensor([
        kick_is_muted,
        snare_is_muted,
        hat_is_muted,
        tom_is_muted,
        cymbal_is_muted
    ]).long()
    
    # Sample from the model using the updated latent_z
    # Note: Implement the model's sampling logic here
    # decode
    h, v, o = model.sample(
        latent_z = latent_z,
        genre= genre_ix,
        global_density_bins=torch.tensor([global_density]),
        kick_is_muted=torch.tensor([mutes[0]]),
        snare_is_muted=torch.tensor([mutes[1]]),
        hat_is_muted=torch.tensor([mutes[2]]),
        tom_is_muted=torch.tensor([mutes[3]]),
        cymbal_is_muted=torch.tensor([mutes[4]]),
        voice_thresholds=torch.tensor([0.5] * 9),
        voice_max_count_allowed=torch.tensor([32] * 9),
        sampling_mode=0
    )
    
    hvo = torch.cat([h, v, o], dim=2)
    
    hvo_seq = dataset.hvo_sequences[0].copy_empty()
    hvo_seq.hvo = hvo[0, :, :].squeeze().detach().cpu().numpy()
    
    pr = hvo_seq.piano_roll(width=700, height=300)
    # convert bokeh figure to html
    
    html = file_html(pr, CDN, "my plot")
    
    audio = hvo_seq.synthesize(sf_path="hvo_sequence/soundfonts/Standard_Drum_Kit.sf2")
    IPython.display.display(ipd.Audio(audio, rate=44100))
    
    # show html
    IPython.display.display(IPython.display.HTML(html))
    
# Setup interactive widgets for the generate function
interact_manual(generate,
    genre=dataset.genre_tags,
    global_density=(0, 9),
    kick_is_muted=widgets.Checkbox(value=False),
    snare_is_muted=widgets.Checkbox(value=False),
    hat_is_muted=widgets.Checkbox(value=False),
    tom_is_muted=widgets.Checkbox(value=False),
    cymbal_is_muted=widgets.Checkbox(value=False)
)


VBox(children=(HBox(children=(FloatSlider(value=0.43832647800445557, layout=Layout(height='75px', width='20px'…

interactive(children=(Dropdown(description='genre', options=('Afro', 'Blues', 'Disco', 'Funk', 'Hip-Hop/R&B/So…

<function __main__.generate(genre, global_density, kick_is_muted, snare_is_muted, hat_is_muted, tom_is_muted, cymbal_is_muted)>