In [1]:
import sys, os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
sys.path.insert(0, os.path.dirname(os.getcwd()))

In [2]:
from model import WhisperSegmenterFast, WhisperSegmenter
import librosa
import numpy as np
from tqdm import tqdm
from copy import deepcopy
from train import evaluate
from datautils import get_audio_and_label_paths
import os
from audio_utils import SpecViewer
import subprocess
from glob import glob
import json

In [3]:
def evaluate_dataset( dataset_folder, model_path, num_trials, consolidation_method = "clustering",
                      max_length = 448, num_beams = 4, batch_size = 8, tolerance = None, spec_time_step = None, eps = None ):
    audio_list, label_list = [], []
    audio_paths, label_paths = get_audio_and_label_paths(dataset_folder)

    # ## debugging
    # audio_paths = [item[0] for item in audio_paths ] 
    
    for audio_mc_paths, label_path in zip(audio_paths, label_paths):
        label = json.load( open( label_path ) )

        audio_mc = []
        for audio_path in audio_mc_paths:
            audio, _ = librosa.load( audio_path, sr = label["sr"] )
            audio_mc.append( audio )
        min_len_list = min( [ len(audio) for audio in audio_mc ] )
        for c_idx in range(len(audio_mc)):
            audio_mc[c_idx] = audio_mc[c_idx][:min_len_list]
        audio = np.asarray(audio_mc)
        
        audio_list.append(audio)
        label_list.append(label) 
        
        if tolerance is not None:
            label["tolerance"] = tolerance
        if spec_time_step is not None:
            label["spec_time_step"] = spec_time_step
        if eps is not None:
            label["eps"] = eps 


    segmenter = WhisperSegmenterFast(  model_path = model_path,  device = "cuda")
    res = evaluate( audio_list, label_list, segmenter, batch_size, max_length, num_trials, consolidation_method, num_beams, 
                    target_cluster = None
                  )

    all_res = {
        "segment_wise_scores": {"N-true-positive": res["segment_wise"][0],
                                "N-positive-in-prediction": res["segment_wise"][1],
                                "N-positive-in-ground-truth": res["segment_wise"][2],
                                "precision": res["segment_wise"][3],
                                "recall": res["segment_wise"][4],
                                "F1": res["segment_wise"][5]
                                },
        "frame_wise_scores": {"N-true-positive": res["frame_wise"][0],
                                "N-positive-in-prediction": res["frame_wise"][1],
                                "N-positive-in-ground-truth": res["frame_wise"][2],
                                "precision": res["frame_wise"][3],
                                "recall": res["frame_wise"][4],
                                "F1": res["frame_wise"][5]
                                }
    }
    return all_res

In [10]:
evaluate_dataset( "../data/datasets/zebra_finch_full/test/", 
                 "../model/mc-whisperseg-mc-zebra-finch-mask-08-2_5-5epochs/final_checkpoint_ct2/", 
                  num_trials = 3, tolerance = 0.02 )

100%|████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [05:07<00:00, 30.75s/it]


{'segment_wise_scores': {'N-true-positive': 835,
  'N-positive-in-prediction': 901,
  'N-positive-in-ground-truth': 892,
  'precision': 0.9267480577136515,
  'recall': 0.9360986547085202,
  'F1': 0.9313998884551034},
 'frame_wise_scores': {'N-true-positive': 78341,
  'N-positive-in-prediction': 81184,
  'N-positive-in-ground-truth': 83000,
  'precision': 0.964980784391013,
  'recall': 0.9438674698795181,
  'F1': 0.9543073624713736}}

In [14]:
from model import WhisperSegmenterFast, WhisperSegmenter
from audio_utils import SpecViewer
# segmenter = WhisperSegmenterFast( "../model/mc-whisperseg-mc-zebra-finch-mask-08-2_5-5epochs/final_checkpoint_ct2/", device="cuda" )
segmenter = WhisperSegmenterFast( "nccratliri/mc-whisperseg-zebra-finch-ct2-v1.0", device="cuda" )
spec_viewer = SpecViewer()

In [25]:
audio_file_pattern = "../data/datasets/zebra_finch_full/test/BP_2021-05-23_09-22-46_918470_0580000_radio2_as_target_channel_*.wav" 

audio_fname_list = sorted(glob( audio_file_pattern ), key = lambda x:int(x.split("_")[-1].split(".wav")[0]))
label_fname = audio_fname_list[0][:-4] + ".json"
label = json.load(open(label_fname))

In [26]:
sr = label["sr"]  
min_frequency = label["min_frequency"]
spec_time_step = label["spec_time_step"]
min_segment_length = label["min_segment_length"]
eps = label["eps"]
num_trials = 3

In [27]:
audio_list = []
for audio_fname in audio_fname_list:
    audio, _ = librosa.load( audio_fname, sr = sr )
    audio_list.append(  audio )

