In [12]:
import sys, os
sys.path.insert(0, os.path.dirname(os.getcwd()))

In [13]:
from model import WhisperSegmenterFast, WhisperSegmenter
import librosa
import numpy as np
from tqdm import tqdm
from copy import deepcopy
from train import evaluate
from datautils import get_audio_and_label_paths
import os
from audio_utils import SpecViewer
import subprocess
from glob import glob
import json

In [14]:
def evaluate_dataset( dataset_folder, model_path, num_trials, consolidation_method = "clustering",
                      max_length = 448, num_beams = 4, batch_size = 8, tolerance = None, spec_time_step = None, eps = None ):
    audio_list, label_list = [], []
    audio_paths, label_paths = get_audio_and_label_paths(dataset_folder)

    # ## debugging
    # audio_paths = [item[0] for item in audio_paths ] 
    
    for audio_mc_paths, label_path in zip(audio_paths, label_paths):
        label = json.load( open( label_path ) )

        audio_mc = []
        for audio_path in audio_mc_paths:
            audio, _ = librosa.load( audio_path, sr = label["sr"] )
            audio_mc.append( audio )
        min_len_list = min( [ len(audio) for audio in audio_mc ] )
        for c_idx in range(len(audio_mc)):
            audio_mc[c_idx] = audio_mc[c_idx][:min_len_list]
        audio = np.asarray(audio_mc)
        
        audio_list.append(audio)
        label_list.append(label) 
        
        if tolerance is not None:
            label["tolerance"] = tolerance
        if spec_time_step is not None:
            label["spec_time_step"] = spec_time_step
        if eps is not None:
            label["eps"] = eps 


    segmenter = WhisperSegmenterFast(  model_path = model_path,  device = "cuda")
    res = evaluate( audio_list, label_list, segmenter, batch_size, max_length, num_trials, consolidation_method, num_beams, 
                    target_cluster = None
                  )

    all_res = {
        "segment_wise_scores": {"N-true-positive": res["segment_wise"][0],
                                "N-positive-in-prediction": res["segment_wise"][1],
                                "N-positive-in-ground-truth": res["segment_wise"][2],
                                "precision": res["segment_wise"][3],
                                "recall": res["segment_wise"][4],
                                "F1": res["segment_wise"][5]
                                },
        "frame_wise_scores": {"N-true-positive": res["frame_wise"][0],
                                "N-positive-in-prediction": res["frame_wise"][1],
                                "N-positive-in-ground-truth": res["frame_wise"][2],
                                "precision": res["frame_wise"][3],
                                "recall": res["frame_wise"][4],
                                "F1": res["frame_wise"][5]
                                }
    }
    return all_res

In [34]:
evaluate_dataset( "../data/example_subset/Zebra_finch/test/", 
                 "../model/mc-whisperseg-zebra-finch-from-animal-vad/final_checkpoint_ct2/", 
                  num_trials = 3, tolerance = 0.02 )

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:37<00:00,  9.31s/it]


{'segment_wise_scores': {'N-true-positive': 275,
  'N-positive-in-prediction': 283,
  'N-positive-in-ground-truth': 296,
  'precision': 0.9717314487632509,
  'recall': 0.9290540540540541,
  'F1': 0.9499136442141624},
 'frame_wise_scores': {'N-true-positive': 26679,
  'N-positive-in-prediction': 27730,
  'N-positive-in-ground-truth': 28430,
  'precision': 0.9620988099531194,
  'recall': 0.9384101301442138,
  'F1': 0.9501068376068375}}

In [36]:
import requests,json,base64

## define a function for segmentation
def call_segment_service( service_address, 
                          audio_file_path,
                          channel_id,
                          sr,
                          min_frequency,
                          spec_time_step,
                          min_segment_length,
                          eps,
                          num_trials
                        ):
    audio_file_base64_string = base64.b64encode( open(audio_file_path, 'rb').read()).decode('ASCII')
    response = requests.post( service_address,
                              data = json.dumps( {
                                  "audio_file_base64_string":audio_file_base64_string,
                                  "channel_id":channel_id,
                                  "sr":sr,
                                  "min_frequency":min_frequency,
                                  "spec_time_step":spec_time_step,
                                  "min_segment_length":min_segment_length,
                                  "eps":eps,
                                  "num_trials":num_trials
                              } ),
                              headers = {"Content-Type": "application/json"}
                            )
    return response.json()

