In [6]:
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import csv
import io
from scipy.io import wavfile
import pandas as pd
import scipy.signal
from IPython.display import Audio

In [2]:
# Load the model.
model = hub.load('https://tfhub.dev/google/yamnet/1')

# Input: 3 seconds of silence as mono 16 kHz waveform samples.
waveform = np.zeros(3 * 16000, dtype=np.float32)

# Run the model, check the output.
scores, embeddings, log_mel_spectrogram = model(waveform)
scores.shape.assert_is_compatible_with([None, 521])
embeddings.shape.assert_is_compatible_with([None, 1024])
log_mel_spectrogram.shape.assert_is_compatible_with([None, 64])

# Find the name of the class with the top score when mean-aggregated across frames.
def class_names_from_csv(class_map_csv_text):
  """Returns list of class names corresponding to score vector."""
  class_map_csv = io.StringIO(class_map_csv_text)
  class_names = [display_name for (class_index, mid, display_name) in csv.reader(class_map_csv)]
  class_names = class_names[1:]  # Skip CSV header
  return class_names
class_map_path = model.class_map_path().numpy()
class_names = class_names_from_csv(tf.io.read_file(class_map_path).numpy().decode('utf-8'))
print(class_names[scores.numpy().mean(axis=0).argmax()])  # Should print 'Silence'.

Silence


In [3]:
!wget https://github.com/audio-samples/audio-samples.github.io/raw/master/samples/wav/music/sample-0.wav

--2021-08-31 15:42:06--  https://github.com/audio-samples/audio-samples.github.io/raw/master/samples/wav/music/sample-0.wav
Resolving github.com (github.com)... 140.82.112.4
Connecting to github.com (github.com)|140.82.112.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/audio-samples/audio-samples.github.io/master/samples/wav/music/sample-0.wav [following]
--2021-08-31 15:42:06--  https://raw.githubusercontent.com/audio-samples/audio-samples.github.io/master/samples/wav/music/sample-0.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 321386 (314K) [audio/wav]
Saving to: ‘sample-0.wav’


2021-08-31 15:42:07 (9.09 MB/s) - ‘sample-0.wav’ saved [321386/321386]



In [4]:
def ensure_sample_rate(original_sample_rate, waveform,
                       desired_sample_rate=16000):
  """Resample waveform if required."""
  if original_sample_rate != desired_sample_rate:
    desired_length = int(round(float(len(waveform)) /
                               original_sample_rate * desired_sample_rate))
    waveform = scipy.signal.resample(waveform, desired_length)
  return desired_sample_rate, waveform

In [7]:
wav_file_name = "sample-0.wav"

sample_rate, wav_data = wavfile.read(wav_file_name, 'rb')
sample_rate, wav_data = ensure_sample_rate(sample_rate, wav_data)

# Show some basic information about the audio.
duration = len(wav_data)/sample_rate
print(f'Sample rate: {sample_rate} Hz')
print(f'Total duration: {duration:.2f}s')
print(f'Size of the input: {len(wav_data)}')

# Listening to the wav file.
Audio(wav_data, rate=sample_rate)

Sample rate: 16000 Hz
Total duration: 10.04s
Size of the input: 160654


In [8]:
waveform = wav_data / tf.int16.max

In [9]:
scores, embeddings, log_mel_spectrogram = model(waveform)
class_map_path = model.class_map_path().numpy()
class_names = class_names_from_csv(tf.io.read_file(class_map_path).numpy().decode('utf-8'))
print(class_names[scores.numpy().mean(axis=0).argmax()]) 

Music


In [10]:
df = pd.DataFrame(index=class_names, data=scores.numpy()[0], columns=['score'])
df.sort_values(by='score', ascending=False).head()

Unnamed: 0,score
Music,0.999032
Musical instrument,0.256434
Keyboard (musical),0.215603
Piano,0.183519
Harp,0.068866


It works! Now let us wrap it in function predicting for any file

In [11]:
def predict_for_musicfile(filename):
  wav_file_name = filename
  sample_rate, wav_data = wavfile.read(wav_file_name, 'rb')
  sample_rate, wav_data = ensure_sample_rate(sample_rate, wav_data)
  duration = len(wav_data)/sample_rate
  print(f'Sample rate: {sample_rate} Hz')
  print(f'Total duration: {duration:.2f}s')
  print(f'Size of the input: {len(wav_data)}')
  waveform = wav_data / tf.int16.max
  scores, embeddings, log_mel_spectrogram = model(waveform)
  class_map_path = model.class_map_path().numpy()
  class_names = class_names_from_csv(tf.io.read_file(class_map_path).numpy().decode('utf-8'))
  df = pd.DataFrame(index=class_names, data=scores.numpy()[0], columns=['score'])
  return df.sort_values(by='score', ascending=False).head()


predict_for_musicfile("./sample-0.wav")

Sample rate: 16000 Hz
Total duration: 10.04s
Size of the input: 160654


Unnamed: 0,score
Music,0.999032
Musical instrument,0.256434
Keyboard (musical),0.215603
Piano,0.183519
Harp,0.068866


In [12]:
!gsutil cp -r gs://audioset .

Copying gs://audioset/golden_whistle.wav...
Copying gs://audioset/miaow_16k.wav...
Copying gs://audioset/speech_whistling2.wav...
Copying gs://audioset/vggish_model.ckpt...
/ [4 files][278.1 MiB/278.1 MiB]                                                
==> NOTE: You are performing a sequence of gsutil operations that may
run significantly faster if you instead use gsutil -m cp ... Please
see the -m section under "gsutil help options" for further information
about when gsutil -m can be advantageous.

Copying gs://audioset/vggish_pca_params.npz...
Copying gs://audioset/yamnet.h5...
Copying gs://audioset/yamnet.tflite...
Copying gs://audioset/yamalyzer/audio/accordion.wav...
Copying gs://audioset/yamalyzer/audio/acoustic-guitar.wav...
Copying gs://audioset/yamalyzer/audio/applause.wav...
Copying gs://audioset/yamalyzer/audio/bark.wav...
Copying gs://audioset/yamalyzer/audio/chewing.wav...
Copying gs://audioset/yamalyzer/audio/chime.wav...
Copying gs://audioset/yamalyzer/audio/cough.wav..

In [13]:
Audio("audioset/yamalyzer/audio/zipper.wav")

In [14]:
path = "audioset/yamalyzer/audio/zipper.wav"
# Audio(path)
predict_for_musicfile(path)

Sample rate: 16000 Hz
Total duration: 2.50s
Size of the input: 40000


Unnamed: 0,score
Scrape,0.540508
Camera,0.493014
Single-lens reflex camera,0.418859
Mechanisms,0.379557
Zipper (clothing),0.20076
