In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import gin

gin.enter_interactive_mode()

from IPython.display import display, Audio
import torch
import numpy as np
import librosa
from tqdm import tqdm

import sys

sys.path.append('..')

import os

os.chdir("/data/nils/repos/AFTER")
torch.set_grad_enabled(False)

from IPython.display import display, Audio
import ipywidgets as widgets
from IPython.display import HTML
from copy import copy
from music2latent import EncoderDecoder
import matplotlib.pyplot as plt

### Params

### Load model

In [3]:
def load_model(name, autoencoder_name, step, device, folder="./after_runs/"):
    folder = folder + name
    autoencoder_path = "../AFTER/pretrained/" + autoencoder_name
    checkpoint_path = folder + "/checkpoint" + str(step) + "_EMA.pt"
    config = folder + "/config.gin"
    from after.diffusion import RectifiedFlow

    # Parse config
    gin.parse_config_file(config)
    SR = gin.query_parameter("%SR")

    try:
        N_SIGNAL = gin.query_parameter("%N_SIGNAL")
    except:
        N_SIGNAL = 64

    print(N_SIGNAL)

    # Emb model

    # Instantiate model
    blender = RectifiedFlow(device=device)

    # Load checkpoints
    state_dict = torch.load(checkpoint_path, map_location="cpu")["model_state"]
    #state_dict = {k:v for k, v in state_dict.items() if "encoder_time" not in k}
    blender.load_state_dict(state_dict, strict=False)

    # Emb model
    if autoencoder_name == "music2latent":
        emb_model = EncoderDecoder(device=device)
    else:
        emb_model = torch.jit.load(autoencoder_path).eval()
    blender.emb_model = emb_model

    # Send to device
    blender = blender.eval().to(device)
    return blender, SR, N_SIGNAL


import torchaudio


def load_and_encode_audio(filename, blender, SR):
    # Load audio
    audio, sr = torchaudio.load(filename)

    audio = audio / torch.max(abs(audio)) * 0.8

    if sr != SR:
        audio = torchaudio.transforms.Resample(orig_freq=sr,
                                               new_freq=SR)(audio)
    audio = audio.reshape(1, -1).float().to(device)

    # Encode audio
    z = blender.emb_model.encode(audio).float()
    ae_ratio = audio.shape[-1] // z.shape[-1]

    if audio.shape[-1] < N_SIGNAL * ae_ratio:
        audio = torch.cat((audio, audio), -1)
        z = torch.cat((z, z), -1)
    z = z[..., :N_SIGNAL]
    audio = audio[..., :N_SIGNAL * ae_ratio]

    audio = audio.squeeze().cpu().numpy()

    return audio, z


def sample(x0,
           time_cond,
           cond,
           nb_steps=30,
           return_z=False,
           total_guidance=1.):
    x0, time_cond, cond = x0.to(device), time_cond.to(device), cond.to(device)

    x_rec = blender.sample(
        x0,
        time_cond=time_cond,
        cond=cond,
        nb_steps=nb_steps,
        guidance_cond_factor=1.,
        guidance_joint_factor=.3,
        total_guidance=total_guidance,
    )

    if return_z:
        return x_rec

    audio_rec = blender.emb_model.decode(x_rec).cpu().numpy().squeeze()
    return audio_rec


In [4]:
names = [
    "choir_newcycle_noadv2_margin_interpolant",
    "choir_newcycle_noadv2_margin_interpolant",
]
steps = [1500000, 850000]

device = "cuda:0"
autoencoder_name = "acapellav5.ts"

In [5]:
blenders = []
for name, step in zip(names, steps):
    blender, SR, N_SIGNAL = load_model(
        name,
        autoencoder_name,
        step,
        device,
        folder="/data/nils/repos/AFTER/after_runs/")
    blenders.append(blender)

In [6]:
from after.dataset import CombinedDataset

main_folder = "/data/nils/datasets/acapella/raw"
folders = ["homemade", "chorals", "holly_albums"]
augmentation_keys = [
    "z_shift_stretch0", "z_shift_stretch1", "z_shift_stretch2",
    "z_shift_stretch3"
]

dataset = CombinedDataset(freqs="estimate",
                          keys=["z"] + augmentation_keys,
                          init_cache=False,
                          num_samples=None,
                          path_dict={
                              folder: {
                                  "path":
                                  os.path.join(main_folder, folder,
                                               "ae_44k_augs")
                              }
                              for folder in folders
                          })

homemade  :  17188 examples 
chorals  :  7128 examples 
holly_albums  :  7112 examples 


In [7]:
idx = 19909
data = dataset[idx]
z = data["z"]
z = torch.from_numpy(z).unsqueeze(0).to(device)

print("source")
z = z[..., :N_SIGNAL]
audio_in = blender.emb_model.decode(z).cpu().numpy().squeeze()
display(Audio(audio_in, rate=SR))

idx = 78
data = dataset[idx]
z2 = data["z"]
z2 = torch.from_numpy(z2).unsqueeze(0).to(device)

print("timbre target")
z2 = z2[..., :N_SIGNAL]
audio_in = blender.emb_model.decode(z2).cpu().numpy().squeeze()
display(Audio(audio_in, rate=SR))

total_guidance = 2.
nb_steps = 10

print("reconstruction")

for blender in blenders:
    zsem = blender.encoder(z)
    time_cond = blender.encoder_time(z)
    #time_cond = torch.ones_like(time_cond) * blender.drop_value
    z0 = torch.randn_like(z)
    z_out = sample(z0,
                   time_cond,
                   zsem,
                   nb_steps=nb_steps,
                   return_z=True,
                   total_guidance=total_guidance)

    zsem_rec = blender.encoder(z_out)
    print(torch.nn.functional.mse_loss(zsem, zsem_rec))
    audio_out = blender.emb_model.decode(z_out).cpu().numpy().squeeze()
    display(Audio(audio_out, rate=SR))

print("transfer")
for blender in blenders:
    zsem = blender.encoder(z2)
    time_cond = blender.encoder_time(z)
    #time_cond = torch.ones_like(time_cond) * blender.drop_value
    z0 = torch.randn_like(z)
    z_out = sample(z0,
                   time_cond,
                   zsem,
                   nb_steps=nb_steps,
                   return_z=True,
                   total_guidance=total_guidance)
    zsem_rec = blender.encoder(z_out)
    print(torch.nn.functional.mse_loss(zsem, zsem_rec))
    audio_out = blender.emb_model.decode(z_out).cpu().numpy().squeeze()
    display(Audio(audio_out, rate=SR))

print("inverse transfer")
for blender in blenders:
    zsem = blender.encoder(z)
    time_cond = blender.encoder_time(z2)
    #time_cond = torch.ones_like(time_cond) * blender.drop_value
    z0 = torch.randn_like(z)
    z_out = sample(z0,
                   time_cond,
                   zsem,
                   nb_steps=nb_steps,
                   return_z=True,
                   total_guidance=total_guidance)
    zsem_rec = blender.encoder(z_out)
    print(torch.nn.functional.mse_loss(zsem, zsem_rec))
    audio_out = blender.emb_model.decode(z_out).cpu().numpy().squeeze()
    display(Audio(audio_out, rate=SR))

