In [31]:
import torchaudio
import math
from functools import partial
import sys
sys.path.append("/home/irene/Documents/audio-diffusion-pytorch-trainer")

from main.separation import ContextualSeparator, differential_with_gaussian
from pathlib import Path
from audio_diffusion_pytorch import KarrasSchedule
import torch
from script.misc import hparams
import main.module_base
from main.dataset import PitchShift
from evaluation.evaluate_separation import si_snr
from IPython.display import Audio
device = torch.device("cuda")
ROOT_PATH = Path(".")
%load_ext autoreload
%autoreload 2

sampling_rate = 22050

# @markdown Generation length in seconds (will be rounded to be a power of 2 of sample_rate*length)
length = 10
length_samples = 2**math.ceil(math.log2(length * sampling_rate))

# @markdown Number of samples to generate
num_samples = 1

# @markdown Number of diffusion steps (higher tends to be better but takes longer to generate)
num_steps = 150
num_stems = 4

s1, sr1 = torchaudio.load('/home/irene/Documents/audio-diffusion-pytorch-trainer/data/Slack/test/bass/Track01881.wav')
s2, sr2 = torchaudio.load('/home/irene/Documents/audio-diffusion-pytorch-trainer/data/Slack/test/drums/Track01881.wav')
s3, sr3 = torchaudio.load('/home/irene/Documents/audio-diffusion-pytorch-trainer/data/Slack/test/guitar/Track01881.wav')
s4, sr4 = torchaudio.load('/home/irene/Documents/audio-diffusion-pytorch-trainer/data/Slack/test/piano/Track01881.wav')

s1 = torchaudio.functional.resample(s1, orig_freq=sr1, new_freq=sampling_rate)
s2 = torchaudio.functional.resample(s2, orig_freq=sr2, new_freq=sampling_rate)
s3 = torchaudio.functional.resample(s3, orig_freq=sr3, new_freq=sampling_rate)
s4 = torchaudio.functional.resample(s4, orig_freq=sr4, new_freq=sampling_rate)

#display(Audio(s1, rate = sampling_rate))
#display(Audio(s2, rate = sampling_rate))

start_sample = 100 * sampling_rate
s1 = s1.reshape(1, 1, -1)[:, :, start_sample:start_sample + length_samples]
s2 = s2.reshape(1, 1, -1)[:, :, start_sample:start_sample + length_samples]
s3 = s3.reshape(1, 1, -1)[:, :, start_sample:start_sample + length_samples]
s4 = s4.reshape(1, 1, -1)[:, :, start_sample:start_sample + length_samples]
m = s1+s2+s3+s4

smin = 1e-4
smax = 1.0
rho = 7.0
sigma_schedule=KarrasSchedule(sigma_min=smin, sigma_max=smax, rho=rho)

def load_model(path):
  model = main.module_base.Model(**{**hparams, "in_channels": 4})
  model.load_state_dict(torch.load(path)["state_dict"])
  model.to(device);
  return model

def gamma_fn(x):
    return x / 0.5

model = load_model("/home/irene/Documents/audio-diffusion-pytorch-trainer/logs/ckpts/avid-darkness-164_epoch=419-valid_loss=0.015.ckpt")
separator = ContextualSeparator(model=model, stems=["bass", "drums", "guitar", "piano"], sigma_schedule=sigma_schedule,
                               num_resamples=1)#, differential_fn=partial(differential_with_gaussian, gamma_fn=gamma_fn))


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [194]:
T = 2
p_max = (T - 1) / T
p_min = p_max * (1 - 0.95)
def schedule_prob(t, T, alpha=0.95, p_min=p_min, p_max=p_max):
    p_t = 1 - max(0, p_max - (t - 1) * (p_max - p_min) / (alpha * (T - 1)))
    return p_t

display([schedule_prob(i, T) for i in range(T)])

[0.0, 0.5]

In [13]:
# GIBSS

