In [None]:
import sys
from pathlib import Path
import torch

ROOT_PATH=Path("..").resolve().absolute()
DEVICE=torch.device("cuda")#torch.device("cpu")#
sys.path.append(str(ROOT_PATH))

%load_ext autoreload
%autoreload 2

In [None]:
import main.module_base
from audio_diffusion_pytorch import LogNormalDistribution
import torch

diffusion_sigma_distribution = LogNormalDistribution(mean=-3.0, std=1.0)
    
model_1 = main.module_base.Model(
    learning_rate=1e-4,
    beta1=0.9,
    beta2=0.99,
    in_channels=1,
    channels=256,
    patch_factor=16,
    patch_blocks=1,
    resnet_groups=8,
    kernel_multiplier_downsample=2,
    kernel_sizes_init=[1, 3, 7],
    multipliers=[1, 2, 4, 4, 4, 4, 4],
    factors=[4, 4, 4, 2, 2, 2],
    num_blocks= [2, 2, 2, 2, 2, 2],
    attentions= [False, False, False, True, True, True],
    attention_heads=8,
    attention_features=128,
    attention_multiplier=2,
    use_nearest_upsample=False,
    use_skip_scale=True,
    use_attention_bottleneck=True,
    diffusion_sigma_distribution=diffusion_sigma_distribution,
    diffusion_sigma_data=0.2,
    diffusion_dynamic_threshold=0.0,
)


ckpt = torch.load(ROOT_PATH / "data/checkpoints/piano.ckpt", map_location=DEVICE)
model_1.load_state_dict(ckpt["state_dict"])

In [None]:
import main.module_base
from audio_diffusion_pytorch import LogNormalDistribution
import torch

diffusion_sigma_distribution = LogNormalDistribution(mean=-3.0, std=1.0)
    
model_2 = main.module_base.Model(
    learning_rate=1e-4,
    beta1=0.9,
    beta2=0.99,
    in_channels=1,
    channels=256,
    patch_factor=16,
    patch_blocks=1,
    resnet_groups=8,
    kernel_multiplier_downsample=2,
    kernel_sizes_init=[1, 3, 7],
    multipliers=[1, 2, 4, 4, 4, 4, 4],
    factors=[4, 4, 4, 2, 2, 2],
    num_blocks= [2, 2, 2, 2, 2, 2],
    attentions= [False, False, False, True, True, True],
    attention_heads=8,
    attention_features=128,
    attention_multiplier=2,
    use_nearest_upsample=False,
    use_skip_scale=True,
    use_attention_bottleneck=True,
    diffusion_sigma_distribution=diffusion_sigma_distribution,
    diffusion_sigma_data=0.2,
    diffusion_dynamic_threshold=0.0,
)


ckpt = torch.load(ROOT_PATH / "data/checkpoints/bass.ckpt", map_location=DEVICE)
model_2.load_state_dict(ckpt["state_dict"])

In [None]:
import torchaudio
from IPython.display import Audio
import math
import matplotlib.pyplot as plt

sampling_rate = 22050

# @markdown Generation length in seconds (will be rounded to be a power of 2 of sample_rate*length)
length = 10
length_samples = 2**math.ceil(math.log2(length * sampling_rate))

# @markdown Number of samples to generate
num_samples = 1

# @markdown Number of diffusion steps (higher tends to be better but takes longer to generate)
num_steps = 10

s1, sr1 = torchaudio.load('/data/5B/Piano_22050/test/Track01881.wav')
s2, sr2 = torchaudio.load('/data/5B/Bass_22050/test/Track01881.wav')

s1 = torchaudio.functional.resample(s1, orig_freq=sr1, new_freq=sampling_rate)
s2 = torchaudio.functional.resample(s2, orig_freq=sr1, new_freq=sampling_rate)

# display(Audio(s1, rate = sampling_rate))
# display(Audio(s2, rate = sampling_rate))

start_sample = 100 * sampling_rate
s1 = s1.reshape(1, 1, -1)[:, :, start_sample:start_sample + length_samples]
s2 = s2.reshape(1, 1, -1)[:, :, start_sample:start_sample + length_samples]
m = s1+s2
#m = m[:, :, start_sample:start_sample + length_samples]


