In [1]:
import torch
import torchaudio
import numpy as np
import sys
from pathlib import Path
from typing import *
import sys

sys.path.append('./msldm')
sys.path.append('./SourceVAE')

import audio_diffusion_pytorch
from audio_diffusion_pytorch import AudioDiffusionModel

from main.module_base_latent import Model
import main
from models.model.dac_vae import DACVAE
from audio_diffusion_pytorch import KarrasSchedule

import soundfile as sf
from tqdm import tqdm


In [2]:
def score_differential(x, sigma, denoise_fn):
    d = (x - denoise_fn(x, sigma=sigma)) / sigma 
    # print(sigma)
    return d

@torch.no_grad()
def generate_track(
    denoise_fn: Callable,
    sigmas: torch.Tensor,
    noises: torch.Tensor,
    source: Optional[torch.Tensor] = None,
    mask: Optional[torch.Tensor] = None,
    num_resamples: int = 1,
    s_churn: float = 0.0,
    differential_fn: Callable = score_differential,
) -> torch.Tensor:

    x = sigmas[0] * noises
    _, num_sources, _  = x.shape    

    # Initialize default values
    source = torch.zeros_like(x) if source is None else source
    mask = torch.zeros_like(x) if mask is None else mask
    
    sigmas = sigmas.to(x.device)
    gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
    
    # Iterate over all timesteps
    for i in tqdm(range(len(sigmas) - 1)):
        sigma, sigma_next = sigmas[i], sigmas[i+1]

        # Noise source to current noise level
        noisy_source = source + sigma*torch.randn_like(source)
        
        for r in range(num_resamples):
            # Merge noisy source and current x
            x = mask*noisy_source + (1.0 - mask)*x 

            # Inject randomness
            sigma_hat = sigma * (gamma + 1)            
            x_hat = x + torch.randn_like(x) * (sigma_hat**2 - sigma**2)**0.5

            # Compute conditioned derivative
            d = differential_fn(x=x_hat, sigma=sigma_hat, denoise_fn=denoise_fn)

            # Update integral
            x = x_hat + d*(sigma_next - sigma_hat)
                
            # Renoise if not last resample step
            if r < num_resamples - 1:
                x = x + torch.randn_like(x) * (sigma**2 - sigma_next**2)**0.5

    return mask*source + (1.0 - mask)*x

### set gpu device

In [3]:
device = 'cuda:3'

### instantiate models (SourceVAE and Latent Diffusion)

In [7]:
import torch.serialization
from main.data import MultiSourceLatentDatasetOld

torch.serialization.add_safe_globals([MultiSourceLatentDatasetOld])

In [9]:
sourcevae_ckpt_path = './ckpt/sourcevae_ckpt'
from main.module_base_latent import Model
# model = Model.load_from_checkpoint('./ckpt/msldm_large.ckpt').to(device)
torch.serialization.add_safe_globals([main.diffusion.UniformDistribution])
model = Model.load_from_checkpoint('./ckpt/msldm.ckpt') # use the small model
model.eval()
denoise_fn = model.model.diffusion.denoise_fn


In [12]:
# instantiate model
vae = DACVAE(
    encoder_dim = 64,
    encoder_rates = [2, 4, 5, 8],
    latent_dim = 80,
    decoder_dim = 1536,
    decoder_rates = [8, 5, 4, 2],
    sample_rate = 22050).to(torch.device('cpu'))

# load checkpoints
model_ckpt = torch.load(sourcevae_ckpt_path, map_location=torch.device('cpu'))
vae.load_state_dict(model_ckpt['generator'])
vae.eval()
print('finish loading ckpts from: ', sourcevae_ckpt_path)

  WeightNorm.apply(module, name, dim)


finish loading ckpts from:  ./ckpt/sourcevae_ckpt


# Total Generation

In [15]:
# Generation hyper-parameters
s_churn = 20.
batch_size = 1
num_steps = 150
num_resamples = 1

latent_dim=80

# Define timestep schedule
schedule = KarrasSchedule(sigma_min=1e-2, sigma_max=3, rho=7)(num_steps, torch.device('cpu'))

