In [1]:
import os
import glob
from pathlib import Path
import json
import numpy as np
import nussl
import torch
from nussl.datasets import transforms as nussl_tfm
from common import utils, argbind
import matplotlib.pyplot as plt
from nussl.ml.networks.modules import AmplitudeToDB, BatchNorm, RecurrentStack, Embedding
from nussl.separation.base import MaskSeparationBase, DeepMixin, SeparationException
from torch import nn
# from torch.nn.utils import weight_norm
from ignite.engine import Events, Engine, EventEnum
from nussl.ml import SeparationModel
from nussl.ml.networks.modules import (
    Embedding, DualPath, DualPathBlock, STFT, Concatenate, 
    LearnedFilterBank, AmplitudeToDB, RecurrentStack,
    MelProjection, BatchNorm, InstanceNorm, ShiftAndScale
)
import pandas as pd

from setup_al3625 import *

In [2]:
import warnings
utils.logger()
warnings.filterwarnings('ignore')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

stft_params = nussl.STFTParams(window_length=512, hop_length=128)
nf = stft_params.window_length // 2 + 1

In [3]:
test_folder = "../1_create_alignments/results/test"
main_folder = "trained_models"
eval_folder = "eval_results"
oracle_folder = "eval_results/oracle"
names = ["all_audio", 
        #  "all_post", 
         "audio_post_aligns", 
         "audio_post_no_norm", 
        #  "audio_post_no_norm_larger", 
        #  "audio_post_norm"] # good but had more training
]
key_of_keys = [
    [],
    # ["posterior"],
    ["posterior", "full_lyrics"],
    ["posterior"],
    # ["posterior"],
    ["posterior"],
]
post_depths = [
    False,
    # True,
    False,
    False,
    # True,
    False
]

In [4]:
input_models = []
output_paths = []
input_paths = []
sep_paths = []
for name in names:
    input_paths.append(os.path.join(main_folder, name,  
                    "checkpoints", "best.model.pth"))
    if os.path.exists(os.path.join(main_folder, name, "separator")):
        sep_paths.append(os.path.join(main_folder, name,  
                        "separator", "separator.model.pth"))
    input_models.append(torch.load(input_paths[-1]))
    output_paths.append(os.path.join(eval_folder, name))

results_folder = os.path.join(eval_folder, name)

In [5]:
# oracle  eval
_, _, test_data = get_data("", keys=[], post_depth=False, use_corpus=False) 
for i, item in enumerate(test_data):
    source_keys = list(item['sources'].keys())
    sources = [item['sources'][k] for k in source_keys]
    estimates = sources
    evaluator = nussl.evaluation.BSSEvalScale(
        sources, estimates, source_labels=source_keys
    )
    scores = evaluator.evaluate()
    output_file = os.path.join(oracle_folder, f"{i}.json")
    with open(output_file, 'w') as f:
        json.dump(scores, f, indent=4)

In [6]:
aggregated = pd.DataFrame()

In [7]:
display(aggregated)

In [8]:
import pandas as pd

for output, name in zip(output_paths + [oracle_folder], names + ["Oracle"]):
    print(f"\n\n\t\t   ***  ***   {name}   ***   ***", )
    output += "/"
    json_files = glob.glob(f"{output}*.json")
    df = nussl.evaluation.aggregate_score_files(
        json_files, aggregator=np.nanmedian)
    df["Group_Name"] = [name]*len(df)
    if aggregated.empty:
        aggregated = df.copy()
    else:
        aggregated = pd.concat((aggregated, df))
    # nussl.evaluation.associate_metrics(separator.model, df, test_data)
    report_card = nussl.evaluation.report_card(
        df, report_each_source=True)
    print(report_card)



		   ***  ***   all_audio   ***   ***
                                                                      
                        MEAN +/- STD OF METRICS                         
                                                                      
┌────────────┬──────────────────┬──────────────────┬──────────────────┐
│ METRIC     │     OVERALL      │    NON-VOCALS    │      VOCALS      │
╞════════════╪══════════════════╪══════════════════╪══════════════════╡
│ #          │        90        │        45        │        45        │
├────────────┼──────────────────┼──────────────────┼──────────────────┤
│ SI-SDR     │   4.95 +/-  6.30 │   9.50 +/-  3.11 │   0.40 +/-  5.33 │
├────────────┼──────────────────┼──────────────────┼──────────────────┤
│ SI-SIR     │  12.50 +/-  5.81 │  15.46 +/-  4.95 │   9.54 +/-  5.09 │
├────────────┼──────────────────┼──────────────────┼──────────────────┤
│ SI-SAR     │   6.66 +/-  7.97 │  12.14 +/-  6.09 │   1.18 +/-  5.47 │
├────────────┼───────────

In [13]:
vocals = aggregated[aggregated.source=='vocals'].groupby('Group_Name', as_index=False).mean().copy()
non_vocals = aggregated[aggregated.source=='non-vocals'].groupby('Group_Name', as_index=False).mean().copy()
vocals["Source"] = "Vocals"
non_vocals["Source"] = "Non_vocals"
means = pd.concat((vocals, non_vocals))
means = means.round(decimals = 1)
means = means.reindex(columns=["Group_Name", "Source", "SI-SDR", "SI-SIR", "SI-SAR", "SNR", "SRR"])
# means = means.reindex(columns=["Group_Name"])
display(means)

Unnamed: 0,Group_Name,Source,SI-SDR,SI-SIR,SI-SAR,SNR,SRR
0,Oracle,Vocals,78.0,78.0,151.9,,78.0
1,all_audio,Vocals,0.4,9.5,1.2,3.3,4.5
2,audio_post_aligns,Vocals,0.6,9.0,1.4,3.5,3.0
3,audio_post_no_norm,Vocals,1.2,8.7,2.2,3.3,9.5
0,Oracle,Non_vocals,76.4,76.4,151.8,,76.4
1,all_audio,Non_vocals,9.5,15.5,12.1,10.0,20.6
2,audio_post_aligns,Non_vocals,9.7,13.9,12.3,10.2,19.9
3,audio_post_no_norm,Non_vocals,9.5,17.8,10.5,10.0,17.0


In [None]:
def visualize_and_embed(sources, y_axis='mel'):
    plt.figure(figsize=(10, 4))
    plt.subplot(111)
    nussl.utils.visualize_sources_as_masks(
        sources, db_cutoff=-60, y_axis=y_axis)
    plt.tight_layout()
    plt.show()

    # nussl.play_utils.multitrack(sources, ext='.wav')