def plot_waves(s1, s2):
    #plt.plot(s1.reshape(-1))
    #plt.plot(s2.reshape(-1))
    
    fig = plt.figure()
    axes = fig.add_subplot(2,1,1)
    spec, _, _, img = plt.specgram(s1.reshape(-1), NFFT=1024,)

    axes = fig.add_subplot(2,1,2)
    spec, _, _, img = plt.specgram(s2.reshape(-1), NFFT=1024,)

    #plt.legend(["piano", "bass"])
    plt.show()
    
    display(Audio(s1.reshape(1,-1), rate = sampling_rate))
    display(Audio(s2.reshape(1,-1), rate = sampling_rate))
    display(Audio((s1+s2).reshape(1,-1), rate = sampling_rate))
    
plot_waves(s1, s2)

In [None]:
from main.separation import KarrasSeparator, AEulerSeparator
from audio_diffusion_pytorch.diffusion import KarrasSchedule

@torch.no_grad()
def separate(
    model1,
    model2,
    mixture, 
    device: torch.device = torch.device("cuda"), 
    num_steps:int = 100,
):
    
    batch, in_channels = 1, 1
    samples = mixture.shape[-1]

    m = mixture.to(device)
    models = [model1.model, model2.model]
    
    for model in models:
        model.to(device)
        
    sigma_sched = KarrasSchedule(sigma_min=1e-4, sigma_max=1.0, rho=9.0)
    diffusion_separator = AEulerSeparator(mixture=m, delta=1.0)

    sigma = sigma_sched(num_steps, device)
    fns = [model.diffusion.denoise_fn for model in models]
    noises = [torch.randn_like(m).to(device), torch.randn_like(m).to(device)]
    return diffusion_separator.forward(noises=noises, fns=fns, sigmas=sigma, num_steps=num_steps)


y1, y2 = separate(model_1, model_2, m)

In [None]:
display(Audio((m).detach().cpu().view(1,-1), rate = sampling_rate))

import torchmetrics.functional.audio as tma
def si_snr(preds, targets):
    return tma.scale_invariant_signal_noise_ratio(preds.cpu(), targets.cpu()).item()

print(f"SI-SNR (1): {si_snr(y1, s1)}")
print(f"SI-SNR (2): {si_snr(y2, s2)}")
print(f"SI-SNRi(1): {si_snr(y1, s1) - si_snr(s1, m)}")
print(f"SI-SNRi (2): {si_snr(y2, s2) - si_snr(s2, m)}")
print(f"SI-SNR (mix): {si_snr(y1+y2, m)}")

plot_waves(y1.cpu(), y2.cpu())
plot_waves((s1.cpu()-y1.cpu()), (s2.cpu()-y2.cpu()))

In [None]:
from typing import List, Optional

import torch
from torch import Tensor

from audio_diffusion_pytorch.diffusion import AEulerSampler, ADPM2Sampler, Diffusion, KarrasSchedule, Sampler, Schedule
from audio_diffusion_pytorch.model import AudioDiffusionModel
from audio_diffusion_pytorch.utils import default, exists

class DiffusionSeparator:
    def __init__(
        self,
        diffusions: List[Diffusion],
        *,
        samplers: List[Sampler],
        sigma_schedules: List[Schedule],
        num_steps: Optional[int] = None,
        nu = 2e-1
    ):
        super().__init__()
        self.denoise_fns = [diffusion.denoise_fn for diffusion in diffusions]
        self.samplers = samplers
        self.sigma_schedules = sigma_schedules
        self.num_steps = num_steps
        self.nu = nu

    def forward(self, m: Tensor, noises: List[Tensor], num_steps: Optional[int] = None) -> List[Tensor]:
        device = noises[0].device
        num_steps = default(num_steps, self.num_steps)
        assert exists(num_steps), "Parameter `num_steps` must be provided"
        
        # Compute sigmas using schedule
        sigmas_list = [sigma_schedule(num_steps, device) for sigma_schedule in self.sigma_schedules]
        
        # Append additional kwargs to denoise function (used e.g. for conditional unet)
        fns = self.denoise_fns
        
        # Separation procedure
        xs = [sigmas[0] * noise for sigmas, noise in zip(sigmas_list, noises)]
        
        likelihood_steps = torch.tensor([self.nu]*len(sigmas_list[0]))
        #likelihood_steps = compute_likelihood_steps(sigmas_list[0], gamma=2.25)

        @torch.no_grad()
        def perform_sample_step(xs:list, step:int):
            for j, x in enumerate(xs):
                    yield self.samplers[j].step(
                        x, fn=fns[j], sigma=sigmas_list[j][step], sigma_next=sigmas_list[j][step + 1])
                    
        # Denoise to sample
        for step in range(num_steps - 1):
            us = list(perform_sample_step(xs, step))
            g_x = torch.stack(xs).sum(dim=0)
            
            #likelihood_steps = 1.0 / (torch.tensor([((2*(sigmas[:step])**2).sum()) for step in range(100)]) + 5.0)
            #step_size = likelihood_steps[step]
            step_size = self.nu #* (sigmas_list[0][step]**2) / sigmas_list[0][-1]**2
            #step_size = sigmas_list[0][step]**2 / 0.012**2
            likelihood = step_size*(m - g_x)
            
            if step % 20 == 0:
                likelihood_norm = torch.norm(m - g_x)
                print("likelihood norm:", likelihood_norm.item())
                print("step-size:", step_size)
                print("prior 1 norm:", torch.norm(us[0] - xs[0]).item())
                print("prior 2 norm:", torch.norm(us[1] - xs[1]).item())
                print("")
                #plot_waves(xs[0].cpu(), xs[1].cpu())
            
            for i in range(len(xs)):
                xs[i] = us[i] + likelihood
                
        xs = [x.clamp(-1.0, 1.0) for x in xs]
        print(likelihood_norm.item())
        return xs


