In [2]:
import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np
from encodec.seanet import SEANetEncoder, SEANetDecoder

from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import os

import sklearn
from sklearn.cluster import SpectralClustering
from sklearn.preprocessing import normalize
from sklearn.metrics import silhouette_score
from scipy.sparse import csgraph
from scipy.sparse.linalg import eigsh

import matplotlib.pyplot as plt

In [3]:
class BabySlakhDataset(Dataset):
    def __init__(self, data_dir, sample_rate=44100, normalize=True):
        self.data_dir = Path(data_dir)
        self.filepaths = list(self.data_dir.glob("*.wav"))
        self.sample_rate = sample_rate
        self.normalize = normalize

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        fpath = self.filepaths[idx]
        waveform, sr = torchaudio.load(fpath)

        if sr != self.sample_rate:
            resampler = T.Resample(sr, self.sample_rate)

        if self.normalize:
            max_val = waveform.abs().max()
            if max_val > 0:
                waveform = waveform / max_val
        return waveform, fpath.name 
    
data_dir = "musdb18_others"
dataset = BabySlakhDataset(data_dir, sample_rate=44100, normalize=True)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [2]:
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ratios = [4, 4, 4, 2]
encoder = SEANetEncoder(channels=2, dimension=512, n_filters=48, ratios=ratios).to(device)
decoder = SEANetDecoder(channels=2, dimension=512, n_filters=48, ratios=ratios, final_activation='Tanh').to(device)

#checkpoint = torch.load("best_model_long2.pth", map_location=device)
checkpoint = torch.load(r"C:\Users\barbo\Desktop\AML-\Music-Source-Separation-for-an-Additional-Instrument-encodec\Music-Source-Separation-for-an-Additional-Instrument-encodec\best_model_long2.pth\best_model_long2.pth", map_location=device)
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])
encoder.eval()
decoder.eval()
print(f"Loaded best model from epoch {checkpoint['epoch']} with loss {checkpoint['loss']:.4f}")


  WeightNorm.apply(module, name, dim)
  checkpoint = torch.load(r"C:\Users\barbo\Desktop\AML-\Music-Source-Separation-for-an-Additional-Instrument-encodec\Music-Source-Separation-for-an-Additional-Instrument-encodec\best_model_long2.pth\best_model_long2.pth", map_location=device)


Loaded best model from epoch 24 with loss 0.1882


In [10]:
# Find ground truth, and mix them together to create a new track
# Try separation on the new track and compare with the ground truth
import torchaudio
from torchaudio.transforms import Resample
import torch

import soundfile as sf

file1 = r"C:\Users\barbo\Desktop\AML-\Music-Source-Separation-for-an-Additional-Instrument-encodec\Music-Source-Separation-for-an-Additional-Instrument-encodec\110_F_AcousticGuitar_03_724.wav"
file2 = r"C:\Users\barbo\Desktop\AML-\Music-Source-Separation-for-an-Additional-Instrument-encodec\Music-Source-Separation-for-an-Additional-Instrument-encodec\110_F_Lofipiano_SP_10_406.wav"

target_duration = 6.8  
sample_rate = 44100 
output_file = "processed_combined_song.wav"

def preprocess_audio(file_path, target_duration, sample_rate):
    waveform, sr = torchaudio.load(file_path)

    if sr != sample_rate:
        resampler = Resample(sr, sample_rate)
        waveform = resampler(waveform)

    # Truncate or pad the waveform to target duration
    num_samples = int(target_duration * sample_rate)
    if waveform.size(1) > num_samples:
        # Truncate
        waveform = waveform[:, :num_samples]
    else:
        # Pad
        padding = num_samples - waveform.size(1)
        waveform = torch.nn.functional.pad(waveform, (0, padding))

    return waveform

# Preprocess both tracks
waveform1 = preprocess_audio(file1, target_duration, sample_rate)
waveform2 = preprocess_audio(file2, target_duration, sample_rate)

# Combine the two waveforms
combined_waveform = waveform1 + waveform2

# Normalize the combined waveform to prevent clipping
max_val = combined_waveform.abs().max()
if max_val > 0:
    combined_waveform = combined_waveform / max_val

# Convert to s16 (16-bit PCM)
combined_waveform = (combined_waveform * 32767).short()
sf.write(output_file, combined_waveform.numpy().T, sample_rate, subtype="PCM_16", format="WAV")
print(f"Processed combined track saved as {output_file}")


Processed combined track saved as processed_combined_song.wav


In [None]:
# The total downsampling factor is 4*4*4*2 = 128 for SEANet with the given ratios.
TOTAL_DOWNSAMPLE = 128
sample_rate = 44100

def pad_to_multiple(waveform: torch.Tensor, multiple: int):
    """
    Pad waveform (B x C x T) with zeros so that T is a multiple of 'multiple'.
    Returns padded_waveform, num_padded
    """
    b, c, t = waveform.shape
    remainder = t % multiple
    if remainder == 0:
        return waveform, 0
    pad_amount = multiple - remainder
    padded = nn.functional.pad(waveform, (0, pad_amount))
    return padded, pad_amount

