In [22]:
from pathlib import Path 
import sys
from audio_diffusion_pytorch import AudioDiffusionModel, DiffusionInpainter
from audio_diffusion_pytorch.model import get_default_sampling_kwargs
import torch
import torchaudio
import torch.nn as nn
from ipywidgets import widgets
from IPython.display import Audio, HTML, display
import math

from torch import Tensor
root_path=Path("..").resolve().absolute()
device=torch.device("cuda")#torch.device("cpu")#
sys.path.append(str(root_path))

SAMPLING_RATE=22050
LENGTH_SAMPLES = 2**math.ceil(math.log2(6 * SAMPLING_RATE))


%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import hydra
hydra.initialize(config_path="..", job_name="test_app")

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  hydra.initialize(config_path="..", job_name="test_app")


hydra.initialize()

## Initialize model

In [3]:
device = 'cuda'
cfg = hydra.compose(config_name="config", overrides=["exp=base_slakh_supervised_unconditional_2stem"])
model = hydra.utils.instantiate(cfg.model)

#ckpt = torch.load('logs/ckpts/drums_piano_slakh/epoch=606-valid_loss=0.017.ckpt', map_location=device)["state_dict"]
state_dict = torch.load('/data/drums_piano_epoch=358.pth.tar', map_location=device)["state_dict"]
model.load_state_dict(state_dict)
model : AudioDiffusionModel = model.model.to(device)

In [10]:
cfg = hydra.compose(config_name="config", overrides=["exp=base_slakh_supervised_unconditional"])
model = hydra.utils.instantiate(cfg.model)

ckpt = torch.load(root_path / 'logs/ckpts/all_slakh_epoch=419.ckpt', map_location=device)
model.load_state_dict(ckpt['state_dict'])
model : AudioDiffusionModel = model.model.to(device)

#model = load_context('/home/emilian/Projects/audio-diffusion-pytorch-trainer/logs/ckpts/all_slakh_epoch=419.ckpt', device, 4)
#model : AudioDiffusionModel = model.model.to(device)

# Forward Inference

In [44]:
def check_results_to_save(results: dict, stems: list, sampling_rate: int):
    checkbox_widgets=[]
    num_samples = results[stems[0]].shape[0]
    
    for i in range(num_samples):
        audio_widgets = []
        for stem in cfg.stems:
            out = widgets.Output()
            with out:
                display(f'Sample {i} - Stem {stem}')
                display(Audio(results[stem][i, :], rate = sampling_rate, normalize=False))
            audio_widgets.append(out)

        out = widgets.Output()
        with out:
            display(f'Sample {i} - Mixture')
            display(Audio(results['mix'][i, :], rate = sampling_rate, normalize=False))
        audio_widgets.append(out)

        out = widgets.Output()
        with out:
            #display(f"Save sample {i}?")
            checkbox_widgets.append(widgets.Checkbox(value=False, description='Save sample', disabled=False))
            display(checkbox_widgets[-1])
        audio_widgets.append(out)

        display(widgets.HBox(audio_widgets))
    return checkbox_widgets

def store_results(checkbox_widgets: list, results: dict, stems: list, output_path: str, sampling_rate: int):
    # Store choosen mixtures
    for i, c in enumerate(checkbox_widgets):
        if c.value:
            for s in stems:
                torchaudio.save(root_path/f"{output_path}/sample-{s}-{i}.mp3", src=results[s][i:i+1], sample_rate=sampling_rate)
            torchaudio.save(root_path/f"{output_path}/sample-mix-{i}.mp3", src=results["mix"][i:i+1], sample_rate=sampling_rate)

            #if len(stems) == 2:
            #    t = torch.stack([results[s][i] for s in stems],dim=0)
            #    torchaudio.save(root_path/f"logs/tmp/sample-{i}-stereo.mp3", src=t, sample_rate=sampling_rate)

## Total generation

In [56]:
from audio_diffusion_pytorch import KarrasSchedule, AEulerSampler, ADPM2Sampler
 
