In [1]:
import sys
from pathlib import Path
import torch

ROOT_PATH=Path("..").resolve().absolute()
DEVICE=torch.device("cpu")
sys.path.append(str(ROOT_PATH))

In [2]:
import main.module_base
from audio_diffusion_pytorch import LogNormalDistribution
import torch

diffusion_sigma_distribution = LogNormalDistribution(mean=-3.0, std=1.0)
    
model_1 = main.module_base.Model(
    learning_rate=1e-4,
    beta1=0.9,
    beta2=0.99,
    in_channels=1,
    channels=256,
    patch_factor=16,
    patch_blocks=1,
    resnet_groups=8,
    kernel_multiplier_downsample=2,
    kernel_sizes_init=[1, 3, 7],
    multipliers=[1, 2, 4, 4, 4, 4, 4],
    factors=[4, 4, 4, 2, 2, 2],
    num_blocks= [2, 2, 2, 2, 2, 2],
    attentions= [False, False, False, True, True, True],
    attention_heads=8,
    attention_features=128,
    attention_multiplier=2,
    use_nearest_upsample=False,
    use_skip_scale=True,
    use_attention_bottleneck=True,
    diffusion_sigma_distribution=diffusion_sigma_distribution,
    diffusion_sigma_data=0.2,
    diffusion_dynamic_threshold=0.0,
)


ckpt = torch.load(ROOT_PATH / "data/checkpoints/drums-worst.ckpt", map_location=DEVICE)
model_1.load_state_dict(ckpt["state_dict"])

<All keys matched successfully>

In [3]:
import main.module_base
from audio_diffusion_pytorch import LogNormalDistribution
import torch

diffusion_sigma_distribution = LogNormalDistribution(mean=-3.0, std=1.0)
    
model_2 = main.module_base.Model(
    learning_rate=1e-4,
    beta1=0.9,
    beta2=0.99,
    in_channels=1,
    channels=256,
    patch_factor=16,
    patch_blocks=1,
    resnet_groups=8,
    kernel_multiplier_downsample=2,
    kernel_sizes_init=[1, 3, 7],
    multipliers=[1, 2, 4, 4, 4, 4, 4],
    factors=[4, 4, 4, 2, 2, 2],
    num_blocks= [2, 2, 2, 2, 2, 2],
    attentions= [False, False, False, True, True, True],
    attention_heads=8,
    attention_features=128,
    attention_multiplier=2,
    use_nearest_upsample=False,
    use_skip_scale=True,
    use_attention_bottleneck=True,
    diffusion_sigma_distribution=diffusion_sigma_distribution,
    diffusion_sigma_data=0.2,
    diffusion_dynamic_threshold=0.0,
)


ckpt = torch.load(ROOT_PATH / "data/checkpoints/guitar.ckpt", map_location=DEVICE)
model_2.load_state_dict(ckpt["state_dict"])

<All keys matched successfully>

In [4]:
import torchaudio
from IPython.display import Audio
import math

sampling_rate = 22050

# @markdown Generation length in seconds (will be rounded to be a power of 2 of sample_rate*length)
length = 10
length_samples = 2**math.ceil(math.log2(length * sampling_rate))

# @markdown Number of samples to generate
num_samples = 1

# @markdown Number of diffusion steps (higher tends to be better but takes longer to generate)
num_steps = 10

s1, sr1 = torchaudio.load('../data/SLAKH/guitar/test/Track01877.wav')
s2, sr2 = torchaudio.load('../data/SLAKH/drums/test/Track01877.wav')

s1 = torchaudio.functional.resample(s1, orig_freq=sr1, new_freq=sampling_rate)
s2 = torchaudio.functional.resample(s2, orig_freq=sr1, new_freq=sampling_rate)

# display(Audio(s1, rate = sampling_rate))
# display(Audio(s2, rate = sampling_rate))


s1 = s1.reshape(1, 1, -1)
s2 = s2.reshape(1, 1, -1)
m = s1+s2

start_sample = 100 * sampling_rate
m = m[:, :, start_sample:start_sample + length_samples]

display(Audio(m.reshape(1,-1), rate = sampling_rate))

In [5]:
from audio_diffusion_pytorch.diffusion import AEulerSampler, ADPM2Sampler, Diffusion, KarrasSchedule, Sampler, Schedule
DEVICE="cuda"
with torch.no_grad():
     samples = model_1.model.sample(
         noise=torch.randn((num_samples, 1, length_samples), device=DEVICE),
         num_steps=num_steps,
         sigma_schedule=KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0),
         sampler=AEulerSampler(),
     )