with torch.no_grad():
    for batch_data in dataloader:
        waveform, sample_rate, fname, _ = batch_data
        waveform = waveform.to(device)  # (B=1, C, T)

        # If mono, duplicating to stereo
        if waveform.shape[1] == 1:
            waveform = waveform.repeat(1, 2, 1)

        # 1) Pading so that the encoder/decoder reconstructs the full length
        padded_waveform, pad_amount = pad_to_multiple(waveform, TOTAL_DOWNSAMPLE)

        # 2) Encoding to latent
        z = encoder(padded_waveform)  # shape: (B=1, feat_dim=512, T_down)
        # Rearranging to (B * T_down, feat_dim) for clustering
        z_flat = z.permute(0, 2, 1).reshape(-1, z.shape[1])  # (B*T_down, 512)

        # 3) Constructing a similarity matrix
        #    Typical spectral clustering uses L = D - A or normalized Laplacian
        #    but for demonstration we still build from rbf and then compute L.
        z_flat_norm = normalize(z_flat.cpu().numpy())  # shape (N, 512)
        gamma = 0.1 
        # RBF kernel
        dist_sq = sklearn.metrics.pairwise.euclidean_distances(z_flat_norm, squared=True)
        similarity_matrix = np.exp(-gamma * dist_sq)

        # Building the normalized Laplacian for standard spectral clustering
        # L = D^(-1/2) (D - A) D^(-1/2)
        # We'll use scipy's built-in function: csgraph.laplacian
        laplacian, diag = csgraph.laplacian(similarity_matrix, normed=True, return_diag=True)

        # 4) Choosing the number of clusters k
        #    There are many ways to choose k. Here are two common approaches:
        #    Option A: Using an eigengap heuristic on the *smallest* eigenvalues of L
        #    Option B: Trying k in [2..max_clusters], and picking the best silhouette.
        max_clusters = 4
        possible_ks = range(2, max_clusters+1)

        # (A) Eigen-decomposing the normalized Laplacian
        #     We want the SMALLEST k eigenvalues (which='SM')
        #     The largest dimension we need is max_clusters.
        #     the 0th eigenvalue is often 0 with multiplicity of at least 1.
        w, v = eigsh(laplacian, k=max_clusters, which='SM')  # shape of w: (max_clusters,)

        # The eigenvalues w are in ascending order. The largest gap among w[0..max_clusters-1].
        # This is a common heuristic for choosing the number of clusters.
        #   w_sorted = np.sort(w)
        #   we compute gaps in consecutive sorted eigenvalues: w[i+1] - w[i]
        #   then pick i that yields the largest gap => i+1 clusters
        w_sorted = np.sort(w)
        #gaps = w_sorted[1:] - w_sorted[:-1]  # length max_clusters-1
        #best_k_gap = np.argmax(gaps) + 1

        # After w_sorted = np.sort(w):
        skip_first = 1  # often the smallest eigenvalue is near 0 
        w_skip = w_sorted[skip_first:]  # ignoring the first eigenvalue
        gaps = w_skip[1:] - w_skip[:-1]
        best_k_gap_idx = np.argmax(gaps)
        best_k_gap = best_k_gap_idx + 1 + skip_first  # +1 for 1-based index, +skip_first offset
        best_k_gap = max(2, min(best_k_gap, max_clusters))  # clamp in [2, max_clusters]



        # (B) Alternative: silhouette-based approach:
        # silhouette_scores = []
        # for test_k in possible_ks:
        #     sc = SpectralClustering(n_clusters=test_k, affinity="precomputed", random_state=42)
        #     test_labels = sc.fit_predict(similarity_matrix)
        #     score = silhouette_score(z_flat_norm, test_labels)
        #     silhouette_scores.append(score)
        # best_k_silhouette = possible_ks[np.argmax(silhouette_scores)]

        num_clusters = best_k_gap
        print(f"[{fname}] - best_k_gap = {best_k_gap} (raw eigenvalues: {w_sorted})")

        # 5) Running spectral clustering with the chosen k
        spectral = SpectralClustering(
            n_clusters=num_clusters,
            affinity="precomputed",
            random_state=42,
        )
        cluster_labels = spectral.fit_predict(similarity_matrix)  # shape (N,)

        # Reshaping cluster_labels to (B=1, T_down)
        # We had z.shape = (1, 512, T_down), so N = T_down for a single batch
        cluster_labels = torch.tensor(cluster_labels, device=device).view(z.shape[0], z.shape[2])  # (B=1, T_down)

        # 6) Building masks and decode each cluster
        separated_sources = []
        for cluster_id in range(num_clusters):
            mask = (cluster_labels == cluster_id).float()  # shape (B=1, T_down)
            mask = mask.unsqueeze(1)  # (B=1, 1, T_down)
            z_masked = z * mask  # (B=1, 512, T_down)
            # Decoding
            separated_source = decoder(z_masked, input_length=padded_waveform.shape[-1])  # shape (B=1, C=2, T)
            # Removing the extra pad at the end if we added any
            if pad_amount > 0:
                separated_source = separated_source[..., :-pad_amount]
            separated_sources.append(separated_source)

        output_root = Path("separated_sources")
        output_root.mkdir(exist_ok=True, parents=True)

        # 7) Saving separated sources
        for i, source in enumerate(separated_sources):
            out_path = output_root / f"{fname}_source_{i}.wav"
            # Squeeze batch dim -> (C, T)
            torchaudio.save(out_path, source.squeeze(0).cpu(), sample_rate=sample_rate)
            print(f"Saved: {out_path}")


In [21]:
# Cluster a single song:

import torch
import torch.nn as nn
import torchaudio
import numpy as np
from sklearn.cluster import SpectralClustering
from sklearn.preprocessing import normalize
from scipy.sparse.linalg import eigsh
from scipy.sparse import csgraph
from mir_eval.separation import bss_eval_sources
from pathlib import Path

TOTAL_DOWNSAMPLE = 128
sample_rate = 44100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mixture_file = "processed_combined_song.wav"
original_sources = ["original_stem1.wav", "original_stem2.wav"]

def pad_to_multiple(waveform: torch.Tensor, multiple: int):
    b, c, t = waveform.shape
    remainder = t % multiple
    if remainder == 0:
        return waveform, 0
    pad_amount = multiple - remainder
    padded = nn.functional.pad(waveform, (0, pad_amount))
    return padded, pad_amount

def preprocess_audio(file_path, target_sample_rate, duration=None):
    waveform, sr = torchaudio.load(file_path)
    if sr != target_sample_rate:
        resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
        waveform = resampler(waveform)
    if duration is not None:
        num_samples = int(target_sample_rate * duration)
        if waveform.shape[1] > num_samples:
            waveform = waveform[:, :num_samples]
        else:
            waveform = nn.functional.pad(waveform, (0, num_samples - waveform.shape[1]))
    return waveform

waveform = preprocess_audio(mixture_file, sample_rate).unsqueeze(0).to(device)

if waveform.shape[1] == 1:
    waveform = waveform.repeat(1, 2, 1)

padded_waveform, pad_amount = pad_to_multiple(waveform, TOTAL_DOWNSAMPLE)

z = encoder(padded_waveform)  # shape: (B=1, feat_dim=512, T_down)
z_flat = z.permute(0, 2, 1).reshape(-1, z.shape[1])  # (B*T_down, 512)

