In [None]:
import numpy as np
import sounddevice as sd
import tensorflow as tf
import matplotlib.pyplot as plt
import librosa
from scipy.signal import stft, istft
import time
import os

# Step 1: Train and save model if not already present
model_path = "models/fcnn_model.keras"
if not os.path.exists(model_path):
    print("Model not found. Training a new FCNN model...")
    os.makedirs("models", exist_ok=True)

    # Dummy training data (replace with real framed audio later)
    X = np.random.rand(100, 1024, 1)
    Y = np.random.rand(100, 1024, 1)

    # Updated: Use `shape=` instead of deprecated `input_shape=`
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(shape=(1024, 1)),
        tf.keras.layers.Conv1D(64, kernel_size=1, activation='relu', padding='same'),
        tf.keras.layers.Conv1D(64, kernel_size=1, activation='relu', padding='same'),
        tf.keras.layers.Conv1D(1, kernel_size=1, activation='linear', padding='same')
    ])

    model.compile(optimizer=tf.keras.optimizers.Adam(1e-3), loss='mse')
    model.fit(X, Y, epochs=5, batch_size=16)
    model.save(model_path)
    print("Model trained and saved to models/fcnn_model.keras")
else:
    model = tf.keras.models.load_model(model_path)
    print("Model loaded.")

# Step 2: Real-time audio denoising setup
SAMPLING_RATE = 16000
FRAME_SIZE = 1024
HOP_LENGTH = 512
N_FFT = 1024
window = np.hanning(N_FFT)

# Set up live spectrogram plot
plt.ion()
fig, ax = plt.subplots(figsize=(8, 4))
spec_img = ax.imshow(np.zeros((N_FFT // 2 + 1, 10)), aspect='auto', origin='lower', cmap='magma')
ax.set_title("Live Denoised Spectrogram")
ax.set_xlabel("Time")
ax.set_ylabel("Frequency")
plt.tight_layout()

# Real-time denoising with STFT
def denoise_stft(audio_frame):
    f, t, Zxx = stft(audio_frame, fs=SAMPLING_RATE, window=window, nperseg=N_FFT, noverlap=N_FFT - HOP_LENGTH)
    magnitude = np.abs(Zxx)
    phase = np.angle(Zxx)

    mag_input = magnitude.T[..., np.newaxis]
    denoised_mag = model.predict(mag_input, verbose=0)
    denoised_mag = denoised_mag.squeeze().T

    # Update spectrogram
    spec_img.set_data(denoised_mag)
    fig.canvas.draw()
    fig.canvas.flush_events()

    Zxx_denoised = denoised_mag * np.exp(1j * phase)
    _, denoised_audio = istft(Zxx_denoised, fs=SAMPLING_RATE, window=window, nperseg=N_FFT, noverlap=N_FFT - HOP_LENGTH)
    return denoised_audio

# Callback for real-time audio stream
def callback(indata, outdata, frames, time_info, status):
    if status:
        print(status)
    mono = indata[:, 0]
    if len(mono) < FRAME_SIZE:
        mono = np.pad(mono, (0, FRAME_SIZE - len(mono)))

    start = time.time()
    denoised = denoise_stft(mono)
    latency = (time.time() - start) * 1000
    print(f"Frame latency: {latency:.2f} ms")

    outdata[:, 0] = denoised[:frames]

# Start stream
print("Starting real-time denoising with visualization...")
with sd.Stream(channels=1, callback=callback, samplerate=SAMPLING_RATE, blocksize=FRAME_SIZE):
    print("Press Ctrl+C to stop.")
    try:
        while True:
            sd.sleep(1000)
    except KeyboardInterrupt:
        print("\nStopped.")