# Audio-to-audio generation 


In [4]:
import gin

gin.enter_interactive_mode()

from IPython.display import display, Audio
import torch
import numpy as np
import librosa

import sys

sys.path.append('..')

torch.set_grad_enabled(False)

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"


### Checkpoint setup

In [5]:
name = "test_diffusion"
step = 0
autoencoder_path = ""
device = "cpu"

### Instantiate te model and load the checkpoint

In [None]:
from diffusion.model import RectifiedFlow

folder = os.path.join("../diffusion/runs", name)
checkpoint_path = folder + "/checkpoint" + str(step) + "_EMA.pt"
config = folder + "/config.gin"

# Parse config
gin.parse_config_file(config)
SR = gin.query_parameter("%SR")
n_signal = gin.query_parameter("%N_SIGNAL")

# Emb model

# Instantiate model
blender = RectifiedFlow(device=device)

# Load checkpoints
state_dict = torch.load(checkpoint_path, map_location="cpu")["model_state"]
blender.load_state_dict(state_dict, strict=False)

# Emb model
emb_model = torch.jit.load(autoencoder_path).eval()
blender.emb_model = emb_model

# Send to device
blender = blender.eval().to(device)

### Load the dataset

In [7]:
from dataset import CachedSimpleDataset
from IPython.display import display, Audio

db_path = ""
dataset = CachedSimpleDataset(path=db_path, keys=["z"])


In [8]:
z1 = dataset[0]["z"][..., :n_signal]  # guitar
z2 = dataset[1]["z"][..., :n_signal]

z1, z2 = torch.tensor(z1).to(device).unsqueeze(0), torch.tensor(z2).to(
    device).unsqueeze(0)

x1, x2 = blender.emb_model.decode(
    z1).cpu().squeeze(), blender.emb_model.decode(z2).cpu().squeeze()

display(Audio(x1, rate=SR))
display(Audio(x2, rate=SR))

#### Generation

In [9]:
nb_steps = 10  #Number of diffusion steps
guidance = 1.0  #Classifier free guidance strength

In [10]:
# Compute structure representation
time_cond1, time_cond2 = blender.encoder_time(z1), blender.encoder_time(z2)

# Compute timbre representation
zsem1, zsem2 = blender.encoder(z1), blender.encoder(z2)

time_cond = time_cond1
zsem = zsem2

# Sample initial noise
x0 = torch.randn_like(z1)

print("Normal")
total_guidance = 1.0
guidance_joint_factor = 1.0
guidance_cond_factor = 0

xS = blender.sample(
    x0,
    time_cond=time_cond,
    cond=zsem,
    nb_steps=nb_steps,
    guidance_cond_factor=guidance_cond_factor,
    guidance_joint_factor=guidance_joint_factor,
    total_guidance=total_guidance,
)

audio_out = blender.emb_model.decode(xS).cpu().numpy().squeeze()
display(Audio(audio_out, rate=SR))

print("More guidance on timbre")
total_guidance = 3.
guidance_joint_factor = .5
guidance_cond_factor = 0.8

xS = blender.sample(
    x0,
    time_cond=time_cond,
    cond=zsem,
    nb_steps=nb_steps,
    guidance_cond_factor=guidance_cond_factor,
    guidance_joint_factor=guidance_joint_factor,
    total_guidance=total_guidance,
)

time_cond_rec = blender.encoder_time(xS)
zsem_rec = blender.encoder(xS)

audio_out = blender.emb_model.decode(xS).cpu().numpy().squeeze()
display(Audio(audio_out, rate=SR))

print("no zsem")
for k in range(5):

    total_guidance = 1.0
    guidance_joint_factor = 0.
    guidance_cond_factor = 0.

    xS = blender.sample(
        torch.randn_like(x0),
        time_cond=time_cond,
        cond=zsem,
        nb_steps=nb_steps,
        guidance_cond_factor=guidance_cond_factor,
        guidance_joint_factor=guidance_joint_factor,
        total_guidance=total_guidance,
    )
    audio_out = blender.emb_model.decode(xS).cpu().numpy().squeeze()
    display(Audio(audio_out, rate=SR))

print("no time_cond")
for k in range(5):

    total_guidance = 1.0
    guidance_joint_factor = 0.
    guidance_cond_factor = 1.

    xS = blender.sample(
        torch.randn_like(x0),
        time_cond=-2. * torch.ones_like(time_cond),
        cond=zsem,
        nb_steps=nb_steps,
        guidance_cond_factor=guidance_cond_factor,
        guidance_joint_factor=guidance_joint_factor,
        total_guidance=total_guidance,
    )
    audio_out = blender.emb_model.decode(xS).cpu().numpy().squeeze()
    display(Audio(audio_out, rate=SR))


Normal


More guidance on timbre


no zsem


no time_cond
