In [1]:
import sys
from pathlib import Path
import torch

ROOT_PATH=Path("..").resolve().absolute()
DEVICE=torch.device("cpu")
sys.path.append(str(ROOT_PATH))

In [2]:
import main.module_base
from audio_diffusion_pytorch import LogNormalDistribution
import torch

diffusion_sigma_distribution = LogNormalDistribution(mean=-3.0, std=1.0)
    
model_1 = main.module_base.Model(
    learning_rate=1e-4,
    beta1=0.9,
    beta2=0.99,
    in_channels=1,
    channels=256,
    patch_factor=16,
    patch_blocks=1,
    resnet_groups=8,
    kernel_multiplier_downsample=2,
    kernel_sizes_init=[1, 3, 7],
    multipliers=[1, 2, 4, 4, 4, 4, 4],
    factors=[4, 4, 4, 2, 2, 2],
    num_blocks= [2, 2, 2, 2, 2, 2],
    attentions= [False, False, False, True, True, True],
    attention_heads=8,
    attention_features=128,
    attention_multiplier=2,
    use_nearest_upsample=False,
    use_skip_scale=True,
    use_attention_bottleneck=True,
    diffusion_sigma_distribution=diffusion_sigma_distribution,
    diffusion_sigma_data=0.2,
    diffusion_dynamic_threshold=0.0,
)


ckpt = torch.load(ROOT_PATH / "data/checkpoints/drums-worst.ckpt", map_location=DEVICE)
model_1.load_state_dict(ckpt["state_dict"])

<All keys matched successfully>

In [3]:
import main.module_base
from audio_diffusion_pytorch import LogNormalDistribution
import torch

diffusion_sigma_distribution = LogNormalDistribution(mean=-3.0, std=1.0)
    
model_2 = main.module_base.Model(
    learning_rate=1e-4,
    beta1=0.9,
    beta2=0.99,
    in_channels=1,
    channels=256,
    patch_factor=16,
    patch_blocks=1,
    resnet_groups=8,
    kernel_multiplier_downsample=2,
    kernel_sizes_init=[1, 3, 7],
    multipliers=[1, 2, 4, 4, 4, 4, 4],
    factors=[4, 4, 4, 2, 2, 2],
    num_blocks= [2, 2, 2, 2, 2, 2],
    attentions= [False, False, False, True, True, True],
    attention_heads=8,
    attention_features=128,
    attention_multiplier=2,
    use_nearest_upsample=False,
    use_skip_scale=True,
    use_attention_bottleneck=True,
    diffusion_sigma_distribution=diffusion_sigma_distribution,
    diffusion_sigma_data=0.2,
    diffusion_dynamic_threshold=0.0,
)


ckpt = torch.load(ROOT_PATH / "data/checkpoints/guitar.ckpt", map_location=DEVICE)
model_2.load_state_dict(ckpt["state_dict"])

<All keys matched successfully>

In [21]:
import torchaudio
from IPython.display import Audio
import math

sampling_rate = 22050

# @markdown Generation length in seconds (will be rounded to be a power of 2 of sample_rate*length)
length = 10
length_samples = 2**math.ceil(math.log2(length * sampling_rate))

# @markdown Number of samples to generate
num_samples = 1

# @markdown Number of diffusion steps (higher tends to be better but takes longer to generate)
num_steps = 10

s1, sr1 = torchaudio.load('../data/SLAKH/guitar/test/Track01877.wav')
s2, sr2 = torchaudio.load('../data/SLAKH/drums/test/Track01877.wav')

s1 = torchaudio.functional.resample(s1, orig_freq=sr1, new_freq=sampling_rate)
s2 = torchaudio.functional.resample(s2, orig_freq=sr1, new_freq=sampling_rate)

# display(Audio(s1, rate = sampling_rate))
# display(Audio(s2, rate = sampling_rate))


s1 = s1.reshape(1, 1, -1)
s2 = s2.reshape(1, 1, -1)
m = s1+s2

start_sample = 100 * sampling_rate
m = m[:, :, start_sample:start_sample + length_samples]