source


timbre target


reconstruction
tensor(0.0112, device='cuda:0')


tensor(0.0151, device='cuda:0')


transfer
tensor(0.0275, device='cuda:0')


tensor(0.0188, device='cuda:0')


inverse transfer
tensor(0.0580, device='cuda:0')


tensor(0.0391, device='cuda:0')


In [None]:
blender = blenders[0]
total_guidance = 1
guidance_cond_factor = 1
guidance_joint_factor = 1
nb_steps = 10

z0 = torch.randn_like(z)
cond = blender.encoder(z2)
time_cond = blender.encoder_time(z)

dt = 1 / nb_steps
t_values = torch.linspace(0, 1, nb_steps + 1).to(blender.device)[:-1]
cond_error_final_list = []
target = "tc"
for tstop in [0.2, 0.5, 0.7, 0.95]:
    x = z0.to(blender.device)
    cond_error_list = []
    for i, t in enumerate(t_values):
        t = t.reshape(1, 1, 1).repeat(x.shape[0], 1, 1)
        if t.item() > tstop:
            cond_cur = torch.ones_like(cond) * blender.drop_value
        else:
            cond_cur = cond.clone()
        if False:  #t.item() > tstop:
            time_cond_cur = torch.ones_like(time_cond) * blender.drop_value
        else:
            time_cond_cur = time_cond.clone()

        prediction = blender.model_forward(x, t, cond_cur, time_cond_cur,
                                           guidance_cond_factor,
                                           guidance_joint_factor,
                                           total_guidance)

        x_estimate = x + (1 - t) * prediction

        x = x + prediction * dt

        #x_estimate = prediction + z0

        pred_cond = blender.encoder(x_estimate)
        cond_error = torch.nn.functional.mse_loss(pred_cond, cond)
        cond_error_list.append(cond_error.item())
        #print("distance between z and estimate",
        #      torch.linalg.norm(x_estimate - z))
        print("cond error", cond_error.item())
        if False:
            print(t)
            print(x_estimate.max(), pred_cond)
            print(torch.linalg.norm(x_estimate - z))
            audio_out = blender.emb_model.decode(
                x_estimate).cpu().numpy().squeeze()
            display(Audio(audio_out, rate=SR))
            plt.figure(figsize=(15, 5))
            plt.plot(audio_out)
            plt.show()

    plt.plot(t_values.cpu(), cond_error_list, label=f"tstop = {tstop}")
    pred_cond_final = blender.encoder(x)

    cond_error_final_list.append(
        torch.nn.functional.mse_loss(pred_cond_final, cond).item())

    print("Cond error finale",
          torch.nn.functional.mse_loss(pred_cond_final, cond))
    #plt.show()
    print(f"tstop = {tstop}")
    audio_out = blender.emb_model.decode(x).cpu().numpy().squeeze()
    display(Audio(audio_out, rate=SR))
    #plt.figure(figsize=(15, 5))
    #plt.plot(audio_out)
    #plt.show()
plt.legend()
plt.show()

plt.plot([0.2, 0.5, 0.7, 0.95], cond_error_final_list)


In [None]:
blender = blenders[0]
total_guidance = 1
guidance_cond_factor = 1
guidance_joint_factor = 1
nb_steps = 10

z0 = torch.randn_like(z)
cond = blender.encoder(z2)
time_cond = blender.encoder_time(z)

dt = 1 / nb_steps
t_values = torch.linspace(0, 1, nb_steps + 1).to(blender.device)[:-1]
cond_error_final_list = []
target = "tc"
for tstop in [0.2, 0.5, 0.7, 0.95]:
    x = z0.to(blender.device)
    cond_error_list = []
    for i, t in enumerate(t_values):
        t = t.reshape(1, 1, 1).repeat(x.shape[0], 1, 1)
        if t.item() > tstop:
            cond_cur = torch.ones_like(cond) * blender.drop_value
        else:
            cond_cur = cond.clone()
        if False:  #t.item() > tstop:
            time_cond_cur = torch.ones_like(time_cond) * blender.drop_value
        else:
            time_cond_cur = time_cond.clone()

        prediction = blender.model_forward(x, t, cond_cur, time_cond_cur,
                                           guidance_cond_factor,
                                           guidance_joint_factor,
                                           total_guidance)

        prediction_uncond = blender.model_forward(
            x, t, -4 * torch.ones_like(cond_cur), time_cond_cur,
            guidance_cond_factor, guidance_joint_factor, total_guidance)

        x_estimate = x + (1 - t) * prediction

        x = x + prediction * dt

        #x_estimate = prediction + z0

        pred_cond = blender.encoder(x_estimate)
        cond_error = torch.nn.functional.mse_loss(pred_cond, cond)
        cond_error_list.append(
            torch.nn.functional.mse_loss(prediction_uncond, prediction).item())
        #print("distance between z and estimate",
        #      torch.linalg.norm(x_estimate - z))
        print("cond error", cond_error.item())
        if False:
            print(t)
            print(x_estimate.max(), pred_cond)
            print(torch.linalg.norm(x_estimate - z))
            audio_out = blender.emb_model.decode(
                x_estimate).cpu().numpy().squeeze()
            display(Audio(audio_out, rate=SR))
            plt.figure(figsize=(15, 5))
            plt.plot(audio_out)
            plt.show()

    plt.plot(t_values.cpu(), cond_error_list, label=f"tstop = {tstop}")
    pred_cond_final = blender.encoder(x)

    cond_error_final_list.append(
        torch.nn.functional.mse_loss(pred_cond_final, cond).item())

    print("Cond error finale",
          torch.nn.functional.mse_loss(pred_cond_final, cond))
    #plt.show()
    print(f"tstop = {tstop}")
    audio_out = blender.emb_model.decode(x).cpu().numpy().squeeze()
    display(Audio(audio_out, rate=SR))
    #plt.figure(figsize=(15, 5))
    #plt.plot(audio_out)
    #plt.show()
plt.legend()
plt.show()

plt.plot([0.2, 0.5, 0.7, 0.95], cond_error_final_list)


In [None]:
a = torch.tensor([0, 1, 2])
torch.maximum(a, torch.tensor(1.))

In [None]:
print(z.max(), z.min(), z.std())
print(z2.max(), z2.min(), z2.std())

### Stream

