In [9]:
from collections import defaultdict
import pandas as pd
from typing import *
import math
import torchaudio
import torch
from pathlib import Path
import IPython.display as ipd
from pesq import pesq
import numpy as np
import os


def compute_pesq(ref, deg, fs):
    """
    Compute the PESQ score.

    Args:
    ref : numpy.ndarray
        Reference audio signal.
    deg : numpy.ndarray
        Degraded audio signal.
    fs : int
        Sampling frequency of the audio signals.

    Returns:
    float
        The PESQ score.
    """
    ref = ref[0].numpy()
    deg = deg[0].numpy()
    # Ensure audio is not silent and has sufficient length
    if len(ref) < fs * 0.3:  # Ensure at least 300 ms long
        return float('nan')  # Return NaN if too short for PESQ
    if np.all(ref == 0) or np.all(deg == 0):
        return float('nan')  # Return NaN if silent

    # Calculate PESQ
    try:
        return pesq(fs, ref, deg, 'wb')  # 'wb' is for wideband, use 'nb' for narrowband
    except Exception as e:
        print(f"Error computing PESQ: {e}")
        return float('nan')  # Return NaN on error

def compute_ssnr(ref, deg, frame_size=256, eps=1e-8):
    """
    Compute the Segmental Signal-to-Noise Ratio (SSNR).

    Args:
    ref : numpy.ndarray
        Reference audio signal.
    deg : numpy.ndarray
        Degraded audio signal.
    frame_size : int
        The size of each frame for SSNR computation.
    eps : float
        Small number to avoid division by zero.

    Returns:
    float
        The mean SSNR over all frames.
    """
    ref = ref[0].numpy()
    deg = deg[0].numpy()
    ssnr_values = []
    for start in range(0, len(ref) - frame_size, frame_size):
        ref_frame = ref[start:start + frame_size]
        # if np.all(ref_frame == 0):
        #     continue  # Skip completely silent frames
        deg_frame = deg[start:start + frame_size]
        noise = ref_frame - deg_frame
        signal_energy = np.sum(ref_frame ** 2) + eps
        noise_energy = np.sum(noise ** 2) + eps
        ssnr_values.append(10 * np.log10(signal_energy / noise_energy))
    
    # Filter out non-finite values which might occur if the noise is zero
    ssnr_values = [x for x in ssnr_values if np.isfinite(x)]
    
    return np.mean(ssnr_values) if ssnr_values else 0 #float('nan') 


############################




def assert_is_audio(*signal: torch.Tensor):
    for s in signal:
        assert len(s.shape) == 2
        assert s.shape[0] == 1 or s.shape[0] == 2


def is_silent(signal: torch.Tensor, silence_threshold: float = 1.5e-5) -> bool:
    assert_is_audio(signal)
    num_samples = signal.shape[-1]
    return torch.linalg.norm(signal) / num_samples < silence_threshold