display(Audio(m.reshape(1,-1), rate = sampling_rate))

In [None]:
class DiffusionSeparator:
    def __init__(
        self,
        diffusions: List[Diffusion],
        *,
        samplers: List[Sampler],
        sigma_schedules: List[Schedule],
        num_steps: Optional[int] = None,
        eta: int = 0.0005
    ):
        super().__init__()
        self.denoise_fns = [diffusion.denoise_fn for diffusion in diffusions]
        self.samplers = samplers
        self.sigma_schedules = sigma_schedules
        self.num_steps = num_steps
        self.eta = eta

    def forward(self, m: Tensor, noises: List[Tensor], num_steps: Optional[int] = None, **kwargs) -> List[Tensor]:
        device = noises[0].device
        num_steps = default(num_steps, self.num_steps)
        assert exists(num_steps), "Parameter `num_steps` must be provided"
        
        # Compute sigmas using schedule
        sigmas_list = [sigma_schedule(num_steps, device) for sigma_schedule in self.sigma_schedules]
        
        # Append additional kwargs to denoise function (used e.g. for conditional unet)
        fns = [lambda *a, **ka: denoise_fn(*a, **{**ka, **kwargs}) for denoise_fn in self.denoise_fns]
        
        # Separation procedure
        xs = [sigmas[0] * noise for sigmas, noise in zip(sigmas_list, noises)]
        
        @torch.no_grad()
        def perform_sample_step(xs:list, step:int):
            for j, x in enumerate(xs):
                    print(fns[j])
                    yield self.samplers[j].step(
                        x, fn=fns[j], sigma=sigmas_list[j][step], sigma_next=sigmas_list[j][step + 1])
        
        # Denoise to sample
        for i in range(num_steps - 1):
            
            # update with respect to the prior
            xs = list(perform_sample_step(xs, i))
            xs = [x.detach() for x in xs]
            
            # compute likelihood function
            # first perturb m with forward noise (works only if sigma scheduler is always the same)
            m_i = m # + torch.randn_like(m) * sigmas_list[0][i]
            #m_i = m[0:1,0:1]
            for x in xs:
                x.requires_grad = True
            
            sum_x = torch.stack(xs).sum(dim=0)
            likelihood = torch.mean(torch.norm(m_i - sum_x, dim=[1, 2]))
            if i%10 == 0: print(likelihood.item())
            likelihood.backward()
            
            for xi, x in enumerate(xs):
                xs[xi] = x - self.eta * x.grad 
                xs[xi].grad = None
        
        xs = [x.clamp(-1.0, 1.0) for x in xs]
        return xs


In [18]:
from audio_diffusion_pytorch.diffusion import AEulerSampler, ADPM2Sampler, Diffusion, KarrasSchedule, Sampler, Schedule
DEVICE="cuda"
with torch.no_grad():
     samples = model_1.model.sample(
         noise=torch.randn((num_samples, 1, length_samples), device=DEVICE),
         num_steps=num_steps,
         sigma_schedule=KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0),
         sampler=AEulerSampler(),
     )

# Log audio samples
for i, sample in enumerate(samples):
     display(Audio(sample.cpu(), rate = sampling_rate))

In [27]:
from typing import List, Optional

import torch
from torch import Tensor

from audio_diffusion_pytorch.diffusion import AEulerSampler, ADPM2Sampler, Diffusion, KarrasSchedule, Sampler, Schedule
from audio_diffusion_pytorch.model import AudioDiffusionModel
from audio_diffusion_pytorch.utils import default, exists