In [8]:
def sample_step(model,
                t_starts,
                t_ends,
                z_start,
                cond,
                time_cond=None,
                nb_steps=10,
                return_trajectory=False,
                total_guidance=1.,
                guidance_joint_factor=1.,
                guidance_cond_factor=1.):

    t_starts = t_starts.to(model.device)
    t_ends = t_ends.to(model.device)

    x_out = []

    x = z_start

    alpha = torch.linspace(0, 1, nb_steps + 1).to(model.device)

    t_values = (1 - alpha[:, None]) * t_starts + alpha[:, None] * t_ends

    for t_cur, t_next in zip(t_values[:-1], t_values[1:]):
        t_cur = t_cur.reshape(1, -1).repeat(x.shape[0], 1)
        t_next = t_next.reshape(1, -1).repeat(x.shape[0], 1)

        dt = t_next - t_cur

        x = x + model.model_forward(
            x,
            time=t_cur,
            cond=cond,
            time_cond=time_cond,
            total_guidance=total_guidance,
            guidance_joint_factor=guidance_joint_factor,
            guidance_cond_factor=guidance_cond_factor) * dt[:, None, :]
        x_out.append(x.cpu())

    if return_trajectory:
        return x, torch.stack(x_out, dim=-1)
    return x

In [9]:
idx = -90
#idx = 100
data = dataset[idx]
z = data["z"]
z = torch.from_numpy(z).unsqueeze(0).to(device)

print("source")
z_full = z
z_timbre = z[..., :N_SIGNAL]
audio_in = blender.emb_model.decode(z_full).cpu().numpy().squeeze()
display(Audio(audio_in, rate=SR))

idx = 329
data = dataset[idx]
z2 = data["z"]
z2 = torch.from_numpy(z2).unsqueeze(0).to(device)

print("timbre target")
z_timbre = z2[..., :N_SIGNAL]
audio_in = blender.emb_model.decode(z_timbre).cpu().numpy().squeeze()
display(Audio(audio_in, rate=SR))

print("transfer")
zsem, _ = blender.encoder(z_timbre)
time_cond, _ = blender.encoder_time(z[..., :N_SIGNAL])
#time_cond = torch.ones_like(time_cond) * blender.drop_value
z0 = torch.randn_like(z_timbre)
z_out = sample(z0, time_cond, zsem, nb_steps=20, return_z=True)
audio_out = blender.emb_model.decode(z_out).cpu().numpy().squeeze()
display(Audio(audio_out, rate=SR))


In [10]:
nseq = z_timbre.shape[-1]
noise_start = 32
chunk_size = 8
nb_steps = 5
total_guidance = 1.
guidance_joint_factor = 1.
guidance_cond_factor = 1.
N_STEPS = (z_full.shape[-1] - nseq) // chunk_size

t_starts = blender.get_noising_scheme(random=False,
                                      random_power=1,
                                      len_seq=nseq,
                                      chunk_size=chunk_size,
                                      randomise_start_value=False,
                                      randomise_step_values=False,
                                      noise_start=noise_start)

plt.plot(t_starts.squeeze())

t_ends = blender.get_noising_scheme(random=False,
                                    random_power=1,
                                    len_seq=nseq + chunk_size,
                                    chunk_size=chunk_size,
                                    randomise_start_value=False,
                                    randomise_step_values=False,
                                    noise_start=noise_start +
                                    chunk_size)[:-chunk_size]
plt.plot(t_ends.squeeze())
plt.show()

z_start = z_timbre  #[..., :N_SIGNAL]
noise_buffer = [
    torch.randn_like(z_start[..., :chunk_size]) for i in range(N_STEPS)
]

cond, _ = blender.encoder(z_timbre)
time_cond, _ = blender.encoder_time(z_full)

z_rec = []
last_zcur = z_start.clone()
last_zcur = t_starts[None, None, :].to(device) * last_zcur + (
    1 - t_starts[None, None, :]).to(device) * torch.randn_like(last_zcur)

for i in range(N_STEPS):

    time_cond_cur = time_cond[..., i * chunk_size:(i * chunk_size) + nseq]
    zcur = sample_step(blender,
                       t_starts,
                       t_ends,
                       last_zcur,
                       cond,
                       time_cond=time_cond_cur,
                       nb_steps=nb_steps,
                       return_trajectory=False,
                       total_guidance=1.,
                       guidance_joint_factor=1.,
                       guidance_cond_factor=1.)

    zcur_out = zcur[..., noise_start:noise_start + chunk_size]
    z_rec.append(zcur_out)
    zcur[..., :noise_start] = last_zcur[..., :noise_start]
    zcur = zcur[..., chunk_size:]
    zcur = torch.cat([zcur, noise_buffer[i]], dim=-1)
    last_zcur = zcur.clone()

z_rec = torch.cat(z_rec, dim=-1)
audio_rec = blender.emb_model.decode(z_rec)
display(Audio(audio_rec.squeeze().cpu().numpy(), rate=44100))

In [11]:
nseq = z_timbre.shape[-1]
noise_start = 32
chunk_size = 8
nb_steps = 5
total_guidance = 1.
guidance_joint_factor = 1.
guidance_cond_factor = 1.
N_STEPS = (z_full.shape[-1] - nseq) // chunk_size

t_starts = blender.get_noising_scheme(random=False,
                                      random_power=1,
                                      len_seq=nseq,
                                      chunk_size=chunk_size,
                                      randomise_start_value=False,
                                      randomise_step_values=False,
                                      noise_start=noise_start)

plt.plot(t_starts.squeeze())

t_ends = blender.get_noising_scheme(random=False,
                                    random_power=1,
                                    len_seq=nseq + chunk_size,
                                    chunk_size=chunk_size,
                                    randomise_start_value=False,
                                    randomise_step_values=False,
                                    noise_start=noise_start +
                                    chunk_size)[:-chunk_size]
plt.plot(t_ends.squeeze())
plt.show()

z_start = z_timbre  #[..., :N_SIGNAL]
noise_buffer = [
    torch.randn_like(z_start[..., :chunk_size]) for i in range(N_STEPS)
]

cond, _ = blender.encoder(z_timbre)
time_cond, _ = blender.encoder_time(z_full)

z_rec = []
last_zcur = z_start.clone()
last_zcur = t_starts[None, None, :].to(device) * last_zcur + (
    1 - t_starts[None, None, :]).to(device) * torch.randn_like(last_zcur)

for i in range(N_STEPS):

    time_cond_cur = time_cond[..., i * chunk_size:(i * chunk_size) + nseq]
    zcur = sample_step(blender,
                       t_starts,
                       t_ends,
                       last_zcur,
                       cond,
                       time_cond=time_cond_cur,
                       nb_steps=nb_steps,
                       return_trajectory=False,
                       total_guidance=1.,
                       guidance_joint_factor=1.,
                       guidance_cond_factor=1.)

    zcur_out = zcur[..., noise_start:noise_start + chunk_size]
    z_rec.append(zcur_out)
    zcur[..., :noise_start] = last_zcur[..., :noise_start]
    zcur = zcur[..., chunk_size:]
    zcur = torch.cat([zcur, noise_buffer[i]], dim=-1)
    last_zcur = zcur.clone()

z_rec = torch.cat(z_rec, dim=-1)
audio_rec = blender.emb_model.decode(z_rec)
display(Audio(audio_rec.squeeze().cpu().numpy(), rate=44100))

### Load audios

In [12]:
base_folder = "/data/nils/datasets/instruments/syntheticv1/audio/"
base_files = os.listdir(base_folder)
base_files.sort()

print(base_files)

