In [1]:
%load_ext autoreload
%autoreload 2
    
import tensorflow as tf
from tensorflow  import keras
from keras import metrics
import numpy as np

import librosa 
import matplotlib.pyplot as plt
from tqdm import tqdm
import h5py
import soundfile as sf
from scipy.signal import resample_poly
import librosa
from modAL.models import ActiveLearner
from scikeras.wrappers import KerasClassifier
import random as rand

from config import *
from util import DEFAULT_TOKENS
from models import build_resnet16
from preprocessing import AL_split

2024-09-30 12:45:24.313541: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-30 12:45:24.313577: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-30 12:45:24.314641: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-30 12:45:24.321465: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
logmel_f = h5py.File(INTERMEDIATE / 'logmel.hdf5', 'r')
samples_f = h5py.File(INTERMEDIATE / '22sr_samples.hdf5', 'r')

X = np.load(INTERMEDIATE / 'logmel.npy')
Y = np.load(INTERMEDIATE / 'logmel_labels.npy')

In [3]:
batch = 32

# resnet 16
X_shape = (40, 107, 1)
model = build_resnet16(input_shape=X_shape)
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=[
                metrics.Recall(thresholds=0.5),
                metrics.Precision(thresholds=0.5),
                metrics.AUC(curve='pr', name='auc_pr')
              ])
classifier = KerasClassifier(model, batch_size=batch, verbose=2)

init, pool, test = AL_split(X, Y)
initial_X, initial_Y = init
pool_X, pool_Y = pool
test_X, test_Y = test
pool_X.shape, pool_Y.shape

((135462, 40, 107), (135462, 4))

In [4]:
currently_labelled = len(initial_X)
initial_ds_size = currently_labelled + len(pool_X)

def random_sampling(classifier, X_unlabelled, n_instances):
    i = np.random.choice(range(X_unlabelled.shape[0]), size=n_instances, replace=False)
    return i, X[i] 

learner = ActiveLearner(
    estimator=classifier,
    X_training=initial_X, # !!! TODO: use must larger initial labelled set (?)
    y_training=initial_Y,
    verbose=2,
    query_strategy=random_sampling 
) # trains on the initial ds

1059/1059 - 65s - 62ms/step - auc_pr: 0.0198 - loss: 0.0333 - precision: 0.0472 - recall: 0.0645


In [None]:
from preprocessing import evaluation_dict
import pickle
from models import build_resnet16
from keras import metrics

n_queries = 200
query_size = int(pool_X.shape[0] / n_queries)

LB_metrics = []
trained_X = initial_X
trained_Y = initial_Y

print("=== RANDOM SAMPLING ===")
for idx in tqdm(range(n_queries)):
    print(f'Query no. {idx + 1}/{n_queries}')

    # query for instances
    query_indicies, query_instances = learner.query(pool_X, n_instances=query_size)
    
    # train on instances
    learner.teach(
        X=pool_X[query_indicies], y=pool_Y[query_indicies], 
        only_new=True, verbose=2
    )

    # get evaluation metrics
    print("evaluating ..")
    currently_labelled += query_size
    labelling_budget = currently_labelled / initial_ds_size 
    pred_Y = classifier.predict_proba(test_X, batch_size=batch, verbose=2)
    LB_metrics.append(
        (labelling_budget, evaluation_dict(pred_Y, test_Y)))

    # store trained on samples 
    trained_X = np.vstack((trained_X, pool_X[query_indicies]))
    trained_Y = np.vstack((trained_Y, pool_Y[query_indicies]))
    
    # remove queried instance from pool
    pool_X = np.delete(pool_X, query_indicies, axis=0)
    pool_Y = np.delete(pool_Y, query_indicies, axis=0)

    with open(INTERMEDIATE / 'AL' / 'RS_metrics.pkl', 'wb') as f:
        pickle.dump(LB_metrics, f)


=== RANDOM SAMPLING ===


  0%|                                                                                           | 0/200 [00:00<?, ?it/s]

Query no. 1/200
22/22 - 1s - 54ms/step - auc_pr: 0.4706 - loss: 0.0153 - precision: 0.5000 - recall: 0.2857
evaluating ..


2024-09-30 12:46:54.576724: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 724723840 exceeds 10% of free system memory.


1323/1323 - 18s - 14ms/step


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{   'fast_trill_6khz': {   'auc_pr': 0.1569690342889799,
                           'f1': 0.0,
                           'precision': 0.0,
                           'recall': 0.0},
    'nr_syllable_3khz': {   'auc_pr': 0.6563203867409203,
                            'f1': 0.5553719008264463,
                            'precision': 0.45901639344262296,
                            'recall': 0.702928870292887},
    'triangle_3khz': {   'auc_pr': 0.05499791170054827,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0},
    'upsweep_500hz': {   'auc_pr': 0.14546355322096463,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0}}


  0%|▍                                                                                | 1/200 [00:21<1:11:40, 21.61s/it]

