In [1]:
import wandb
import os
os.environ["WANDB_SILENT"] = "true"

import numpy as np

%cd ..

from ISMIR_2024_EVALs import load_model
from model import BaseVAE, MuteVAE, MuteGenreLatentVAE, MuteLatentGenreInputVAE, GenreClassifier
from data.src.dataLoaders import Groove2Drum2BarDataset

import torch


In [2]:

down_sampled_ratio=None
# load dataset as torch.utils.data.Dataset
dataset = Groove2Drum2BarDataset(
    dataset_setting_json_path="data/dataset_json_settings/Balanced_6000_performed.json",
    subset_tag="test",
    max_len=32,
    tapped_voice_idx=2,
    collapse_tapped_sequence=True,
    num_voice_density_bins=3,
    num_tempo_bins=6,
    num_global_density_bins=7,
    augment_dataset=False,
    force_regenerate=False
)


In [3]:

model_classifier = load_model("./trained_models/genre_classifier.pth", GenreClassifier)
# model_BaseVAE_0_2 = load_model("./trained_models/base_vae_beta_0_2.pth", BaseVAE)
# model_BaseVAE_0_5 = load_model("./trained_models/base_vae_beta_0_5.pth", BaseVAE)
# model_BaseVAE_1_0 = load_model("./trained_models/base_vae_beta_1_0.pth", BaseVAE)
# model_MuteVAE_0_2 = load_model("./trained_models/mute_vae_beta_0_2.pth", MuteVAE)
# model_MuteVAE_0_5 = load_model("./trained_models/mute_vae_beta_0_5.pth", MuteVAE)
# model_MuteVAE_1_0 = load_model("./trained_models/mute_vae_beta_1_0.pth", MuteVAE)
# model_MuteGenreLatentVAE_0_2 = load_model("./trained_models/mute_genre_latent_vae_beta_0_2.pth", MuteGenreLatentVAE)
model_MuteGenreLatentVAE_0_5 = load_model("./trained_models/mute_genre_latent_vae_beta_0_5.pth", MuteGenreLatentVAE)
# model_MuteGenreLatentVAE_1_0 = load_model("./trained_models/mute_genre_latent_vae_beta_1_0.pth", MuteGenreLatentVAE)
# model_MuteLatentGenreInputVAE_0_2 = load_model("./trained_models/mute_latent_genre_input_vae_beta_0_2.pth", MuteLatentGenreInputVAE)
# model_MuteLatentGenreInputVAE_0_5 = load_model("./trained_models/mute_latent_genre_input_vae_beta_0_5.pth", MuteLatentGenreInputVAE)
# model_MuteLatentGenreInputVAE_1_0 = load_model("./trained_models/mute_latent_genre_input_vae_beta_1_0.pth", MuteLatentGenreInputVAE)


"# Generate Random Styles

In [4]:
latent_dim = 128
latent_z_A = torch.randn(1, latent_dim)
latent_z_B = torch.randn(1, latent_dim)
print("Generated latent_z: ")
print("-"*50)
print("re-run this cell to generate a new latent_z")
import ipywidgets as widgets
from ipywidgets import HBox, VBox
import IPython.display as ipd
from bokeh.embed import file_html
from bokeh.resources import CDN


In [5]:
import matplotlib.pyplot as plt
from matplotlib import gridspec