np.random.seed(8)
base_files = np.random.choice(base_files, 6)
print(base_files)
audios = []
zlist = []

for file in base_files:
    audio, z = load_and_encode_audio(os.path.join(base_folder, file), blender,
                                     SR)

    audios.append(audio)
    zlist.append(z)

all_transfers = []
z_cond = torch.cat(zlist, dim=0)
cond, cond_mean, _ = blender.encoder(z_cond, return_mean=True)

for z_time_cond in zlist:
    transfers = []

    x0 = torch.randn_like(z_cond)
    time_cond, _ = blender.encoder_time(z_time_cond)
    time_cond = time_cond.repeat((z_cond.shape[0], 1, 1))
    transfers = sample(x0, time_cond, cond_mean, nb_steps=10)
    all_transfers.append(list(transfers))

In [13]:
audio_size = str(150)

audio_style = "<style>audio { margin-left: 0px; width: " + audio_size + "px; }</style>"
display(HTML(audio_style))

blank_out = widgets.Output(layout={'width': '80px'})
lines, text_widgets, true_audios = [], [blank_out,
                                        blank_out], [blank_out, blank_out]

for i, (name, audio) in enumerate(zip(base_files, audios)):
    text = name[:-4]
    out0 = widgets.Output()
    with out0:
        display(HTML(f"<h2 style='text-align:center'>{text}</h1>"))

    text_widgets.append(out0)

    audio_widgets = []
    out = widgets.Output(
        layout={
            'border': '0.5px solid black',
            'width': f'{audio_size}px',
            'background_color': 'black'
        })
    with out:
        display(Audio(data=audio, rate=SR))
    audio_widgets.append(out0)
    audio_widgets.append(out)
    true_audios.append(out)

    for j, transfer in enumerate(all_transfers[i]):

        out = widgets.Output(layout={
            'border': '0.5px solid black',
            'width': f'{audio_size}px',
        })

        with out:
            display(Audio(data=transfer, rate=SR))
        audio_widgets.append(out)

    lines.extend(copy(audio_widgets))

outputs = widgets.GridBox(
    text_widgets + true_audios + lines,
    layout=widgets.Layout(
        grid_template_columns=f"80px repeat({len(audios) + 1}, {audio_size}px)"
    ))

box = widgets.VBox(children=[outputs])
display(box)

#widgets.HBox(audio_widgets)

### Improvise

In [14]:
all_transfers = []

nimpros = 4

for z_time_cond in zlist:
    impros_cond = []

    x0 = torch.randn(nimpros,
                     z_time_cond.shape[1],
                     z_time_cond.shape[2],
                     device=device)
    cond, _ = blender.encoder(z_time_cond)
    time_cond, _ = blender.encoder_time(z_time_cond)

    time_cond = time_cond.repeat((nimpros, 1, 1))
    cond = cond.repeat((nimpros, 1))

    # Cond impro
    cond_impros = sample(x0,
                         time_cond,
                         blender.drop_value * torch.ones_like(cond),
                         nb_steps=20)
    time_cond_impros = sample(x0,
                              blender.drop_value * torch.ones_like(time_cond),
                              cond,
                              nb_steps=20)

    all_transfers.append(list(cond_impros) + list(time_cond_impros))


In [15]:
blank_out = widgets.Output(layout={'width': '80px'})
big_blank_out = widgets.Output(layout={'width': '80px'})
lines, text_widgets = [], [blank_out, blank_out]

# Add first row with text boxes
first_row = [blank_out, big_blank_out]

out = widgets.Output(layout={'border': '0.5px solid black', 'width': '600px'})
with out:
    display(HTML(f"<h2 style='text-align:center'> Timbre Improvisation </h2>"))
first_row.append(out)

out = widgets.Output(layout={'border': '0.5px solid black', 'width': '600px'})
with out:
    display(
        HTML(f"<h2 style='text-align:center'>Structure Improvisation</h2>"))
first_row.append(out)

#lines.extend(first_row)

for i, (name, audio) in enumerate(zip(base_files, audios)):
    text = name[:-4]
    out0 = widgets.Output()
    with out0:
        display(HTML(f"<h2 style='text-align:center'>{text}</h1>"))

    text_widgets.append(out0)

    audio_widgets = []
    out = widgets.Output(
        layout={
            'border': '0.5px solid black',
            'width': '150px',
            'background_color': 'black'
        })
    with out:
        display(Audio(data=audio, rate=SR))
    audio_widgets.append(out0)
    audio_widgets.append(out)
    for j, transfer in enumerate(all_transfers[i]):

        out = widgets.Output(layout={
            'border': '0.5px solid black',
            'width': '150px',
        })

        with out:
            display(Audio(data=transfer, rate=SR))
        audio_widgets.append(out)

    lines.extend(copy(audio_widgets))

outputs = widgets.GridBox(
    lines,
    layout=widgets.Layout(
        grid_template_columns=f"80px repeat({2 * nimpros + 1}, 150px)"))

top_outputs = widgets.GridBox(
    first_row,
    layout=widgets.Layout(grid_template_columns=f"80px 150px 600px 600px"))

box = widgets.VBox(children=[top_outputs, outputs])
display(box)

#widgets.HBox(audio_widgets)

### Conditionning influence

In [16]:
import matplotlib.pyplot as plt

z1 = zlist[5]
z2 = zlist[1]
tlist = np.linspace(0, 1, 20)
plot = False


def smooth_function(x, slope=9):
    return 0.5 * (1 + torch.tanh(slope * (0.2 - x)))


curve = smooth_function(torch.tensor(tlist))