# Log audio samples
for i, sample in enumerate(samples):
     display(Audio(sample.cpu(), rate = sampling_rate))

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

In [None]:
from typing import List, Optional

import torch
from torch import Tensor

from audio_diffusion_pytorch.diffusion import AEulerSampler, ADPM2Sampler, Diffusion, KarrasSchedule, Sampler, Schedule
from audio_diffusion_pytorch.model import AudioDiffusionModel
from audio_diffusion_pytorch.utils import default, exists

class DiffusionSeparator:
    def __init__(
        self,
        diffusions: List[Diffusion],
        *,
        samplers: List[Sampler],
        sigma_schedules: List[Schedule],
        num_steps: Optional[int] = None,
        nu = 1
    ):
        super().__init__()
        self.denoise_fns = [diffusion.denoise_fn for diffusion in diffusions]
        self.samplers = samplers
        self.sigma_schedules = sigma_schedules
        self.num_steps = num_steps
        self.nu = nu

    def forward(self, m: Tensor, noises: List[Tensor], num_steps: Optional[int] = None) -> List[Tensor]:
        device = noises[0].device
        num_steps = default(num_steps, self.num_steps)
        assert exists(num_steps), "Parameter `num_steps` must be provided"
        
        # Compute sigmas using schedule
        sigmas_list = [sigma_schedule(num_steps, device) for sigma_schedule in self.sigma_schedules]
        
        # Append additional kwargs to denoise function (used e.g. for conditional unet)
        fns = self.denoise_fns
        
        # Separation procedure
        xs = [sigmas[0] * noise for sigmas, noise in zip(sigmas_list, noises)]
        
        @torch.no_grad()
        def perform_sample_step(xs:list, step:int):
            for j, x in enumerate(xs):
                    yield self.samplers[j].step(
                        x, fn=fns[j], sigma=sigmas_list[j][step], sigma_next=sigmas_list[j][step + 1])
        
        # Denoise to sample
        for step in range(num_steps - 1):
            us = list(perform_sample_step(xs, step))
            g_x = torch.stack(xs).sum(dim=0)
            
            step_size = self.nu * (sigmas_list[0][step]**2) / sigmas_list[0].max()**2
            likelihood = step_size*(m - g_x)
            
            if step % 10 == 0:
                likelihood_norm = torch.sum(m - g_x, dim=[1, 2])**2
                print(likelihood_norm.item())
            
            for i in range(len(xs)):
                xs[i] = us[i] + likelihood
                
        xs = [x.clamp(-1.0, 1.0) for x in xs]
        return xs


@torch.no_grad()
def separate(
    model1,
    model2,
    mixture, 
    device: torch.device = torch.device("cuda"), 
    num_steps:int = 400,
):
    
    batch, in_channels = 1, 1
    samples = mixture.shape[-1]

    m = torch.tensor(mixture).to(device)
    models = [model1.model, model2.model]
    
    for model in models:
        model.to(device)
        
    diffusion_separator = DiffusionSeparator(
        [model.diffusion for model in models],
        samplers=[ADPM2Sampler(), ADPM2Sampler()],
        sigma_schedules=[
            KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0),
            KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0),
        ],
        num_steps=num_steps
    )
    noises = [torch.randn_like(m).to(device), torch.randn_like(m).to(device)]
    return diffusion_separator.forward(m, noises)

y1, y2 = separate(model_1, model_2, m)

  m = torch.tensor(mixture).to(device)


6538049.0
15188.7158203125
77200.390625
1085.3975830078125
109738.609375
9.71138858795166
106651.515625
1449.7803955078125
1089.94384765625
134.91065979003906
14132.9091796875
30132.310546875
74376.578125
214995.0625
335152.4375
394983.09375
296823.3125
216214.21875
125996.6171875
61970.70703125
57771.6171875
48452.1953125
31367.033203125
26031.61328125
18785.267578125
11649.0146484375
10186.8623046875
5807.359375
3966.04931640625
3004.455322265625
2073.977783203125
1682.55810546875
1843.791748046875
1745.4512939453125
1479.4041748046875
1077.8255615234375
844.3760375976562
46.05760955810547
744.8848266601562


In [29]:
display(Audio(y1.detach().cpu().view(1,-1), rate = sampling_rate))
display(Audio(y2.detach().cpu().view(1,-1), rate = sampling_rate))
display(Audio((y1+y2).detach().cpu().view(1,-1), rate = sampling_rate))

In [30]:
k = KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0)
k(num_steps=100, device="cpu")

import matplotlib.pyplot as plt
k(num_steps=100, device="cpu")[-1]

tensor(0.)