def interp_between_two_rand_samples(model_, n_interp, sample1_i, sample2_i, genre_ix=None, genre_classifier_model=None):
    
    sample1 = dataset[sample1_i]
    sample2 = dataset[sample2_i]
    
    if genre_ix is None:
        genre_tag = sample1[4].unsqueeze(0)
    else:
        genre_tag = torch.tensor([genre_ix])
    
    # kick_is_muted, snare_is_muted, hat_is_muted, tom_is_muted, cymbal_is_muted
    kick_is_muted = torch.tensor([0])
    snare_is_muted = torch.tensor([0])
    hat_is_muted = torch.tensor([0])
    tom_is_muted = torch.tensor([0])
    cymbal_is_muted = torch.tensor([0])
    
    if isinstance(model_, BaseVAE):
        # get the latent_z for the two samples
        _, latent_z1 = model_.predict(
            flat_hvo_groove=sample1[0].unsqueeze(0)
        )
        
        _, latent_z2 = model_.predict(
            flat_hvo_groove=sample2[0].unsqueeze(0)
        )
    elif isinstance(model_, MuteVAE):
        
        # get the latent_z for the two samples
        _, latent_z1 = model_.predict(
            flat_hvo_groove=sample1[0].unsqueeze(0),
            kick_is_muted=kick_is_muted,
            snare_is_muted=snare_is_muted,
            hat_is_muted=hat_is_muted,
            tom_is_muted=tom_is_muted,
            cymbal_is_muted=cymbal_is_muted
        )

        _, latent_z2 = model_.predict(
            flat_hvo_groove=sample2[0].unsqueeze(0),
            kick_is_muted=kick_is_muted,
            snare_is_muted=snare_is_muted,
            hat_is_muted=hat_is_muted,
            tom_is_muted=tom_is_muted,
            cymbal_is_muted=cymbal_is_muted
        )
    elif isinstance(model_, MuteGenreLatentVAE):
        _, latent_z1 = model_.predict(
            flat_hvo_groove=sample1[0].unsqueeze(0),
            genre_tags=genre_tag,
            kick_is_muted=kick_is_muted,
            snare_is_muted=snare_is_muted,
            hat_is_muted=hat_is_muted,
            tom_is_muted=tom_is_muted,
            cymbal_is_muted=cymbal_is_muted
        )
        
        _, latent_z2 = model_.predict(
            flat_hvo_groove=sample2[0].unsqueeze(0),
            genre_tags=genre_tag,
            kick_is_muted=kick_is_muted,
            snare_is_muted=snare_is_muted,
            hat_is_muted=hat_is_muted,
            tom_is_muted=tom_is_muted,
            cymbal_is_muted=cymbal_is_muted
        )
    elif isinstance(model_, MuteLatentGenreInputVAE):
        _, latent_z1 = model_.predict(
            flat_hvo_groove=sample1[0].unsqueeze(0),
            genre_tags=genre_tag,
            kick_is_muted=kick_is_muted,
            snare_is_muted=snare_is_muted,
            hat_is_muted=hat_is_muted,
            tom_is_muted=tom_is_muted,
            cymbal_is_muted=cymbal_is_muted
        )
        
        _, latent_z2 = model_.predict(
            flat_hvo_groove=sample2[0].unsqueeze(0),
            genre_tags=genre_tag,
            kick_is_muted=kick_is_muted,
            snare_is_muted=snare_is_muted,
            hat_is_muted=hat_is_muted,
            tom_is_muted=tom_is_muted,
            cymbal_is_muted=cymbal_is_muted
        )
        
    else:
        raise ValueError("Model not supported")
    
    genre_probs, genre_preds = None, None   
    
    if genre_classifier_model is not None:
        genre_probs = []
        genre_preds = []
        
    z_step = (latent_z2 - latent_z1) / (n_interp + 1)
    intermediate_latents = []
    
    # decode
    z_s = [latent_z1]
    for i in range(1, n_interp + 1):
        latent_z = latent_z1 + z_step * i
        z_s.append(latent_z)
    z_s.append(latent_z2)
    
    hvo_seqs = []
    
    for z in z_s:
        if isinstance(model_, BaseVAE):
            h, v, o = model_.sample(
                latent_z = z,
                voice_thresholds=torch.tensor([0.5] * 9),
                voice_max_count_allowed=torch.tensor([32] * 9),
            )
        elif isinstance(model_, MuteVAE):
            h, v, o = model_.sample(
                latent_z = z,
                kick_is_muted=kick_is_muted,
                snare_is_muted=snare_is_muted,
                hat_is_muted=hat_is_muted,
                tom_is_muted=tom_is_muted,
                cymbal_is_muted=cymbal_is_muted,
                voice_thresholds=torch.tensor([0.5] * 9),
                voice_max_count_allowed=torch.tensor([32] * 9),
                sampling_mode=0
            )
        elif isinstance(model_, MuteGenreLatentVAE):
            h, v, o = model_.sample(
                latent_z = z,
                genre=genre_tag,
                kick_is_muted=kick_is_muted,
                snare_is_muted=snare_is_muted,
                hat_is_muted=hat_is_muted,
                tom_is_muted=tom_is_muted,
                cymbal_is_muted=cymbal_is_muted,
                voice_thresholds=torch.tensor([0.5] * 9),
                voice_max_count_allowed=torch.tensor([32] * 9),
                sampling_mode=0
            )
        elif isinstance(model_, MuteLatentGenreInputVAE):
            h, v, o = model_.sample(
                latent_z = z,
                kick_is_muted=kick_is_muted,
                snare_is_muted=snare_is_muted,
                hat_is_muted=hat_is_muted,
                tom_is_muted=tom_is_muted,
                cymbal_is_muted=cymbal_is_muted,
                voice_thresholds=torch.tensor([0.5] * 9),
                voice_max_count_allowed=torch.tensor([32] * 9),
                sampling_mode=0
            )
        else:
            raise ValueError("Model not supported")
            
        hvo = torch.cat([h, v, o], dim=2)
        
        hvo_seq = dataset.hvo_sequences[sample1_i].copy_empty()
        hvo_seq.hvo = hvo[0, :, :].squeeze().detach().cpu().numpy()
        hvo_seqs.append(hvo_seq)
        
        if genre_classifier_model is not None:
            genre_classifier_model.eval()
            with torch.no_grad():
                ix, prob = genre_classifier_model.predict(hvo)
                genre_probs.append(prob)
                genre_preds.append(ix)
                
    # append HVOS
    hvo_seq_all = None
    
    for ix, h_s in enumerate(hvo_seqs):
        if ix == 0:
            hvo_seq_all = h_s
        else:
            hvo_seq_all = hvo_seq_all + h_s
    
    if genre_classifier_model is not None:
        genre_preds = [p.item() for p in genre_preds]

        probs = []
        for g_p in genre_probs:
            probs.append([p.numpy().tolist() for p in g_p] )
        
        
    return hvo_seq_all, hvo_seqs, probs, genre_preds


