In [2]:
import sys
from pathlib import Path
from typing import *
import torch

DEVICE = torch.device("cuda:0")
SAMPLE_RATE = 22050 # < IMPORTANT: do not change
STEMS = ["bass","drums","guitar","piano"] # < IMPORTANT: do not change
ROOT_PATH = Path("..").resolve().absolute()
CKPT_PATH = ROOT_PATH / "ckpts"
DATA_PATH = ROOT_PATH / "data"

sys.path.append(str(ROOT_PATH))
%load_ext autoreload
%autoreload 2

### Utility functions

In [None]:
import torch
import torchaudio
import numpy as np

def load_track(track_folder: Path, stems: List[str]):
    wavs = []
    for s in stems:
        wav, wav_sr = torchaudio.load(track_folder/f"{s}.wav")
        assert wav_sr == SAMPLE_RATE
        assert wav.shape[0] == 1 # < single channel
        wavs += [wav]
    return torch.cat(wavs, dim=0).unsqueeze(0)

def to_audio_widget(wav: torch.Tensor, normalize: bool = False):
    assert len(wav.shape) == 2, f"shape: {wav.shape}"
    return Audio(
            wav.sum(dim=0, keepdims=True).cpu(), 
            rate=SAMPLE_RATE, 
            normalize=normalize,
        )
    
def wrap_in_out(*obj):
    out = widgets.Output()
    with out:
        display(*obj)
    return out

def grid_widget(grid_of_objs):
    col_boxes = []
    for row_of_objs in grid_of_objs:
        row_outs = []
        for obj in row_of_objs:
            row_outs += [obj]
        col_boxes += [widgets.HBox(row_outs)]
    return widgets.VBox(col_boxes)

In [None]:
import tqdm
from audio_diffusion_pytorch import KarrasSchedule
from main.separation import MSDMSeparator


def score_differential(x, sigma, denoise_fn):
    d = (x - denoise_fn(x, sigma=sigma)) / sigma 
    return d


@torch.no_grad()
def generate_track(
    denoise_fn: Callable,
    sigmas: torch.Tensor,
    noises: torch.Tensor,
    source: Optional[torch.Tensor] = None,
    mask: Optional[torch.Tensor] = None,
    num_resamples: int = 1,
    s_churn: float = 0.0,
    differential_fn: Callable = score_differential,
) -> torch.Tensor:

    x = sigmas[0] * noises
    _, num_sources, _  = x.shape    

    # Initialize default values
    source = torch.zeros_like(x) if source is None else source
    mask = torch.zeros_like(x) if mask is None else mask
    
    sigmas = sigmas.to(x.device)
    gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
    
    # Iterate over all timesteps
    for i in tqdm.tqdm(range(len(sigmas) - 1)):
        sigma, sigma_next = sigmas[i], sigmas[i+1]

        # Noise source to current noise level
        noisy_source = source + sigma*torch.randn_like(source)
        
        for r in range(num_resamples):
            # Merge noisy source and current x
            x = mask*noisy_source + (1.0 - mask)*x 

            # Inject randomness
            sigma_hat = sigma * (gamma + 1)            
            x_hat = x + torch.randn_like(x) * (sigma_hat**2 - sigma**2)**0.5

            # Compute conditioned derivative
            d = differential_fn(x=x_hat, sigma=sigma_hat, denoise_fn=denoise_fn)

            # Update integral
            x = x_hat + d*(sigma_next - sigma_hat)
                
            # Renoise if not last resample step
            if r < num_resamples - 1:
                x = x + torch.randn_like(x) * (sigma**2 - sigma_next**2)**0.5

    return mask*source + (1.0 - mask)*x

### Load Model

In [13]:
from main.module_base import Model

# Load model
model = Model.load_from_checkpoint(CKPT_PATH / f"glorious-star-335/epoch=729-valid_loss=0.014.ckpt").to(DEVICE)
denoise_fn = model.model.diffusion.denoise_fn

## Generation

In [None]:
# Generation hyper-parameters
s_churn = 20.0
batch_size = 9
num_steps = 150
num_resamples = 1