audio = np.asarray(audio_list)

prediction = segmenter.segment(  audio, sr = sr, min_frequency = min_frequency, spec_time_step = spec_time_step,
                       min_segment_length = min_segment_length, eps = eps,num_trials = num_trials, batch_size=4 )

In [28]:
segmenter.segment_score( prediction, label)

(78, 88, 86, 0.8863636363636364, 0.9069767441860465, 0.896551724137931)

In [29]:
spec_viewer.visualize( audio = audio, sr = sr, min_frequency= min_frequency, prediction = prediction, label=label, 
                       window_size=5, precision_bits=1, 
                       audio_channel_names=["Mic", "non-target\nradio", "target\nradio"] )

interactive(children=(FloatSlider(value=207.0, description='offset', max=414.4304375, step=0.25), Output()), _…

<function ipywidgets.widgets.interaction._InteractFactory.__call__.<locals>.<lambda>(*args, **kwargs)>

In [4]:
from audio_utils import SpecViewer
import librosa
import numpy as np
from glob import glob
import json
from model import MultiChannelWhisperSeg

In [5]:
mc_segmenter = MultiChannelWhisperSeg( "nccratliri/mc-whisperseg-zebra-finch-ct2-v1.0", device="cuda" )

In [6]:
sr = 16000
## There should be at least 1 radio channel, and at least 1 mic channel. 
## The number of radio channels can be greater than 2
radio_fname_list = [ 
     "../data/example_subset/Zebra_finch/test/BP_2021-05-23_09-22-46_918470_0580000_radio1.wav",
     "../data/example_subset/Zebra_finch/test/BP_2021-05-23_09-22-46_918470_0580000_radio2.wav"
]
mic_fname_list = [
     "../data/example_subset/Zebra_finch/test/BP_2021-05-23_09-22-46_918470_0580000_daq1.wav"
]

In [7]:
radio_channels = [ librosa.load( fname, sr = sr )[0] for fname in radio_fname_list ]
mic_channels = [ librosa.load( fname, sr = sr )[0] for fname in mic_fname_list ]

In [8]:
predictions = mc_segmenter.segment( radio_channels, mic_channels, sr,
                      min_frequency = 0,
                      spec_time_step = 0.0025,
                      min_segment_length = 0.005,
                      eps = 0.02,
                      num_trials = 3
                    )

Segmenting radio channel 0 [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] 100.00%
Segmenting radio channel 1 [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] 100.00%


predictions is a list of prediction. Each prediction corresponds to one radio channel. For example, predictions[0] contains the segmentation results of radio_channels[0].

predictions[0] is a dictionary containing "onset", "offset" and "cluster" as usual.

Let's visualize the segmentation results.

In [9]:
spec_viewer = SpecViewer()

In [10]:
""" suppose radio1 is the target channel we are segmenting now,
    predictions[0] contains the segmentation results of radio1"""
## load the ground-truth segmentation:
label_radio1 = json.load(open("../data/example_subset/Zebra_finch/test/BP_2021-05-23_09-22-46_918470_0580000_radio1.json"))
## compute score
segment_score = mc_segmenter.segmenter.segment_score( prediction = predictions[0], label = label_radio1 )
print("segment-F1:",segment_score[-1])
## visualize
spec_viewer.visualize(
    np.asarray([ radio_channels[0], radio_channels[1], mic_channels[0] ]),
    sr = sr,
    label = label_radio1,
    prediction= predictions[0],
    audio_channel_names=["Mic (daq1)", "radio 2", "radio 1\n(target)"]
)

segment-F1: 0.9692307692307693


interactive(children=(FloatSlider(value=207.0, description='offset', max=414.4304375, step=0.25), Output()), _…

<function ipywidgets.widgets.interaction._InteractFactory.__call__.<locals>.<lambda>(*args, **kwargs)>

In [11]:
""" suppose radio2 is the target channel we are segmenting now,
    predictions[1] contains the segmentation results of radio2"""
## load the ground-truth segmentation:
label_radio2 = json.load(open("../data/example_subset/Zebra_finch/test/BP_2021-05-23_09-22-46_918470_0580000_radio2.json"))
## compute score
segment_score = mc_segmenter.segmenter.segment_score( prediction = predictions[1], label = label_radio2 )
print("segment-F1:",segment_score[-1])
## visualize
spec_viewer.visualize(
    np.asarray([ radio_channels[0], radio_channels[1], mic_channels[0] ]),
    sr = sr,
    label = label_radio2,
    prediction= predictions[1],
    audio_channel_names=["Mic (daq1)", "radio 2\n(target)", "radio 1"]
)

segment-F1: 0.896551724137931


interactive(children=(FloatSlider(value=207.0, description='offset', max=414.4304375, step=0.25), Output()), _…

<function ipywidgets.widgets.interaction._InteractFactory.__call__.<locals>.<lambda>(*args, **kwargs)>