class DiffusionSeparator:
    def __init__(
        self,
        diffusions: List[Diffusion],
        *,
        samplers: List[Sampler],
        sigma_schedules: List[Schedule],
        num_steps: Optional[int] = None,
        eta: int = 0.0005
    ):
        super().__init__()
        self.denoise_fns = [diffusion.denoise_fn for diffusion in diffusions]
        self.samplers = samplers
        self.sigma_schedules = sigma_schedules
        self.num_steps = num_steps
        self.eta = eta

    def forward(self, m: Tensor, noises: List[Tensor], num_steps: Optional[int] = None) -> List[Tensor]:
        device = noises[0].device
        num_steps = default(num_steps, self.num_steps)
        assert exists(num_steps), "Parameter `num_steps` must be provided"
        
        # Compute sigmas using schedule
        sigmas_list = [sigma_schedule(num_steps, device) for sigma_schedule in self.sigma_schedules]
        
        # Append additional kwargs to denoise function (used e.g. for conditional unet)
        fns = self.denoise_fns
        
        # Separation procedure
        xs = [sigmas[0] * noise for sigmas, noise in zip(sigmas_list, noises)]
        
        @torch.no_grad()
        def perform_sample_step(xs:list, step:int):
            for j, x in enumerate(xs):
                    yield self.samplers[j].step(
                        x, fn=fns[j], sigma=sigmas_list[j][step], sigma_next=sigmas_list[j][step + 1])
        
        # Denoise to sample
        for step in range(num_steps - 1):
            xs = list(perform_sample_step(xs, step))
            xs = [x.detach() for x in xs]
            
            # compute likelihood function
            # first perturb m with forward noise (works only if sigma scheduler is always the same)
            m_i = m + torch.randn_like(m) * sigmas_list[0][step]

            for x in xs:
                x.requires_grad = True
            
            sum_x = torch.stack(xs).sum(dim=0)
            likelihood = torch.mean(torch.norm(m_i - sum_x, dim=[1, 2]))
            if step%10 == 0: print(likelihood.item())
            likelihood.backward()
            
            for xi, x in enumerate(xs):
                xs[xi] = x - self.eta * x.grad 
                xs[xi].grad = None
                
        xs = [x.clamp(-1.0, 1.0) for x in xs]
        return xs



def separate(
    model1,
    model2,
    mixture, 
    device: torch.device = torch.device("cuda"), 
    num_steps:int = 1000,
):
    
    batch, in_channels = 1, 1
    samples = mixture.shape[-1]

    m = torch.tensor(mixture).to(device)
    models = [model1.model, model2.model]
    
    for model in models:
        model.to(device)
        
    diffusion_separator = DiffusionSeparator(
        [model.diffusion for model in models],
        samplers=[ADPM2Sampler(), ADPM2Sampler()],
        sigma_schedules=[
            KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0),
            KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0),
        ],
        eta=0.1,
        num_steps=num_steps,
    )
    noises = [torch.randn_like(m).to(device), torch.randn_like(m).to(device)]
    return diffusion_separator.forward(m, noises)

y1, y2 = separate(model_1, model_2, m)

  m = torch.tensor(mixture).to(device)


2649.098876953125
2488.556396484375
2337.90087890625
2194.560546875
2060.111572265625
1934.0379638671875
1812.6771240234375
1700.4677734375
1597.444091796875
1492.364501953125
1400.33056640625
1308.1129150390625
1223.509765625
1143.3760986328125
1064.0775146484375
994.07958984375
928.039794921875
867.1126708984375
810.0368041992188
753.3858642578125
701.4130859375
654.3328247070312
609.7549438476562
566.4274291992188
527.7482299804688
489.8131103515625
454.92059326171875
421.9647521972656
390.91314697265625
363.1814270019531
336.87310791015625
311.8438720703125
288.92352294921875
267.7077941894531
247.6041259765625
229.88800048828125
211.80308532714844
196.74282836914062
181.91659545898438
168.58517456054688
155.81898498535156
144.8299560546875
134.45921325683594
124.47846221923828
115.95970916748047
108.00172424316406
100.6392593383789
94.18273162841797
88.1115951538086
82.78924560546875
77.75357055664062
73.27642059326172
69.07811737060547
65.33245849609375
62.011817932128906
58.8201

In [29]:
display(Audio(y1.detach().cpu().view(1,-1), rate = sampling_rate))
display(Audio(y2.detach().cpu().view(1,-1), rate = sampling_rate))
display(Audio((y1+y2).detach().cpu().view(1,-1), rate = sampling_rate))