# Define timestep schedule
schedule = KarrasSchedule(sigma_min=1e-4, sigma_max=20.0, rho=7)(num_steps, DEVICE)

# Unconditionally sample from diffusion model
generated_tracks = generate_track(
    denoise_fn,
    sigmas=schedule,
    noises=torch.randn(batch_size, 4, 262144).to(DEVICE),
    s_churn=s_churn,
    num_resamples=num_resamples,
)

In [None]:
from IPython.display import Audio, HTML, Markdown

num_generations = generated_tracks.shape[0]
w, h = 3, num_generations//3 + int(num_generations%3 > 0)

# Organize results into a grid
grid = []
for i in range(h):
    row = []
    for j in range(w):
        index = i*w + j
        if index >= num_generations:
            continue 

        row.append(
            wrap_in_out(
                Markdown(f"**Sample at index** [{i*w+j}]:"),
                to_audio_widget(generated_tracks[index,:,:]),
            )
        )
        
    grid.append(row)

        
# Show results
display(Markdown("## **Generations:**"))
display(grid_widget(grid))

## Partial Generation

In [None]:
# Load audio track. Shape = [1, num_sources, num_samples]
sources = load_track(DATA_PATH / "dummy_slakh2100/test/Track01888", STEMS).to(DEVICE)

In [None]:
# Listen to the track sources
for i,s in enumerate(stems):
    print(f"{s}:")
    display(Audio(sources[:,i,:], rate=SAMPLE_RATE, normalize=True))

In [None]:
# Partial generation hyper-parameters
s_churn = 10.0
batch_size = 4
num_resamples = 1
num_steps = 256
start_second = 20.0
stems_to_inpaint = {"bass", "guitar", "piano"}

@torch.no_grad()
def get_rms(source_waveforms): 
    # Get Root Mean Square of waveforms
    return torch.mean(source_waveforms ** 2, dim=-1)**0.5

@torch.no_grad()
def generate_inpaint_mask(sources, stem_to_inpaint: List[int]):
    mask = torch.ones_like(sources)
    for stem_idx in stem_to_inpaint:
        mask[:,stem_idx,:] = 0.0
    return mask
        
# Select window from track
start_sample = int(start_second*SAMPLE_RATE)
source_chunk = sources[:,:, start_sample:start_sample + 262144] 

# Generate inpainting mask
assert len([s for s in stems_to_inpaint if s not in STEMS]) == 0 # < stems_to_inpaint must be a subset of STEMS
stemidx_to_inpaint = [i for i,s in enumerate(STEMS) if s in stems_to_inpaint]
inpaint_mask = generate_inpaint_mask(source_chunk, stem_to_inpaint=stemidx_to_inpaint)

# Define timestep schedule
schedule = KarrasSchedule(sigma_min=1e-4, sigma_max=20.0, rho=7)(num_steps, DEVICE)

# Inpaint tracks together with the original sources
inpainted_tracks = generate_track(
    source=source_chunk,
    mask=inpaint_mask,
    denoise_fn=denoise_fn,
    sigmas=schedule,
    noises=torch.randn_like(source_chunk).repeat(batch_size, 1, 1),
    s_churn=s_churn,
    num_resamples=num_resamples,
)

In [None]:
from IPython.display import Audio, HTML, Markdown

num_inpaints = inpainted_tracks.shape[0]
mask = inpaint_mask.squeeze(0)

# Organize results into a grid
grid = []
for i in range(num_inpaints):
    wav = inpainted_tracks[i,:,:]
    inpaint_widget = to_audio_widget((1.0 - mask)*wav)
    mix_widget = to_audio_widget(wav)

    row = [
        wrap_in_out(Markdown(f"**Inpainted + Original track** [{i}]:"), mix_widget),
        wrap_in_out(Markdown(f"**Inpainted track** [{i}]:"), inpaint_widget),
    ]
    grid.append(row)
        
# Show results
display(Markdown("## **Inpainting results:**"))
display(wrap_in_out(Markdown(f"**Original track**:"), to_audio_widget(mask*source_chunk.squeeze(0))))
display(grid_widget(grid))