In [None]:
import os
import torch
import musdb
import soundfile as sf
import numpy as np
from tqdm import tqdm


from linear_cae import Autoencoder 
from linear_cae.baselines import StableAudioVAE  # will only work if you have stable-audio-tools installed


  import pkg_resources


In [None]:

OUTPUT_ROOT = "./demo_output"
MUSDB_ROOT = "data/MUSDB18-HQ"
MAX_CHUNK_SIZE = 10 * 44100  # 10 seconds

MODEL_IDS = [
    "m2l",
    "lin-cae",
    "lin-cae-2",
    "stable-audio-vae",
]

# Dictionary of tracks and time segments (in seconds) to process.
TRACKS_TO_PROCESS = {
    "Sambasevam Shanmugam - Kaathaadi": {"start_s": 61, "end_s": 67},
    "Al James - Schoolboy Facination": {"start_s": 15, "end_s": 20},
    "Zeno - Signs": {"start_s": 40, "end_s": 45},
    "Cristina Vane - So Easy": {"start_s": 40, "end_s": 45},
}

# List of scalar values for the latent scaling
SCALARS_FOR_VOCALS = [0.1, 0.5, 2.0]

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

try:
    mus = musdb.DB(root=MUSDB_ROOT, is_wav=True, subsets="test", sample_rate=44100)
    print("MUSDB18 dataset loaded successfully.")
except Exception as e:
    print(f"Error loading MUSDB18 dataset from '{MUSDB_ROOT}'. Please check the path. Error: {e}")

Using device: cuda
MUSDB18 dataset loaded successfully.


In [None]:
def load_track_segment(mus_db, track_name, start_s, end_s, target_sr):
    """Loads a specific segment of a track and its stems."""
    # Find the track
    track = None
    for t in mus_db.tracks:
        if t.name == track_name:
            track = t
            break
    
    if track is None:
        raise ValueError(f"Track '{track_name}' not found in MUSDB database.")
    
    track.chunk_start = start_s
    track.chunk_duration = end_s - start_s
    
    if track.rate != target_sr:
        print(f"Warning: Track sample rate ({track.rate} Hz) differs from model SR ({target_sr} Hz). Resampling is required, which may not be implemented here.")
    
    mixture = track.audio.T.astype(np.float32) # Shape: [channels, samples]
    
    stems = {
        'vocals': track.targets['vocals'].audio.T.astype(np.float32),
        'bass': track.targets['bass'].audio.T.astype(np.float32),
        'drums': track.targets['drums'].audio.T.astype(np.float32),
        'other': track.targets['other'].audio.T.astype(np.float32),
    }

    # Average to mono if stereo
    if mixture.shape[0] > 1:
        mixture = mixture.mean(axis=0, keepdims=True)
        for key in stems:
            stems[key] = stems[key].mean(axis=0, keepdims=True)
            
    return mixture, stems

def save_audio(wav_tensor, path, sr):
    """Saves a torch tensor as a WAV file."""
    # Ensure the directory exists
    os.makedirs(os.path.dirname(path), exist_ok=True)
    
    # Move tensor to CPU, convert to numpy, and ensure correct shape [samples] or [samples, channels]
    wav_numpy = wav_tensor.detach().cpu().numpy().squeeze()
    sf.write(path, wav_numpy, sr)

print("Helper functions defined.")

Helper functions defined.