def interp_between_two_rand_samples_with_controls(model_, n_interp, sample1_i, sample2_i, genre_classifier_model=None, mutes1=None, mutes2=None):
    
    sample1 = dataset[sample1_i]
    sample2 = dataset[sample2_i]
    
    genre1 = sample1[4].unsqueeze(0)
    genre2 = sample2[4].unsqueeze(0)
    
    mutes1 = torch.tensor([sample1[8], sample1[9], sample1[10], sample1[11], sample1[12]]) if mutes1 is None else mutes1
    mutes2 = torch.tensor([sample2[8], sample2[9], sample2[10], sample2[11], sample2[12]]) if mutes2 is None else mutes2
    
    assert isinstance(model_, MuteGenreLatentVAE) or isinstance(model_, MuteVAE), "Model must be an instance of MuteGenreLatentVAE or MuteVAE"
    if isinstance(model_, MuteVAE):
        
        # get the latent_z for the two samples
        _, latent_z1 = model_.predict(
            flat_hvo_groove=sample1[0].unsqueeze(0),
            kick_is_muted=mutes1[0],
            snare_is_muted=mutes1[1],
            hat_is_muted=mutes1[2],
            tom_is_muted=mutes1[3],
            cymbal_is_muted=mutes1[4]
        )

        _, latent_z2 = model_.predict(
            flat_hvo_groove=sample2[0].unsqueeze(0),
            kick_is_muted=mutes2[0],
            snare_is_muted=mutes2[1],
            hat_is_muted=mutes2[2],
            tom_is_muted=mutes2[3],
            cymbal_is_muted=mutes2[4]
        )
        
    else:
        raise ValueError("Model must be an instance of BaseVAE or MuteVAE")
    
    genre_probs, genre_preds = None, None   
    
    if genre_classifier_model is not None:
        genre_probs = []
        genre_preds = []
        
    z_step = (latent_z2 - latent_z1) / (n_interp + 1)
    intermediate_latents = []
    
    # decode
    hvo_seqs = []
    
    for i in range(0, n_interp + 2):
        if isinstance(model_, MuteVAE):
            h, v, o = model_.interpolate_2d(
                interpolation_factor=i / (n_interp + 1),
                latent_z_1=latent_z1,
                mute_controls_1=mutes1,
                latent_z_2=latent_z2,
                mute_controls_2=mutes2,
                voice_thresholds=torch.tensor([0.5] * 9),
                voice_max_count_allowed=torch.tensor([32] * 9),
                sampling_mode=0
            )
        
            
        hvo = torch.cat([h, v, o], dim=2)
        
        hvo_seq = dataset.hvo_sequences[sample1_i].copy_empty()
        hvo_seq.hvo = hvo[0, :, :].squeeze().detach().cpu().numpy()
        hvo_seqs.append(hvo_seq)
        
        if genre_classifier_model is not None:
            genre_classifier_model.eval()
            with torch.no_grad():
                ix, prob = genre_classifier_model.predict(hvo)
                genre_probs.append(prob)
                genre_preds.append(ix)
                
    # append HVOS
    hvo_seq_all = None
    
    for ix, h_s in enumerate(hvo_seqs):
        if ix == 0:
            hvo_seq_all = h_s
        else:
            hvo_seq_all = hvo_seq_all + h_s
    
    if genre_classifier_model is not None:
        genre_preds = [p.item() for p in genre_preds]

        probs = []
        for g_p in genre_probs:
            probs.append([p.numpy().tolist() for p in g_p] )
        
        
    return hvo_seq_all, hvo_seqs, probs, genre_preds