Query no. 2/200
22/22 - 1s - 54ms/step - auc_pr: 0.3917 - loss: 0.0136 - precision: 0.5000 - recall: 0.2000
evaluating ..


2024-09-30 12:47:16.177539: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 724723840 exceeds 10% of free system memory.


1323/1323 - 17s - 13ms/step


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{   'fast_trill_6khz': {   'auc_pr': 0.5126939533699157,
                           'f1': 0.26618705035971224,
                           'precision': 0.8809523809523809,
                           'recall': 0.15677966101694915},
    'nr_syllable_3khz': {   'auc_pr': 0.658913104276558,
                            'f1': 0.5645161290322581,
                            'precision': 0.7894736842105263,
                            'recall': 0.4393305439330544},
    'triangle_3khz': {   'auc_pr': 0.0554653381924374,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0},
    'upsweep_500hz': {   'auc_pr': 0.14795104344879995,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0}}


  1%|▊                                                                                | 2/200 [00:42<1:09:47, 21.15s/it]

Query no. 3/200
22/22 - 1s - 54ms/step - auc_pr: 0.5068 - loss: 0.0110 - precision: 0.7500 - recall: 0.3750
evaluating ..


2024-09-30 12:47:37.006414: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 724723840 exceeds 10% of free system memory.


1323/1323 - 17s - 13ms/step


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{   'fast_trill_6khz': {   'auc_pr': 0.16329653254892196,
                           'f1': 0.19941348973607037,
                           'precision': 0.3238095238095238,
                           'recall': 0.1440677966101695},
    'nr_syllable_3khz': {   'auc_pr': 0.6953582212206083,
                            'f1': 0.5250737463126843,
                            'precision': 0.89,
                            'recall': 0.3723849372384937},
    'triangle_3khz': {   'auc_pr': 0.015615248610342698,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0},
    'upsweep_500hz': {   'auc_pr': 0.1596875673810729,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0}}


  2%|█▏                                                                               | 3/200 [01:03<1:09:03, 21.03s/it]

Query no. 4/200
22/22 - 1s - 54ms/step - auc_pr: 0.1967 - loss: 0.0148 - precision: 1.0000 - recall: 0.1111
evaluating ..


2024-09-30 12:47:57.899491: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 724723840 exceeds 10% of free system memory.


1323/1323 - 17s - 13ms/step


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{   'fast_trill_6khz': {   'auc_pr': 0.28637863851213935,
                           'f1': 0.19011406844106463,
                           'precision': 0.9259259259259259,
                           'recall': 0.1059322033898305},
    'nr_syllable_3khz': {   'auc_pr': 0.6024633624975044,
                            'f1': 0.3157894736842105,
                            'precision': 0.9782608695652174,
                            'recall': 0.18828451882845187},
    'triangle_3khz': {   'auc_pr': 0.020810571385699674,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0},
    'upsweep_500hz': {   'auc_pr': 0.13436245926716217,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0}}


  2%|█▌                                                                               | 4/200 [01:24<1:08:34, 20.99s/it]

Query no. 5/200
22/22 - 1s - 56ms/step - auc_pr: 0.4493 - loss: 0.0122 - precision: 0.5000 - recall: 0.2000
evaluating ..


2024-09-30 12:48:18.903077: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 724723840 exceeds 10% of free system memory.


1323/1323 - 17s - 13ms/step


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{   'fast_trill_6khz': {   'auc_pr': 0.4636095165197648,
                           'f1': 0.379746835443038,
                           'precision': 0.75,
                           'recall': 0.2542372881355932},
    'nr_syllable_3khz': {   'auc_pr': 0.7189474603920167,
                            'f1': 0.6600496277915633,
                            'precision': 0.8109756097560976,
                            'recall': 0.5564853556485355},
    'triangle_3khz': {   'auc_pr': 0.053617121467133905,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0},
    'upsweep_500hz': {   'auc_pr': 0.12352010064342972,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0}}


  2%|██                                                                               | 5/200 [01:45<1:08:11, 20.98s/it]

Query no. 6/200
22/22 - 1s - 57ms/step - auc_pr: 0.1901 - loss: 0.0161 - precision: 1.0000 - recall: 0.1000
evaluating ..
1323/1323 - 17s - 13ms/step


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{   'fast_trill_6khz': {   'auc_pr': 0.33220874110616544,
                           'f1': 0.17391304347826086,
                           'precision': 0.6,
                           'recall': 0.1016949152542373},
    'nr_syllable_3khz': {   'auc_pr': 0.7162429552111889,
                            'f1': 0.5329341317365269,
                            'precision': 0.9368421052631579,
                            'recall': 0.3723849372384937},
    'triangle_3khz': {   'auc_pr': 0.059972378633929936,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0},
    'upsweep_500hz': {   'auc_pr': 0.11135827442768505,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0}}


  3%|██▍                                                                              | 6/200 [02:06<1:07:44, 20.95s/it]