for i in range(10):
    idx = np.random.randint(0, len(zlist), 2)
    z1 = zlist[idx[0]]
    z2 = zlist[idx[1]]
    x0 = torch.randn_like(z1)

    cond_dists_target = []
    out_cond_norm = []
    out_uncond_norm = []
    cond_dists_orig = []
    cond_uncond_dists = []
    mse_dists = []

    cond_target, cond_target_mean, _ = blender.encoder(z2, return_mean=True)
    cond_orig, cond_orig_mean, _ = blender.encoder(z2, return_mean=True)
    time_cond, _ = blender.encoder_time(z1)
    #time_cond = blender.drop_value* torch.ones_like(time_cond)

    audio_source = blender.emb_model.decode(z2).cpu().squeeze()
    audio_transfer = sample(x0, time_cond, cond_target, nb_steps=50)
    audio_target = blender.emb_model.decode(z1).cpu().squeeze()

    texts, audios_widg = [], []

    ### SOURCE
    if plot:
        out0 = widgets.Output()
        with out0:
            display(
                HTML(f"<h2 style='text-align:center'> orignal audio </h1>"))
        texts.append(out0)
        out = widgets.Output(
            layout={
                'border': '0.5px solid black',
                'width': '150px',
                'background_color': 'black'
            })
        with out:
            display(Audio(data=audio_source, rate=SR))
        audios_widg.append(out)

        ### TARGET
        out0 = widgets.Output()
        with out0:
            display(
                HTML(f"<h2 style='text-align:center'> target transfer </h1>"))
        texts.append(out0)
        out = widgets.Output(
            layout={
                'border': '0.5px solid black',
                'width': '150px',
                'background_color': 'black'
            })
        with out:
            display(Audio(data=audio_target, rate=SR))
        audios_widg.append(out)

        ### TRANSFER
        out0 = widgets.Output()
        with out0:
            display(
                HTML(
                    f"<h2 style='text-align:center'> full transfer audio </h1>"
                ))
        texts.append(out0)

        out = widgets.Output(
            layout={
                'border': '0.5px solid black',
                'width': '150px',
                'background_color': 'black'
            })
        with out:
            display(Audio(data=audio_transfer, rate=SR))
        audios_widg.append(out)

    for j, t in enumerate(tlist):
        interpolant = (1 - t) * x0 + t * z2

        model_output_transfer = blender.net(
            interpolant,
            time=torch.tensor(t).view(1).to(z1),
            time_cond=time_cond,
            cond=cond_target)

        model_output_unconditionnal = blender.net(
            interpolant,
            time=torch.tensor(t).view(1).to(z1),
            time_cond=time_cond,
            cond=blender.drop_value * torch.ones_like(cond_target))

        out = interpolant + (1 - t) * model_output_transfer

        mse_dists.append(torch.nn.functional.mse_loss(out, z2).mean().item())

        cond_rec, cond_rec_mean, _ = blender.encoder(out, return_mean=True)

        cond_dists_target.append(
            torch.nn.functional.mse_loss(cond_target_mean,
                                         cond_rec_mean).mean().item())

        cond_dists_orig.append(
            torch.nn.functional.mse_loss(cond_orig_mean,
                                         cond_rec_mean).mean().item())

        cond_uncond_dists.append(
            torch.nn.functional.mse_loss(
                model_output_transfer,
                model_output_unconditionnal).mean().item())
        out_cond_norm.append(
            torch.linalg.norm(model_output_transfer).mean().item())
        out_uncond_norm.append(
            torch.linalg.norm(model_output_unconditionnal).mean().item())

        if plot and j % 2 == 0:
            audio = blender.emb_model.decode(out).cpu().squeeze()

            out0 = widgets.Output()
            with out0:
                display(
                    HTML(
                        f"<h2 style='text-align:center'>t = {np.round(t, 2)}</h1>"
                    ))
            texts.append(out0)

            out = widgets.Output(
                layout={
                    'border': '0.5px solid black',
                    'width': '150px',
                    'background_color': 'black'
                })
            with out:
                display(Audio(data=audio, rate=SR))

            audios_widg.append(out)

    if plot:
        outputs = widgets.GridBox(
            texts + audios_widg,
            layout=widgets.Layout(
                grid_template_columns=f"repeat({len(audios_widg) }, 150px)"))

        box = widgets.VBox(children=[outputs])
        display(box)

    #plt.plot(tlist, cond_dists_target, label="Timbre distance wrt target sample", color="g")
    #plt.plot(tlist, cond_dists_orig, label="Timbre distance wrt original sample", color="b")
    #plt.xlabel("Denoising step")

    ##plt.plot(tlist, mse_dists, label="mse", color="k")
    #plt.legend()
    #plt.show()

    #plt.plot(tlist, out_cond_norm, label="norm cond", color="b")
    #plt.plot(tlist, out_uncond_norm, label="norm uncond", color="r")
    #plt.legend()
    #plt.show()
    plt.plot(tlist, cond_uncond_dists, label="mse dist", color="k")

plt.show()


In [17]:
import matplotlib.pyplot as plt

z1 = zlist[5]
z2 = zlist[1]
tlist = np.linspace(0, 1, 20)
plot = True


def smooth_function(x, slope=6):
    return 0.5 * (1 + torch.tanh(slope * (0.45 - x)))


curve = smooth_function(torch.tensor(tlist))

for i in range(5):
    idx = np.random.randint(0, len(zlist), 2)
    z1 = zlist[idx[0]]
    z2 = zlist[idx[1]]
    x0 = torch.randn_like(z1)

    cond_dists_target = []
    cond_dists_orig = []
    cond_uncond_dists = []
    mse_dists = []

    _, cond_target, _ = blender.encoder(z2, return_mean=True)
    _, cond_orig, _ = blender.encoder(z2, return_mean=True)

    time_cond_orig, _ = blender.encoder_time(z2)
    time_cond_target, _ = blender.encoder_time(z2)

    audio_source = blender.emb_model.decode(z2).cpu().squeeze()
    audio_transfer = sample(x0, time_cond, cond_target, nb_steps=50)
    audio_target = blender.emb_model.decode(z1).cpu().squeeze()

    texts, audios_widg = [], []

    ### SOURCE

    out0 = widgets.Output()
    with out0:
        display(HTML(f"<h2 style='text-align:center'> orignal audio </h1>"))
    texts.append(out0)
    out = widgets.Output(
        layout={
            'border': '0.5px solid black',
            'width': '150px',
            'background_color': 'black'
        })
    with out:
        display(Audio(data=audio_source, rate=SR))
    audios_widg.append(out)

    ### TARGET
    out0 = widgets.Output()
    with out0:
        display(HTML(f"<h2 style='text-align:center'> target transfer </h1>"))
    texts.append(out0)
    out = widgets.Output(
        layout={
            'border': '0.5px solid black',
            'width': '150px',
            'background_color': 'black'
        })
    with out:
        display(Audio(data=audio_target, rate=SR))
    audios_widg.append(out)

    ### TRANSFER
    out0 = widgets.Output()
    with out0:
        display(
            HTML(f"<h2 style='text-align:center'> full transfer audio </h1>"))
    texts.append(out0)

    out = widgets.Output(
        layout={
            'border': '0.5px solid black',
            'width': '150px',
            'background_color': 'black'
        })
    with out:
        display(Audio(data=audio_transfer, rate=SR))
    audios_widg.append(out)

    for j, t in enumerate(tlist):
        interpolant = (1 - t) * x0 + t * z2

        model_output_transfer = blender.net(
            interpolant,
            time=torch.tensor(t).view(1).to(z1),
            time_cond=time_cond_target,
            cond=cond_orig)

        model_output_normal = blender.net(interpolant,
                                          time=torch.tensor(t).view(1).to(z1),
                                          time_cond=blender.drop_value *
                                          torch.ones_like(time_cond_target),
                                          cond=cond_orig)

        out = interpolant + (1 - t) * model_output_transfer

        mse_dists.append(torch.nn.functional.mse_loss(out, z2).mean().item())

        _, time_cond_rec, _ = blender.encoder_time(out, return_mean=True)

        cond_dists_target.append(
            torch.nn.functional.mse_loss(time_cond_target,
                                         time_cond_rec).mean().item())

        cond_dists_orig.append(
            torch.nn.functional.mse_loss(time_cond_orig,
                                         time_cond_rec).mean().item())

        cond_uncond_dists.append(
            torch.nn.functional.mse_loss(model_output_transfer,
                                         model_output_normal).mean().item())

        if plot and j % 2 == 0:
            audio = blender.emb_model.decode(out).cpu().squeeze()

            out0 = widgets.Output()
            with out0:
                display(
                    HTML(
                        f"<h2 style='text-align:center'>t = {np.round(t, 2)}</h1>"
                    ))
            texts.append(out0)

            out = widgets.Output(
                layout={
                    'border': '0.5px solid black',
                    'width': '150px',
                    'background_color': 'black'
                })
            with out:
                display(Audio(data=audio, rate=SR))

            audios_widg.append(out)

    if plot:
        outputs = widgets.GridBox(
            texts + audios_widg,
            layout=widgets.Layout(
                grid_template_columns=f"repeat({len(audios_widg) }, 150px)"))

        box = widgets.VBox(children=[outputs])
        display(box)

    plt.plot(tlist, cond_dists_target, label="cond target", color="g")
    plt.plot(tlist, cond_dists_orig, label="cond orig", color="b")
    plt.plot(tlist, 0.1 * curve, label="curve", color="k", linestyle="--")

    plt.show()
    plt.plot(tlist, cond_uncond_dists, label="mse dist", color="g")
    plt.show()

    #plt.plot(tlist, mse_dists, label="mse", color="k")
    #plt.plot(tlist, curve, label="curve", color = "k", linestyle="--")
    #plt.legend()

    plt.show()