def calculate_area_under_piecewise_lines(x, y):
    """
    Calculate the area under the piecewise linear function defined by x and y
    """
    area = 0
    for i in range(1, len(x)):
        area += (x[i] - x[i-1]) * (y[i] + y[i-1]) / 2
    return area

# jaccard similarity
from sklearn.metrics import jaccard_score


def jaccard_distance(h1, h2):

    val = jaccard_score(h1.flatten(), h2.flatten(), average="binary")
    # 2 floating point precision
    if np.isnan(val):
        return 1
    else:
        return round(1 - val, 2)
    

def calculate_relative_jackards(hvo_seqs_, rhythm_flat=False):
    if rhythm_flat:
        hvo_seqs = [hvo_seq.copy() for hvo_seq in hvo_seqs_]
        for hvo_seq in hvo_seqs:
            hvo_seq.hvo = hvo_seq.flatten_voices()
    else:
        hvo_seqs = hvo_seqs_
        
    AB_Jaccard = jaccard_distance(hvo_seqs[0].hits, hvo_seqs[-1].hits)
    
    if AB_Jaccard == 0:
        return None, None, None
    
    relative_jaccards = []
    interp_labels = []
    is_ = ["A", r"$I_1$", r"$I_2$", r"$I_3$", r"$I_4$", r"$I_5$", r"$I_6$", r"$I_7$", r"$I_8$", r"$I_9$", r"$I_{10}$", "B"]

    for ix, hvo_s in enumerate(hvo_seqs):
        if len(hvo_s.hits) == len(hvo_seqs[0].hits):
            A_Here_Jaccard = jaccard_distance(hvo_s.hits, hvo_seqs[0].hits) / AB_Jaccard
        
        # normalize
        relative_jaccards.append(np.round(A_Here_Jaccard , 2))
        interp_labels.append(is_[ix])
    
    interp_labels[0] = "A"
    interp_labels[-1] = "B"
    
    interp_factors = np.linspace(0, 1, len(hvo_seqs))
    
    return interp_factors, relative_jaccards, interp_labels

from sklearn.metrics.pairwise import cosine_similarity

def cosine_sim_hvo(h1, h2):
    h1 = torch.tensor(h1)
    h2 = torch.tensor(h2)
    return cosine_similarity(h1.flatten().unsqueeze(0), h2.flatten().unsqueeze(0))

def calculate_relative_cosine_distance(hvo_seqs):
    AB_Cosine = 1 - cosine_sim_hvo(hvo_seqs[0].get("v"), hvo_seqs[-1].get("v"))
    
    relative_cosines = []
    interp_labels = []
    is_ = ["A", r"$I_1$", r"$I_2$", r"$I_3$", r"$I_4$", r"$I_5$", r"$I_6$", r"$I_7$", r"$I_8$", r"$I_9$", r"$I_{10}$", "B"]
    
    for ix, hvo_s in enumerate(hvo_seqs):
        A_Here_Cosine = (1 - cosine_sim_hvo(hvo_s.get("v"), hvo_seqs[0].get("v"))) / AB_Cosine
        # normalize
        relative_cosines.append(np.round(A_Here_Cosine.item() , 2))
        interp_labels.append(is_[ix])
    
    interp_labels[0] = "A"
    interp_labels[-1] = "B"
    
    interp_factors = np.linspace(0, 1, len(hvo_seqs))
    
    return interp_factors, relative_cosines, interp_labels

def calculate_vel_MSE_distance(hvo_seqs):
    AB_MSE = torch.nn.MSELoss(reduction='none')(torch.tensor(hvo_seqs[0].get("v")), torch.tensor(hvo_seqs[-1].get("v"))).sum()
    
    relative_MSEs = []
    interp_labels = []
    is_ = ["A", r"$I_1$", r"$I_2$", r"$I_3$", r"$I_4$", r"$I_5$", r"$I_6$", r"$I_7$", r"$I_8$", r"$I_9$", r"$I_{10}$", "B"]
    
    for ix, hvo_s in enumerate(hvo_seqs):
        A_Here_MSE = torch.nn.MSELoss(reduction='none')(torch.tensor(hvo_s.get("v")), torch.tensor(hvo_seqs[0].get("v"))).sum() / AB_MSE
        # normalize
        relative_MSEs.append(np.round(A_Here_MSE.item() , 2))
        interp_labels.append(is_[ix])
    
    interp_labels[0] = "A"
    interp_labels[-1] = "B"
    
    
    interp_factors = np.linspace(0, 1, len(hvo_seqs))
    
    return interp_factors, relative_MSEs, interp_labels

