In [None]:
import os
import re
from pydub import AudioSegment

RECORDINGS_DIR = "Recordings"
WAV_DIR = "wav"
os.makedirs(WAV_DIR, exist_ok=True)

recordings = {}
for fname in sorted(os.listdir(RECORDINGS_DIR), key=lambda x: int(re.search(r'\d+', x).group())):
    match = re.match(r's_(\d+)_(g|b)\.m4a', fname)
    if match:
        num = int(match.group(1))
        label = match.group(2)
        m4a_path = os.path.join(RECORDINGS_DIR, fname)
        wav_path = os.path.join(WAV_DIR, f"s_{num}_{label}.wav")
        AudioSegment.from_file(m4a_path, format="m4a").export(wav_path, format="wav")
        recordings[num] = {"label": label, "wav_path": wav_path}

print(f"Converted {len(recordings)} files")
print(f"  Good: {sum(1 for v in recordings.values() if v['label'] == 'g')}")
print(f"  Bad:  {sum(1 for v in recordings.values() if v['label'] == 'b')}")


In [None]:
import librosa
import numpy as np
import noisereduce as nr
from scipy.signal import butter, sosfilt
import matplotlib.pyplot as plt

SR = 48000
NOISE_SAMPLE_DURATION = 0.5  # seconds â€” assumes recording starts with silence before first tap

def bandpass_filter(audio, sr, low_hz=200, high_hz=20000, order=5):
    sos = butter(order, [low_hz, high_hz], btype='bandpass', fs=sr, output='sos')
    return sosfilt(sos, audio)

for num, info in recordings.items():
    audio, _ = librosa.load(info['wav_path'], sr=SR, mono=True)
    noise_sample = audio[:int(NOISE_SAMPLE_DURATION * SR)]

    # Step 1: spectral gating (Adobe Audition-style noise reduction)
    audio_nr = nr.reduce_noise(y=audio, sr=SR, y_noise=noise_sample)

    # Step 2: bandpass filter to keep only tap-response frequencies
    audio_clean = bandpass_filter(audio_nr, SR)

    recordings[num]['audio'] = audio_clean

print(f"Denoised {len(recordings)} recordings")

# --- Before/after comparison on first recording ---
sample_num = next(iter(recordings))
raw, _ = librosa.load(recordings[sample_num]['wav_path'], sr=SR, mono=True)
clean = recordings[sample_num]['audio']
t = np.arange(len(raw)) / SR

fig, ax = plt.subplots(2, 1, figsize=(12, 5), sharex=True)
ax[0].plot(t, raw, linewidth=0.5)
ax[0].set_title(f"Raw â€” s_{sample_num} ({recordings[sample_num]['label']})")
ax[0].set_ylabel("Amplitude")
ax[0].grid()
ax[1].plot(t, clean, linewidth=0.5, color='tab:orange')
ax[1].set_title("After spectral gating + bandpass (200 Hz - 20 kHz)")
ax[1].set_xlabel("Time (s)")
ax[1].set_ylabel("Amplitude")
ax[1].grid()
plt.tight_layout()
plt.show()


In [None]:
from scipy.signal import find_peaks

def splice_audio(audio, sr, pre_peak=100, post_peak=3000):
    # Use relative threshold: 30% of the signal peak, robust across recordings
    threshold = 0.3 * np.max(np.abs(audio))
    peaks, _ = find_peaks(audio, height=threshold, distance=30000)
    # Guard against peaks too close to the edges
    valid = [p for p in peaks if p >= pre_peak and p + post_peak <= len(audio)]
    splices = np.array([audio[p - pre_peak : p + post_peak] for p in valid])
    return splices

healthy_hits = []
unhealthy_hits = []

for num, info in recordings.items():
    splices = splice_audio(info['audio'], SR)
    if splices.size == 0:
        print(f"  WARNING: no peaks found in s_{num} ({info['label']})")
        continue
    if info['label'] == 'g':
        healthy_hits.append(splices)
    else:
        unhealthy_hits.append(splices)

healthy_hits = np.concatenate(healthy_hits, axis=0)
unhealthy_hits = np.concatenate(unhealthy_hits, axis=0)

print(f"Healthy hits:   {healthy_hits.shape}  ({healthy_hits.shape[0]} splices)")
print(f"Unhealthy hits: {unhealthy_hits.shape}  ({unhealthy_hits.shape[0]} splices)")
print(f"Splice length:  {healthy_hits.shape[1]} samples = {healthy_hits.shape[1]/SR*1000:.1f} ms")


In [10]:
split = 0.7

train_healthy   = healthy_hits[:int(split * len(healthy_hits))]
test_healthy    = healthy_hits[int(split * len(healthy_hits)):]

