In [1]:
import wandb
import os
os.environ["WANDB_SILENT"] = "true"

import numpy as np
import pandas as pd

import sys
sys.path.insert(0, "../..")

from helpers import load_variational_mgt_model
from model import GrooveTransformerEncoderVAE
import torch

from umap import UMAP

from bokeh.palettes import inferno, Category20b
from bokeh.core.enums import MarkerType
from bokeh.plotting import figure, show, save
from bokeh.io import output_notebook
output_notebook()

run_name = "noble-field-7"


Holoviews not installed, will not be able to generate violin plots




In [2]:
run = wandb.init()
links = {
    "GOOD_AVERAGE_glamorous-sweep-62": "mmil_vae_g2d/voice_distribution_and_genre_distribution_imbalance/model_epoch_100:v283",
    "GOOD_azure-sweep-54": "mmil_vae_g2d/voice_distribution_and_genre_distribution_imbalance/model_epoch_100:v279",
    "GOOD_apricot-sweep-17": "mmil_vae_g2d/voice_distribution_and_genre_distribution_imbalance/model_epoch_100:v242",
    "GOOD_hearty-sweep-60": "mmil_vae_g2d/voice_distribution_and_genre_distribution_imbalance/model_epoch_100:v280",
    "GOOD_worldly-sweep-22": "mmil_vae_g2d/voice_distribution_and_genre_distribution_imbalance/model_epoch_100:v245",
    "GOOD_legendary-sweep-5": "mmil_vae_g2d/voice_distribution_and_genre_distribution_imbalance/model_epoch_100:v230",
    "drawn_river_6": "mmil_vae_g2d/beta_annealing_study/model_epoch_100:v2",
    "worldly-firebrand-5": "mmil_vae_g2d/beta_annealing_study/model_epoch_100:v1",
    "noble-field-7": "mmil_vae_g2d/beta_annealing_study/model_epoch_100:v3",
    "young-violet-12": "mmil_vae_g2d/beta_annealing_study/model_epoch_200:v0",
    "kind-gorge-14": "mmil_vae_g2d/beta_annealing_study/model_epoch_500:v1"
}
artifact = run.use_artifact(links[run_name], type='model')
artifact_dir = artifact.download()
model = load_variational_mgt_model(os.path.join(artifact_dir, "100.pth"))

Offset activation is sigmoid, bias is initialized to 0.5


In [3]:
from data import load_gmd_hvo_sequences

train_set = load_gmd_hvo_sequences(
    dataset_setting_json_path = "../../data/dataset_json_settings/4_4_Beats_gmd.json", 
    subset_tag = "test", 
    force_regenerate=False)

INFO:data.Base.dataLoaders:Loading gmd dataset
INFO:data.Base.dataLoaders:Loading Cached Version from: data/gmd/resources/cached/beat_division_factor_[4]/drum_mapping_label_ROLAND_REDUCED_MAPPING/beat_type_['beat']_time_signature_['4-4']


In [4]:
gt_sample = train_set[0]
groove = torch.tensor([gt_sample.flatten_voices(reduce_dim=True)], dtype=torch.float32)
gt_sample.metadata

{'Source': 'Groove MIDI Dataset',
 'drummer': 'drummer1',
 'session': 'eval_session',
 'loop_id': 'drummer1/eval_session/10:000',
 'master_id': 'drummer1/eval_session/10',
 'style_primary': 'soul',
 'style_secondary': 'groove10',
 'bpm': '102',
 'beat_type': 'beat',
 'time_signature': '4-4',
 'full_midi_filename': 'drummer1/eval_session/10_soul-groove10_102_beat_4-4.mid',
 'full_audio_filename': 'drummer1/eval_session/10_soul-groove10_102_beat_4-4.wav'}

In [5]:
mu, logvar = model.encode_to_mu_logvar(groove)
latent_z = model.reparametrize(mu, logvar)
latent_z