def plot_distance(interp_factors, relative_distances, interp_labels):
    
    plt.figure(figsize=(2, 2))
    plt.plot(interp_factors, relative_distances, 'o-')
    # draw a dashed line (45 degree line)
    plt.plot([interp_factors[0], interp_factors[-1]], [0, 1], 'k--')
    
    # fill between the lines
    plt.fill_between(interp_factors, interp_factors, relative_distances, alpha=0.1)
    
    # use the labels on x-axis
    plt.xticks(interp_factors, interp_labels)
    plt.yticks([*interp_factors, 1], np.round([*interp_factors, 1], 2))
    
    # rotate the xticks by 45 degrees
    plt.xticks(rotation=0)
    
def calculate_deviation_area(interp_factors, vals, normalize=False):
    """
    Calculate the difference between the area under the 45 degree line (i.e. perfect interpolation) and the area under the actual interpolation
    """
    ideal_area = calculate_area_under_piecewise_lines([interp_factors[0], interp_factors[-1]], [0, 1])
    actual_area = calculate_area_under_piecewise_lines(interp_factors, vals)
    
    if normalize:
        return (ideal_area - actual_area) / ideal_area
    else:
        return ideal_area - actual_area

def calculate_total_error(vals, normalize=True):
    """
       vals should be close to interp_factors
    """
    """
       vals should be close to interp_factors
    """
    # worst case scenario --> all values are 1
    target = np.linspace(0, 1, len(vals))
    most_error = np.sum(np.abs(target - 1))
    if normalize:
        return np.sum(np.abs(target - vals)) / most_error
    else:
        return np.sum(np.abs(target - vals))




In [6]:
import IPython.display
# heatmap of the genre probabilities
import seaborn as sns
import pandas as pd

n_interp = 10

hvo_seq_all, hvo_seqs, genre_probs, genre_preds = interp_between_two_rand_samples(model_MuteGenreLatentVAE_0_5, n_interp, int(np.random.randint(0, len(dataset))), int(np.random.randint(0, len(dataset))), genre_ix=None, genre_classifier_model=model_classifier)

# mutes1 = torch.tensor([1, 1, 0, 0, 0])
# mutes2 = torch.tensor([0, 0, 0, 0, 0])
# hvo_seq_all, hvo_seqs, genre_probs, genre_preds = interp_between_two_rand_samples_with_controls(model_MuteVAE_0_5, n_interp, int(np.random.randint(0, len(dataset))), int(np.random.randint(0, len(dataset))), genre_classifier_model=model_classifier, mutes1=None, mutes2=None)

interp_factors, relative_jaccards, interp_labels = calculate_relative_jackards(hvo_seqs)
interp_factors_cosine, relative_cosines, interp_labels_cosine = calculate_relative_cosine_distance(hvo_seqs)
interp_factors_MSE, relative_MSEs, interp_labels_MSE = calculate_vel_MSE_distance(hvo_seqs)

print(calculate_deviation_area(relative_jaccards, interp_factors, normalize=True), calculate_deviation_area(relative_cosines, interp_factors_cosine, normalize=True), calculate_deviation_area(relative_MSEs, interp_factors_MSE, normalize=True))
print(calculate_total_error(relative_jaccards), calculate_total_error(relative_cosines), calculate_total_error(relative_MSEs))



genre_probs = np.array(genre_probs)
genre_probs = genre_probs.squeeze()

genre_probs_df = pd.DataFrame(genre_probs)

# axis labels
genre_probs_df.columns = dataset.genre_tags

# interp labels
genre_probs_df.index = interp_labels


# transpose
genre_probs_df = genre_probs_df.T
sns.heatmap(genre_probs_df, annot=True, fmt=".1f", cmap="YlGnBu")

# plot
pr = hvo_seq_all.piano_roll(width=1400, height=300)
# convert bokeh figure to html
html = file_html(pr, CDN, "my plot")
audio = hvo_seq_all.synthesize(sf_path="hvo_sequence/soundfonts/Standard_Drum_Kit.sf2")
IPython.display.display(ipd.Audio(audio, rate=44100))
# show html
IPython.display.display(IPython.display.HTML(html))
# plot_distance(interp_factors, relative_jaccards, interp_labels)
# plot_distance(interp_factors_cosine, relative_cosines, interp_labels_cosine)
plot_distance(interp_factors_MSE, relative_MSEs, interp_labels_MSE)