def sdr(preds: torch.Tensor, target: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
    s_target = torch.norm(target, dim=-1)**2 + eps
    s_error = torch.norm(target - preds, dim=-1)**2 + eps
    return 10 * torch.log10(s_target/s_error)




def sisnr(preds: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    alpha = (torch.sum(preds * target, dim=-1, keepdim=True) + eps) / (torch.sum(target**2, dim=-1, keepdim=True) + eps)
    target_scaled = alpha * target
    noise = target_scaled - preds
    s_target = torch.sum(target_scaled**2, dim=-1) + eps
    s_error = torch.sum(noise**2, dim=-1) + eps
    return 10 * torch.log10(s_target / s_error)


def evaluate_separations(
    separation_path: Union[str, Path],
    dataset_path: Union[str, Path],
    separation_sr: int,
    filter_single_source: bool = True,
    eps: float = 1e-8,
    chunk_duration: float = 4.0, 
    overlap_duration: float = 2.0,
) -> pd.DataFrame:

    separation_path = Path(separation_path)
    dataset_path = Path(dataset_path)
    
    df_entries = defaultdict(list)

    files = os.listdir(dataset_path)

    for file in files:


        try:
            # load seperated tracks and resample track
            separated_track, _ = torchaudio.load(separation_path / file) # load_chunks(separation_path/str(forder_id), ["mixture"])

            # load original track
            original_track, _ = torchaudio.load(dataset_path / file) #load_chunks(dataset_path/str(forder_id), ["mixture"])
        except:
            continue



        chunk_samples = int(chunk_duration * separation_sr)
        overlap_samples = int(overlap_duration * separation_sr)

        # Calculate the step size between consecutive sub-chunks
        step_size = chunk_samples - overlap_samples

        # Determine the number of evaluation chunks based on step_size
        num_eval_chunks = math.ceil((original_track.shape[-1] - overlap_samples) / step_size)-1
            
        for i in range(num_eval_chunks):
            start_sample = i * step_size
            end_sample = start_sample + chunk_samples
            
            # Determine number of active signals in sub-chunk
            num_active_signals = 0
            for k in separated_track:
                o = original_track[:,start_sample:end_sample]
                if not is_silent(o):
                    num_active_signals += 1
            
            # Skip sub-chunk if necessary
            if filter_single_source and num_active_signals < 1:
                continue

            # Compute SI-SNRi for each stem
            o = original_track[:,start_sample:end_sample]
            s = separated_track[:,start_sample:end_sample]

            # df_entries[k].append((sisnr(s, o, eps) ) - sisnr(m, o, eps))

            # df_entries[k].append((sdr(s, o, eps) ).item()) 
            df_entries["snr"].append((sisnr(s, o, eps) ).item()) #- sisnr(m, o, eps))

            # Compute PESQ
            pesq_score = 0.0 #compute_pesq(o, s, separation_sr)
            df_entries["pesq"].append(pesq_score)

            # Compute SSNR
            ssnr_score = 0.0 #compute_ssnr(o, s)
            df_entries["ssnr"].append(ssnr_score)


            # Add chunk and sub-chunk info to dataframe entry
            df_entries["start_sample"].append(start_sample)
            df_entries["end_sample"].append(end_sample)
            df_entries["file"].append(file)



    # Create and return dataframe
    return pd.DataFrame(df_entries)

base_dir = "/home/karchkhadze/testing_sisdr/samples/"

#############################################################

separation_dir = base_dir + "bighifigan"
dataset_path = base_dir + "gt"

# Compute metrics
results = evaluate_separations(separation_dir, dataset_path, 24000, eps=1e-8 )

print(f'\nResults bighifigan:{results["snr"].mean()}')

#############################################################

separation_dir = base_dir + "bigvgan"
dataset_path = base_dir + "gt"

# Compute metrics
results = evaluate_separations(separation_dir, dataset_path, 24000, eps=1e-8 )

print(f'\nResults bigvgan:{results["snr"].mean()}')

#############################################################

separation_dir = base_dir + "bigvgan-base"
dataset_path = base_dir + "gt"

# Compute metrics
results = evaluate_separations(separation_dir, dataset_path, 24000, eps=1e-8 )

print(f'\nResults bigvgan-base:{results["snr"].mean()}')

#############################################################

separation_dir = base_dir + "hifigan"
dataset_path = base_dir + "gt"

# Compute metrics
results = evaluate_separations(separation_dir, dataset_path, 24000, eps=1e-8 )

print(f'\nResults hifigan:{results["snr"].mean()}')

#############################################################

separation_dir = base_dir + "hifigan_mrd"
dataset_path = base_dir + "gt"

# Compute metrics
results = evaluate_separations(separation_dir, dataset_path, 24000, eps=1e-8 )

print(f'\nResults hifigan_mrd:{results["snr"].mean()}')


#############################################################

separation_dir = base_dir + "mrd_snake"
dataset_path = base_dir + "gt"

# Compute metrics
results = evaluate_separations(separation_dir, dataset_path, 24000, eps=1e-8 )

print(f'\nResults mrd_snake:{results["snr"].mean()}')

#############################################################

separation_dir = base_dir + "sc-wavernn"
dataset_path = base_dir + "gt"

# Compute metrics
results = evaluate_separations(separation_dir, dataset_path, 24000, eps=1e-8 )

print(f'\nResults sc-wavernn:{results["snr"].mean()}')

#############################################################

separation_dir = base_dir + "univnet"
dataset_path = base_dir + "gt"

# Compute metrics
results = evaluate_separations(separation_dir, dataset_path, 24000, eps=1e-8 )

print(f'\nResults univnet:{results["snr"].mean()}')


Results bighifigan:-26.898960584617523

Results bigvgan:-24.156062689172217

Results bigvgan-base:-28.501391025910895

Results hifigan:-28.573745922869946

Results hifigan_mrd:-28.35362121793959

Results mrd_snake:-17.38124105665419

Results sc-wavernn:-37.40817304404385

Results univnet:-28.953988155686712


In [31]:
from collections import defaultdict
import pandas as pd
from typing import *
import math
import torchaudio
import torch
from pathlib import Path
import IPython.display as ipd
from pesq import pesq
import numpy as np
import os


def compute_pesq(ref, deg, fs):
    """
    Compute the PESQ score.

    Args:
    ref : numpy.ndarray
        Reference audio signal.
    deg : numpy.ndarray
        Degraded audio signal.
    fs : int
        Sampling frequency of the audio signals.

    Returns:
    float
        The PESQ score.
    """
    ref = ref[0].numpy()
    deg = deg[0].numpy()
    # Ensure audio is not silent and has sufficient length
    if len(ref) < fs * 0.3:  # Ensure at least 300 ms long
        return float('nan')  # Return NaN if too short for PESQ
    if np.all(ref == 0) or np.all(deg == 0):
        return float('nan')  # Return NaN if silent

    # Calculate PESQ
    try:
        return pesq(fs, ref, deg, 'wb')  # 'wb' is for wideband, use 'nb' for narrowband
    except Exception as e:
        print(f"Error computing PESQ: {e}")
        return float('nan')  # Return NaN on error

def compute_ssnr(ref, deg, frame_size=256, eps=1e-8):
    """
    Compute the Segmental Signal-to-Noise Ratio (SSNR).

    Args:
    ref : numpy.ndarray
        Reference audio signal.
    deg : numpy.ndarray
        Degraded audio signal.
    frame_size : int
        The size of each frame for SSNR computation.
    eps : float
        Small number to avoid division by zero.

    Returns:
    float
        The mean SSNR over all frames.
    """
    ref = ref[0].numpy()
    deg = deg[0].numpy()
    ssnr_values = []
    for start in range(0, len(ref) - frame_size, frame_size):
        ref_frame = ref[start:start + frame_size]
        # if np.all(ref_frame == 0):
        #     continue  # Skip completely silent frames
        deg_frame = deg[start:start + frame_size]
        noise = ref_frame - deg_frame
        signal_energy = np.sum(ref_frame ** 2) + eps
        noise_energy = np.sum(noise ** 2) + eps
        ssnr_values.append(10 * np.log10(signal_energy / noise_energy))
    
    # Filter out non-finite values which might occur if the noise is zero
    ssnr_values = [x for x in ssnr_values if np.isfinite(x)]
    
    return np.mean(ssnr_values) if ssnr_values else 0 #float('nan') 


############################




def assert_is_audio(*signal: torch.Tensor):
    for s in signal:
        assert len(s.shape) == 2
        assert s.shape[0] == 1 or s.shape[0] == 2


def is_silent(signal: torch.Tensor, silence_threshold: float = 1.5e-5) -> bool:
    assert_is_audio(signal)
    num_samples = signal.shape[-1]
    return torch.linalg.norm(signal) / num_samples < silence_threshold


def sdr(preds: torch.Tensor, target: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
    s_target = torch.norm(target, dim=-1)**2 + eps
    s_error = torch.norm(target - preds, dim=-1)**2 + eps
    return 10 * torch.log10(s_target/s_error)




def sisnr(preds: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    alpha = (torch.sum(preds * target, dim=-1, keepdim=True) + eps) / (torch.sum(target**2, dim=-1, keepdim=True) + eps)
    target_scaled = alpha * target
    noise = target_scaled - preds
    s_target = torch.sum(target_scaled**2, dim=-1) + eps
    s_error = torch.sum(noise**2, dim=-1) + eps
    return 10 * torch.log10(s_target / s_error)


def evaluate_separations(
    separation_path: Union[str, Path],
    dataset_path: Union[str, Path],
    separation_sr: int,
    filter_single_source: bool = True,
    eps: float = 1e-8,
    chunk_duration: float = 4.0, 
    overlap_duration: float = 2.0,
) -> pd.DataFrame:

    separation_path = Path(separation_path)
    dataset_path = Path(dataset_path)
    
    df_entries = defaultdict(list)

    files = os.listdir(dataset_path)

    for file in files:


        try:
            # load seperated tracks and resample track
            separated_track, _ = torchaudio.load(separation_path / file) # load_chunks(separation_path/str(forder_id), ["mixture"])

            # load original track
            original_track, _ = torchaudio.load(dataset_path / file) #load_chunks(dataset_path/str(forder_id), ["mixture"])

            # make mono
            if separated_track.shape[0] > 1:
                separated_track = separated_track[0].unsqueeze(0)
            if original_track.shape[0] > 1:
                original_track = original_track[0].unsqueeze(0)
            
            # Compare lengths and cut the longer one to the shorter one's length
            min_length = min(separated_track.size(1), original_track.size(1))
            separated_track = separated_track[:, :min_length]
            original_track = original_track[:, :min_length]

            # Further processing can be added here

        except Exception as e:
            # print(f"Error processing {file}: {e}")
            continue



        chunk_samples = int(chunk_duration * separation_sr)
        overlap_samples = int(overlap_duration * separation_sr)

        # Calculate the step size between consecutive sub-chunks
        step_size = chunk_samples - overlap_samples

        # Determine the number of evaluation chunks based on step_size
        num_eval_chunks = math.ceil((original_track.shape[-1] - overlap_samples) / step_size)-1
            
        for i in range(num_eval_chunks):
            start_sample = i * step_size
            end_sample = start_sample + chunk_samples
            
            # Determine number of active signals in sub-chunk
            num_active_signals = 0
            for k in separated_track:
                o = original_track[:,start_sample:end_sample]
                if not is_silent(o):
                    num_active_signals += 1
            
            # Skip sub-chunk if necessary
            if filter_single_source and num_active_signals < 1:
                continue

            # Compute SI-SNRi for each stem
            o = original_track[:,start_sample:end_sample]
            s = separated_track[:,start_sample:end_sample]

            # df_entries[k].append((sisnr(s, o, eps) ) - sisnr(m, o, eps))

            # df_entries[k].append((sdr(s, o, eps) ).item()) 
            # print(file)
            df_entries["snr"].append((sisnr(s, o, eps) ).item()) #- sisnr(m, o, eps))

            # Compute PESQ
            pesq_score = 0.0 #compute_pesq(o, s, separation_sr)
            df_entries["pesq"].append(pesq_score)

            # Compute SSNR
            ssnr_score = 0.0 #compute_ssnr(o, s)
            df_entries["ssnr"].append(ssnr_score)


            # Add chunk and sub-chunk info to dataframe entry
            df_entries["start_sample"].append(start_sample)
            df_entries["end_sample"].append(end_sample)
            df_entries["file"].append(file)



    # Create and return dataframe
    return pd.DataFrame(df_entries)

base_dir = "/home/karchkhadze/testing_sisdr/assets/fma/"

#############################################################

separation_dir = base_dir + "bvg_bwe"
dataset_path = base_dir + "gt_bwe"

# Compute metrics
results = evaluate_separations(separation_dir, dataset_path, 44100, eps=1e-8 )

print(f'\nResults bighifigan:{results["snr"].mean()}')

# #############################################################

separation_dir = base_dir + "bvg_voc"
dataset_path = base_dir + "gt_voc"

# Compute metrics
results = evaluate_separations(separation_dir, dataset_path, 22050, eps=1e-8 )

print(f'\nResults bigvgan:{results["snr"].mean()}')

#############################################################

separation_dir = base_dir + "bvg_m2s"
dataset_path = base_dir + "gt_stereo"

# Compute metrics
results = evaluate_separations(separation_dir, dataset_path, 44100, eps=1e-8 )

print(f'\nResults bigvgan-base:{results["snr"].mean()}')




# #############################################################
# #############################################################

# separation_dir = "/home/karchkhadze/testing_sisdr/assets/generated/bvg_bwe"
# dataset_path = base_dir + "gt_bwe"

# # Compute metrics
# results = evaluate_separations(separation_dir, dataset_path, 44100, eps=1e-8 )

# print(f'\nResults bighifigan:{results["snr"].mean()}')

# #############################################################

# separation_dir = "/home/karchkhadze/testing_sisdr/assets/generated/bvg_voc"
# dataset_path = base_dir + "g_voct"

# # Compute metrics
# results = evaluate_separations(separation_dir, dataset_path, 22050, eps=1e-8 )

# print(f'\nResults bigvgan:{results["snr"].mean()}')

# #############################################################

# separation_dir = "/home/karchkhadze/testing_sisdr/assets/generated/bvg_m2s"
# dataset_path = base_dir + "gt_stereo"

# # Compute metrics
# results = evaluate_separations(separation_dir, dataset_path, 44100, eps=1e-8 )

# print(f'\nResults bigvgan-base:{results["snr"].mean()}')


# #############################################################

# separation_dir = base_dir + "sc-wavernn"
# dataset_path = base_dir + "gt"

# # Compute metrics
# results = evaluate_separations(separation_dir, dataset_path, 24000, eps=1e-8 )

# print(f'\nResults sc-wavernn:{results["snr"].mean()}')

# #############################################################

# separation_dir = base_dir + "univnet"
# dataset_path = base_dir + "gt"

# # Compute metrics
# results = evaluate_separations(separation_dir, dataset_path, 24000, eps=1e-8 )

# print(f'\nResults univnet:{results["snr"].mean()}')


Results bighifigan:-23.93779212550113

Results bigvgan:-23.895018778349225

Results bigvgan-base:-24.32210568377846