noise = torch.randn(12, len(cfg.stems), LENGTH_SAMPLES).to(device)
samples = model.sample(
    noise=noise, 
    num_steps=200, 
    sigma_schedule=KarrasSchedule(sigma_min=1e-4, sigma_max=1.0, rho=7.0), 
    sampler=ADPM2Sampler(),
)

res = {}
for i, stem in enumerate(cfg.stems):
    res[stem] = samples[:, i, :].detach().cpu()
res['mix'] = samples.sum(dim=1).detach().cpu()

In [57]:
checkbox = check_results_to_save(res, stems=["drums", "piano"], sampling_rate=SAMPLING_RATE);

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

In [58]:
store_results(
    checkbox_widgets=checkbox,
    results=res, 
    stems=["drums", "piano"],
    output_path="logs/tmp/generation",
    sampling_rate=22050,
)

## Partial generation

Load data

In [52]:
num_stems = len(cfg.stems)

#drums_track, loaded_sr = torchaudio.load('/home/giorgio/drums_conditional.wav')

drums_track, loaded_sr = torchaudio.load('/data/Slakh_supervised/test/Track01881/drums.wav')
piano_track, loaded_sr_2 = torchaudio.load('/data/Slakh_supervised/test/Track01881/piano.wav')

drums_track = torchaudio.functional.resample(drums_track, orig_freq=loaded_sr, new_freq=SAMPLING_RATE)
piano_track = torchaudio.functional.resample(piano_track, orig_freq=loaded_sr_2, new_freq=SAMPLING_RATE)

start_sample = SAMPLING_RATE * 40
drums_track = drums_track[:, start_sample:start_sample + LENGTH_SAMPLES]
piano_track = piano_track[:, start_sample:start_sample + LENGTH_SAMPLES]

assert drums_track.shape[-1] == piano_track.shape[-1] == length_samples, (piano_track.shape[-1], drums_track.shape[-1], length_samples)

display(Audio(drums_track, rate = SAMPLING_RATE))
display(Audio(piano_track, rate = SAMPLING_RATE))

Load model and prepair inpainter

In [53]:
from audio_diffusion_pytorch import KarrasSchedule, AEulerSampler, ADPM2Sampler

#sigma_schedule = get_default_sampling_kwargs()['sigma_schedule']

inpainter = DiffusionInpainter(
    diffusion=model.diffusion,
    num_steps=200, 
    num_resamples=1,
    sampler= ADPM2Sampler(), 
    sigma_schedule=KarrasSchedule(sigma_min=1e-4, sigma_max=1.0, rho=7.0),
)

Piano inpainting:

In [7]:
inpaint = torch.randn(32, len(cfg.stems), length_samples).to(device)
inpaint[:, 0, :] = drums_track
inpaint_mask = torch.ones_like(inpaint)
inpaint_mask[:, 1, :] = 0.

samples = inpainter(inpaint, inpaint_mask.bool())

piano_inpaint_results = {stem: samples[:, i, :].detach().cpu() for i, stem in enumerate(cfg.stems)}
piano_inpaint_results['mix'] = samples.sum(dim=1).detach().cpu()

checkbox = check_results_to_save(piano_inpaint_results, stems=["drums", "piano"], sampling_rate=SAMPLING_RATE);

In [None]:
store_results(
    checkbox_widgets=checkbox,
    results=piano_inpaint_results, 
    stems=["drums", "piano"],
    output_path=ROOT_PATH / "logs/tmp/inpainting",
    sampling_rate=SAMPLING_RATE,
)

Drums inpainting:

In [54]:
inpaint = torch.randn(12, len(cfg.stems), length_samples).to(device)
inpaint[:, 1, :] = piano_track
inpaint_mask = torch.ones_like(inpaint)
inpaint_mask[:, 0, :] = 0.

samples = inpainter(inpaint, inpaint_mask.bool())