z_flat_norm = normalize(z_flat.detach().cpu().numpy())
gamma = 0.1
dist_sq = sklearn.metrics.pairwise.euclidean_distances(z_flat_norm, squared=True)
similarity_matrix = np.exp(-gamma * dist_sq)

laplacian, _ = csgraph.laplacian(similarity_matrix, normed=True, return_diag=True)

w, _ = eigsh(laplacian, k=4, which="SM")
num_clusters = max(2, np.argmax(np.diff(w[1:])) + 2)  # Skip trivial eigenvalue
spectral = SpectralClustering(n_clusters=num_clusters, affinity="precomputed", random_state=42)
cluster_labels = spectral.fit_predict(similarity_matrix)

cluster_labels = torch.tensor(cluster_labels, device=device).view(z.shape[0], z.shape[2])
separated_sources = []
for cluster_id in range(num_clusters):
    mask = (cluster_labels == cluster_id).float().unsqueeze(1)
    z_masked = z * mask
    separated_source = decoder(z_masked, input_length=padded_waveform.shape[-1])
    if pad_amount > 0:
        separated_source = separated_source[..., :-pad_amount]
    separated_sources.append(separated_source)

output_root = Path("separated_sources")
output_root.mkdir(exist_ok=True, parents=True)
for i, source in enumerate(separated_sources):
    out_path = output_root / f"source_{i}.wav"
    source_int16 = (source.squeeze(0).detach().cpu() * 32767).type(torch.int16)
    torchaudio.save(out_path, source_int16, sample_rate=sample_rate)

In [1]:
import librosa
import mir_eval
import numpy as np

ref, sr = librosa.load('110_F_AcousticGuitar_03_724.wav', sr=None)
est, _ = librosa.load(r"C:\Users\barbo\Desktop\AML-\Music-Source-Separation-for-an-Additional-Instrument-encodec\Music-Source-Separation-for-an-Additional-Instrument-encodec\separated_sources\source_0.wav", sr=None)

ref_sources = np.stack([ref], axis=0)
est_sources = np.stack([est], axis=0)

sdr, sir, sar, _ = mir_eval.separation.bss_eval_sources(ref_sources, est_sources)

print(f"SDR: {sdr}, SIR: {sir}, SAR: {sar}")

SDR: [-4.09080358], SIR: [inf], SAR: [-4.09080358]


In [None]:
import torch
import torchaudio
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np
import os

from sklearn.cluster import KMeans, DBSCAN
from sklearn.preprocessing import normalize

from scipy import linalg

import matplotlib.pyplot as plt
import soundfile as sf

MAX_SECONDS = 5  
SR = 44100 

output_root = Path("latent_separation_eigengap")
output_root.mkdir(exist_ok=True, parents=True)

def build_latent_similarity(z_np, alpha=2, rbf_sigma=None):
    """
    Construct NxN similarity matrix from latent frames z_np.
      z_np.shape => (latent_channels, latent_length)
    We'll interpret each 'time step' in the latent domain as a vector in R^(latent_channels).
    
    - We optionally apply an RBF kernel with sigma=rbf_sigma.
    - Then, we do a final normalization step so that the matrix is <= 1.0
    - Output: similarity_matrix (NxN), distance_matrix (NxN).
    """
    # Transpose => shape (N, D) with N=latent_length, D=latent_channels
    z_T = z_np.T  # shape => (N, D)
    N, D = z_T.shape
    
    # dist_mat[i,j] = ||z_T[i] - z_T[j]||^2
    diffs = z_T[:, None, :] - z_T[None, :, :]  # shape => (N, N, D)
    dist_mat_sq = np.sum(diffs**2, axis=-1)    # shape => (N, N)
    
    if rbf_sigma is None:
        median_dist = np.median(dist_mat_sq)
        if median_dist < 1e-12:
            median_dist = 1.0
        rbf_sigma = np.sqrt(median_dist)
    
    # RBF kernel => sim[i,j] = exp(-||z_i - z_j||^2 / (2*sigma^2))
    sim = np.exp(-dist_mat_sq / (2 * (rbf_sigma**2)))
    
    if alpha != 1:
        sim = sim**alpha

    max_val = sim.max()
    if max_val > 1.0:
        sim = sim / (max_val + 1e-12)

    dist = -np.log(sim + 1e-12)
    dist = np.clip(dist, 0, None)
    np.fill_diagonal(dist, 0.0)
    
    return sim, dist


class spectral_clustering_eigengap:
    """
    Spectral clustering with eigengap. 
    This is analogous to your 'spectral_clustering' class,
    but we don't need partial-based Laplacian from STFT.
    We'll directly build L from the NxN adjacency 'A' (the similarity matrix).
    """
    def __init__(self, similarity_matrix):
        self.A = similarity_matrix 
        self.D = np.diag(np.sum(self.A, axis=1))
        self.L = self.D - self.A
        self.eigen_vals, self.eigen_vecs = linalg.eigh(self.L, self.D)

    def get_k(self):
        """
        Use the largest eigengap to pick the number of clusters.
        """
        eigen_vals = self.eigen_vals
        n = len(eigen_vals)
        max_gap = 0.0
        k = 1
        for i in range(n - 1):
            gap = abs(eigen_vals[i + 1] - eigen_vals[i])
            if gap > max_gap:
                max_gap = gap
                k = i + 1
        print("eigengap heuristic suggests:", k)
        return k

    def fit(self, k=None):
        """
        Cluster via KMeans on the top-k eigenvectors of L (skipping index=0).
        """
        if k is None:
            k = self.get_k()

        X = self.eigen_vecs[:, 1 : (1 + k)]

        X_normed = normalize(X, norm='l2')
        
        pred_label = KMeans(n_clusters=k, random_state=42).fit_predict(X_normed)
        return pred_label