In [7]:
# get 100 indices for each genre (0-8)
genres = dataset.genre_targets.numpy()

genre_indices = []
for genre in range(9):
    genre_indices.append(np.where(genres == genre)[0])
    
genre_indices = np.array(genre_indices)


In [8]:
from itertools import product

genre_index_pairs = {i : [] for i in range(9)}

# sample non-repeating pairs per genre
for genre in range(9):
    # remove equal pairs
    pairs = [p for p in list(product(genre_indices[genre], genre_indices[genre])) if p[0] != p[1]]
    # shuffle and get the first 100
    np.random.shuffle(pairs)
    genre_index_pairs[genre] = pairs[:300]

len(genre_index_pairs)

In [9]:
genre_ix = 0
pair_ix = 95

def flatten_hvoseqs(hvo_seqs):
    hvo_seqs_flat = []
    for hvo_seq in hvo_seqs:
        hvo_seq_flat = hvo_seq.copy()
        hvo_seq_flat.hvo = hvo_seq.flatten_voices()
        hvo_seqs_flat.append(hvo_seq_flat)
    return hvo_seqs_flat

hvo_seq_all, hvo_seqs, genre_probs, genre_preds = interp_between_two_rand_samples(model_MuteGenreLatentVAE_0_5, n_interp, int(genre_index_pairs[genre_ix][pair_ix][0]), int(genre_index_pairs[genre_ix][pair_ix][1]), genre_ix=genre_ix, genre_classifier_model=model_classifier)
flat_hvo_seqs = flatten_hvoseqs(hvo_seqs)

interp_factors, relative_jaccards, interp_labels = calculate_relative_jackards(flat_hvo_seqs)
interp_factors_cosine, relative_cosines, interp_labels_cosine = calculate_relative_cosine_distance(hvo_seqs)
interp_factors_MSE, relative_MSEs, interp_labels_MSE = calculate_vel_MSE_distance(hvo_seqs)



In [10]:
from tqdm import tqdm

def calculate_all_distances(model_, n_interp, genre_index_pairs, rhythm_flat=False):
    step_I1_jaccard_cosinevel_msevels = []
    step_I2_jaccard_cosinevel_msevels = []
    step_I3_jaccard_cosinevel_msevels = []
    step_I4_jaccard_cosinevel_msevels = []
    step_I5_jaccard_cosinevel_msevels = []
    step_I6_jaccard_cosinevel_msevels = []
    step_I7_jaccard_cosinevel_msevels = []
    step_I8_jaccard_cosinevel_msevels = []
    step_I9_jaccard_cosinevel_msevels = []
    step_I10_jaccard_cosinevel_msevels = []
    
    for genre_ix in tqdm(range(9)):
        for pair_ix in range(200):
            hvo_seq_all, hvo_seqs, genre_probs, genre_preds = interp_between_two_rand_samples(model_, n_interp, int(genre_index_pairs[genre_ix][pair_ix][0]), int(genre_index_pairs[genre_ix][pair_ix][1]), genre_ix=genre_ix, genre_classifier_model=model_classifier)
            flat_hvo_seqs = flatten_hvoseqs(hvo_seqs)
            interp_factors, relative_jaccards, interp_labels = calculate_relative_jackards(flat_hvo_seqs, rhythm_flat=rhythm_flat)
            if relative_jaccards is not None:
                interp_factors_MSE, relative_MSEs, interp_labels_MSE = calculate_vel_MSE_distance(hvo_seqs)
                step_I1_jaccard_cosinevel_msevels.append([relative_jaccards[1], None, relative_MSEs[1]])
                step_I2_jaccard_cosinevel_msevels.append([relative_jaccards[2], None, relative_MSEs[2]])
                step_I3_jaccard_cosinevel_msevels.append([relative_jaccards[3], None, relative_MSEs[3]])
                step_I4_jaccard_cosinevel_msevels.append([relative_jaccards[4], None, relative_MSEs[4]])
                step_I5_jaccard_cosinevel_msevels.append([relative_jaccards[5], None, relative_MSEs[5]])
                step_I6_jaccard_cosinevel_msevels.append([relative_jaccards[6], None, relative_MSEs[6]])
                step_I7_jaccard_cosinevel_msevels.append([relative_jaccards[7], None, relative_MSEs[7]])
                step_I8_jaccard_cosinevel_msevels.append([relative_jaccards[8], None, relative_MSEs[8]])
                step_I9_jaccard_cosinevel_msevels.append([relative_jaccards[9], None, relative_MSEs[9]])
                step_I10_jaccard_cosinevel_msevels.append([relative_jaccards[10], None, relative_MSEs[10]])
            
    step_1_all = np.array(step_I1_jaccard_cosinevel_msevels)
    step_2_all = np.array(step_I2_jaccard_cosinevel_msevels)
    step_3_all = np.array(step_I3_jaccard_cosinevel_msevels)
    step_4_all = np.array(step_I4_jaccard_cosinevel_msevels)
    step_5_all = np.array(step_I5_jaccard_cosinevel_msevels)
    step_6_all = np.array(step_I6_jaccard_cosinevel_msevels)
    step_7_all = np.array(step_I7_jaccard_cosinevel_msevels)
    step_8_all = np.array(step_I8_jaccard_cosinevel_msevels)
    step_9_all = np.array(step_I9_jaccard_cosinevel_msevels)
    step_10_all = np.array(step_I10_jaccard_cosinevel_msevels)
    
    return step_1_all, step_2_all, step_3_all, step_4_all, step_5_all, step_6_all, step_7_all, step_8_all, step_9_all, step_10_all