# Unconditionally sample from diffusion model
generated_tracks = generate_track(
    denoise_fn,
    sigmas=schedule,
    # noises=torch.randn(1, 4, 2**16).to(device),
    noises=torch.randn(batch_size, latent_dim*4, 1024).to(torch.device('cpu')),
    s_churn=s_churn,
    num_resamples=num_resamples,
)
bs = generated_tracks.shape[0]
generated_tracks = generated_tracks.reshape(bs, 4, latent_dim, -1)
generated_tracks = generated_tracks.reshape(bs*4, latent_dim, -1)
with torch.no_grad():
    waves = vae.decode(generated_tracks)
waves = waves.reshape(bs, 4, -1)

100%|██████████| 150/150 [00:39<00:00,  3.78it/s]


In [16]:
waves.shape

torch.Size([1, 4, 327672])

In [18]:
import numpy as np
from IPython.display import Audio, display

# Garante que waves é NumPy
waves = np.array(waves)  # ou .cpu().numpy() se vier do PyTorch

for i in range(bs):
    print(f'sample {str(i)}:')
    
    # Inicializa mixture com zeros e tipo float32 (ou o que estiver usando)
    mixture = np.zeros_like(waves[i, 0], dtype=np.float32)
    
    for j in range(4):
        waveform = np.asarray(waves[i, j], dtype=np.float32)  # Garante tipo certo
        audio = Audio(data=waveform, rate=22050)
        mixture += waveform
        display(audio)

    print('mixture')
    audio = Audio(data=mixture, rate=22050)
    display(audio)

sample 0:


  waves = np.array(waves)  # ou .cpu().numpy() se vier do PyTorch


mixture


## Own generation test

In [19]:
import torchaudio
import torch

audio_path = "sample.wav"
waveform, sample_rate = torchaudio.load(audio_path)

In [22]:
waveform.shape, sample_rate

(torch.Size([2, 444416]), 44100)

In [23]:
target_shape = (1, 4 * latent_dim, 1024)

# Redimensiona o áudio (deixa mono e repete canais)
mono = waveform.mean(dim=0, keepdim=True)  # (1, samples)

# Repete até o número de canais desejado
repeated = mono.repeat(target_shape[1], 1)

# Ajusta o comprimento para 1024
if repeated.shape[1] < 1024:
    pad_size = 1024 - repeated.shape[1]
    repeated = torch.nn.functional.pad(repeated, (0, pad_size))
else:
    repeated = repeated[:, :1024]

# Adiciona o batch dimension
source = repeated.unsqueeze(0)  # (1, 4*latent_dim, 1024)

In [29]:
generated_tracks = generate_track(
    denoise_fn,
    sigmas=schedule,
    noises=torch.randn_like(source),
    s_churn=s_churn,
    num_resamples=num_resamples,
    source=source,
)

bs = generated_tracks.shape[0]
generated_tracks = generated_tracks.reshape(bs, 4, latent_dim, -1)
generated_tracks = generated_tracks.reshape(bs*4, latent_dim, -1)
with torch.no_grad():
    waves = vae.decode(generated_tracks)
    
waves = waves.reshape(bs, 4, -1)

100%|██████████| 150/150 [00:41<00:00,  3.65it/s]


In [30]:
waves.shape

torch.Size([1, 4, 327672])

In [None]:
waves = np.array(waves)

for i in range(bs):
    print(f'sample {str(i)}:')
    
    mixture = np.zeros_like(waves[i, 0], dtype=np.float32)
    
    for j in range(4):
        waveform = np.asarray(waves[i, j], dtype=np.float32)
        audio = Audio(data=waveform, rate=22050)
        mixture += waveform
        display(audio)

    print('mixture')
    audio = Audio(data=mixture, rate=22050)
    display(audio)

sample 0:


  waves = np.array(waves)  # ou .cpu().numpy() se vier do PyTorch


mixture


# Partial Generation

In [1]:
import os
import torch
import librosa

### impaint function

In [None]:
STEMS = ["bass","drums","guitar","piano"] # < IMPORTANT: do not change

@torch.no_grad()
def generate_inpaint_mask(sources, stem_to_inpaint: List[int]):
    mask = torch.ones_like(sources) # bs, 4, n_samples
    for stem_idx in stem_to_inpaint:
        mask[:,stem_idx*80:(stem_idx+1)*80:,:] = 0.0
    return mask