def cluster_latent_dbscan(dist_matrix, eps=0.5, min_samples=5, max_clusters=4):
    """
    Perform DBSCAN in the latent domain, but forcibly ensure
    the final # of clusters <= max_clusters by merging if needed.
    
    - dist_matrix: NxN distance matrix
    - eps, min_samples: DBSCAN hyperparams
    - max_clusters: if #clusters > this, we do a naive merging approach.
    """
    db = DBSCAN(eps=eps, min_samples=min_samples, metric='precomputed')
    labels = db.fit_predict(dist_matrix)
    
    unique_l = np.unique(labels)
    positives = [u for u in unique_l if u >= 0]
    if len(positives) > max_clusters:
        print(f"DBSCAN found {len(positives)} clusters, merging to {max_clusters} ...")
        positives_sorted = sorted(positives)
        keep = positives_sorted[: (max_clusters - 1)]
        merge_rest = positives_sorted[(max_clusters - 1) : ]

        new_labels = labels.copy()
        for i in range(len(new_labels)):
            lab = new_labels[i]
            if lab in keep:
                pass
            elif lab in merge_rest:
                new_labels[i] = keep[-1] 
        labels = new_labels
    
    return labels


with torch.no_grad():
    for batch_data in dataloader:
        waveform, fname = batch_data
        if isinstance(fname, (list, tuple)):
            fname = fname[0]

        waveform = waveform.to("cpu")

        if waveform.shape[1] == 1:
            waveform = waveform.repeat(1, 2, 1)

        wave_len = waveform.shape[-1]
        max_samples = SR * MAX_SECONDS
        if wave_len > max_samples:
            waveform = waveform[..., :max_samples]

        z = encoder(waveform)  
        z_np = z.squeeze(0).cpu().numpy()  

        similarity_matrix, distance_matrix = build_latent_similarity(z_np, alpha=1.0, rbf_sigma=None)

        print(f"\n=== Processing {fname} ===")
        print(f"latent shape: {z_np.shape}, similarity_matrix: {similarity_matrix.shape}")

        sc = spectral_clustering_eigengap(similarity_matrix)
        k_estimated = sc.get_k()
        labels_spectral = sc.fit(k=k_estimated)

        labels = labels_spectral

        unique_labels = np.unique(labels)
        n_clusters = sum(u >= 0 for u in unique_labels)
        print(f"Clusters found: {n_clusters} (eigengap => k={k_estimated})")

        separated_signals = []
        positives = [u for u in unique_labels if u >= 0] 
        for c_id in positives:
            z_masked = np.zeros_like(z_np)
            z_masked[:, labels == c_id] = z_np[:, labels == c_id]

            z_masked_torch = torch.from_numpy(z_masked).unsqueeze(0).float()
            z_masked_torch = z_masked_torch.to(waveform.device)

            separated_decoded = decoder(z_masked_torch)
            separated_audio = separated_decoded.squeeze(0).cpu().numpy()
            separated_signals.append(separated_audio)

        out_folder = output_root / fname
        out_folder.mkdir(parents=True, exist_ok=True)

        for i, audio_c in enumerate(separated_signals):
            out_path = out_folder / f"cluster_{i}.wav"
            sf.write(str(out_path), audio_c.T, SR)
        
        print(f"Saved {len(separated_signals)} clusters to '{out_folder}'\n")

In [None]:
import torch
import torch.nn.functional as F

def compute_cluster_entropy(
    separated_signals,
    sample_rate=44100,
    n_fft=1024,
    hop_length=512
):
    """
    Compute the Shannon entropy of frequency distribution for each separated signal.

    Arguments:
        separated_signals (List[torch.Tensor]):
            - Each tensor is shape [C, T] or [T] (mono). 
            - If multi-channel, we'll just pick one channel or average channels.
        sample_rate (int): Audio sample rate (not mandatory, but can be useful if you'd like
                           to adapt window sizes based on sample rate).
        n_fft (int): Number of FFT bins.
        hop_length (int): STFT hop length (window shift).

    Returns:
        entropies (List[float]): One entropy value per signal in `separated_signals`.
    """
    entropies = []

    for idx, sig in enumerate(separated_signals):
        if sig.dim() == 2:
            sig = sig[0]  

        stft_out = torch.stft(
            sig, 
            n_fft=n_fft, 
            hop_length=hop_length, 
            window=None, 
            center=True, 
            normalized=False, 
            onesided=True, 
            return_complex=True
        )
        mag = stft_out.abs() 
        freq_sum = mag.sum(dim=1)

        total_energy = freq_sum.sum()
        if total_energy < 1e-12:
            entropies.append(0.0)
            continue

        p_f = freq_sum / total_energy
        p_f = torch.clamp(p_f, min=1e-12)
        H = - (p_f * torch.log2(p_f)).sum()

        entropies.append(H.item())

    return entropies

In [None]:
import torch
import torch.nn.functional as F
import math

def compute_cluster_sparsity(
    separated_signals,
    sample_rate=44100,
    n_fft=1024,
    hop_length=512
):
    """
    Compute spectral sparsity for each separated signal using Hoyer's measure.

    Arguments:
        separated_signals (List[torch.Tensor]):
            - Each tensor is shape [C, T] or [T] (mono).
        sample_rate (int): Audio sample rate (not strictly needed, but kept for consistency).
        n_fft (int): Number of FFT bins.
        hop_length (int): STFT hop length (window shift).

    Returns:
        sparsities (List[float]): One sparsity value per signal in 'separated_signals'.
    """
    sparsities = []

    for sig in separated_signals:
        if sig.dim() == 2:
            sig = sig[0] 

        stft_out = torch.stft(
            sig,
            n_fft=n_fft,
            hop_length=hop_length,
            window=None,
            center=True,
            normalized=False,
            onesided=True,
            return_complex=True  
        )
        mag = stft_out.abs() 
        freq_sum = mag.sum(dim=1) 
        n = freq_sum.shape[0]

        if freq_sum.sum() < 1e-12:
            sparsities.append(0.0)
            continue

        l1 = freq_sum.sum().item()                     
        l2 = freq_sum.norm(2).item()                   
        sqrt_n = math.sqrt(n)

        hoyer_num = sqrt_n - (l1 / l2)
        hoyer_den = sqrt_n - 1.0

        sparsity_val = max(0.0, min(hoyer_num / hoyer_den, 1.0))
        sparsities.append(sparsity_val)

    return sparsities

In [None]:
import torch