@torch.no_grad()
def separate(
    model1,
    model2,
    mixture, 
    device: torch.device = torch.device("cuda"), 
    num_steps:int = 202,
):
    
    batch, in_channels = 1, 1
    samples = mixture.shape[-1]

    m = torch.tensor(mixture).to(device)
    models = [model1.model, model2.model]
    
    for model in models:
        model.to(device)
    
    #schedule = lambda num_steps,device: torch.flip(torch.sort(diffusion_sigma_distribution(num_steps*10, device))[0][::10],dims=[0])
    #schedule = lambda num_steps, device: torch.arange(0.6, 0.0, -0.6/num_steps,device=device)
    schedule = KarrasSchedule(sigma_min=1e-4, sigma_max=1.0, rho=8.0)
    
    diffusion_separator = DiffusionSeparator(
        [model.diffusion for model in models],
        samplers=[AEulerSampler(), AEulerSampler()],
        sigma_schedules=[schedule, schedule],
        num_steps=num_steps
    )

    noises = [torch.randn_like(m).to(device), torch.randn_like(m).to(device)]
    return diffusion_separator.forward(m, noises)


def compute_likelihood_steps(sigmas_rev: torch.tensor, gamma:float=1.0) -> torch.tensor:
    sigmas = sigmas_rev.flip(dims=[0])
    likelihood_steps = 1.0 / (torch.tensor([((2*(sigmas[:step])**2).sum()) for step in range(len(sigmas))]) + gamma**2)
    return likelihood_steps.flip(dims=[0])


y1, y2 = separate(model_1, model_2, m)

In [None]:
torchaudio.save("s1.wav", s1.reshape(1,-1).cpu(), 22050)
torchaudio.save("s2.wav", s2.reshape(1,-1).cpu(), 22050)
torchaudio.save("m.wav", m.reshape(1,-1).cpu(), 22050)

# Data Separation

In [None]:
from audio_data_pytorch import WAVDataset, AllTransform

dataset_1 = WAVDataset(
    "/data/5B/Piano_22050/test",
    transforms=AllTransform(source_rate=22050, target_rate=22050,random_crop_size=262144, loudness=-20),
)

dataset_2 = WAVDataset(
    "/data/5B/Bass_22050/test",
    transforms=AllTransform(source_rate=22050, target_rate=22050, random_crop_size=262144, loudness=-20),
)

In [None]:
from pathlib import Path
import tqdm

output_folder = Path("separations")
output_folder.mkdir(exist_ok=True)
for i in tqdm.autonotebook.tqdm(range(min(len(dataset_1), len(dataset_2)))):
    x1, x2 = dataset_1[i], dataset_2[i]
    x1 ,x2 = x1.unsqueeze(0), x2.unsqueeze(0)
    
    assert len(x1.shape) == len(x2.shape) == 3
    assert x1.shape[0:2] == x2.shape[0:2] == (1,1)
    assert x1.shape == x2.shape
    
    y1, y2 = separate(model_1, model_2, x1 + x2)
    
    separation_folder = (output_folder / f"{i}" )
    separation_folder.mkdir(exist_ok=True)
    
    torchaudio.save(separation_folder/ "ori1.wav", x1.reshape(1,-1).cpu(), 22050)
    torchaudio.save(separation_folder/"ori2.wav", x2.reshape(1,-1).cpu(), 22050)
    torchaudio.save(separation_folder /"sep1.wav", y1.reshape(1,-1).cpu(), 22050)
    torchaudio.save(separation_folder /"sep2.wav", y2.reshape(1,-1).cpu(), 22050)
    torchaudio.save(separation_folder /"mix.wav", m.reshape(1,-1).cpu(), 22050)