In [18]:
plt.hist(time_cond.flatten().cpu())

### Interpolation

In [19]:
# Compute structure representation
ninterp = 6

texts, audios_widg = [], []

out0 = widgets.Output()
with out0:
    display(HTML(f"<h2 style='text-align:center'> Audio A </h1>"))
texts.append(out0)

out0 = widgets.Output()
with out0:
    display(HTML(f"<h2 style='text-align:center'> Audio B </h1>"))
texts.append(out0)

for j in range(4):
    z1 = zlist[j]
    z2 = zlist[j + 2]
    x0 = torch.randn(1, z1.shape[1], z1.shape[2], device=device)
    cond1, _ = blender.encoder(z1)
    cond2, _ = blender.encoder(z2)
    time_cond, _ = blender.encoder_time(z1)

    x0 = torch.randn_like(z1)
    audio = blender.emb_model.decode(z1).cpu().squeeze()
    out = widgets.Output(
        layout={
            'border': '0.5px solid black',
            'width': '150px',
            'background_color': 'black'
        })
    with out:
        display(Audio(data=audio, rate=SR))
    audios_widg.append(out)

    audio = blender.emb_model.decode(z2).cpu().squeeze()
    out = widgets.Output(
        layout={
            'border': '0.5px solid black',
            'width': '150px',
            'background_color': 'black'
        })
    with out:
        display(Audio(data=audio, rate=SR))
    audios_widg.append(out)

    for t in np.linspace(0, 1, ninterp):
        cond = (1 - t) * cond1 + t * cond2
        x0 = torch.randn_like(z1)
        audio_out = sample(x0, time_cond, cond, nb_steps=30)

        if j == 0:
            out0 = widgets.Output()
            with out0:
                display(
                    HTML(
                        f"<h2 style='text-align:center'>t = {np.round(t, 2)}</h1>"
                    ))
            texts.append(out0)

        out = widgets.Output(
            layout={
                'border': '0.5px solid black',
                'width': '150px',
                'background_color': 'black'
            })
        with out:
            display(Audio(data=audio_out, rate=SR))

        audios_widg.append(out)

outputs = widgets.GridBox(
    texts + audios_widg,
    layout=widgets.Layout(
        grid_template_columns=f"repeat({ninterp + 2}, 150px)"))

box = widgets.VBox(children=[outputs])
display(box)

### Out of Domain 

In [20]:
base_folder = "/data/nils/repos/AFTER/notebooks/samples/base2/"
base_files = os.listdir(base_folder)  #[:6]
base_files.sort()
print(base_files)

base_files = base_files[:6]
audios = []
zlist = []

for file in base_files:
    audio, z = load_and_encode_audio(os.path.join(base_folder, file), blender,
                                     SR)
    audios.append(audio)
    zlist.append(z)

ood_folder = "/data/nils/repos/AFTER/notebooks/samples/out_of_domain/"
ood_files = os.listdir(ood_folder)  #[:6]
ood_files.sort()
ood_files = ood_files[:6]

ood_audios = []
ood_zlist = []

for file in ood_files:
    audio, z = load_and_encode_audio(os.path.join(ood_folder, file), blender,
                                     SR)
    ood_audios.append(audio)
    ood_zlist.append(z)

all_transfers = []
z_cond = torch.cat(zlist, dim=0)
cond, _ = blender.encoder(z_cond)

for z_time_cond in ood_zlist:
    transfers = []
    x0 = torch.randn_like(z_cond)
    time_cond, _ = blender.encoder_time(z_time_cond)
    time_cond = time_cond.repeat((z_cond.shape[0], 1, 1))
    transfers = sample(x0, time_cond, cond, nb_steps=30)
    all_transfers.append(list(transfers))

In [21]:
audio_size = str(150)

audio_style = "<style>audio { margin-left: 0px; width: " + audio_size + "px; }</style>"
display(HTML(audio_style))

blank_out = widgets.Output(layout={'width': '80px'})
lines, text_widgets, true_audios = [], [blank_out,
                                        blank_out], [blank_out, blank_out]

for i, (name, audio) in enumerate(zip(ood_files, ood_audios)):
    audio_widgets = []

    text = name[:-4]
    out0 = widgets.Output()
    with out0:
        display(HTML(f"<h2 style='text-align:center'>{text}</h1>"))

    out = widgets.Output(
        layout={
            'border': '0.5px solid black',
            'width': f'{audio_size}px',
            'background_color': 'black'
        })
    with out:
        display(Audio(data=audio, rate=SR))
    audio_widgets.append(out0)
    audio_widgets.append(out)

    for j, transfer in enumerate(all_transfers[i]):

        if i == 0:
            out = widgets.Output(
                layout={
                    'border': '0.5px solid black',
                    'width': f'{audio_size}px',
                    'background_color': 'black'
                })
            with out:
                display(Audio(data=audios[j], rate=SR))
            true_audios.append(out)

            text = base_files[j][:-4]
            out0 = widgets.Output()
            with out0:
                display(HTML(f"<h2 style='text-align:center'>{text}</h1>"))

            text_widgets.append(out0)

        out = widgets.Output(layout={
            'border': '0.5px solid black',
            'width': f'{audio_size}px',
        })

        with out:
            display(Audio(data=transfer, rate=SR))
        audio_widgets.append(out)

    lines.extend(copy(audio_widgets))

outputs = widgets.GridBox(
    text_widgets + true_audios + lines,
    layout=widgets.Layout(
        grid_template_columns=f"80px repeat({len(audios) + 1}, {audio_size}px)"
    ))

box = widgets.VBox(children=[outputs])
display(box)

#widgets.HBox(audio_widgets)