def compute_energy_distribution(separated_signals, mixture):
    """
    Compute energy distribution of separated sources relative to the total mixture energy.

    Args:
        separated_signals (List[torch.Tensor]):
            Each is shape [channels, time] or [time].
        mixture (torch.Tensor):
            The original mixture, shape [channels, time] or [time].

    Returns:
        energy_ratios (List[float]):
            Each ratio = (energy of cluster i) / (energy of mixture).
        sum_of_ratios (float):
            Sum of all cluster ratios, ideally ~1.0 if energy is preserved.
        total_mix_energy (float):
            Energy of the mixture, for reference.
    """
    if mixture.dim() == 2:
        total_mix_energy = mixture.pow(2).sum().item()
    else:
        total_mix_energy = mixture.pow(2).sum().item()

    if total_mix_energy < 1e-12:
        n_signals = len(separated_signals)
        return [0.0] * n_signals, 0.0, 0.0

    energy_ratios = []
    for sig in separated_signals:
        if sig.dim() == 2:
            E_i = sig.pow(2).sum().item()
        else:
            E_i = sig.pow(2).sum().item()

        ratio = E_i / total_mix_energy
        energy_ratios.append(ratio)

    sum_of_ratios = sum(energy_ratios)

    return energy_ratios, sum_of_ratios, total_mix_energy

In [None]:
import torch
import matplotlib.pyplot as plt

def visualize_spectrograms(mixture, separated_signals, sample_rate=44100, n_fft=1024, hop_length=512):
    """
    Plot the spectrogram of `mixture` and each separated signal in subplots.
    
    Args:
        mixture (torch.Tensor): shape [C, T] or [T].
        separated_signals (List[torch.Tensor]]): each shape [C, T] or [T].
        sample_rate (int): for display
        n_fft (int): STFT size
        hop_length (int): hop length
    """
    if mixture.dim() == 2:
        mix_mono = mixture[0, :]  
    else:
        mix_mono = mixture

    total_plots = 1 + len(separated_signals)
    fig, axs = plt.subplots(total_plots, 1, figsize=(8, 2 * total_plots))
    
    axs[0].set_title("Mixture Spectrogram")
    mixture_stft = torch.stft(
        mix_mono, 
        n_fft=n_fft, 
        hop_length=hop_length, 
        window=None, 
        center=True,
        onesided=True,
        return_complex=True 
    )
    mix_mag = mixture_stft.abs().cpu().numpy()  
    axs[0].imshow(20 * np.log10(mix_mag + 1e-6), aspect='auto', origin='lower')
    axs[0].set_ylabel("Frequency Bin")
    
    for i, sig in enumerate(separated_signals, start=1):
        if sig.dim() == 2:
            sig_mono = sig[0, :]
        else:
            sig_mono = sig

        stft_out = torch.stft(
            sig_mono, 
            n_fft=n_fft,
            hop_length=hop_length,
            window=None,
            center=True,
            onesided=True,
            return_complex=True
        )
        mag = stft_out.abs().cpu().numpy()
        
        axs[i].imshow(20 * np.log10(mag + 1e-6), aspect='auto', origin='lower')
        axs[i].set_title(f"Source {i} Spectrogram")
        axs[i].set_ylabel("Frequency Bin")

    plt.tight_layout()
    plt.show()

In [None]:
# KMeans clustering
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import numpy as np
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics.pairwise import pairwise_distances
import scipy.sparse as sp
import scipy.sparse.linalg as spla
from pathlib import Path
import torch
import torchaudio

def compute_laplacian(data, sigma=1.0, mode='rbf'):
    dist = pairwise_distances(data, data, metric='euclidean')
    if mode == 'rbf':
        W = np.exp(-dist**2 / (2.0 * sigma**2))
    else:
        raise NotImplementedError("Only RBF kernel is implemented.")
    d = np.sum(W, axis=1)
    D_inv_sqrt = np.diag(1.0 / np.sqrt(d + 1e-12))
    W_tilde = D_inv_sqrt @ W @ D_inv_sqrt
    L = np.eye(W.shape[0]) - W_tilde
    return L

def spectral_embedding_eigengap(data, k_max=5, sigma=1.0, maxiter=3000, tol=1e-3):
    """
    Return a spectral embedding (dim chosen by largest eigen-gap).
    """
    N = data.shape[0]
    if N <= k_max:
        return data

    L = compute_laplacian(data, sigma=sigma, mode='rbf')
    L_sparse = sp.csr_matrix(L)

    try:
        eigenvals, eigenvects = spla.eigsh(
            L_sparse, k=k_max+1, which='SM',
            maxiter=maxiter, tol=tol,
            ncv=min(2*(k_max+1), N)
        )
    except spla.ArpackNoConvergence as e:
        print(f"[WARNING] ARPACK failed to converge: {e}")
        return data

    idx_sorted = np.argsort(eigenvals)
    eigenvals = eigenvals[idx_sorted]
    eigenvects = eigenvects[:, idx_sorted]

    lam_nontrivial = eigenvals[1:]
    gaps = np.diff(lam_nontrivial)
    i_gap = np.argmax(gaps)
    chosen_dim = i_gap + 1
    chosen_dim = max(2, min(chosen_dim, k_max))

    embedding = eigenvects[:, 1:1+chosen_dim]
    return embedding

def auto_kmeans(X, k_min=2, k_max=5):
    """
    Try KMeans for k in [k_min..k_max], pick the k with the highest silhouette score.
    Return (best_labels, best_k).
    """
    best_k = k_min
    best_score = -1
    best_labels = None

    for k in range(k_min, k_max + 1):
        km = KMeans(n_clusters=k, random_state=0)
        labels = km.fit_predict(X)

        unique_labels = np.unique(labels)
        if len(unique_labels) < 2:
            continue

        score = silhouette_score(X, labels)
        if score > best_score:
            best_score = score
            best_k = k
            best_labels = labels

    if best_labels is None:
        km = KMeans(n_clusters=k_min, random_state=0)
        best_labels = km.fit_predict(X)
        best_k = k_min

    return best_labels, best_k

def subsample_latent(z_t, max_frames=3000):
    """
    Optional: Subsample the latent space data to reduce frames.
    """
    T_enc = z_t.shape[0]
    if T_enc > max_frames:
        idxs = np.linspace(0, T_enc - 1, max_frames, dtype=int)
        z_t = z_t[idxs]
    return z_t