In [13]:

from tqdm import tqdm
import matplotlib.pyplot as plt


def interpolation_plot(model_name, beta, axs):
    def get_model(model_name, beta):
        # if model_name == "Base":
        #     if beta == 0.2:
        #         return model_BaseVAE_0_2
        #     elif beta == 0.5:
        #         return model_BaseVAE_0_5
        #     elif beta == 1.0:
        #         return model_BaseVAE_1_0
        # elif model_name == "Mute":
        #     if beta == 0.2:
        #         return model_MuteVAE_0_2
        #     elif beta == 0.5:
        #         return model_MuteVAE_0_5
        #     elif beta == 1.0:
        #         return model_MuteVAE_1_0
        # elif model_name == "MuteGenre1":
        #     if beta == 0.2:
        #         return model_MuteGenreLatentVAE_0_2
        #     elif beta == 0.5:
        #         return model_MuteGenreLatentVAE_0_5
        #     elif beta == 1.0:
        #         return model_MuteGenreLatentVAE_1_0
        # elif model_name == "MuteGenre2":
        #     if beta == 0.2:
        #         return model_MuteLatentGenreInputVAE_0_2
        #     elif beta == 0.5:
        #         return model_MuteLatentGenreInputVAE_0_5
        #     elif beta == 1.0:
        #         return model_MuteLatentGenreInputVAE_1_0
        # else:
        #     raise ValueError("Invalid model name")
        return model_MuteGenreLatentVAE_0_5
    
    step_1_all, step_2_all, step_3_all, step_4_all, step_5_all, step_6_all, step_7_all, step_8_all, step_9_all, step_10_all = calculate_all_distances(get_model(model_name, beta), n_interp, genre_index_pairs, rhythm_flat=True)
    
    
    for feature_ix in [0, 2]:
        if feature_ix == 0:
            ax = axs[0]
        else:
            ax = axs[1]
            
        # boxplot for each step
        boxprops = dict(linestyle='--', linewidth=0.5, color='darkgoldenrod')
        whiskerprops = dict(linestyle='-',linewidth=1.0, color='darkgoldenrod', alpha=0.5)

        step_1_data = np.array([x[feature_ix] for x in step_1_all])
        step_2_data = np.array([x[feature_ix] for x in step_2_all])
        step_3_data = np.array([x[feature_ix] for x in step_3_all])
        step_4_data = np.array([x[feature_ix] for x in step_4_all])
        step_5_data = np.array([x[feature_ix] for x in step_5_all])
        step_6_data = np.array([x[feature_ix] for x in step_6_all])
        step_7_data = np.array([x[feature_ix] for x in step_7_all])
        step_8_data = np.array([x[feature_ix] for x in step_8_all])
        step_9_data = np.array([x[feature_ix] for x in step_9_all])
        step_10_data = np.array([x[feature_ix] for x in step_10_all])
        
        # step_1_data[np.isnan(step_1_data)] = 1
        # step_2_data[np.isnan(step_2_data)] = 1
        # step_3_data[np.isnan(step_3_data)] = 1
        # step_4_data[np.isnan(step_4_data)] = 1
        # step_5_data[np.isnan(step_5_data)] = 1
        # step_6_data[np.isnan(step_6_data)] = 1
        # step_7_data[np.isnan(step_7_data)] = 1
        # step_8_data[np.isnan(step_8_data)] = 1
        # step_9_data[np.isnan(step_9_data)] = 1
        # step_10_data[np.isnan(step_10_data)] = 1
        
        if feature_ix == 0:
            d_jac = [step_1_data, step_2_data, step_3_data, step_4_data, step_5_data, step_6_data, step_7_data, step_8_data, step_9_data, step_10_data]
            pos_jac = interp_factors_MSE[1:-1]
            ax.boxplot(d_jac, positions=pos_jac, showfliers=False, widths=0.03, boxprops=boxprops, whiskerprops=whiskerprops)

        else:
            d_mse = [step_1_data, step_2_data, step_3_data, step_4_data, step_5_data, step_6_data, step_7_data, step_8_data, step_9_data, step_10_data]
            pos_mse = interp_factors_MSE[1:-1]
            ax.boxplot(d_mse, positions=pos_mse, showfliers=False, widths=0.03, boxprops=boxprops, whiskerprops=whiskerprops)
            
        
        # violin_parts= ax.violinplot([step_1_data, step_2_data, step_3_data, step_4_data, step_5_data, step_6_data, step_7_data, step_8_data, step_9_data, step_10_data],
        #         positions=interp_factors[1:-1], showmeans=False,
        #           showmedians=True, widths=0.05)
            
        # draw a dashed line (45 degree line)
        ax.plot([interp_factors[0], interp_factors[-1]], [0, 1], 'k--', alpha=0.5, label="Perfect Linear Interpolation")
        
        ax.set_xlim(-0.2, 1.2)
        ax.set_ylim(-0.2, 1.2)
        ax.set_xticks(interp_factors, interp_labels, fontsize=10, rotation=0)
        ax.set_yticks([*interp_factors, 1], np.round([*interp_factors, 1], 2), fontsize=10)
        # remove grid
        ax.grid(False)
        # # draw vertical/horizontal lines for each step (in range(0, 1))
        # for i in range(1, 11):
        #     plt.scatter([interp_factors[i]], [interp_factors[i]], color="darkgoldenrod", s=20)
        
        
        
        # draw a line passing through center of the boxplots
        ax.plot(interp_factors_MSE, np.median([np.zeros_like(step_1_data), step_1_data, step_2_data, step_3_data, step_4_data, step_5_data, step_6_data, step_7_data, step_8_data, step_9_data, step_10_data, np.ones_like(step_10_data)], axis=-1), 
                 color="darkgoldenrod", label="Observed Interpolation", alpha=0.5)
        
        # change label and title font size
        # ax.set_xlabel("Interpolation Factor", fontsize=12)
        if feature_ix == 0:
            ax.set_title("Normalized Jaccard Distance", fontsize=12)
        else:
            ax.set_title("Normalized Velocity MSE", fontsize=12)
        
    ax.legend(fontsize=10, loc="lower right")
    # rotate the xticks by 45 degrees
        
    return d_jac, pos_jac, d_mse, pos_mse
