In [1]:
%cd '/home/giorgio_mariani/Documents/audio-diffusion-pytorch-trainer/'

/home/giorgio_mariani/Documents/audio-diffusion-pytorch-trainer


In [11]:
pwd

'/home/giorgio_mariani/Documents/audio-diffusion-pytorch-trainer'

In [2]:
from collections import defaultdict
import json
import os
from pathlib import Path
from pathlib import Path
import re
from typing import List, Mapping, Optional, Tuple, Union
import main.module_base
from script.misc import hparams
import math
import numpy as np
#import museval
import pandas as pd
import torch
import torchaudio
import torchmetrics.functional.audio as tma
#from evaluation.evaluate_separation import evaluate_data
from tqdm import tqdm
from torchaudio.transforms import Resample

from main.dataset import is_silent
from main.likelihood import log_likelihood_song

In [3]:
def sdr(preds: torch.Tensor, target: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
    s_target = torch.norm(target, dim=-1)**2 + eps
    s_error = torch.norm(target - preds, dim=-1)**2 + eps
    return 10 * torch.log10(s_target/s_error)


In [4]:
def sisnr(preds: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    alpha = (torch.sum(preds * target, dim=-1, keepdim=True) + eps) / (torch.sum(target**2, dim=-1, keepdim=True) + eps)
    target_scaled = alpha * target
    noise = target_scaled - preds
    s_target = torch.sum(target_scaled**2, dim=-1) + eps
    s_error = torch.sum(noise**2, dim=-1) + eps
    return 10 * torch.log10(s_target / s_error)

In [5]:
def get_rms(source_waveforms):
  """Return shape (source,) weights for signals that are nonzero."""
  return torch.sqrt(torch.mean(source_waveforms ** 2, dim=-1))
  #return source_norms <= 1e-8

In [6]:
def load_chunks(chunk_folder: Path) -> Tuple[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor], int]:
    original_tracks_and_rate = {ori.name.split(".")[0][3:]: torchaudio.load(ori) for ori in sorted(list(chunk_folder.glob("ori*.wav")))}
    separated_tracks_and_rate = {sep.name.split(".")[0][3:]: torchaudio.load(sep) for sep in sorted(list(chunk_folder.glob("sep*.wav")))}
    assert tuple(original_tracks_and_rate.keys()) == tuple(separated_tracks_and_rate.keys())

    original_tracks = {k:t for k, (t,_) in original_tracks_and_rate.items()}
    sample_rates_ori = [s for (_,s) in original_tracks_and_rate.values()]

    separated_tracks = {k:t for k, (t,_) in separated_tracks_and_rate.items()}
    sample_rates_sep = [s for (_,s) in separated_tracks_and_rate.values()]

    assert len({*sample_rates_ori, *sample_rates_sep}) == 1, print({*sample_rates_ori, *sample_rates_sep})
    assert len(original_tracks) == len(separated_tracks)
    sr = sample_rates_ori[0]

    return original_tracks, separated_tracks, sr

In [7]:
def evaluate_tracks_chunks_mike(separation_path: Union[str, Path], chunk_prop: int, 
                           orig_sr: int = 44100, resample_sr: Optional[int] = None, 
                           filter_single_source: bool = True, eps: float = 10-8):

    separation_folder = Path(separation_path)
    assert separation_folder.exists(), separation_folder
    assert (separation_folder.parent / "chunk_data.json").exists(), separation_folder

    with open(separation_folder.parent / "chunk_data.json") as f:
        chunk_data = json.load(f)
        
    def load_model(path):
        model = main.module_base.Model(**{**hparams, "in_channels": 4})
        model.load_state_dict(torch.load(path)["state_dict"])
        model.to("cuda:0")
        return model
    
    ckpts_path = Path("/home/irene/Documents/audio-diffusion-pytorch-trainer/logs/ckpts")
    model = load_model(ckpts_path / "avid-darkness-164_epoch=419-valid_loss=0.015.ckpt")
    denoise_fn = model.model.diffusion.denoise_fn
    
    resample_fn = Resample(orig_freq=orig_sr, new_freq=resample_sr) if resample_sr is not None else lambda x: x

    track_to_chunks = defaultdict(list)
    for chunk_data in chunk_data:
        track = chunk_data["track"]
        chunk_idx = chunk_data["chunk_index"]
        start_sample = chunk_data["start_chunk_sample"]
        track_to_chunks[track].append( (start_sample, chunk_idx) )

    # reorder chunks into ascending order and compute sdr
    results = defaultdict(list)
    for track, chunks in tqdm(track_to_chunks.items()):
        sorted_chunks = sorted(chunks)

        separated_wavs, original_wavs = defaultdict(list), defaultdict(list)
        for _, chunk_idx in sorted_chunks:
                        
            chunk_folder = separation_folder / str(chunk_idx)
            original_tracks, separated_tracks, sr = load_chunks(chunk_folder)
           
            assert sr == orig_sr, f"chunk [{chunk_folder.name}]: expected freq={orig_sr}, track freq={sr}"
            
            
        mixture = sum([owav for owav in original_tracks.values()])
        
        chunk_size = int(separated_tracks["1"].shape[-1] * chunk_prop)
        
        for k in separated_tracks:
            o = original_tracks[k]
            
            s = separated_tracks[k]
            
            m = mixture
            
            
        for i in range(mixture.shape[-1] // chunk_size):
            
            num_silent_signals = 0
            for k in separated_tracks:
                
                o = original_tracks[k][:,i*chunk_size:(i+1)*chunk_size]
               
                if is_silent(o) and filter_single_source:
                    num_silent_signals += 1
            if num_silent_signals > 3:
                continue
            else:
                for k in separated_tracks:
                    o = original_tracks[k][:,i*chunk_size:(i+1)*chunk_size]
                    s = separated_tracks[k][:,i*chunk_size: (i+1)*chunk_size]
                    m = mixture[:,i*chunk_size: (i+1)*chunk_size]
                    results[k].append((sisnr(s, o, eps) - sisnr(m, o, eps)).item())
                    
    return pd.DataFrame(results)

In [18]:
def evaluate_tracks_chunks_mike_with_overlap(separation_path: Union[str, Path],
                           orig_sr: int = 44100, resample_sr: Optional[int] = None, 
                           filter_single_source: bool = True, eps: float = 10-8, chunk_duration: float = 4.0, overlap_duration: float = 2.0):

    separation_folder = Path(separation_path)
    assert separation_folder.exists(), separation_folder
    assert (separation_folder.parent / "chunk_data.json").exists(), separation_folder

    with open(separation_folder.parent / "chunk_data.json") as f:
        chunk_data = json.load(f) 
        
    def load_model(path):
        model = main.module_base.Model(**{**hparams, "in_channels": 4})
        model.load_state_dict(torch.load(path)["state_dict"])
        model.to("cuda:0")
        return model
    
    ckpts_path = Path("/home/irene/Documents/audio-diffusion-pytorch-trainer/logs/ckpts")
    model = load_model(ckpts_path / "avid-darkness-164_epoch=419-valid_loss=0.015.ckpt")
    denoise_fn = model.model.diffusion.denoise_fn
    
    resample_fn = Resample(orig_freq=orig_sr, new_freq=resample_sr) if resample_sr is not None else lambda x: x

    track_to_chunks = defaultdict(list)
    for chunk_data in chunk_data:
        track = chunk_data["track"]
        chunk_idx = chunk_data["chunk_index"]
        start_sample = chunk_data["start_chunk_sample"]
        track_to_chunks[track].append( (start_sample, chunk_idx) )

    # reorder chunks into ascending order and compute sdr
    results = defaultdict(list)
    for track, chunks in tqdm(track_to_chunks.items()):
        sorted_chunks = sorted(chunks)

        separated_wavs, original_wavs = defaultdict(list), defaultdict(list)
        for _, chunk_idx in sorted_chunks:
                        
            chunk_folder = separation_folder / str(chunk_idx)
            original_tracks, separated_tracks, sr = load_chunks(chunk_folder)
            #print(separated_tracks['1'].shape)
           
            assert sr == orig_sr, f"chunk [{chunk_folder.name}]: expected freq={orig_sr}, track freq={sr}"
            
            
        mixture = sum([owav for owav in original_tracks.values()])
        
        chunk_samples = chunk_duration * orig_sr
        overlap_samples = overlap_duration * orig_sr
        
        chunk_samples = int(chunk_duration * orig_sr)
        
        overlap_samples = int(overlap_duration * orig_sr)

        # Calculate the step size between consecutive chunks
        step_size = chunk_samples - overlap_samples

        # Determine the number of chunks based on step_size
        num_chunks = math.ceil((mixture.shape[-1] - overlap_samples) / step_size)
        #print(mixture.shape)
        

        for i in range(num_chunks):
            start_sample = i * step_size
            end_sample = start_sample + chunk_samples

            num_silent_signals = 0
            for k in separated_tracks:
                o = original_tracks[k][:,start_sample:end_sample]
                #print(o.shape)
                if is_silent(o) and filter_single_source:
                    num_silent_signals += 1
            if num_silent_signals > 3:
                continue
            else:
                for k in separated_tracks:
                    o = original_tracks[k][:,start_sample:end_sample]
                    s = separated_tracks[k][:,start_sample:end_sample]
                    m = mixture[:,start_sample:end_sample]
                    results[k].append((sisnr(s, o, eps) - sisnr(m, o, eps)).item())


    return pd.DataFrame(results)

In [8]:
results = evaluate_tracks_chunks_mike("separations/debug/sep_round_0", chunk_prop=1/3, orig_sr=22050, eps=1e-8)
results.shape

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [00:01<00:00, 46.45it/s]


(213, 4)

In [29]:
results_wo = evaluate_tracks_chunks_mike_with_overlap("separations/debug/sep_round_0",chunk_duration=4.0, overlap_duration=2.0 , orig_sr=22050, eps=1e-8)
results_wo.shape

100%|███████████████████████████████████████████| 71/71 [00:01<00:00, 66.26it/s]


(355, 4)

In [30]:
mean_results = results.mean()
print("Mean results:")
#print(mean_results)

# Access mean values by column name
mean_col1 = mean_results['1']
mean_col2 = mean_results['2']
mean_col3 = mean_results['3']
mean_col4 = mean_results['4']

print("\nMean values by column name:")
print("Column 1:", mean_col1)
print("Column 2:", mean_col2)
print("Column 3:", mean_col3)
print("Column 4:", mean_col4)

mean_of_means = mean_results.mean()

print("Mean of means:", mean_of_means)

Mean results:

Mean values by column name:
Column 1: 15.575636171958816
Column 2: 16.88300941919497
Column 3: 13.593393392965828
Column 4: 14.085952395564513
Mean of means: 15.034497844921031


In [31]:
mean_results = results_wo.mean()
print("Mean results:")
#print(mean_results)

# Access mean values by column name
mean_col1 = mean_results['1']
mean_col2 = mean_results['2']
mean_col3 = mean_results['3']
mean_col4 = mean_results['4']

print("\nMean values by column name:")
print("Column 1:", mean_col1)
print("Column 2:", mean_col2)
print("Column 3:", mean_col3)
print("Column 4:", mean_col4)

mean_of_means = mean_results.mean()

print("Mean of means:", mean_of_means)

Mean results:

Mean values by column name:
Column 1: 15.875413102163396
Column 2: 16.880386715203944
Column 3: 13.490192078200865
Column 4: 14.140144774947368
Mean of means: 15.096534167628892