train_unhealthy = unhealthy_hits[:int(split * len(unhealthy_hits))]
test_unhealthy  = unhealthy_hits[int(split * len(unhealthy_hits)):]


print(f"Train — healthy: {len(train_healthy)}, unhealthy: {len(train_unhealthy)}, total: {len(X_train)}")
print(f"Test  — healthy: {len(test_healthy)}, unhealthy: {len(test_unhealthy)}, total: {len(X_test)}")


Train — healthy: 104, unhealthy: 89, total: 193
Test  — healthy: 45, unhealthy: 39, total: 84


In [None]:
# FFT plots

from scipy.fft import fft, ifft 

def fftmagnitude(signal,sr):
    L = len(signal) # data length
    Y = abs(fft(signal)) # fft and normalize
    y = Y[0:L//2] # take positive frequencies
    y[2:-1] = 2*y[2:-1] # correct for amplitude
    f = np.arange(0, L//2) * sr / L # frequency axis
    return y, f

healthy_mag, healthy_freq = fftmagnitude(healthy, sr)
unhealthy_mag, unhealthy_freq = fftmagnitude(unhealthy, sr)

fig, ax = plt.subplots(2, 1, figsize=(10, 6))
ax[0].plot(healthy_freq/1000, healthy_mag)
ax[0].set_title("FFT: Healthy Cell (C4) ")
ax[0].set_xlabel("Frequency (kHz)")
ax[0].set_ylabel("Magnitude")
ax[0].grid()
ax[0].set_xlim(0, 20)  # Focus on frequencies up to 20 kHz
ax[1].plot(unhealthy_freq/1000, unhealthy_mag)
ax[1].set_title("FFT: Unhealthy Cell (B3)")
ax[1].set_xlabel("Frequency (kHz)")
ax[1].set_ylabel("Magnitude")
ax[1].grid()
ax[1].set_xlim(0, 20)  # Focus on frequencies up to 20 kHz
plt.tight_layout()
plt.show()

In [None]:
# PSD analysis

from scipy import signal
f_healthy, Pxx_healthy = signal.periodogram(healthy, fs=1) # set fs=1 for normalized frequency
f_unhealthy, Pxx_unhealthy = signal.periodogram(unhealthy, fs=1)

fig, ax = plt.subplots(2, 1, figsize=(10, 6))
ax[0].plot(f_healthy, Pxx_healthy)
ax[0].set_title("PSD: Healthy Cell (C4)")
ax[0].set_xlabel("Normalized Frequency")
ax[0].set_ylabel("PSD")
ax[0].grid()
ax[0].set_xlim(0, 0.5)  # Focus on normalized frequencies up to Nyquist (0.5)
ax[1].plot(f_unhealthy, Pxx_unhealthy)
ax[1].set_title("PSD: Unhealthy Cell (B3)")
ax[1].set_xlabel("Normalized Frequency")
ax[1].set_ylabel("PSD")
ax[1].grid()
ax[1].set_xlim(0, 0.5)  # Focus on normalized frequencies up to Nyquist (0.5)
plt.tight_layout()
plt.show()


In [None]:
# STFT Analysis
f_healthy_stft, t_healthy_stft, Zxx_healthy = signal.stft(healthy, fs=sr, nperseg=256)
f_unhealthy_stft, t_unhealthy_stft, Zxx_unhealthy = signal.stft(unhealthy, fs=sr, nperseg=256)
Zxx_healthy_db = 20 * np.log10(np.abs(Zxx_healthy))
Zxx_unhealthy_db = 20 * np.log10(np.abs(Zxx_unhealthy))
Zxx_healthy_db_normalized = Zxx_healthy_db - np.max(Zxx_healthy_db)  # peak becomes 0 dB, everything else is negative
Zxx_unhealthy_db_normalized = Zxx_unhealthy_db - np.max(Zxx_unhealthy_db)  # peak becomes 0 dB, everything else is negative

fig, ax = plt.subplots(2, 1, figsize=(10, 6))
ax[0].pcolormesh(1000*t_healthy_stft, f_healthy_stft/1000, Zxx_healthy_db_normalized, shading='nearest', vmin=-40, vmax=0)  # Limit color range to -60 dB to 0 dB
ax[0].set_title("STFT: Healthy (C4)")
ax[0].set_xlabel("Time (ms)")
ax[0].set_ylabel("Frequency (kHz)")
ax[0].set_xlim(0, 50)  # Focus on the first 50 ms
ax[1].pcolormesh(1000*t_unhealthy_stft, f_unhealthy_stft/1000, Zxx_unhealthy_db_normalized, shading='nearest', vmin=-40, vmax=0)
ax[1].set_title("STFT: Unhealthy (B3)")
ax[1].set_xlabel("Time (ms)")
ax[1].set_ylabel("Frequency (kHz)")
ax[1].set_xlim(0, 50)  # Focus on the first 50 ms
plt.tight_layout()
plt.colorbar(ax[0].collections[0], ax=ax[0], label='Magnitude (dB)')
plt.colorbar(ax[1].collections[0], ax=ax[1], label='Magnitude (dB)')
plt.show()

In [None]:
# CWT Analysis

import pywt
from matplotlib.ticker import ScalarFormatter

wavelet = pywt.ContinuousWavelet('morl')
fc = pywt.central_frequency(wavelet)

frequencies = np.linspace(1, sr/2, 200)
scales = fc / (frequencies * (1/sr))
coeffs_healthy, _ = pywt.cwt(healthy, scales, wavelet, sampling_period=1/sr)
coeffs_unhealthy, _ = pywt.cwt(unhealthy, scales, wavelet, sampling_period=1/sr)




fig, ax = plt.subplots(2, 1, figsize=(10, 6))
ax[0].pcolormesh(1000*t_healthy, frequencies/1000, np.abs(coeffs_healthy), shading='nearest')
ax[0].set_title("CWT: Healthy (C4)")
ax[0].set_xlabel("Time (ms)")
ax[0].set_ylabel("Frequency (kHz)")
ax[0].set_yscale('log')
ax[0].set_ylim(1, sr/2000)  # Focus on frequencies up to Nyquist (24 kHz)
ax[1].pcolormesh(1000*t_unhealthy, frequencies/1000, np.abs(coeffs_unhealthy), shading='nearest')
ax[1].set_title("CWT: Unhealthy (B3)")
ax[1].set_xlabel("Time (ms)")
ax[1].set_ylabel("Frequency (kHz)")
ax[1].set_yscale('log')
ax[1].set_ylim(1, sr/2000)  # Focus on frequencies up to Nyquist (24 kHz)

ax[0].yaxis.set_major_formatter(ScalarFormatter())
ax[0].yaxis.get_major_formatter().set_scientific(False)
ax[1].yaxis.set_major_formatter(ScalarFormatter())
ax[1].yaxis.get_major_formatter().set_scientific(False)

plt.tight_layout()
plt.colorbar(ax[0].collections[0], ax=ax[0], label='Magnitude')
plt.colorbar(ax[1].collections[0], ax=ax[1], label='Magnitude')
plt.show()





In [None]:
# MFCC Analysis
mfcc_healthy = librosa.feature.mfcc(y=healthy, sr=sr) # feature matrices
mfcc_unhealthy = librosa.feature.mfcc(y=unhealthy, sr=sr)


fig, ax = plt.subplots(2, 1, figsize=(10, 6))


img1 = librosa.display.specshow(mfcc_healthy, sr=sr, x_axis='time', ax=ax[0], vmin=-50, vmax=50)
ax[0].set_title("MFCC: Healthy (C4)")
plt.colorbar(img1, ax=ax[0], label='MFCC Coefficient Value')

ax[0].set_xlabel("Time (ms)")
# convert current tick labels from seconds to ms
ticks = ax[0].get_xticks()
ax[0].set_xticks(ticks)
ax[0].set_xticklabels([f"{t*1000:.0f}" for t in ticks])
ax[0].set_xlim(0, 50/1000)  # Focus on the first 50 ms
ax[0].set_ylabel("MFCC Coefficient Index")
ax[0].set_yticks(range(0, 20, 2))  # 0, 2, 4, ... 18
ax[0].set_yticklabels(range(0, 20, 2))

img2 = librosa.display.specshow(mfcc_unhealthy, sr=sr, x_axis='time', ax=ax[1], vmin=-50, vmax=50)
ax[1].set_title("MFCC: Unhealthy (B3)")
plt.colorbar(img2, ax=ax[1], label='MFCC Coefficient Value')
ax[1].set_xlabel("Time (ms)")
# convert current tick labels from seconds to ms
ticks = ax[1].get_xticks()
ax[1].set_xticks(ticks)
ax[1].set_xticklabels([f"{t*1000:.0f}" for t in ticks])
ax[1].set_xlim(0, 50/1000)  # Focus on the first 50 ms
ax[1].set_ylabel("MFCC Coefficient Index")
ax[1].set_yticks(range(0, 20, 2))  # 0, 2, 4, ... 18
ax[1].set_yticklabels(range(0, 20, 2))

plt.tight_layout()
plt.show()