def generate_mask_and_sources(sources, fixed_sources_idx=[]):
    mobile_sources_idx = list(set([0,1,2,3]) - set(fixed_sources_idx))
    mobile_sources_idx.sort()
    inpaint                = torch.randn(num_samples, num_stems, length_samples).to(device)
    inpaint[:, fixed_sources_idx, :]       = sources[:, fixed_sources_idx, :]
    inpaint_mask           = torch.ones_like(inpaint)
    inpaint_mask[:, mobile_sources_idx, :] = 0.
    return inpaint, inpaint_mask

sources_idx=torch.arange(4, device=device)
separations_hint_list = []
sources_gt = torch.cat([s1, s2, s3, s4], axis=1).to(device)
sources = torch.cat([s1, s2, s3, s4], axis=1).to(device)
for i in range(T):
    p = schedule_prob(i, T)
    mask = torch.bernoulli(torch.ones(4, device=device) * p).bool()
    fixed_sources_idx = sources_idx[mask]
    inpaint, inpaint_mask = generate_mask_and_sources(sources=sources, fixed_sources_idx=fixed_sources_idx)
    separations_hint = separator.separate_with_hint(
        mixture=m,
        num_steps=num_steps,
    )
sources = torch.cat([separations_hint["bass"], separations_hint["drums"],
                     separations_hint["guitar"], separations_hint["piano"]], axis=1).to(device)
# separations_hint_list.append(separations_hint)

In [40]:
# NO GIBSS

def generate_mask_and_sources(sources, fixed_sources_idx=[]):
    mobile_sources_idx = list(set([0,1,2,3]) - set(fixed_sources_idx))
    mobile_sources_idx.sort()
    inpaint                = torch.randn(num_samples, num_stems, length_samples).to(device)
    if len(fixed_sources_idx) > 0:
        inpaint[:, fixed_sources_idx, :]       = torchaudio.functional.pitch_shift(
                            sources[:, fixed_sources_idx, :], sample_rate=sampling_rate, n_steps=0)
    inpaint_mask           = torch.ones_like(inpaint)
    inpaint_mask[:, mobile_sources_idx, :] = 0.
    return inpaint, inpaint_mask

sources_idx=torch.arange(4, device=device)
separations_hint_list = []
sources_gt = torch.cat([s1, s2, s3, s4], axis=1).to(device)
sources = torch.cat([s1, s2, s3, s4], axis=1).to(device)

fixed_sources_idx = [0]
inpaint, inpaint_mask = generate_mask_and_sources(sources=sources, fixed_sources_idx=fixed_sources_idx)
separations_hint = separator.separate_with_hint(
    mixture=m,
    source_with_hint=inpaint,
    mask=inpaint_mask,
    num_steps=num_steps,
)
sources = torch.cat([separations_hint["bass"], separations_hint["drums"],
                     separations_hint["guitar"], separations_hint["piano"]], axis=1).to(device)
# separations_hint_list.append(separations_hint)

In [41]:
sources_names = ["bass", "drums", "guitar", "piano"]
sources_gt = torch.cat([s1, s2, s3, s4], axis=1).to(device)
#for i, separations_hint in enumerate(separations_hint_list):
ss = []
#    print(f"----------------{i=}------------------")
for i in range(4):
    s = si_snr(separations_hint[sources_names[i]], sources_gt[:,i,:].unsqueeze(0))
    print(s)
    ss.append(s)
mobile_sources_idx = list(set([0,1,2,3]) - set(fixed_sources_idx))
ss = torch.tensor(ss)
somma = sum(ss[mobile_sources_idx]) / len(mobile_sources_idx)
print(f"{somma = }")

98.99530029296875
14.623781204223633
-0.7570466995239258
6.833930015563965
somma = tensor(6.9002)


In [152]:
display(Audio(separations_hint_list[0]["bass"].squeeze().squeeze().cpu().detach(), rate = sampling_rate))