output_root = Path("separated_sources")
output_root.mkdir(exist_ok=True, parents=True)

with torch.no_grad():
    for batch_data in dataloader:
        waveform, fname = batch_data
        if isinstance(fname, (list, tuple)):
            fname = fname[0]

        print("\n=============================================")
        print(f"[INFO] Processing file: {fname}")

        waveform = waveform.to(device)
        if waveform.shape[1] == 1:
            waveform = waveform.repeat(1, 2, 1)

        z = encoder(waveform)  
        print(f"  [DEBUG] Encoded shape: {z.shape}")

        z_t = z.squeeze(0).permute(1, 0).cpu().numpy()  
        print(f"  [DEBUG] Flattened shape: {z_t.shape}")

        z_t = subsample_latent(z_t, max_frames=3000)
        print(f"  [DEBUG] After subsampling: {z_t.shape}")

        embedded = spectral_embedding_eigengap(
            z_t, k_max=5, sigma=1.0, maxiter=5000, tol=1e-2
        )
        print("  [DEBUG] Spectral embedding complete.")

        labels, chosen_k = auto_kmeans(embedded, k_min=2, k_max=5)
        print(f"[INFO] K-MEANS => final cluster count = {chosen_k}")

        out_dir = output_root / Path(fname).stem
        out_dir.mkdir(exist_ok=True, parents=True)

        z_original = z.clone()
        T_sub = embedded.shape[0]
        T_full = z_original.shape[-1]
        
        unique_labels = sorted(set(labels))

        full_mask = torch.zeros((1, 1, T_full), dtype=torch.float, device=device)
        cluster_index = 1
        for c_label in unique_labels:
            idx_c = np.where(labels == c_label)[0]
            mask = full_mask.clone()
            for i_c in idx_c:
                if i_c < T_full:
                    mask[0, 0, i_c] = 1.0

            z_masked = z_original * mask
            separated_waveform = decoder(z_masked).squeeze(0)

            out_path = out_dir / f"source_{cluster_index}.wav"
            torchaudio.save(str(out_path), separated_waveform.cpu(), 44100)
            print(f"  [DEBUG] Saved cluster_{cluster_index} (label={c_label}) -> {out_path}")
            cluster_index += 1

        entropies = compute_cluster_entropy(separated_signals, sample_rate=44100)
        sparsities = compute_cluster_sparsity(separated_signals, sample_rate=44100)

        for i, (H, S) in enumerate(zip(entropies, sparsities), start=1):
            print(f"[INFO] Cluster {i} Entropy: {H:.4f} | Sparsity: {S:.4f}")

        energy_ratios, sum_of_ratios, total_mix_energy = compute_energy_distribution(separated_signals, mix)

        print(f"\n[INFO] Energy Distribution for {fname}:")
        print(f"  Mixture energy = {total_mix_energy:.2f}")
        for i, ratio in enumerate(energy_ratios, start=1):
            print(f"  Cluster {i} => {ratio*100:.2f}% of mixture energy")
        print(f"  Sum of ratios = {sum_of_ratios:.3f} (ideally ~1.0)")

        visualize_spectrograms(mix, separated_signals, sample_rate=44100)

        print(f"[INFO] Done with {fname}.")

In [None]:
# Agglomerative clustering
import numpy as np
import torch
import torchaudio

import scipy.sparse as sp
import scipy.sparse.linalg as spla

from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import silhouette_score
from sklearn.metrics.pairwise import pairwise_distances
from pathlib import Path

def compute_laplacian(data, sigma=1.0, mode='rbf'):
    """
    Compute the normalized Laplacian using an RBF kernel.
    """
    dist = pairwise_distances(data, data, metric='euclidean')

    if mode == 'rbf':
        W = np.exp(-dist**2 / (2.0 * sigma**2))
    else:
        raise NotImplementedError("Only RBF kernel is implemented.")

    d = np.sum(W, axis=1)
    D_inv_sqrt = np.diag(1.0 / np.sqrt(d + 1e-12))
    W_tilde = D_inv_sqrt @ W @ D_inv_sqrt

    # Normalized Laplacian: L = I - D^(-1/2) * W * D^(-1/2)
    L = np.eye(W.shape[0]) - W_tilde
    return L


def spectral_embedding_eigengap(data, k_max=5, sigma=1.0, maxiter=3000, tol=1e-3):
    """
    Compute the spectral embedding for 'data', choosing embedding dimension
    via the largest eigengap among the smallest (k_max+1) eigenvalues.
    Returns the embedded data of shape [N, chosen_dim].
    """
    N = data.shape[0]
    if N <= k_max:
        return data

    L = compute_laplacian(data, sigma=sigma, mode='rbf')
    L_sparse = sp.csr_matrix(L)

    try:
        eigenvals, eigenvects = spla.eigsh(
            L_sparse, k=k_max + 1, which='SM',
            maxiter=maxiter, tol=tol,
            ncv=min(2*(k_max+1), N)
        )
    except spla.ArpackNoConvergence as e:
        print(f"[WARNING] ARPACK failed to converge: {e}")
        return data

    idx_sorted = np.argsort(eigenvals)
    eigenvals = eigenvals[idx_sorted]
    eigenvects = eigenvects[:, idx_sorted]

    lam_nontrivial = eigenvals[1:]
    gaps = np.diff(lam_nontrivial)
    i_gap = np.argmax(gaps)
    chosen_dim = i_gap + 1
    chosen_dim = max(2, min(chosen_dim, k_max))

    embedding = eigenvects[:, 1:1 + chosen_dim]
    return embedding

def auto_agglomerative(X, k_min=2, k_max=5):
    """
    Try AgglomerativeClustering for k in [k_min..k_max],
    pick the k with the highest silhouette score.
    Returns (best_labels, best_k).
    """
    best_k = k_min
    best_score = -1
    best_labels = None

    for k in range(k_min, k_max + 1):
        agg = AgglomerativeClustering(n_clusters=k)
        labels = agg.fit_predict(X)

        unique_labels = np.unique(labels)
        if len(unique_labels) < 2:
            continue

        score = silhouette_score(X, labels)
        if score > best_score:
            best_score = score
            best_k = k
            best_labels = labels

    if best_labels is None:
        agg = AgglomerativeClustering(n_clusters=k_min)
        best_labels = agg.fit_predict(X)
        best_k = k_min

    return best_labels, best_k