In [22]:
base_folder = "/data/nils/repos/AFTER/notebooks/samples/out_of_domain/"
base_files = os.listdir(base_folder)  #[:6]
base_files.sort()
print(base_files)

base_files = base_files[:6]
audios = []
zlist = []

for file in base_files:
    audio, z = load_and_encode_audio(os.path.join(base_folder, file), blender,
                                     SR)
    audios.append(audio)
    zlist.append(z)

ood_folder = "/data/nils/repos/AFTER/notebooks/samples/base2/"
ood_files = os.listdir(ood_folder)  #[:6]
ood_files.sort()
ood_files = ood_files[:6]

ood_audios = []
ood_zlist = []

for file in ood_files:
    audio, z = load_and_encode_audio(os.path.join(ood_folder, file), blender,
                                     SR)
    ood_audios.append(audio)
    ood_zlist.append(z)

all_transfers = []
z_cond = torch.cat(zlist, dim=0)
cond, _ = blender.encoder(z_cond)

for z_time_cond in ood_zlist:
    transfers = []
    x0 = torch.randn_like(z_cond)
    time_cond, _ = blender.encoder_time(z_time_cond)
    time_cond = time_cond.repeat((z_cond.shape[0], 1, 1))
    transfers = sample(x0, time_cond, cond, nb_steps=30)
    all_transfers.append(list(transfers))

In [23]:
audio_size = str(150)

audio_style = "<style>audio { margin-left: 0px; width: " + audio_size + "px; }</style>"
display(HTML(audio_style))

blank_out = widgets.Output(layout={'width': '80px'})
lines, text_widgets, true_audios = [], [blank_out,
                                        blank_out], [blank_out, blank_out]

for i, (name, audio) in enumerate(zip(ood_files, ood_audios)):
    audio_widgets = []

    text = name[:-4]
    out0 = widgets.Output()
    with out0:
        display(HTML(f"<h2 style='text-align:center'>{text}</h1>"))

    out = widgets.Output(
        layout={
            'border': '0.5px solid black',
            'width': f'{audio_size}px',
            'background_color': 'black'
        })
    with out:
        display(Audio(data=audio, rate=SR))
    audio_widgets.append(out0)
    audio_widgets.append(out)

    for j, transfer in enumerate(all_transfers[i]):

        if i == 0:
            out = widgets.Output(
                layout={
                    'border': '0.5px solid black',
                    'width': f'{audio_size}px',
                    'background_color': 'black'
                })
            with out:
                display(Audio(data=audios[j], rate=SR))
            true_audios.append(out)

            text = base_files[j][:-4]
            out0 = widgets.Output()
            with out0:
                display(HTML(f"<h2 style='text-align:center'>{text}</h1>"))

            text_widgets.append(out0)

        out = widgets.Output(layout={
            'border': '0.5px solid black',
            'width': f'{audio_size}px',
        })

        with out:
            display(Audio(data=transfer, rate=SR))
        audio_widgets.append(out)

    lines.extend(copy(audio_widgets))

outputs = widgets.GridBox(
    text_widgets + true_audios + lines,
    layout=widgets.Layout(
        grid_template_columns=f"80px repeat({len(audios) + 1}, {audio_size}px)"
    ))

box = widgets.VBox(children=[outputs])
display(box)

#widgets.HBox(audio_widgets)

### Timbre UMAP

In [None]:
from acids_datasets import CachedSimpleDataset

bsize = 16
nfiles = 5000
n_signal = gin.query_parameter("%N_SIGNAL")

db_path = "/data/nils/datasets/instruments/syntheticv1/m2l_2"
#db_path = "/data/nils/datasets/instruments/slakh/slakh2100_flac_redux/slakh_2048/"

valset = CachedSimpleDataset(path=db_path,
                             keys=["z", "metadata", "midi"],
                             validation=False)

print(valset[0]["z"].shape)


def crop(arrays, length, idxs):
    return [
        torch.stack([xc[..., i:i + length] for i, xc in zip(idxs, array)])
        for array in arrays
    ]


def collate_fn(batch):
    x = torch.from_numpy(np.stack([b["z"] for b in batch], axis=0))

    i0 = np.random.randint(0, x.shape[-1] - n_signal, x.shape[0])
    x_target = crop([x], n_signal, i0)[0]

    i1 = np.random.randint(0, x.shape[-1] - n_signal, x.shape[0])
    x_timbre = crop([x], n_signal, i1)[0]

    meta = [b["metadata"]["instrument"] for b in batch]

    return {"x": x_target, "x_cond": x_timbre, "instrument": meta}


valid_loader = torch.utils.data.DataLoader(valset,
                                           batch_size=bsize,
                                           shuffle=True,
                                           num_workers=0,
                                           drop_last=False,
                                           collate_fn=collate_fn)
nbatches = nfiles // bsize

data = {
    "x": [],
    "cond": [],
    "cond_mean": [],
    "time_cond": [],
    "instrument": [],
}

pbar = tqdm(valid_loader, total=nbatches)

for j, batch in enumerate(pbar):

    x = batch["x"].to(device)
    x_cond = batch["x_cond"].to(device)

    cond, cond_mean, _ = blender.encoder(x_cond, return_mean=True)
    time_cond, _, _ = blender.encoder_time(x_cond, return_mean=True)

    data["cond"].append(cond.cpu())
    data["time_cond"].append(time_cond.cpu())
    data["x"].append(x_cond.cpu())
    data["cond_mean"].append(cond_mean.cpu())

    data["instrument"].extend(batch["instrument"])

    if j > nbatches:
        break
    continue

for k, v in data.items():
    try:
        data[k] = torch.cat(v)
    except Exception as e:
        print(e, "______", k)

data["instrument"] = [
    f if 'Strings (continued)' not in f else 'Strings'
    for f in data["instrument"]
]


#### Timbre embedding

In [8]:
import umap

import matplotlib.pyplot as plt

nexamples = 10000

cond_data = data["cond_mean"][:nexamples]

instrument_to_idx = {
    instr: idx
    for idx, instr in enumerate(set(data["instrument"]))
}

labels = torch.tensor(
    [instrument_to_idx[instr] for instr in data["instrument"]])
labels = labels[:nexamples]

if cond_data.shape[-1] > 2:

    # Perform UMAP dimensionality reduction
    reducer = umap.UMAP(n_components=2, n_neighbors=30, min_dist=0.01)
    embedding = reducer.fit_transform(cond_data)

embedding = data["cond_mean"]

# Create a scatter plot
plt.figure(figsize=(10, 8))
scatter = plt.scatter(embedding[:, 0],
                      embedding[:, 1],
                      c=labels,
                      cmap='Spectral',
                      s=5)

plt.title('2D UMAP of data["cond"]')
# Convert the legend to instrument names instead of indices
handles, _ = scatter.legend_elements()
instrument_names = [
    list(instrument_to_idx.keys())[list(instrument_to_idx.values()).index(i)]
    for i in range(len(handles))
]
plt.legend(handles, instrument_names, title="Instruments")
plt.ylabel('UMAP 2')
#plt.xlim(2, 15)
#plt.ylim(-2, 15)
plt.legend(handles,
           instrument_names,
           title="Instruments",
           loc='center left',
           bbox_to_anchor=(1, 0.5))
