# Audio-to-audio generation 

<div style="text-align:center;">
<img src="../images/method.png" alt="Example Image" width="800" />
</div>


This notebook implements the inference for audio-to-audio generation. We demonstrate using the demo samples from the [webpage](https://nilsdem.github.io/control-transfer-diffusion/), but you can load your own structure and timbre targets. 
Please note that although any structure input can be used, the model require samples from the datasets (or quite similar) for the timbre target.


Make sure to [download]() the pretrained models and place them in `./pretrained`. Two pretrained models are available, one trained on [SLAKH 2100](http://www.slakh.com/), and one trained on multiple real-world instrumental recordings (Maestro, URMP, Filobass, GuitarSet...).

In [24]:
%load_ext autoreload
%autoreload 2

In [1]:
import gin

gin.enter_interactive_mode()

from IPython.display import display, Audio
import torch
import numpy as np
import librosa

import sys

sys.path.append('..')

torch.set_grad_enabled(False)

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"


### Checkpoint setup

In [5]:
# Import paths
folder = "../runs/test_distill5_1past_comp"
step = 1300000
checkpoint_path = folder + "/checkpoint" + str(step) + "_EMA.pt"
config = folder + "/config.gin"

autoencoder_path = "../pretrained/slakh.ts"

# GPU
device = "cuda:0"

### Instantiate te model and load the checkpoint

In [24]:
from diffusion.model import RectifiedFlow

import cached_conv as cc

cc.use_cached_conv(False)

# Parse config
gin.parse_config_file(config)
SR = gin.query_parameter("%SR")

# Emb model

# Instantiate model
blender = RectifiedFlow(device=device)

# Load checkpoints
state_dict = torch.load(checkpoint_path, map_location="cpu")["model_state"]
blender.load_state_dict(state_dict, strict=False)

# Emb model
emb_model = torch.jit.load(autoencoder_path).eval()
blender.emb_model = emb_model

# Send to device
blender = blender.eval().to(device)

  WeightNorm.apply(module, name, dim)
  state_dict = torch.load(checkpoint_path, map_location="cpu")["model_state"]


### Load the dataset

In [7]:
from acids_datasets import SimpleDataset

db_path = "/data/nils/datasets/instruments/slakh/slakh2100_flac_redux/slakh_2048/"
dataset = SimpleDataset(path=db_path, keys=["waveform", "z"])

In [8]:
z1 = dataset[65889]["z"]  #[..., :64]  # guitar
z2 = dataset[3850]["z"]  #[..., :64]

#z1 = dataset[390]["z"]  #[..., :64]
#z2 = dataset[38079]["z"]  #[..., :64]

z1, z2 = torch.tensor(z1).to(device).unsqueeze(0), torch.tensor(z2).to(
    device).unsqueeze(0)

x1, x2 = blender.emb_model.decode(
    z1).cpu().squeeze(), blender.emb_model.decode(z2).cpu().squeeze()

display(Audio(x1, rate=SR))
display(Audio(x2, rate=SR))

#### Generation

In [16]:
nb_steps = 1  #Number of diffusion steps
n_signal_stream = 8

# Compute structure representation
time_cond1, time_cond2 = blender.encoder_time(z1), blender.encoder_time(z2)

# Compute timbre representation
try:
    zsem1, zsem2 = blender.post_encoder(blender.encoder(
        z1[..., :64])), blender.post_encoder(blender.encoder(z2[..., :64]))
except:
    zsem1, zsem2 = blender.encoder(z1[..., :64]), blender.encoder(z2[..., :64])

zsem = zsem1
time_cond = time_cond1

# Sample initial noise

print("Normal")
total_guidance = 1.
guidance_joint_factor = 1.0
guidance_cond_factor = 1.0

xS = -2. * torch.ones_like(z1[..., :n_signal_stream])
xtot = []

for i in range(1, z1.shape[-1] // n_signal_stream - 1):
    x0 = torch.randn_like(z1[..., :n_signal_stream])
    time_cond_cur = time_cond[...,
                              i * n_signal_stream:(i + 1) * n_signal_stream]

    x_past = z1[..., (i - 1) * n_signal_stream:i * n_signal_stream]
    x_past = xS.clone()

    #time_cond_full = torch.cat([time_cond_cur, x_past], dim=1)

    xS = blender.sample(
        x0,
        time_cond=time_cond_cur,
        time_cond_add=x_past,
        cond=zsem,
        nb_steps=nb_steps,
        guidance_cond_factor=guidance_cond_factor,
        guidance_joint_factor=guidance_joint_factor,
        total_guidance=total_guidance,
    )

    xtot.append(xS.clone())

xtot = torch.cat(xtot, dim=-1)
audio = blender.emb_model.decode(xtot).cpu().squeeze()
display(Audio(audio, rate=SR))


Normal


In [25]:
nb_steps = 1  #Number of diffusion steps
n_signal_stream = 8

# Compute structure representation
time_cond1, time_cond2 = blender.encoder_time(z1), blender.encoder_time(z2)

# Compute timbre representation
try:
    zsem1, zsem2 = blender.post_encoder(blender.encoder(
        z1[..., :64])), blender.post_encoder(blender.encoder(z2[..., :64]))
except:
    zsem1, zsem2 = blender.encoder(z1[..., :64]), blender.encoder(z2[..., :64])

z = z1
zsem = zsem1
time_cond = time_cond1

# Sample initial noise

print("Normal")
total_guidance = 1.
guidance_joint_factor = 1.0
guidance_cond_factor = 1.0

xS = -2. * torch.ones_like(z1[..., :n_signal_stream])
xtot = []

for i in range(1, z1.shape[-1] // n_signal_stream - 1):
    x0 = torch.randn_like(z1[..., :n_signal_stream])

    x_time_cond_cur = z[..., i * n_signal_stream:(i + 1) * n_signal_stream]
    time_cond_cur = blender.encoder_time(x_time_cond_cur)

    if i == 1:
        print(time_cond_cur)
    #x_past = z1[..., (i - 1) * n_signal_stream:i * n_signal_stream]
    x_past = xS.clone()

    #time_cond_full = torch.cat([time_cond_cur, x_past], dim=1)

    xS = blender.sample(
        x0,
        time_cond=time_cond_cur,
        time_cond_add=x_past,
        cond=zsem,
        nb_steps=nb_steps,
        guidance_cond_factor=guidance_cond_factor,
        guidance_joint_factor=guidance_joint_factor,
        total_guidance=total_guidance,
    )

    xtot.append(xS.clone())

xtot = torch.cat(xtot, dim=-1)
audio = blender.emb_model.decode(xtot).cpu().squeeze()
display(Audio(audio, rate=SR))


Normal
tensor([[[ 0.0469,  0.3547,  0.0735,  0.4009, -0.2207,  0.0660, -0.1304,
          -0.4334],
         [ 0.3100,  0.0180, -0.1174, -0.1071,  0.4585,  0.6238, -0.0148,
          -0.0882],
         [ 0.1583,  0.3158,  0.4900,  0.0179,  0.2929,  0.4704, -0.0020,
           0.3526],
         [ 0.2453, -0.2693, -0.0043, -0.1929, -0.5953,  0.3330,  0.1204,
           0.0092],
         [ 0.1601, -0.0486,  0.0921, -0.0687,  0.0824, -0.0188, -0.2228,
           0.1101],
         [ 0.4236, -0.0402, -0.4012,  0.1729,  0.3241,  0.2615,  0.0332,
           0.1646],
         [ 0.1811,  0.3415,  0.0219,  0.0041,  0.1906,  0.2388, -0.0347,
          -0.0230],
         [ 0.1890,  0.2055,  0.1165,  0.1523,  0.0480,  0.1023,  0.2722,
          -0.0896],
         [ 0.1187,  0.3088, -0.0657,  0.4004, -0.0137,  0.0275,  0.1602,
           0.4123],
         [ 0.1589, -0.0750,  0.1476,  0.4700,  0.0231,  0.0404, -0.1764,
          -0.0867],
         [ 0.3598,  0.2125,  0.1534,  0.2434,  0.0279,  0.1569,

In [23]:
nb_steps = 1  #Number of diffusion steps
n_signal_stream = 8

# Compute structure representation
time_cond1, time_cond2 = blender.encoder_time(z1), blender.encoder_time(z2)

# Compute timbre representation
try:
    zsem1, zsem2 = blender.post_encoder(blender.encoder(
        z1[..., :64])), blender.post_encoder(blender.encoder(z2[..., :64]))
except:
    zsem1, zsem2 = blender.encoder(z1[..., :64]), blender.encoder(z2[..., :64])

z = z1
zsem = zsem1
time_cond = time_cond1

# Sample initial noise

print("Normal")
total_guidance = 1.
guidance_joint_factor = 1.0
guidance_cond_factor = 1.0

xS = -2. * torch.ones_like(z1[..., :n_signal_stream])
xtot = []

for i in range(1, z1.shape[-1] // n_signal_stream - 1):
    x0 = torch.randn_like(z1[..., :n_signal_stream])

    x_time_cond_cur = z[..., i * n_signal_stream:(i + 1) * n_signal_stream]
    time_cond_cur = blender.encoder_time(x_time_cond_cur)

    if i == 1:
        print(time_cond_cur)
    #x_past = z1[..., (i - 1) * n_signal_stream:i * n_signal_stream]
    x_past = xS.clone()

    #time_cond_full = torch.cat([time_cond_cur, x_past], dim=1)

    xS = blender.sample(
        x0,
        time_cond=time_cond_cur,
        time_cond_add=x_past,
        cond=zsem,
        nb_steps=nb_steps,
        guidance_cond_factor=guidance_cond_factor,
        guidance_joint_factor=guidance_joint_factor,
        total_guidance=total_guidance,
    )

    xtot.append(xS.clone())

xtot = torch.cat(xtot, dim=-1)
audio = blender.emb_model.decode(xtot).cpu().squeeze()
display(Audio(audio, rate=SR))


Normal
tensor([[[ 0.2438,  0.2652,  0.0337,  0.3950, -0.2051,  0.0757, -0.1282,
          -0.4339],
         [ 0.2844, -0.0401, -0.1649, -0.1364,  0.4540,  0.6215, -0.0236,
          -0.0953],
         [ 0.0483,  0.2939,  0.4554, -0.0033,  0.2863,  0.4716,  0.0008,
           0.3516],
         [ 0.2141, -0.3383, -0.0266, -0.1984, -0.6072,  0.3315,  0.1234,
           0.0110],
         [-0.0117, -0.0380,  0.1418, -0.0525,  0.0964, -0.0081, -0.2143,
           0.1133],
         [ 0.3709, -0.1572, -0.4681,  0.1474,  0.3230,  0.2610,  0.0330,
           0.1614],
         [ 0.2141,  0.3489,  0.0231,  0.0038,  0.2014,  0.2452, -0.0324,
          -0.0214],
         [ 0.1194,  0.2238,  0.1033,  0.1442,  0.0537,  0.1099,  0.2743,
          -0.0894],
         [ 0.0340,  0.3576, -0.1026,  0.3789, -0.0056,  0.0310,  0.1677,
           0.4139],
         [ 0.2168, -0.1069,  0.1055,  0.4429,  0.0226,  0.0458, -0.1771,
          -0.0878],
         [ 0.4541,  0.2876,  0.1678,  0.2459,  0.0451,  0.1654,

In [None]:
if False:
    n_signal = 8
    n_signal_timbre = 64
    bsize = 8


    def crop(arrays, length, idxs):
        return [
            torch.stack([xc[..., i:i + length] for i, xc in zip(idxs, array)])
            for array in arrays
        ]


    def collate_fn(batch):

        margin = n_signal_timbre // 2

        x = torch.from_numpy(np.stack([b["z"] for b in batch], axis=0))

        i0 = np.random.randint(margin, x.shape[-1] - n_signal, x.shape[0])
        x_target_margin = crop([x], n_signal + margin, i0 - margin)[0]
        x_target = crop([x], n_signal, i0)[0]
        x_past = crop([x], n_signal, i0 - n_signal)[0]

        i1 = np.random.randint(0, x.shape[-1] - n_signal_timbre, x.shape[0])
        x_timbre = crop([x], n_signal_timbre, i1)[0]

        return {
            "x": x_target,
            "x_cond": x_timbre,
            "x_time_cond": x_target_margin,
            "x_time_cond_additional": x_past,
            "margin": margin,
        }


    train_loader = torch.utils.data.DataLoader(dataset,
                                            batch_size=bsize,
                                            shuffle=True,
                                            num_workers=0,
                                            drop_last=True,
                                            collate_fn=collate_fn)

    d = next(iter(train_loader))

    x_target = d["x"]
    x_time_cond = d["x_time_cond"][..., d["margin"]:]
    x_timbre = d["x_cond"]
    x_past = d["x_time_cond_additional"]

    z1 = torch.tensor(0.).to(device)
    audio_target = emb_model.decode(x_target.to(z1)).cpu().numpy()
    audio_time_cond = emb_model.decode(x_time_cond.to(z1)).cpu().numpy()
    audio_timbre = emb_model.decode(x_timbre.to(z1)).cpu().numpy()
    audio_past = emb_model.decode(x_past.to(z1)).cpu().numpy()

    audio_past_and_cond = torch.cat((x_past, x_time_cond), dim=-1)

    display(Audio(audio_timbre[0], rate=SR))
    display(Audio(audio_target[0], rate=SR))
    display(Audio(audio_time_cond[0], rate=SR))
    display(Audio(audio_past[0], rate=SR))
    display(
        Audio(emb_model.decode(audio_past_and_cond.to(z1)).cpu().numpy()[0],
            rate=SR))