In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
from scipy.io import wavfile
from skimage.feature import peak_local_max

In [None]:
%matplotlib inline
plt.rcParams.update({'font.size': 18})

In [None]:
rate1, song_array1 = wavfile.read('Katy_Perry.wav')
rate2, song_array2 = wavfile.read('Daft_Punk.wav')

In [None]:
spec1, freqs1, t1 = mlab.specgram(song_array1, NFFT=4096, Fs=44100, noverlap=2048)
spec2, freqs2, t2 = mlab.specgram(song_array2, NFFT=4096, Fs=44100, noverlap=2048)

In [None]:
spec1[spec1 == 0] = 1e-6
spec2[spec2 == 0] = 1e-6

In [None]:
Z = 10.0 * np.log10(spec1)
Z = np.flipud(Z)

In [None]:
fig1 = plt.figure(figsize=(10, 8), facecolor='white')
extent = 0, np.amax(t1), freqs1[0], freqs1[-1]
plt.imshow(Z, cmap='viridis', extent=extent)
plt.xlabel('Time bin [s]')
plt.ylabel('Frequency [Hz]')
plt.axis('auto')
plt.show()

In [None]:
min_freq = 0
max_freq = 15000

spec1 = spec1[(freqs1 >= min_freq) & (freqs1 <= max_freq)]
freqs1 = freqs1[(freqs1 >= min_freq) & (freqs1 <= max_freq)]

spec2 = spec2[(freqs2 >= min_freq) & (freqs2 <= max_freq)]
freqs2 = freqs2[(freqs2 >= min_freq) & (freqs2 <= max_freq)]

Z1 = 10.0 * np.log10(spec1)
Z1 = np.flipud(Z1)

Z2 = 10.0 * np.log10(spec2)
Z2 = np.flipud(Z2)

In [None]:
coordinates1 = peak_local_max(Z1, min_distance=20, threshold_abs=20)
coordinates2 = peak_local_max(Z2, min_distance=20, threshold_abs=20)

In [None]:
fig1 = plt.figure(figsize=(10, 8), facecolor='white')
plt.imshow(Z1, cmap='viridis')
plt.scatter(coordinates1[:, 1], coordinates1[:, 0])
ax1 = plt.gca()
plt.xlabel('Time')
plt.ylabel('Frequency')
plt.axis('auto')
plt.scatter(coordinates1[:, 1], coordinates1[:, 0])
ax1.set_xlim([0, len(t1)])
ax1.set_ylim([len(freqs1), 0])
ax1.xaxis.set_ticklabels([])
ax1.yaxis.set_ticklabels([])
plt.show()

In [None]:
fig2 = plt.figure(figsize=(10, 8), facecolor='white')
plt.imshow(Z2, cmap='viridis')
plt.scatter(coordinates2[:, 1], coordinates2[:, 0])
ax2 = plt.gca()
plt.xlabel('Time', fontsize=18)
plt.ylabel('Frequency', fontsize=18)
plt.axis('auto')
ax2.set_xlim([0, len(t2)])
ax2.set_ylim([len(freqs2), 0])
ax2.xaxis.set_ticklabels([])
ax2.yaxis.set_ticklabels([])
plt.show()