# MIDI-to-audio generation 


In [6]:
import gin

gin.enter_interactive_mode()

from IPython.display import display, Audio
import torch
import numpy as np
import librosa
import matplotlib.pyplot as plt
import sys
sys.path.append('..')
torch.set_grad_enabled(False)

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

### Checkpoint setup

In [2]:
model_path = ""
step = 0
autoencoder_path = ""
device = "cpu"

### Instantiate te model and load the checkpoint

In [None]:
from after.diffusion import RectifiedFlow

checkpoint_path = model_path + "/checkpoint" + str(step) + "_EMA.pt"
config = os.path.join(model_path, "config.gin")

# Parse config
gin.parse_config_file(config)
SR = gin.query_parameter("%SR")
n_signal = gin.query_parameter("%N_SIGNAL")

# Emb model

# Instantiate model
blender = RectifiedFlow(device=device)

# Load checkpoints
state_dict = torch.load(checkpoint_path, map_location="cpu")["model_state"]
blender.load_state_dict(state_dict, strict=False)

# Emb model
emb_model = torch.jit.load(autoencoder_path).eval()
blender.emb_model = emb_model

# Send to device
blender = blender.eval().to(device)

### Load the dataset

In [4]:
from after.dataset import SimpleDataset
from IPython.display import display, Audio

db_path = ""
dataset = SimpleDataset(path=db_path, keys=["z", "midi"])


In [None]:
d1 = dataset[0]
d2 = dataset[1]

z1 = d1["z"][..., :n_signal]  # guitar
z2 = d2["z"][..., :n_signal]

z1, z2 = torch.tensor(z1).to(device).unsqueeze(0), torch.tensor(z2).to(
    device).unsqueeze(0)


def normalize(array):
    return (array - array.min()) / (array.max() - array.min() + 1e-6)


ae_ratio = gin.query_parameter("utils.collate_fn.ae_ratio")
full_length = dataset[0]["z"].shape[-1]
times = times = np.linspace(0, full_length * ae_ratio / SR, full_length)

midis = [d1["midi"], d2["midi"]]
pr = [m.get_piano_roll(times=times) for m in midis]
pr = map(normalize, pr)
pr = np.stack(list(pr))
pr = pr[..., :n_signal]

pr = torch.from_numpy(pr).float().unsqueeze(1).to(device)

pr1, pr2 = pr

x1, x2 = blender.emb_model.decode(
    z1).cpu().squeeze(), blender.emb_model.decode(z2).cpu().squeeze()

display(Audio(x1, rate=SR))
plt.imshow(pr1[0].cpu().numpy(), aspect="auto", origin="lower")
plt.show()

display(Audio(x2, rate=SR))
plt.imshow(pr2[0].cpu().numpy(), aspect="auto", origin="lower")
plt.show()

#### Generation

In [8]:
nb_steps = 20  #Number of diffusion steps
guidance = 1.0  #Classifier free guidance strength

In [None]:
# Compute timbre representation
zsem1, zsem2 = blender.encoder(z1), blender.encoder(z2)

time_cond = pr1
zsem = zsem2

# Sample initial noise
x0 = torch.randn_like(z1)

print("Normal")
total_guidance = 1.0
guidance_joint_factor = 1.0
guidance_cond_factor = 0

xS = blender.sample(
    x0,
    time_cond=time_cond,
    cond=zsem,
    nb_steps=nb_steps,
)

audio_out = blender.emb_model.decode(xS).cpu().numpy().squeeze()
display(Audio(audio_out, rate=SR))

