# Audio-to-audio generation 

<div style="text-align:center;">
<img src="../images/method.png" alt="Example Image" width="800" />
</div>


This notebook implements the inference for audio-to-audio generation. We demonstrate using the demo samples from the [webpage](https://nilsdem.github.io/control-transfer-diffusion/), but you can load your own structure and timbre targets. 
Please note that although any structure input can be used, the model require samples from the datasets (or quite similar) for the timbre target.


Make sure to [download]() the pretrained models and place them in `./pretrained`. Two pretrained models are available, one trained on [SLAKH 2100](http://www.slakh.com/), and one trained on multiple real-world instrumental recordings (Maestro, URMP, Filobass, GuitarSet...).

In [26]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
import gin

gin.enter_interactive_mode()

from IPython.display import display, Audio
import torch
import numpy as np
import librosa

import sys

sys.path.append('..')

torch.set_grad_enabled(False)

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"


### Checkpoint setup

In [28]:
# Import paths
folder = "../runs/final_configv3"
step = 800000
checkpoint_path = folder + "/checkpoint" + str(step) + "_EMA.pt"
config = folder + "/config.gin"

autoencoder_path = "../pretrained/slakh.ts"

# GPU
device = "cuda:0"

### Instantiate te model and load the checkpoint

In [29]:
from diffusion.model import RectifiedFlow

# Parse config
gin.parse_config_file(config)
SR = gin.query_parameter("%SR")
n_signal = 64

# Emb model

# Instantiate model
blender = RectifiedFlow(device=device)

# Load checkpoints
state_dict = torch.load(checkpoint_path, map_location="cpu")["model_state"]
blender.load_state_dict(state_dict, strict=False)

# Emb model
emb_model = torch.jit.load(autoencoder_path).eval()
blender.emb_model = emb_model

# Send to device
blender = blender.eval().to(device)

  WeightNorm.apply(module, name, dim)
  state_dict = torch.load(checkpoint_path, map_location="cpu")["model_state"]


### Load the dataset

In [32]:
from acids_datasets import SimpleDataset

db_path = "/data/nils/datasets/instruments/slakh/slakh2100_flac_redux/slakh_2048/"
dataset = SimpleDataset(path=db_path, keys=["waveform", "z"])

In [46]:
z1 = dataset[65889]["z"][..., :n_signal]  # guitar
z2 = dataset[3850]["z"][..., :n_signal]

z1, z2 = torch.tensor(z1).to(device).unsqueeze(0), torch.tensor(z2).to(
    device).unsqueeze(0)

x1, x2 = blender.emb_model.decode(
    z1).cpu().squeeze(), blender.emb_model.decode(z2).cpu().squeeze()

display(Audio(x1, rate=SR))
display(Audio(x2, rate=SR))

#### Generation

In [47]:
nb_steps = 40  #Number of diffusion steps
guidance = 1.0  #Classifier free guidance strength

In [48]:
def shuffle(z):
    for n in range(z.shape[0]):
        zsplit = z[n].split(4, dim=-1)
        zsplit = [zsplit[i] for i in torch.randperm(len(zsplit))]
        z[n] = torch.cat(zsplit, dim=-1)
    return z

In [50]:
# Compute structure representation
time_cond1, time_cond2 = blender.encoder_time(z1), blender.encoder_time(z2)

time_cond1, time_cond2 = blender.vector_quantizer(
    time_cond1)[0], blender.vector_quantizer(time_cond2)[0]
# Compute timbre representation
zsem1, zsem2 = blender.encoder(z1), blender.encoder(z2)

zsem = zsem2
time_cond = time_cond1

# Sample initial noise
x0 = torch.randn_like(z1)

print("Normal")
total_guidance = 1.0
guidance_joint_factor = 1.0
guidance_cond_factor = 0

xS = blender.sample(
    x0,
    time_cond=time_cond,
    cond=zsem,
    nb_steps=nb_steps,
    guidance_cond_factor=guidance_cond_factor,
    guidance_joint_factor=guidance_joint_factor,
    total_guidance=total_guidance,
)

time_cond_rec = blender.encoder_time(xS)
#time_cond_rec = blender.vector_quantizer(time_cond_rec)[0]
zsem_rec = blender.encoder(xS)

print("time_cond")
print("MSE",
      torch.nn.functional.mse_loss(time_cond, time_cond_rec).mean().item())
print("Cosine", (1 - torch.nn.functional.cosine_similarity(
    time_cond, time_cond_rec, dim=1, eps=1e-8)).mean().item())

print("zsem")
print("MSE", torch.nn.functional.mse_loss(zsem, zsem_rec).mean().item())
print("Cosine", (1 - torch.nn.functional.cosine_similarity(
    zsem, zsem_rec, dim=1, eps=1e-8)).mean().item())

audio_out = blender.emb_model.decode(xS).cpu().numpy().squeeze()
display(Audio(audio_out, rate=SR))

print("MOre guidance")
total_guidance = 3.

guidance_joint_factor = .7
guidance_cond_factor = 0.3

xS = blender.sample(
    x0,
    time_cond=time_cond,
    cond=zsem,
    nb_steps=nb_steps,
    guidance_cond_factor=guidance_cond_factor,
    guidance_joint_factor=guidance_joint_factor,
    total_guidance=total_guidance,
)

time_cond_rec = blender.encoder_time(xS)
zsem_rec = blender.encoder(xS)

print("time_cond")
print("MSE",
      torch.nn.functional.mse_loss(time_cond, time_cond_rec).mean().item())
print("Cosine", (1 - torch.nn.functional.cosine_similarity(
    time_cond, time_cond_rec, dim=1, eps=1e-8)).mean().item())

print("zsem")
print("MSE", torch.nn.functional.mse_loss(zsem, zsem_rec).mean().item())
print("Cosine", (1 - torch.nn.functional.cosine_similarity(
    zsem, zsem_rec, dim=1, eps=1e-8)).mean().item())

audio_out = blender.emb_model.decode(xS).cpu().numpy().squeeze()
display(Audio(audio_out, rate=SR))

print("no zsem")
for k in range(5):

    total_guidance = 1.0
    guidance_joint_factor = 0.
    guidance_cond_factor = 0.

    xS = blender.sample(
        torch.randn_like(x0),
        time_cond=time_cond,
        cond=zsem,
        nb_steps=nb_steps,
        guidance_cond_factor=guidance_cond_factor,
        guidance_joint_factor=guidance_joint_factor,
        total_guidance=total_guidance,
    )
    audio_out = blender.emb_model.decode(xS).cpu().numpy().squeeze()
    display(Audio(audio_out, rate=SR))

print("no time_cond")
for k in range(3):

    total_guidance = 1.0
    guidance_joint_factor = 1.
    guidance_cond_factor = 1.

    xS = blender.sample(
        torch.randn_like(x0),
        time_cond=-2. * torch.ones_like(time_cond),
        cond=zsem,
        nb_steps=nb_steps,
        guidance_cond_factor=guidance_cond_factor,
        guidance_joint_factor=guidance_joint_factor,
        total_guidance=total_guidance,
    )
    audio_out = blender.emb_model.decode(xS).cpu().numpy().squeeze()
    display(Audio(audio_out, rate=SR))

total_guidance = 1.
guidance_joint_factor = 0.4

for guidance_cond_factor in [0.2, 0.8]:
    print(f"Guidance cond factor: {guidance_cond_factor}")
    xS = blender.sample(
        x0,
        time_cond=time_cond,
        cond=zsem,
        nb_steps=nb_steps,
        guidance_cond_factor=guidance_cond_factor,
        guidance_joint_factor=guidance_joint_factor,
        total_guidance=total_guidance,
    )

    audio_out = blender.emb_model.decode(xS).cpu().numpy().squeeze()
    display(Audio(audio_out, rate=SR))

Normal


time_cond
MSE 0.9777068495750427
Cosine 1.0507240295410156
zsem
MSE 0.035370949655771255
Cosine 0.10968649387359619


MOre guidance
time_cond
MSE 1.4116277694702148
Cosine 1.0405094623565674
zsem
MSE 0.02062067948281765
Cosine 0.06033533811569214


no zsem


no time_cond


Guidance cond factor: 0.2


Guidance cond factor: 0.8