def subsample_latent(z_t, max_frames=3000):
    """
    Downsample the time dimension to <= max_frames.
    """
    T_enc = z_t.shape[0]
    if T_enc > max_frames:
        idxs = np.linspace(0, T_enc - 1, max_frames, dtype=int)
        z_t = z_t[idxs]
    return z_t

output_root = Path("separated_sources")
output_root.mkdir(exist_ok=True, parents=True)

with torch.no_grad():
    for batch_data in dataloader:
        waveform, fname = batch_data
        if isinstance(fname, (list, tuple)):
            fname = fname[0]

        print("\n=============================================")
        print(f"[INFO] Processing file: {fname}")

        waveform = waveform.to(device)
        if waveform.shape[1] == 1:
            waveform = waveform.repeat(1, 2, 1)

        z = encoder(waveform)  
        print(f"  [DEBUG] Encoded shape: {z.shape}")

        z_t = z.squeeze(0).permute(1, 0).cpu().numpy() 
        print(f"  [DEBUG] Flattened shape: {z_t.shape}")

        z_t = subsample_latent(z_t, max_frames=3000)
        print(f"  [DEBUG] After subsampling: {z_t.shape}")

        embedded = spectral_embedding_eigengap(
            z_t, k_max=5, sigma=1.0, maxiter=5000, tol=1e-2
        )
        print("  [DEBUG] Spectral embedding complete.")

        labels, chosen_k = auto_agglomerative(embedded, k_min=2, k_max=5)
        print(f"[INFO] Agglomerative => final cluster count = {chosen_k}")

        out_dir = output_root / Path(fname).stem
        out_dir.mkdir(exist_ok=True, parents=True)

        z_original = z.clone()
        T_sub = embedded.shape[0]
        T_full = z_original.shape[-1]

        unique_labels = sorted(set(labels))

        full_mask = torch.zeros((1, 1, T_full), dtype=torch.float, device=device)
        cluster_index = 1
        for c_label in unique_labels:
            idx_c = np.where(labels == c_label)[0]
            mask = full_mask.clone()

            for i_c in idx_c:
                if i_c < T_full:
                    mask[0, 0, i_c] = 1.0

            z_masked = z_original * mask
            separated_waveform = decoder(z_masked).squeeze(0)

            out_path = out_dir / f"source_{cluster_index}.wav"
            torchaudio.save(str(out_path), separated_waveform.cpu(), 44100)
            print(f"  [DEBUG] Saved cluster_{cluster_index} (label={c_label}) -> {out_path}")
            cluster_index += 1

        entropies = compute_cluster_entropy(separated_signals, sample_rate=44100)
        sparsities = compute_cluster_sparsity(separated_signals, sample_rate=44100)

        for i, (H, S) in enumerate(zip(entropies, sparsities), start=1):
            print(f"[INFO] Cluster {i} Entropy: {H:.4f} | Sparsity: {S:.4f}")

        energy_ratios, sum_of_ratios, total_mix_energy = compute_energy_distribution(separated_signals, mix)

        print(f"\n[INFO] Energy Distribution for {fname}:")
        print(f"  Mixture energy = {total_mix_energy:.2f}")
        for i, ratio in enumerate(energy_ratios, start=1):
            print(f"  Cluster {i} => {ratio*100:.2f}% of mixture energy")
        print(f"  Sum of ratios = {sum_of_ratios:.3f} (ideally ~1.0)")

        print(f"[INFO] Done with {fname}.")

In [None]:
# DBSCAN clustering
import numpy as np
import torch
import torchaudio

import scipy.sparse as sp
import scipy.sparse.linalg as spla

from sklearn.cluster import DBSCAN
from sklearn.metrics.pairwise import pairwise_distances
from pathlib import Path

import torch.nn.functional as F

def compute_laplacian(data, sigma=1.0, mode='rbf'):
    dist = pairwise_distances(data, data, metric='euclidean')
    if mode == 'rbf':
        W = np.exp(-dist**2 / (2.0 * sigma**2))
    else:
        raise NotImplementedError("Only RBF kernel is implemented.")
    d = np.sum(W, axis=1)
    D_inv_sqrt = np.diag(1.0 / np.sqrt(d + 1e-12))
    W_tilde = D_inv_sqrt @ W @ D_inv_sqrt
    L = np.eye(W.shape[0]) - W_tilde
    return L

def spectral_embedding_eigengap(data, k_max=5, sigma=1.0, maxiter=3000, tol=1e-3):
    """
    Return a spectral embedding (dim chosen by largest eigen-gap).
    """
    N = data.shape[0]
    if N <= k_max:
        return data

    L = compute_laplacian(data, sigma=sigma, mode='rbf')
    L_sparse = sp.csr_matrix(L)

    try:
        eigenvals, eigenvects = spla.eigsh(
            L_sparse, k=k_max+1, which='SM',
            maxiter=maxiter, tol=tol,
            ncv=min(2*(k_max+1), N)
        )
    except spla.ArpackNoConvergence as e:
        print(f"[WARNING] ARPACK failed to converge: {e}")
        return data

    idx_sorted = np.argsort(eigenvals)
    eigenvals = eigenvals[idx_sorted]
    eigenvects = eigenvects[:, idx_sorted]

    lam_nontrivial = eigenvals[1:]
    gaps = np.diff(lam_nontrivial)
    i_gap = np.argmax(gaps)
    chosen_dim = i_gap + 1
    chosen_dim = max(2, min(chosen_dim, k_max))

    embedding = eigenvects[:, 1:1+chosen_dim]
    return embedding