Query no. 7/200
22/22 - 1s - 55ms/step - auc_pr: 0.3314 - loss: 0.0202 - precision: 1.0000 - recall: 0.1250
evaluating ..
1323/1323 - 17s - 13ms/step


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{   'fast_trill_6khz': {   'auc_pr': 0.5715965414262991,
                           'f1': 0.5597826086956522,
                           'precision': 0.7803030303030303,
                           'recall': 0.4364406779661017},
    'nr_syllable_3khz': {   'auc_pr': 0.7188303464515329,
                            'f1': 0.5400593471810089,
                            'precision': 0.9285714285714286,
                            'recall': 0.3807531380753138},
    'triangle_3khz': {   'auc_pr': 0.08878042510463598,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0},
    'upsweep_500hz': {   'auc_pr': 0.08789726638734081,
                         'f1': 0.012422360248447204,
                         'precision': 0.25,
                         'recall': 0.006369426751592357}}


  4%|██▊                                                                              | 7/200 [02:27<1:07:27, 20.97s/it]

Query no. 8/200
22/22 - 1s - 56ms/step - auc_pr: 0.2363 - loss: 0.0190 - precision: 0.5000 - recall: 0.2000
evaluating ..
1323/1323 - 17s - 13ms/step


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{   'fast_trill_6khz': {   'auc_pr': 0.35430820152340925,
                           'f1': 0.15444015444015444,
                           'precision': 0.8695652173913043,
                           'recall': 0.0847457627118644},
    'nr_syllable_3khz': {   'auc_pr': 0.6034514840117523,
                            'f1': 0.5917808219178082,
                            'precision': 0.8571428571428571,
                            'recall': 0.45188284518828453},
    'triangle_3khz': {   'auc_pr': 0.054175520294959915,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0},
    'upsweep_500hz': {   'auc_pr': 0.02043358038435099,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0}}


  4%|███▏                                                                             | 8/200 [02:48<1:07:00, 20.94s/it]

Query no. 9/200
22/22 - 1s - 55ms/step - auc_pr: 0.6361 - loss: 0.0144 - precision: 1.0000 - recall: 0.4615
evaluating ..
1323/1323 - 17s - 13ms/step


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{   'fast_trill_6khz': {   'auc_pr': 0.1441147235876034,
                           'f1': 0.2776470588235294,
                           'precision': 0.31216931216931215,
                           'recall': 0.25},
    'nr_syllable_3khz': {   'auc_pr': 0.7061018683281344,
                            'f1': 0.6815144766146993,
                            'precision': 0.7285714285714285,
                            'recall': 0.6401673640167364},
    'triangle_3khz': {   'auc_pr': 0.019798168886493314,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0},
    'upsweep_500hz': {   'auc_pr': 0.02613902793473634,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0}}


  4%|███▋                                                                             | 9/200 [03:08<1:06:36, 20.93s/it]

Query no. 10/200
22/22 - 1s - 55ms/step - auc_pr: 0.2308 - loss: 0.0132 - precision: 0.5000 - recall: 0.2500
evaluating ..
1323/1323 - 17s - 13ms/step


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{   'fast_trill_6khz': {   'auc_pr': 0.19728811636154503,
                           'f1': 0.0859375,
                           'precision': 0.55,
                           'recall': 0.046610169491525424},
    'nr_syllable_3khz': {   'auc_pr': 0.5708622150949085,
                            'f1': 0.4144736842105263,
                            'precision': 0.9692307692307692,
                            'recall': 0.26359832635983266},
    'triangle_3khz': {   'auc_pr': 0.04183716715008642,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0},
    'upsweep_500hz': {   'auc_pr': 0.013575111413344911,
                         'f1': 0.0,
                         'precision': 0.0,
                         'recall': 0.0}}


  5%|████                                                                            | 10/200 [03:29<1:06:13, 20.91s/it]

Query no. 11/200
22/22 - 1s - 55ms/step - auc_pr: 0.3720 - loss: 0.0128 - precision: 1.0000 - recall: 0.2222
evaluating ..


In [None]:
pred_Y = classifier.predict_proba(test_X)

In [None]:
from sklearn.metrics import average_precision_score, precision_recall_curve

for i in range(4):
    m = average_precision_score(test_Y[:, i], pred_Y[:, i])
    print(m)

def plot_pr(name, labels, predictions, **kwargs):
    prec, rec, _ = precision_recall_curve(labels, predictions)
    plt.plot(prec, rec, label=name, linewidth=2, **kwargs)
    plt.xlabel('Precision')
    plt.ylabel('Recall')
    plt.xlim([0,1])
    plt.ylim([0,1])
    plt.grid(True)

label_order = DEFAULT_TOKENS.keys()
for i, name in enumerate(label_order):
    plot_pr(name, test_Y[:, i], pred_Y[:, i])

In [None]:
m = evaluation_dict(pred_Y, test_Y)

In [None]:
from util import LABELS

x = [lb for lb, _ in LB_metrics]
for i, l in enumerate(LABELS):
    y = [m[l]['auc_pr'] for _, m in LB_metrics]
    plt.plot(y, x)
    plt.ylim(0, 1)
    plt.title(f'AP of {l} for Random Sampling')
    plt.show()