drums_inpaint_results = {stem: samples[:, i, :].detach().cpu() for i, stem in enumerate(cfg.stems)}
drums_inpaint_results['mix'] = samples.sum(dim=1).detach().cpu()
checkbox = check_results_to_save(drums_inpaint_results, stems=["drums", "piano"], sampling_rate=SAMPLING_RATE);

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

HBox(children=(Output(), Output(), Output(), Output()))

In [55]:
store_results(
    checkbox_widgets=checkbox,
    results=drums_inpaint_results, 
    stems=["drums", "piano"],
    output_path=root_path / "logs/tmp/inpainting",
    sampling_rate=SAMPLING_RATE,
)

### Likelihood on the samples

In [76]:
def append_dims(x, target_dims):
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
    return x[(...,) + (None,) * dims_to_append]

def to_d(x, sigma, denoised):
    """Converts a denoiser output to a Karras ODE derivative."""
    return (x - denoised) / append_dims(sigma, x.ndim)

@torch.no_grad()
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    v = torch.randint_like(x, 2) * 2 - 1
    fevals = 0
    
    def ode_fn(sigma, x):
        nonlocal fevals
        with torch.enable_grad():
            x = x[0].detach().requires_grad_()
            denoised = model(x, sigma * s_in, **extra_args)
            d = to_d(x, sigma, denoised)
            fevals += 1
            grad = torch.autograd.grad((d * v).sum(), x)[0]
            d_ll = (v * grad).flatten(1).sum(1)
        return d.detach(), d_ll
    
    x_min = x, x.new_zeros([x.shape[0]])
    t = x.new_tensor([sigma_min, sigma_max])
    sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
    
    latent, delta_ll = sol[0][-1], sol[1][-1]
    ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
    return ll_prior + delta_ll, {'fevals': fevals}

## Style Transfer

In [71]:
sampling_rate = 22050
length = 6
length_samples = 2 ** math.ceil(math.log2(length * sampling_rate))
num_samples = 32
num_steps = 50
num_resamples = 1
num_stems = len(cfg.stems)

start_sample = 22050 * 15

drums_track, loaded_sr = torchaudio.load('data/SLAKH2000_super/test/Track01876/drums.wav')
assert sampling_rate == loaded_sr
drums_track = drums_track[:, start_sample:start_sample + length_samples]

bass_track, loaded_sr = torchaudio.load('data/SLAKH2000_super/test/Track01876/bass.wav')
assert sampling_rate == loaded_sr
bass_track = bass_track[:, start_sample:start_sample + length_samples]

sampler = get_default_sampling_kwargs()['sampler']
sigma_schedule = get_default_sampling_kwargs()['sigma_schedule']

inpainter = DiffusionInpainter(diffusion=model.diffusion, num_steps=num_steps, num_resamples=num_resamples,
                               sampler=sampler, sigma_schedule=sigma_schedule)
inpaint = torch.randn(num_samples, num_stems, length_samples).to(device)

inpaint[:, 0, :] = drums_track
inpaint[:, 1, :] = bass_track

In [None]:
inpaint_mask = torch.ones_like(inpaint)
inpaint_mask[:, 1, :] = 0.
display(Audio(drums_track, rate=sampling_rate))
samples = inpainter(inpaint, inpaint_mask.bool())
res = {}
for i, stem in enumerate(cfg.stems):
    res[stem] = samples[:, i, :].detach().cpu()
res['mix'] = samples.sum(dim=1).detach().cpu()

for i in range(num_samples):
    audio_widgets = []
    for stem in cfg.stems:
        out = widgets.Output()
        with out:
            display(f'Sample {i} - Stem {stem}')
            display(Audio(res[stem][i, :], rate=sampling_rate))
        audio_widgets.append(out)
    out = widgets.Output()
    with out:
        display(f'Sample {i} - Mixture')
        display(Audio(res['mix'][i, :], rate=sampling_rate))
    audio_widgets.append(out)
    display(widgets.HBox(audio_widgets))

In [72]:
# display(Audio(bass_track, rate=sampling_rate))