def impaint(
        input, 
        schedule, 
        denoise_fn,
        vae,
        stems_to_inpaint=['drums']):
    bs = input.shape[0] # bs, 4, n_samples

    # input = torch.nn.functional.pad(input, (0, 327672 - 2**18)).to(device)
    input = input.reshape(bs*4, -1)
    with torch.no_grad():
        source_chunk = vae.encode(input.unsqueeze(1)).mode() # bs*4, 80

    source_chunk = source_chunk.reshape(bs, 320, 1024)

    assert len([s for s in stems_to_inpaint if s not in STEMS]) == 0 # < stems_to_inpaint must be a subset of STEMS
    stemidx_to_inpaint = [i for i,s in enumerate(STEMS) if s in stems_to_inpaint]
    stemidx_to_condition = [stemidx for stemidx in range(4) if stemidx not in stemidx_to_inpaint]
    inpaint_mask = generate_inpaint_mask(source_chunk, stem_to_inpaint=stemidx_to_inpaint) # bs, 320, 1024

    inpainted_tracks = generate_track(
        source=source_chunk, # bs, 320, 1024
        mask=inpaint_mask, #
        denoise_fn=denoise_fn,
        sigmas=schedule,
        noises=torch.randn_like(source_chunk),#.repeat(batch_size, 1, 1),
        s_churn=20.0,
        num_resamples=1,
    )
    bs = inpainted_tracks.shape[0]
    inpainted_tracks = inpainted_tracks.reshape(bs, 4, 80, -1)
    inpainted_tracks = inpainted_tracks.reshape(bs*4, 80, -1)
    vae = vae.to(device)
    vae.eval()
    with torch.no_grad():
        waves = vae.decode(inpainted_tracks) #
    waves = waves.reshape(bs, 4, -1)

    condition = input.reshape(bs, 4, -1)[:, stemidx_to_condition, :].sum(1)
    inpaint = waves[:, stemidx_to_inpaint, :].sum(1)
    mixture = inpaint + condition

    return waves, condition, inpaint, mixture


In [None]:
import torch
import torchaudio
import numpy as np

def load_track(track_folder: Path, stems: List[str]):
    wavs = []
    for s in stems:
        wav, wav_sr = torchaudio.load(os.path.join(track_folder, f"{s}.wav"))
        assert wav_sr == 22050
        assert wav.shape[0] == 1 # < single channel
        wavs += [wav]
    return torch.cat(wavs, dim=0).unsqueeze(0)


### load audio

In [None]:
# Load audio track. Shape = [1, num_sources, num_samples]
sources = load_track("./msldm/data/dummy_slakh2100/test/Track01888", STEMS).to(device)
start_second = 20.0
start_sample = int(start_second*22050)
source_chunk = sources[:,:, start_sample:start_sample + 327672] 


### setup what to partially generate

In [None]:
combinations = [['drums'], ['bass', 'drums']]

### partial generation

In [None]:
schedule = KarrasSchedule(sigma_min=1e-2, sigma_max=3.0, rho=7)(150, device)

In [None]:
source_chunk = source_chunk.to(device)

In [None]:
condition_music = {}
inpaint_music = {}
mixture_music = {}
for combo in combinations:
    
    stems_to_inpaint = combo
    combo_name = '_'.join(stems_to_inpaint)
    
    out_types = ['condition', 'impainted', 'mixture']
    waves, condition, inpaint, mixture = impaint(source_chunk, schedule, denoise_fn, vae, stems_to_inpaint=stems_to_inpaint)
    condition = condition.cpu().numpy()
    inpaint = inpaint.cpu().numpy()
    mixture = mixture.cpu().numpy()
    condition_music[combo_name] = condition
    inpaint_music[combo_name] = inpaint
    mixture_music[combo_name] = mixture
    print(combo, ' finished')



In [None]:
for combo in combinations:
    print(f'---------------------------------------------Generate {combo}------------------------------------------')
    combo_name = '_'.join(combo)
    print('condition: ')
    audio = Audio(data=condition_music[combo_name], rate=22050)
    display(audio)
    print(f'generated {combo}: ')
    audio = Audio(data=inpaint_music[combo_name], rate=22050)
    display(audio)
    print('mixture: ')
    audio = Audio(data=mixture_music[combo_name], rate=22050)
    display(audio)
    