def dbscan_only(embedded_data, eps_list=None, min_samples_list=None):
    """
    Attempt DBSCAN with combinations of (eps, min_samples).
    Pick the result with the largest number of clusters (excluding noise).
    If zero clusters, fallback to a single cluster.

    Returns:
        best_labels: np.array of shape [N]
        best_n_clusters: int (the number of clusters excluding noise, or 1 if fallback)
    """
    if eps_list is None:
        eps_list = [1.0, 0.7, 0.5, 0.3, 0.1]
    if min_samples_list is None:
        min_samples_list = [5, 10, 20]

    best_labels = None
    best_n_clusters = 0
    best_settings = (None, None)

    for eps in eps_list:
        for ms in min_samples_list:
            print(f"[DEBUG] Trying DBSCAN with eps={eps}, min_samples={ms}")
            db = DBSCAN(eps=eps, min_samples=ms)
            labels = db.fit_predict(embedded_data)

            cluster_labels = set(labels) - {-1}
            n_clusters = len(cluster_labels)
            print(f"   => Found {n_clusters} clusters (excluding noise).")

            if n_clusters > best_n_clusters:
                best_labels = labels
                best_n_clusters = n_clusters
                best_settings = (eps, ms)

    # print(f"[INFO] Best DBSCAN => {best_n_clusters} cluster(s) with eps={best_settings[0]}, min_samples={best_settings[1]}.")

    if best_n_clusters == 0:
        print("[WARNING] DBSCAN found 0 clusters. Fallback to 1 cluster (everything).")
        N = embedded_data.shape[0]
        best_labels = np.zeros(N, dtype=int)  
        best_n_clusters = 1

    return best_labels, best_n_clusters

def subsample_latent(z_t, max_frames=3000):
    T_enc = z_t.shape[0]
    if T_enc > max_frames:
        idxs = np.linspace(0, T_enc - 1, max_frames, dtype=int)
        z_t = z_t[idxs]
    return z_t

def reconstruction_loss(original_mixture, separated_sources):
    """
    original_mixture: shape [channels, time] (or [time,])
    separated_sources: list of Tensors each of shape [channels, time]
    Returns scalar MSE between original_mixture and sum of separated_sources.
    """
    if not separated_sources:
        return torch.tensor(0.0, device=original_mixture.device)

    summed_sources = torch.stack(separated_sources, dim=0).sum(dim=0)
    return F.mse_loss(summed_sources, original_mixture)

output_root = Path("separated_sources_dbscan_only")
output_root.mkdir(exist_ok=True, parents=True)

with torch.no_grad():
    for batch_data in dataloader:
        waveform, fname = batch_data
        if isinstance(fname, (list, tuple)):
            fname = fname[0]

        print("\n=============================================")
        print(f"[INFO] Processing file: {fname}")

        waveform = waveform.to(device)
        if waveform.shape[1] == 1:
            waveform = waveform.repeat(1, 2, 1)  

        z = encoder(waveform) 
        print(f"  [DEBUG] Encoded shape: {z.shape}")

        z_t = z.squeeze(0).permute(1, 0).cpu().numpy()
        print(f"  [DEBUG] Flattened shape: {z_t.shape}")

        z_t = subsample_latent(z_t, max_frames=3000)
        print(f"  [DEBUG] After subsampling: {z_t.shape}")

        embedded = spectral_embedding_eigengap(z_t, k_max=5, sigma=1.0, maxiter=5000, tol=1e-2)
        print("  [DEBUG] Spectral embedding complete.")

        labels, chosen_k = dbscan_only(embedded, eps_list=[1.0, 0.7, 0.5, 0.3, 0.1],
                                       min_samples_list=[5, 10, 20])
        print(f"[INFO] DBSCAN => final cluster count = {chosen_k}")

        out_dir = output_root / Path(fname).stem
        out_dir.mkdir(exist_ok=True, parents=True)

        z_original = z.clone()  
        T_full = z_original.shape[-1]

        if chosen_k > 1:
            unique_labels = sorted(set(labels) - {-1})
        else:
            unique_labels = sorted(set(labels))

        if not unique_labels:
            print("[WARNING] All points were noise. Forcing single cluster output.")
            unique_labels = [0]
            labels = np.zeros_like(labels)  

        separated_signals = []
        cluster_index = 1

        mix = waveform.squeeze(0) 

        full_mask = torch.zeros((1, 1, T_full), dtype=torch.float, device=device)
        for c_label in unique_labels:
            idx_c = np.where(labels == c_label)[0]
            mask = full_mask.clone()

            for i_c in idx_c:
                if i_c < T_full:
                    mask[0, 0, i_c] = 1.0

            z_masked = z_original * mask
            separated_waveform = decoder(z_masked).squeeze(0)  

            if separated_waveform.shape[-1] > mix.shape[-1]:
                separated_waveform = separated_waveform[..., :mix.shape[-1]]
            elif separated_waveform.shape[-1] < mix.shape[-1]:
                diff = mix.shape[-1] - separated_waveform.shape[-1]
                separated_waveform = torch.nn.functional.pad(
                    separated_waveform, (0, diff)
                )

            separated_signals.append(separated_waveform)

            out_path = out_dir / f"source_{cluster_index}.wav"
            torchaudio.save(str(out_path), separated_waveform.cpu(), 44100)
            print(f"  [DEBUG] Saved cluster_{cluster_index} (label={c_label}) -> {out_path}")

            cluster_index += 1

        entropies = compute_cluster_entropy(separated_signals, sample_rate=44100)
        sparsities = compute_cluster_sparsity(separated_signals, sample_rate=44100)

        for i, (H, S) in enumerate(zip(entropies, sparsities), start=1):
            print(f"[INFO] Cluster {i} Entropy: {H:.4f} | Sparsity: {S:.4f}")

        energy_ratios, sum_of_ratios, total_mix_energy = compute_energy_distribution(separated_signals, mix)

        print(f"\n[INFO] Energy Distribution for {fname}:")
        print(f"  Mixture energy = {total_mix_energy:.2f}")
        for i, ratio in enumerate(energy_ratios, start=1):
            print(f"  Cluster {i} => {ratio*100:.2f}% of mixture energy")
        print(f"  Sum of ratios = {sum_of_ratios:.3f} (ideally ~1.0)")
            
        if separated_signals:
            loss_val = reconstruction_loss(mix, separated_signals)
            print(f"[INFO] Reconstruction MSE for {fname}: {loss_val.item()}")

        print(f"[INFO] Done with {fname}.")