In [37]:
from model import WhisperSegmenterFast, WhisperSegmenter
from audio_utils import SpecViewer
segmenter = WhisperSegmenterFast( "../model/mc-whisperseg-zebra-finch-from-animal-vad/final_checkpoint_ct2/", device="cuda" )
# segmenter = WhisperSegmenter( "nccratliri/whisperseg-large-vad-v1.0", device="cuda" )
spec_viewer = SpecViewer()

In [41]:
audio_file_pattern = "../data/example_subset/Zebra_finch/test/audio_count_1_channel_*.wav"
audio_fname_list = sorted(glob( audio_file_pattern ), key = lambda x:int(x.split("_")[-1].split(".wav")[0]))
label_fname = audio_fname_list[0][:-4] + ".json"
label = json.load(open(label_fname))

In [42]:
sr = label["sr"]  
min_frequency = label["min_frequency"]
spec_time_step = label["spec_time_step"]
min_segment_length = label["min_segment_length"]
eps = label["eps"]
num_trials = 3

In [43]:
audio_list = []
for audio_fname in audio_fname_list:
    audio, _ = librosa.load( audio_fname, sr = sr )
    audio_list.append(  audio )

audio = np.asarray(audio_list)

prediction = segmenter.segment(  audio, sr = sr, min_frequency = min_frequency, spec_time_step = spec_time_step,
                       min_segment_length = min_segment_length, eps = eps,num_trials = num_trials, batch_size=4 )

In [44]:
spec_viewer.visualize( audio = audio_list[0], sr = sr, min_frequency= min_frequency, prediction = prediction, label=label, 
                       window_size=10, precision_bits=1 )

interactive(children=(FloatSlider(value=47.0, description='offset', max=94.857625, step=0.5), Output()), _dom_…

<function ipywidgets.widgets.interaction._InteractFactory.__call__.<locals>.<lambda>(*args, **kwargs)>

In [45]:
spec_viewer.visualize( audio = audio_list[1], sr = sr, min_frequency= min_frequency, prediction = None, label=label, 
                       window_size=10, precision_bits=1 )

interactive(children=(FloatSlider(value=47.0, description='offset', max=94.857625, step=0.5), Output()), _dom_…

<function ipywidgets.widgets.interaction._InteractFactory.__call__.<locals>.<lambda>(*args, **kwargs)>

In [46]:
spec_viewer.visualize( audio = audio_list[2], sr = sr, min_frequency= min_frequency, prediction = None, label=label, 
                       window_size=10, precision_bits=1 )

interactive(children=(FloatSlider(value=47.0, description='offset', max=94.857625, step=0.5), Output()), _dom_…

<function ipywidgets.widgets.interaction._InteractFactory.__call__.<locals>.<lambda>(*args, **kwargs)>

In [25]:
prediction2 = call_segment_service( "http://localhost:8050/segment", 
                          audio_file,   
                          channel_id = 0,
                          sr = sr,
                          min_frequency = min_frequency,
                          spec_time_step = spec_time_step,
                          min_segment_length = min_segment_length,
                          eps = eps,
                          num_trials = num_trials
                        )
spec_viewer.visualize( audio = audio, sr = sr, min_frequency= min_frequency, prediction = prediction2, label=label, 
                       window_size=5, precision_bits=1 )

interactive(children=(FloatSlider(value=13.5, description='offset', max=27.0345625, step=0.25), Output()), _do…

<function ipywidgets.widgets.interaction._InteractFactory.__call__.<locals>.<lambda>(*args, **kwargs)>

In [23]:
prediction == prediction2

True