plt.show()

#### Time cond embedding

In [64]:
if False:
    import umap
    from einops import rearrange

    import matplotlib.pyplot as plt

    nexamples = 1000

    time_cond_data = data["time_cond"][:nexamples]

    nseq = time_cond_data.shape[-1]

    time_cond_data = rearrange(time_cond_data, "b c s -> (b s) c")

    #time_cond_data = time_cond_data.permute(0, 2,1).reshape(-1, time_cond_data.shape[1])

    data_cur = time_cond_data
    instruments_cur = []
    for instr in data["instrument"][:nexamples]:
        instruments_cur.extend([instr] * nseq)

    instrument_to_idx = {
        instr: idx
        for idx, instr in enumerate(set(instruments_cur))
    }

    labels = torch.tensor(
        [instrument_to_idx[instr] for instr in instruments_cur])
    labels = labels

    if data_cur.shape[-1] > 2:
        # Perform UMAP dimensionality reduction
        reducer = umap.UMAP(n_components=2)  #, n_neighbors=30, min_dist=0.01)
        embedding = reducer.fit_transform(data_cur)
    else:
        embedding = data_cur

In [65]:
if False:
    # Create a scatter plot
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(embedding[:, 0],
                          embedding[:, 1],
                          c=labels,
                          cmap='Spectral',
                          s=5)

    plt.title('2D UMAP of data["cond"]')
    # Convert the legend to instrument names instead of indices
    handles, _ = scatter.legend_elements()
    instrument_names = [
        list(instrument_to_idx.keys())[list(
            instrument_to_idx.values()).index(i)] for i in range(len(handles))
    ]
    plt.legend(handles, instrument_names, title="Instruments")
    plt.ylabel('UMAP 2')
    #plt.xlim(6, 15)
    #
    #plt.ylim(0, 10)
    plt.legend(handles,
               instrument_names,
               title="Instruments",
               loc='center left',
               bbox_to_anchor=(1, 0.5))
    plt.show()

### Transfer

In [9]:
instruments = set(data["instrument"])
all_instruments = np.array(data["instrument"])

embedding = data["cond_mean"]

n_max = 1000


def generate(instrument, instrument2, source=True):
    x_instrument = data["x"][all_instruments == instrument][:n_max]
    cond_data_instrument = data["cond"][all_instruments == instrument][:n_max]
    time_cond_instrument = data["time_cond"][all_instruments ==
                                             instrument][:n_max]

    x0 = torch.randn_like(x_instrument)
    model_out = sample(x0,
                       time_cond_instrument,
                       cond_data_instrument,
                       nb_steps=10,
                       return_z=True)

    cond_rec = blender.encoder(model_out, return_mean=True)[0].cpu().numpy()

    ## Generation with other instrument

    time_cond_instrument = data["time_cond"][all_instruments ==
                                             instrument2][:n_max]

    n_samples = min(time_cond_instrument.shape[0],
                    cond_data_instrument.shape[0])

    x0 = torch.randn_like(x_instrument)

    if type(source) == torch.Tensor:
        cond = source.to(device)
        cond = cond.repeat(n_samples, 1)
    else:
        cond = cond_data_instrument[:n_samples]

    model_out = sample(x0[:n_samples],
                       time_cond_instrument[:n_samples],
                       cond=cond,
                       nb_steps=10,
                       return_z=True)

    cond_rec_transfer = blender.encoder(model_out,
                                        return_mean=True)[0].cpu().numpy()
    return cond_rec, cond_rec_transfer

In [None]:
instrument = '071_Clarinet'

for instrument2 in instruments:

    source = torch.tensor([-0.25, -0.10])
    cond_rec, cond_rec_transfer = generate(instrument,
                                           instrument2,
                                           source=False)

    label = "Generated with same instrument"

    # Create a scatter plot
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(embedding[:, 0],
                          embedding[:, 1],
                          c=labels,
                          cmap='Spectral',
                          s=5)

    plt.title(label)

    handles, _ = scatter.legend_elements()

    instrument_names = [
        list(instrument_to_idx.keys())[list(
            instrument_to_idx.values()).index(i)] for i in range(len(handles))
    ]

    ##
    #scatter2 = plt.scatter(cond_rec[:, 0],
    #                       cond_rec[:, 1],
    #                       c="k",
    #                       s=3,
    #                       label="Same instrument")

    label2 = "Transfer"
    scatter3 = plt.scatter(cond_rec_transfer[:, 0],
                           cond_rec_transfer[:, 1],
                           c="b",
                           s=3,
                           label="Transfer")
    handles.append(scatter3)
    instrument_names.append(label2)

    if type(source) == torch.Tensor:
        label3 = "Source"
        scatter4 = plt.scatter(source[0], source[1], c="k", s=25, label=label3)
        handles.append(scatter4)
        instrument_names.append(label3)

    instrument_names[instrument_names.index(
        instrument)] = instrument + " - $\mathbf{original}$"

    if instrument2 != instrument:
        instrument_names[instrument_names.index(
            instrument2)] = instrument2 + r" - $\bf{transfer \ target}$"

    plt.legend(handles,
               instrument_names,
               title="Instruments",
               loc='center left',
               bbox_to_anchor=(1, 0.5))

    plt.show()
    #break

In [11]:
# Create a scatter plot
plt.figure(figsize=(10, 8))
scatter = plt.scatter(embedding[:, 0],
                      embedding[:, 1],
                      c=labels,
                      cmap='Spectral',
                      s=5)
plt.title(label)
handles, _ = scatter.legend_elements()
instrument_names = [
    list(instrument_to_idx.keys())[list(instrument_to_idx.values()).index(i)]
    for i in range(len(handles))
]

label2 = "Interpolation"

u0 = np.array([0.3, -.75])
u1 = np.array([0., 0.])

t = np.linspace(0, 1, 10)
u = t[:, None] * u0 + (1 - t[:, None]) * u1

scatter2 = plt.scatter(u[:, 0], u[:, 1], c="k", s=15, label=label3)
handles.append(scatter2)
instrument_names.append(label2)

plt.legend(handles,
           instrument_names,
           title="Instruments",
           loc='center left',
           bbox_to_anchor=(1, 0.5))
plt.show()

cond = torch.from_numpy(u).float().to(device)

i0 = 190
print(data["instrument"][i0])
x_source = data["x"][i0:i0 + 1]
time_cond = data["time_cond"][i0:i0 + 1]
time_cond = time_cond.repeat(10, 1, 1)
x = torch.randn_like(data["x"][:1]).repeat(10, 1, 1)

model_out = sample(x, time_cond, cond, nb_steps=10, return_z=False)

print("source time_cond")
audio_source = blender.emb_model.decode(x_source).cpu().squeeze()

display(Audio(audio_source, rate=SR))
for i, a in enumerate(model_out):
    print(cond[i])
    display(Audio(data=a, rate=SR))