In [1]:
import os
import torch
import librosa
import soundfile as sf
import numpy as np
import sounddevice as sd
from unet_anc_model import UNetANC
from scipy.io.wavfile import write

# **Preprocessing Function**
def preprocess_audio(file_path, sr=16000, chunk_length=32000, overlap=1600):
    """
    Preprocess the audio into overlapping chunks for model input.
    
    Parameters:
        file_path (str): Path to the audio file.
        sr (int): Sampling rate.
        chunk_length (int): Length of each chunk in samples.
        overlap (int): Overlap between consecutive chunks in samples.
    
    Returns:
        chunks (list): List of torch tensors containing audio chunks.
        original_length (int): Original audio length in samples.
    """
    audio, _ = librosa.load(file_path, sr=sr)
    original_length = len(audio)  # Store the original audio length
    chunks = []

    # Create overlapping chunks
    for start in range(0, original_length, chunk_length - overlap):
        end = start + chunk_length
        chunk = audio[start:end]
        if len(chunk) < chunk_length:
            chunk = np.pad(chunk, (0, chunk_length - len(chunk)), mode='constant')
        chunks.append(torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).unsqueeze(0))  # Add batch and channel dims

    return chunks, original_length


# **Postprocessing Function**
def postprocess_audio(denoised_chunks, original_length):
    """
    Combine the denoised chunks back into a single audio waveform and truncate to the original length.
    
    Parameters:
        denoised_chunks (list): List of numpy arrays representing denoised audio chunks.
        original_length (int): The original length of the audio signal.
    
    Returns:
        np.ndarray: The reconstructed audio signal of the original length.
    """
    # Concatenate all chunks
    combined_audio = np.concatenate(denoised_chunks)
    
    # Truncate to the original length
    return combined_audio[:original_length].astype(np.float32)

# **Audio Recording**
def record_audio(filename="recorded_noisy.wav", duration=10, sr=16000):
    print("Recording...")
    audio = sd.rec(int(duration * sr), samplerate=sr, channels=1, dtype='float32')
    sd.wait()
    write(filename, sr, np.squeeze(audio))
    print(f"Recording saved as {filename}")

# **Denoising Function**
def denoise_audio(model, noisy_chunks, device):
    """
    Perform denoising on a list of audio chunks using the model.
    
    Parameters:
        model: The trained PyTorch model for denoising.
        noisy_chunks (list): List of audio chunks (tensors).
        device: The device to run the model on (CPU/GPU).
    
    Returns:
        list: List of denoised audio chunks (numpy arrays).
    """
    model.eval()
    denoised_chunks = []

    with torch.no_grad():
        for i, chunk in enumerate(noisy_chunks):
            print(f"Denoising chunk {i + 1}/{len(noisy_chunks)}")
            chunk = chunk.to(device)
            denoised_chunk = model(chunk).squeeze().cpu().numpy()
            denoised_chunks.append(denoised_chunk)

    return denoised_chunks


# **Main Script**
if __name__ == "__main__":
    # Define model and device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNetANC().to(device)

    # Load trained model weights
    model_path = "best_model_cpu.pth"
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model checkpoint not found at {model_path}. Ensure the file exists.")
    
    checkpoint = torch.load(model_path, map_location=device)
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    print("Model loaded successfully!")

    # Record or load noisy audio
    record_audio()  # Records a 10-second noisy audio file named "recorded_noisy.wav"

    # Preprocess the recorded noisy audio
    noisy_chunks, original_length = preprocess_audio("recorded_noisy.wav")
    print(f"Processed {len(noisy_chunks)} chunks from the noisy audio.")

    # Perform denoising
    denoised_chunks = denoise_audio(model, noisy_chunks, device)

    # Postprocess to reconstruct the full audio
    denoised_audio = postprocess_audio(denoised_chunks, original_length)

    if len(denoised_audio.shape) != 1 or not np.issubdtype(denoised_audio.dtype, np.floating):
        raise ValueError("Denoised audio must be a 1D float32 array.")

    # Save the denoised audio
    output_file = "denoised_audio.wav"
    sf.write(output_file, denoised_audio, 16000)
    print(f"Denoised audio saved as {output_file}")

    # Plot results
    noisy_audio, _ = librosa.load("recorded_noisy.wav", sr=16000)
    plot_waveform_spectrogram(noisy_audio, title="Noisy Audio")
    plot_waveform_spectrogram(denoised_audio, title="Denoised Audio")


  checkpoint = torch.load(model_path, map_location=device)


Model loaded successfully!
Recording...
Recording saved as recorded_noisy.wav
Processed 6 chunks from the noisy audio.
Denoising chunk 1/6
Denoising chunk 2/6
Denoising chunk 3/6
Denoising chunk 4/6
Denoising chunk 5/6
Denoising chunk 6/6
Denoised audio saved as denoised_audio.wav


NameError: name 'plot_waveform_spectrogram' is not defined