plt.show()

In [14]:

fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(9, 4.5), sharey=False)
d_jac, pos_jac, d_mse, pos_mse = interpolation_plot(model_name="MuteGenre1", beta=0.5, axs=axs)

y = []
x = []
for ix, set_ in enumerate(d_jac):
    for val in set_:
        x.append(pos_jac[ix])
        y.append(val)

y = np.array(y)
x = np.array(x)

# replace nan values with 1
# y[np.isnan(y)] = 1

from scipy import stats

res = stats.spearmanr(x, y)
spearman_correlation_jac = res.statistic 

y = []
x = []
for ix, set_ in enumerate(d_mse):
    for val in set_:
        x.append(pos_mse[ix])
        y.append(val)

y = np.array(y)
x = np.array(x)

# replace nan values with 1
# y[np.isnan(y)] = 1

from scipy import stats

res = stats.spearmanr(x, y)
spearman_correlation_mse = res.statistic

print(f"Spearmans Correlation Jaccard: {np.round(spearman_correlation_jac, 2)}")
print(f"Spearmans Correlation Jaccard: {np.round(spearman_correlation_mse,2)}")
# remove y axis tick labels
axs[1].set_yticklabels([])
#
fig.savefig("./results/interp_plots_with_spears.png", dpi=500)



In [18]:
# make the gap between the axes smaller
fig.subplots_adjust(wspace=0)
fig.savefig("./results/interp_plots_with_spears.png", dpi=500)

In [None]:
0