tensor([[ 0.0621, -1.2016,  0.3151,  0.3732,  0.3921, -0.4903, -0.9658,  0.0089,
         -0.9982, -1.6537, -1.8685,  0.5595, -2.2514,  0.0536,  1.9882, -0.2437,
          0.4604,  1.6035, -1.8894, -0.9298,  0.3221,  0.4742, -0.1683,  1.1657,
          1.0411, -1.0500,  1.0261, -0.0964, -0.6348, -0.3388,  0.5235,  0.9530,
          1.1004,  0.5260,  0.4467, -0.2389,  1.1592, -0.9395, -0.4795,  1.1276,
          1.1287, -0.2342,  1.1515,  0.8847, -0.8411, -1.9223, -1.3171,  0.7731,
          0.2403, -0.0706,  0.3758, -0.6757,  0.1247,  1.4370, -0.3001,  0.8976,
          1.4282,  1.0526,  0.3842, -1.0882,  0.9604,  1.7050,  1.1167,  0.5949,
          0.6746, -0.9384,  1.7821, -0.4862,  0.5111,  0.2560, -1.3443,  0.8826,
          1.8227, -0.2655, -0.4700, -1.0413,  1.9968,  0.4640,  0.7817,  0.0798,
          1.9723, -2.4854,  0.4906, -0.4445, -0.9001,  0.5344, -0.8370,  0.6113,
          1.3542, -0.5372, -0.7828,  1.1344, -0.0630, -0.9328, -0.3236,  1.9297,
         -0.2539,  0.6984,  

In [6]:
voice_thresholds = [0.5] * 9
voice_max_count_allowed = [32] * 9
h, v, o = model.sample(latent_z=latent_z,
                       voice_thresholds=voice_thresholds,
                       voice_max_count_allowed=voice_max_count_allowed,
                       return_concatenated=False,
                       sampling_mode=0)
print(h, v, o)

tensor([[[1., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 1., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0.],
         [1., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 1., 0., 0., 0., 0., 0., 0.],
         [0.,

In [7]:
latents = []
labels = []
metadatas = []
use_all_styles = True

for gt_sample in train_set:
    if (use_all_styles or (gt_sample.metadata["style_primary"] in ["rock", "funk", "afrobeat"])):
        metadatas.append(gt_sample.metadata)
        labels.append(gt_sample.metadata["style_primary"])
        groove = torch.zeros((1, 32, 3))
        flattened_ = torch.tensor([gt_sample.flatten_voices(reduce_dim=True)], dtype=torch.float32)[:,:32, :]
        t_steps = flattened_.shape[1]
        groove[:, :t_steps, :] = flattened_
        mu, logvar = model.encode_to_mu_logvar(groove)
        latent_z = model.reparametrize(mu, logvar)
        latents.append(latent_z.detach().cpu().numpy())

latents = np.array(latents).squeeze(1)
features = np.expand_dims(latents, -1) # we use each dimension of latent_z as a feature
feature_labels = [f"z_{dim}" for dim in range(features.shape[1])]

In [8]:
data = {"style_primary": [style for style in labels]}
data.update({f"z_{dim_i}": latents[:, dim_i] for dim_i in range(features.shape[1])})
df = pd.DataFrame(data)
df.head()

Unnamed: 0,style_primary,z_0,z_1,z_2,z_3,z_4,z_5,z_6,z_7,z_8,...,z_118,z_119,z_120,z_121,z_122,z_123,z_124,z_125,z_126,z_127
0,soul,-1.504023,-0.746362,0.216899,0.568736,-0.138677,-0.335858,-0.111339,-0.424367,-0.09328,...,-0.523937,0.629815,0.334195,1.699931,0.532024,-0.281949,-0.894541,0.147671,2.022294,-2.52062
1,soul,-0.865366,-1.307525,-0.744898,0.456623,0.619358,-0.690655,-0.954481,-0.289283,0.148741,...,1.145694,-0.717326,0.549443,0.653769,0.422832,0.149126,-0.889743,1.427327,1.79339,-0.784249
2,soul,-1.437173,-0.967061,-0.003054,1.953855,-0.410941,1.707905,1.049791,0.165301,1.359331,...,-0.780827,-1.866881,0.713646,1.610979,0.649926,-0.740276,-0.844334,1.181594,1.554497,-1.585975
3,soul,-1.041457,-0.633944,-0.426487,0.272681,0.125262,-0.976435,-1.26461,-0.446593,0.366538,...,-0.398414,-0.872021,0.45183,0.451666,-1.574938,0.676362,-0.586775,1.320147,1.437632,-1.047327
4,soul,0.635481,-0.645181,-0.623246,-1.592751,-0.371392,-0.053667,-0.001898,0.142487,-0.488362,...,0.511045,0.514552,-0.049322,1.69243,-0.639779,-0.26452,-0.909341,0.693686,1.711513,-1.058621


In [9]:
embedding_dims = 2
umap = UMAP(n_components=embedding_dims, metric="euclidean", n_neighbors=20)
embeddings = umap.fit_transform(df.drop("style_primary", axis=1))

In [10]:



styles = sorted(df.style_primary.unique())
hues = inferno(len(styles))
hues = Category20b[20]
style_hue_map = {f"{style}": hues[i] for i, style in enumerate(styles)}

bases = [
    "circle",
    "diamond",
    "inverted_triangle",
    "square",
    "star",
    "triangle",
] * 4

marker_map = {style: bases[ix] for ix, style in enumerate(styles)}

In [11]:



embedding_dict = dict()
styles = df.style_primary


for ix, style in enumerate(styles):
    if(not style in embedding_dict):
        embedding_dict.update({style: ([], [])})
    embedding_dict[style][0].append(embeddings[ix][0])
    embedding_dict[style][1].append(embeddings[ix][1])
    

TOOLS="hover,crosshair,pan,wheel_zoom,zoom_in,zoom_out,box_zoom,undo,redo,reset,tap,save,box_select,poly_select,lasso_select,"

p = figure(tools=TOOLS)

for key, values in embedding_dict.items():
    p.scatter(values[0], values[1],
              fill_color=style_hue_map[key], 
              line_color=None, legend_label=key, marker=marker_map[key], size=12)
p.legend.click_policy="hide"
show(p)

In [12]:
save(p, f"UMAP_{run_name}.html")

embedding_dict

{'soul': ([9.204173,
   10.022699,
   8.999604,
   9.0195675,
   9.514824,
   9.554974,
   9.872127,
   9.2568,
   9.436499,
   8.786773,
   8.58914,
   9.403558,
   9.20988,
   9.26847,
   9.23711,
   7.6388254,
   8.188692,
   8.26136,
   8.132684,
   8.012912,
   7.950381,
   7.9358406,
   7.923902,
   8.146721,
   8.11838,
   8.058627,
   7.8569765,
   8.105436,
   7.34623,
   8.030034,
   9.755134,
   9.848292,
   9.297227,
   9.893126,
   9.632999,
   9.721267,
   9.565367,
   9.880003,
   9.456684,
   9.904676,
   9.700471,
   9.87227,
   9.608926,
   10.013646,
   9.730095,
   10.339525,
   10.284109,
   10.309364,
   10.307837,
   10.117657,
   10.248244,
   10.036124,
   10.387235,
   10.437388,
   10.256025,
   10.363571,
   10.358503,
   10.208483,
   10.222732,
   10.499491,
   9.757482,
   9.759502,
   9.900804,
   9.628125,
   9.712936,
   9.610815,
   9.090016,
   9.690943,
   9.795812,
   9.71144,
   9.425017,
   9.514762,
   9.425902,
   9.510837,
   9.512904,
   7.86

In [13]:
# from sklearn.manifold import TSNE
# umap = TSNE(n_components=embedding_dims)
# embeddings = umap.fit_transform(df.drop("style_primary", axis=1))

In [14]:
# import numpy as np

# from bokeh.plotting import figure, show, save
# from bokeh.io import output_notebook

# TOOLS="hover,crosshair,pan,wheel_zoom,zoom_in,zoom_out,box_zoom,undo,redo,reset,tap,save,box_select,poly_select,lasso_select,"

# p = figure(tools=TOOLS)

# p.scatter(embeddings[:, 0], embeddings[:, 1],
         #  fill_color=colors, fill_alpha=0.6,
          # line_color=None)

# save(p, "TSNE.html") 