In [None]:
# from Pathlib import Path
from pathlib import Path
OUTPUT_ROOT = Path(OUTPUT_ROOT)
for model_id in tqdm(MODEL_IDS, desc="Models"):
    
    print(f"\nLoading model: {model_id}...")
    try:
        if model_id == "stable-audio-vae":
            model = StableAudioVAE()
        else:
            model = Autoencoder.from_pretrained(model_id, max_chunk_size=MAX_CHUNK_SIZE, overlap_percentage=0.5)
        model.to(DEVICE).eval()
        sample_rate = model.sample_rate
        print("Model loaded successfully.")
    except Exception as e:
        print(f"Could not load model {model_id}. Skipping. Error: {e}")
        continue

    for track_name, times in tqdm(TRACKS_TO_PROCESS.items(), desc=f"Tracks for {model_id}", leave=False):
        
        # sane_track_name = track_name.replace(' ', '_').replace('-', '_')
        sane_track_name = track_name
        output_dir = OUTPUT_ROOT / sane_track_name /  model_id  
        output_dir.mkdir(parents=True, exist_ok=True)
        
        print(f"\n--- Processing '{track_name}' for model '{model_id}' ---")
        mixture, stems = load_track_segment(
            mus, track_name, times['start_s'], times['end_s'], sample_rate
        )
        
        if mixture is None:
            continue
            
        mix_tensor = torch.from_numpy(mixture).to(DEVICE).float()
        stem_tensors = {k: torch.from_numpy(v).to(DEVICE).float() for k, v in stems.items()}

        audio_len = mix_tensor.shape[-1]
        
        with torch.no_grad():
            
            print("  Running: Autoencoding...")
            z_mix = model.encode(mix_tensor)
            recon_mix = model.decode(z_mix, full_length=audio_len)
            save_audio(recon_mix, output_dir / "ae_mix.wav", sample_rate)
            save_audio(mix_tensor, output_dir.parent / "gt_mix.wav", sample_rate)

            z_vocals = model.encode(stem_tensors['vocals'])
            recon_vocals = model.decode(z_vocals, full_length=audio_len)
            save_audio(recon_vocals, output_dir / "ae_vocals.wav", sample_rate)

            print("  Running: Additivity...")
            stem_latents = {name: model.encode(audio) for name, audio in stem_tensors.items()}
            z_sum = sum(stem_latents.values())
            recon_from_sum = model.decode(z_sum, full_length=audio_len)
            save_audio(recon_from_sum, output_dir / "additivity.wav", sample_rate)

            print("  Running: Separation...")
            accompaniment = mix_tensor - stem_tensors['vocals']
            z_accomp = model.encode(accompaniment)
            z_sep_vocals = z_mix - z_accomp
            sep_vocals = model.decode(z_sep_vocals, full_length=audio_len)
            save_audio(sep_vocals, output_dir / "sep_vocals.wav", sample_rate)
            save_audio(stem_tensors['vocals'], output_dir / "gt_vocals.wav", sample_rate)

            print("  Running: Latent Scaling...")
            z_vocals = stem_latents.get('vocals')
            if z_vocals is not None:
                for scalar in SCALARS_FOR_VOCALS:
                    z_scaled_vocals = z_vocals * scalar
                    scaled_vocals = model.decode(z_scaled_vocals, full_length=audio_len)
                    filename = f"scale_{str(scalar).replace('.', '_')}.wav" 
                    save_audio(scaled_vocals, output_dir / filename, sample_rate)
            else:
                print("    Skipping latent scaling: 'vocals' stem not found.")
        
print(f"\n✅ All processing complete. Check the '{OUTPUT_ROOT}' directory for results.")

Models:   0%|          | 0/4 [00:00<?, ?it/s]


Loading model: m2l...
Model loaded successfully.





--- Processing 'Sambasevam Shanmugam - Kaathaadi' for model 'm2l' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...





--- Processing 'Al James - Schoolboy Facination' for model 'm2l' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...





--- Processing 'Zeno - Signs' for model 'm2l' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...





--- Processing 'Cristina Vane - So Easy' for model 'm2l' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...


Models:  25%|██▌       | 1/4 [00:36<01:50, 36.87s/it]


Loading model: lin-cae...
Model loaded successfully.





--- Processing 'Sambasevam Shanmugam - Kaathaadi' for model 'lin-cae' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...





--- Processing 'Al James - Schoolboy Facination' for model 'lin-cae' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...





--- Processing 'Zeno - Signs' for model 'lin-cae' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...





--- Processing 'Cristina Vane - So Easy' for model 'lin-cae' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...


Models:  50%|█████     | 2/4 [01:13<01:13, 36.59s/it]


Loading model: lin-cae-2...
Model loaded successfully.





--- Processing 'Sambasevam Shanmugam - Kaathaadi' for model 'lin-cae-2' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...





--- Processing 'Al James - Schoolboy Facination' for model 'lin-cae-2' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...





--- Processing 'Zeno - Signs' for model 'lin-cae-2' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...





--- Processing 'Cristina Vane - So Easy' for model 'lin-cae-2' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...


Models:  75%|███████▌  | 3/4 [01:49<00:36, 36.61s/it]


Loading model: stable-audio-vae...
No module named 'flash_attn'
flash_attn not installed, disabling Flash Attention


  WeightNorm.apply(module, name, dim)


Model loaded successfully.





--- Processing 'Sambasevam Shanmugam - Kaathaadi' for model 'stable-audio-vae' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...





--- Processing 'Al James - Schoolboy Facination' for model 'stable-audio-vae' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...





--- Processing 'Zeno - Signs' for model 'stable-audio-vae' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...





--- Processing 'Cristina Vane - So Easy' for model 'stable-audio-vae' ---
  Running: Autoencoding...
  Running: Additivity...
  Running: Separation...
  Running: Latent Scaling...


Models: 100%|██████████| 4/4 [02:22<00:00, 35.68s/it]


✅ All processing complete. Check the '/home/bernardo/bernardo-torres.github.io/documents/audio/linear-cae' directory for results.