# Other stuff

In [None]:
from audio_diffusion_pytorch.diffusion import KarrasSampler, AEulerSampler, ADPM2Sampler, Diffusion, KarrasSchedule, Sampler, Schedule
num_steps=100
with torch.no_grad():
     samples = model_1.model.sample(
         noise=torch.randn((num_samples, 1, length_samples), device=DEVICE),
         num_steps=num_steps,
         sigma_schedule=KarrasSchedule(sigma_min=1e-4, sigma_max=1.0, rho=9.0),
         sampler=ADPM2Sampler()
     )

# Log audio samples
for i, sample in enumerate(samples):
    import matplotlib.pyplot as plt
    display(Audio(sample.cpu(), rate = sampling_rate))
    spec, _, _, img = plt.specgram(sample.cpu().reshape(-1), NFFT=1024,)


In [None]:
import matplotlib.pyplot as plt

def normal_pdf(x, mean:float=0.0, std: float = 1.0):
    return torch.exp(-0.5*((x - mean)/std**2)**2)/(std * (2*torch.pi)**0.5)

def lognormal_pdf(x:torch.Tensor, mean:float=0.0, std: float=1.0):
     return torch.exp(-0.5*((torch.log(x) - mean)/std**2)**2)/(x * std * (2*torch.pi)**0.5)
        
        
x = torch.arange(1e-4, 3, 0.001)
normal_y = normal_pdf(x, 0 ,1)
lognormal_y = lognormal_pdf(x, -3, 1)

#plt.plot(x, torch.cumsum(lognormal_y,0))
#lognormal_y
plt.plot(torch.sort(diffusion_sigma_distribution(10000))[0][::10])

plt.plot(KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0)(1000,"cpu"))

In [None]:
k = KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0)
sigmas_rev = k(num_steps=99, device="cpu")

def compute_likelihood_steps(sigmas_rev: torch.tensor) -> torch.tensor:
    sigmas = sigmas_rev.flip(dims=[0])
    likelihood_steps = 1.0 / (torch.tensor([((2*(sigmas[:step])**2).sum()) for step in range(len(sigmas))]) + 5.0)
    return likelihood_steps.flip(dims=[0])

#step_size = torch.tensor([3e-1]*100)

import matplotlib.pyplot as plt
#plt.plot(sigmas)
plt.plot(compute_likelihood_steps(sigmas_rev))
plt.plot(step_size)

In [None]:
import numpy as np
x_next = np.arange(1,100)/100
x = np.arange(2,101)/100
x_up = np.sqrt(x_next ** 2 * (x ** 2 - x_next ** 2) / x ** 2)
x_down = np.sqrt(x_next ** 2 - x_up ** 2)

#plt.plot(x_up)
plt.plot(x_down - x)
#plt.plot(x_down)

In [None]:
n_fft=512
hop_length=64
speed_rate=1.5

spec = torchaudio.transforms.Spectrogram(hop_length=hop_length,n_fft=n_fft, power=None)
invspec = torchaudio.transforms.InverseSpectrogram(hop_length=hop_length,n_fft=n_fft)
ts1 = torchaudio.transforms.TimeStretch(hop_length, n_freq=n_fft//2+1, fixed_rate=speed_rate)
ts2 = torchaudio.transforms.TimeStretch(hop_length, n_freq=n_fft//2+1, fixed_rate=1/speed_rate)
#ps1 = torchaudio.transforms.PitchShift(hop_length, n_freq=n_fft//2+1, fixed_rate=1/speed_rate)

print(spec(s1).shape)
s1_rec = invspec(ts2(ts1(spec(s1))))
s1_rec = s1_rec[:,:,:length_samples]
print(si_snr(s1_rec,s1))
plot_waves(s1_rec,s1)
#plt.plot((s1_rec-s1).reshape(-1))

plt.plot(s1_rec.view(-1)[11000:11100:1])
plt.plot(s1.view(-1)[11000:11100:1])

print((s1_rec - s1).abs().max())

In [None]:
torch.rand()

In [None]:
length_samples