In [5]:
%load_ext autoreload
%autoreload 2

import tensorflow as tf
from tensorflow  import keras
import tensorflow_hub as hub
import numpy as np

import librosa 
from util import WavDataset
import matplotlib.pyplot as plt
from tqdm import tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
sr = 16_000

def compute_frame_labels(label_tensor, frame_length=int(sr*0.96), step_size=int(sr*0.48), threshold=0.15):
    n_labels, total_samples = label_tensor.shape
    n_frames = total_samples // step_size 
    
    frame_labels = np.zeros((n_labels, n_frames), dtype=int)
    
    for i in range(n_frames):
        start = i * step_size
        end = start + frame_length
        frame = label_tensor[:, start:end]
        
        # is there >15% annotations in the frame
        frame_label = (np.mean(frame, axis=1) >= threshold).astype(int)
        frame_labels[:, i] = frame_label
    
    return frame_labels

Y = np.zeros((4, 16_000 * 5))
Y[0, 32_000:] = 1
Y[2, :32_000] = 1
Y[3, -8_000:] = 1
compute_frame_labels(Y)

array([[0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 1]])

In [3]:
yamnet_url = 'https://tfhub.dev/google/yamnet/1'
yamnet_layer = hub.KerasLayer(yamnet_url, input_shape=(None,), dtype=tf.float32, trainable=False)

21:44:36 INFO Using /tmp/tfhub_modules to cache modules.
2024-09-08 21:44:37.244922: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [16]:
DEFAULT_TOKENS = {        
    "fast_trill_6khz": 0, 
    "nr_syllable_3khz": 1,
    "triangle_3khz": 2,   
    "upsweep_500hz": 3,   
}
def summarize(arr, name=""):
    return {
        f'{name} mean': np.mean(arr),
        f'{name} median': np.median(arr),
        f'{name} min': np.min(arr),
        f'{name} max': np.max(arr),
        f'{name} sum': np.sum(arr),
        f'{name} std_dev': np.std(arr),
        f'{name} var': np.var(arr),
    }

summarize([0, 3.3, 2, 5, 62])

{' mean': 14.459999999999999,
 ' median': 3.3,
 ' min': 0.0,
 ' max': 62.0,
 ' sum': 72.3,
 ' std_dev': 23.826170485413723,
 ' var': 567.6863999999999}

In [28]:
from config import *
import h5py

hdf5_file = INTERMEDIATE / 'samples_train.hdf5'
new_ds = INTERMEDIATE / 'train.hdf5'
old_ds = h5py.File(hdf5_file, 'r')


def process_samples(chunk):
    _, embedds, _ = yamnet_layer(chunk)
    return embedds
    

with h5py.File(new_ds, 'w') as new_ds:
    for i, rec in tqdm(enumerate(
        list(old_ds)
    )):
        
        Y = old_ds[rec]['Y']  
        X = old_ds[rec]['X']
        
        chunk = 5 * sr # 5 second chunks
        hop = 1 * sr # 1 second overlaps
        
        n_samples = Y.shape[1]

        for start in range(0, n_samples - chunk, chunk - hop):

            group = new_ds.create_group(f"chunk_{i}_{start}")  

            embedds = process_samples(X[start:start+chunk])
            group.create_dataset("X", data=embedds, dtype=np.float32) 
            
            label_frames = compute_frame_labels(Y[:, star    t:start+chunk])
            group.create_dataset("Y", data=label_frames, dtype=bool) 

            # chunk metadata
            group.attrs['recording'] = rec
            group.attrs['start_time'] = start // sr
            group.attrs['end_time'] = (start + chunk) // sr
            group.attrs['shapes'] = (embedds.shape, label_frames.shape)
            group.attrs['classwize_labeled_frame_counts'] = [
                sum(label_frames[index, :]) for label, index in DEFAULT_TOKENS.items()
            ]

            for key, val in summarize(embedds, "embedding_summary").items():
                group.att   rs[key] = val

            for key, val in summarize(label_frames, "label_summary").items():
                group.attrs[key] = val



old_ds.close() 
 

900it [3:28:53, 13.93s/it]
