In [2]:
# @title Cell 0: Imports and Google Drive Mount
# All necessary libraries are imported here at the beginning.
from google.colab import drive
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import librosa
import soundfile as sf
from scipy.signal import butter, lfilter
import random # For data synthesis
from tqdm.notebook import tqdm # For progress bars
from IPython.display import Audio, display # For playing audio in Colab
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Mount Google Drive
drive.mount('/content/drive')
print("Google Drive mounted.")

Mounted at /content/drive
Google Drive mounted.


In [3]:
# @title Cell 1: Initial Project Setup and Folder Creation (Final Corrected Version)
# This cell sets up the necessary directory structure for the project
# in your Google Drive.

# Define the base path for your project within Google Drive
# This path should align with where you want to store your project data.
base_project_path = '/content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1'

# Define subdirectories for Deep Learning data
dl_data_path = os.path.join(base_project_path, 'DeepLearning_Data')

# --- 0_Raw_Audio_Data paths (for original and processed raw audio) ---
raw_audio_path = os.path.join(dl_data_path, '0_Raw_Audio_Data')

# Paths for raw audio sources (where you initially put downloaded/raw files)
RAW_SPEECH_DIR = os.path.join(raw_audio_path, 'speech')
RAW_NOISE_DIR = os.path.join(raw_audio_path, 'noise')
RAW_WHISTLE_DIR = os.path.join(raw_audio_path, 'whistle')

# Paths for processed raw audio files (matching your original naming)
PROCESSED_SPEECH_DIR = os.path.join(raw_audio_path, 'speech_processed')
PROCESSED_NOISE_DIR = os.path.join(raw_audio_path, 'noise_processed')
PROCESSED_WHISTLE_DIR = os.path.join(raw_audio_path, 'whistle_processed') # Matches your original Cell 2

# --- 1_Synthesized_Mixtures paths ---
synthesized_mixtures_path = os.path.join(dl_data_path, '1_Synthesized_Mixtures')

# Paths for training and validation mixtures (matching your original naming and casing)
TRAIN_MIXED_DIR = os.path.join(synthesized_mixtures_path, 'Train', 'Mixed')
TRAIN_WHISTLE_DIR = os.path.join(synthesized_mixtures_path, 'Train', 'WhistleOnly')
VAL_MIXED_DIR = os.path.join(synthesized_mixtures_path, 'Validation', 'Mixed')
VAL_WHISTLE_DIR = os.path.join(synthesized_mixtures_path, 'Validation', 'WhistleOnly')

# --- 2_Spectrograms paths (NEW additions for pre-computed spectrograms) ---
spectrograms_path = os.path.join(dl_data_path, '2_Spectrograms')
TRAIN_MIXED_SPECS_DIR = os.path.join(spectrograms_path, 'train_mixed_specs')
TRAIN_WHISTLE_SPECS_DIR = os.path.join(spectrograms_path, 'train_whistle_specs')
VAL_MIXED_SPECS_DIR = os.path.join(spectrograms_path, 'val_mixed_specs')
VAL_WHISTLE_SPECS_DIR = os.path.join(spectrograms_path, 'val_whistle_specs')

# --- Checkpoint directory ---
CHECKPOINT_DIR = os.path.join(base_project_path, 'checkpoints')
BEST_MODEL_PATH = os.path.join(CHECKPOINT_DIR, 'best_whistle_enhancer_model.pth')
LAST_MODEL_PATH = os.path.join(CHECKPOINT_DIR, 'last_whistle_enhancer_model.pth')

# --- Create all necessary directories ---
# Base directories
os.makedirs(raw_audio_path, exist_ok=True)
os.makedirs(synthesized_mixtures_path, exist_ok=True)
os.makedirs(spectrograms_path, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Raw audio sub-directories
os.makedirs(RAW_SPEECH_DIR, exist_ok=True)
os.makedirs(RAW_NOISE_DIR, exist_ok=True)
os.makedirs(RAW_WHISTLE_DIR, exist_ok=True)

# Processed raw audio sub-directories
os.makedirs(PROCESSED_SPEECH_DIR, exist_ok=True)
os.makedirs(PROCESSED_NOISE_DIR, exist_ok=True)
os.makedirs(PROCESSED_WHISTLE_DIR, exist_ok=True)

# Synthesized mixtures sub-directories
os.makedirs(TRAIN_MIXED_DIR, exist_ok=True)
os.makedirs(TRAIN_WHISTLE_DIR, exist_ok=True)
os.makedirs(VAL_MIXED_DIR, exist_ok=True)
os.makedirs(VAL_WHISTLE_DIR, exist_ok=True)

# Spectrograms sub-directories
os.makedirs(TRAIN_MIXED_SPECS_DIR, exist_ok=True)
os.makedirs(TRAIN_WHISTLE_SPECS_DIR, exist_ok=True)
os.makedirs(VAL_MIXED_SPECS_DIR, exist_ok=True)
os.makedirs(VAL_WHISTLE_SPECS_DIR, exist_ok=True)


print("All project directories created/ensured, including new spectrogram paths.")

# Global variable to track data state
# Data_State: 0=Initial, 1=Raw Processed, 2=Mixtures Synthesized, 3=Spectrograms Pre-processed
Data_State = 3
print(f"Initial Data_State: {Data_State}")

# Device Configuration - Define here as it's a fundamental environment setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

All project directories created/ensured, including new spectrogram paths.
Initial Data_State: 3
Using device: cpu


In [4]:
# @title Cell 2: Configuration Parameters (Cleaned Version)
# All global constants and parameters are defined here for easy modification.
# NOTE: All directory paths (e.g., TRAIN_MIXED_DIR, PROCESSED_WHISTLE_DIR)
# and 'device' are already defined and globally accessible from Cell 0.
# 'Data_State' is also initialized in Cell 0.

# Audio Processing Parameters
TARGET_SAMPLE_RATE = 16000 # All audio will be resampled to this rate
SEGMENT_DURATION_SECONDS = 4 # Duration of generated mixed/whistle-only audio files

# Synthesis Parameters
NUM_TRAIN_MIX = 10000 # Number of training mixtures to generate
NUM_VAL_MIX = 1000    # Number of validation mixtures to generate
SNR_RANGE_DB = [-5, 5] # Signal-to-Noise Ratio range for mixing

# Real-time DSP Processing Parameters (for DL model input)
FRAME_LENGTH_MS = 25
HOP_LENGTH_MS = 10
FRAME_LENGTH_SAMPLES = int(TARGET_SAMPLE_RATE * FRAME_LENGTH_MS / 1000)
HOP_LENGTH_SAMPLES = int(TARGET_SAMPLE_RATE * HOP_LENGTH_MS / 1000)

# Band-Pass Filter Parameters
LOWCUT_FREQ = 500
HIGHCUT_FREQ = 5000
FILTER_ORDER = 6

# Spectrogram Parameters (Mel Spectrograms)
STFT_N_FFT = FRAME_LENGTH_SAMPLES # 400. Used as window size for Mel Spectrogram calculation
STFT_HOP_LENGTH = STFT_N_FFT      # 400. Used as hop length for Mel Spectrogram calculation
N_MELS = 64                       # Number of Mel bands/filters

# Deep Learning Training Parameters
BATCH_SIZE = 64
LEARNING_RATE = 0.001
NUM_EPOCHS = 30
EARLY_STOP_PATIENCE = 5 # Number of epochs to wait for improvement before stopping
scheduler_patience = 5
scheduler_factor = 0.5

print("Configuration parameters loaded.")

Configuration parameters loaded.


In [5]:
# @title Cell 3: DSP Helper Functions
# These functions encapsulate the digital signal processing logic.

def butter_bandpass(lowcut, highcut, fs, order=5):
    """Creates a Butterworth band-pass filter coefficients."""
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return b, a

def apply_filter(data, b, a):
    """Applies the filter to the data."""
    return lfilter(b, a, data)

def get_mini_mel_spectrogram(audio_frame, sr, n_fft, hop_length, n_mels, top_db=80.0):
    """
    Computes a single Mel magnitude spectrogram from an audio frame.
    Returns log-magnitude Mel spectrogram as a PyTorch tensor.
    """
    # Ensure audio frame is long enough for STFT window. Pad with zeros if necessary.
    if len(audio_frame) < n_fft:
        audio_frame = np.pad(audio_frame, (0, n_fft - len(audio_frame)), 'constant')

    # Compute Mel Spectrogram (power spectrogram by default)
    mel_spec = librosa.feature.melspectrogram(y=audio_frame, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)

    # Convert to log-magnitude (dB scale)
    log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max, top_db=top_db)

    # Normalize to a 0-1 range (optional, but often beneficial for DL)
    min_db = -top_db # If top_db is 80, then min_db is -80
    max_db = 0.0   # np.max will be 0dB as ref=np.max
    normalized_spec = (log_mel_spec - min_db) / (max_db - min_db)

    # Convert to PyTorch tensor. Add channel dimension (1 for mono audio) and ensure float32.
    # Expected final shape: (1, n_mels, time_bins)
    return torch.from_numpy(normalized_spec).unsqueeze(0).float()

print("DSP helper functions defined.")

DSP helper functions defined.


In [6]:
# @title Cell 4: RealtimeFrameDataset Class Definition
# This custom PyTorch Dataset handles loading audio, framing, filtering,
# and converting to Mel spectrograms for the DL model.

class RealtimeFrameDataset(Dataset):
    def __init__(self, mixed_dir, whistle_dir, sr, segment_duration,
                 frame_length_samples, hop_length_samples,
                 lowcut_freq, highcut_freq, filter_order,
                 stft_n_fft, stft_hop_length, n_mels):

        self.mixed_dir = mixed_dir
        self.whistle_dir = whistle_dir
        self.sr = sr
        self.segment_duration = segment_duration

        # Real-time framing parameters
        self.frame_length_samples = frame_length_samples
        self.hop_length_samples = hop_length_samples

        # Filter parameters
        self.lowcut_freq = lowcut_freq
        self.highcut_freq = highcut_freq
        self.filter_order = filter_order
        self.b_filter, self.a_filter = butter_bandpass(self.lowcut_freq, self.highcut_freq, self.sr, self.filter_order)

        # Spectrogram parameters
        self.stft_n_fft = stft_n_fft
        self.stft_hop_length = stft_hop_length
        self.n_mels = n_mels

        # List of all base filenames (e.g., "mixed_00001.wav")
        self.base_filenames = [f for f in os.listdir(mixed_dir) if f.endswith('.wav')]
        if not self.base_filenames:
            raise RuntimeError(f"No .wav files found in {mixed_dir}. Please check your synthesized data.")

        # Pre-calculate the total number of individual frames we'll extract
        sample_audio_path = os.path.join(mixed_dir, self.base_filenames[0])
        y_sample, _ = librosa.load(sample_audio_path, sr=self.sr)

        self.num_frames_per_segment = (len(y_sample) - self.frame_length_samples) // self.hop_length_samples + 1

        self.total_frames = len(self.base_filenames) * self.num_frames_per_segment
        print(f"Dataset will provide {self.total_frames} individual real-time frames for training.")


    def __len__(self):
        return self.total_frames

    def __getitem__(self, idx):
        file_idx = idx // self.num_frames_per_segment
        frame_in_file_idx = idx % self.num_frames_per_segment

        base_filename = self.base_filenames[file_idx]
        mixed_file_path = os.path.join(self.mixed_dir, base_filename)
        # Assuming whistle files have names like "whistle_00001.wav" where mixed was "mixed_00001.wav"
        whistle_file_path = os.path.join(self.whistle_dir, f"whistle_{base_filename.split('_')[1]}")

        try:
            mixed_audio_full, _ = librosa.load(mixed_file_path, sr=self.sr)
            whistle_audio_full, _ = librosa.load(whistle_file_path, sr=self.sr)
        except Exception as e:
            print(f"Error loading full audio for {base_filename}: {e}. Returning dummy tensors.")
            dummy_spec_shape = (1, self.n_mels, 2) # Updated dummy shape for Mel spectrogram
            return torch.zeros(dummy_spec_shape, dtype=torch.float32), torch.zeros(dummy_spec_shape, dtype=torch.float32)

        start_sample = frame_in_file_idx * self.hop_length_samples
        end_sample = start_sample + self.frame_length_samples

        mixed_frame_audio = mixed_audio_full[start_sample:end_sample]
        whistle_frame_audio = whistle_audio_full[start_sample:end_sample]

        # Apply DSP pre-processing for mixed audio (input to DL model)
        filtered_mixed_frame = apply_filter(mixed_frame_audio, self.b_filter, self.a_filter)
        mixed_frame_spec = get_mini_mel_spectrogram(filtered_mixed_frame, self.sr, self.stft_n_fft, self.stft_hop_length, self.n_mels)

        # Apply DSP pre-processing for whistle audio (target for DL model)
        filtered_whistle_frame = apply_filter(whistle_frame_audio, self.b_filter, self.a_filter)
        whistle_frame_spec = get_mini_mel_spectrogram(filtered_whistle_frame, self.sr, self.stft_n_fft, self.stft_hop_length, self.n_mels)

        return mixed_frame_spec, whistle_frame_spec

print("RealtimeFrameDataset class defined.")

RealtimeFrameDataset class defined.


In [7]:
# @title Cell 5: LightweightWhistleEnhancer Model Definition
# This defines the neural network architecture for the whistle enhancement.

class LightweightWhistleEnhancer(nn.Module):
    def __init__(self, input_freq_bins, input_time_bins=2):
        super(LightweightWhistleEnhancer, self).__init__()

        # The convolutional layers are designed to operate on spectrograms.
        # kernel_size=(X, 1) means it convolves across X frequency bins, but only 1 time bin at a time.
        # This preserves the time dimension.
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(5, 1), padding=(2, 0)),
            nn.BatchNorm2d(16),
            nn.ReLU()
        )
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 1), padding=(1, 0)),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        # Final 1x1 convolution to map features back to a single channel (output spectrogram)
        self.output_conv = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=(1, 1))

        # Sigmoid activation function to ensure output values are between 0 and 1,
        # matching the normalized Mel spectrogram input/target.
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x shape: (batch_size, 1, input_freq_bins, input_time_bins) e.g., (B, 1, 64, 2)

        x = self.conv_block1(x) # -> (B, 16, 64, 2)
        x = self.conv_block2(x) # -> (B, 32, 64, 2)

        x = self.output_conv(x) # -> (B, 1, 64, 2)
        x = self.sigmoid(x)     # Output in 0-1 range

        return x

print("LightweightWhistleEnhancer model class defined.")

LightweightWhistleEnhancer model class defined.


In [8]:
# @title Cell 6: Download & Prepare Raw Audio Data
# This cell handles downloading (if necessary) and preparing raw audio files
# by resampling and normalizing them to a consistent format.
# Assume you have downloaded your raw datasets (speech, noise, whistle)
# into the respective RAW_*_DIR paths set in Cell 2.
# For example, you might have used !gdown commands or manually uploaded them.

print("--- Preparing Raw Audio Data ---")

def process_audio_files(input_dir, output_dir, target_sr):
    """Resamples and normalizes audio files from input_dir to output_dir."""
    os.makedirs(output_dir, exist_ok=True)
    file_list = [f for f in os.listdir(input_dir) if f.endswith(('.wav', '.flac', '.mp3'))]

    if not file_list:
        print(f"Warning: No audio files found in {input_dir}. Please ensure raw data is in place.")
        return

    print(f"Processing {len(file_list)} files from {input_dir} to {output_dir}...")
    for filename in tqdm(file_list):
        input_filepath = os.path.join(input_dir, filename)
        output_filepath = os.path.join(output_dir, os.path.splitext(filename)[0] + '.wav') # Ensure .wav output

        # --- MODIFICATION START ---
        if os.path.exists(output_filepath):
            # If the output file already exists, skip processing for this file.
            # You could add more thorough checks here (e.g., file size, sample rate)
            # if you need to verify the "goodness" of the existing file beyond just existence.
            # For now, simply checking for existence.
            # print(f"Skipping '{filename}': Output file already exists.") # Uncomment for detailed skips
            continue # Skip to the next file
        # --- MODIFICATION END ---

        try:
            # Load with target sample rate, normalize to mono
            audio, sr = librosa.load(input_filepath, sr=target_sr, mono=True)
            # Save normalized audio to output directory
            sf.write(output_filepath, audio, target_sr, format='WAV', subtype='PCM_16')
        except Exception as e:
            print(f"Error processing {filename}: {e}")
if Data_State < 1:
    # Process speech files
    print(f"Processing speech from {RAW_SPEECH_DIR}...")
    process_audio_files(RAW_SPEECH_DIR, PROCESSED_SPEECH_DIR, TARGET_SAMPLE_RATE)

    # Process noise files
    print(f"Processing noise from {RAW_NOISE_DIR}...")
    process_audio_files(RAW_NOISE_DIR, PROCESSED_NOISE_DIR, TARGET_SAMPLE_RATE)

    # Process whistle files (if separate raw whistle files exist)
    print(f"Processing whistle from {RAW_WHISTLE_DIR}...")
    process_audio_files(RAW_WHISTLE_DIR, PROCESSED_WHISTLE_DIR, TARGET_SAMPLE_RATE)


    print("\nRaw audio data preparation complete (resampling and normalization).")
    print("Make sure you have populated the raw audio directories before running this cell.")
    Data_State = 1


--- Preparing Raw Audio Data ---


In [9]:
# @title Cell 7: Audio Mixture Synthesis
# This cell synthesizes training and validation audio mixtures.

import os
import random
import librosa
import soundfile as sf
import numpy as np
from tqdm import tqdm # For progress bars

print("--- Starting Audio Mixture Synthesis ---")

# --- Helper Function to Count Files ---
def count_files_in_dir(directory, expected_count, file_extension=".wav", desc=""):
    """Counts files in a directory and compares to an expected count."""
    if not os.path.exists(directory):
        print(f"Warning: Directory does not exist: {directory}")
        return False, 0

    actual_files = [f for f in os.listdir(directory) if f.endswith(file_extension)]
    actual_count = len(actual_files)

    if actual_count >= expected_count: # Use >= in case some extra files exist
        print(f"  {desc}: Found {actual_count} {file_extension} files (Expected: {expected_count}) - All good.")
        return True, actual_count
    else:
        print(f"  {desc}: Found {actual_count} {file_extension} files (Expected: {expected_count}) - INCOMPLETE.")
        return False, actual_count

# --- Helper Function for Audio Normalization ---
def normalize_audio(audio):
    """Normalizes audio to prevent clipping and ensure consistent loudness."""
    if np.max(np.abs(audio)) == 0: # Avoid division by zero for silent audio
        return audio
    return audio / np.max(np.abs(audio)) * 0.9 # Scale to 90% of max amplitude


# --- Check if synthesis needs to run based on Data_State and existing files ---
# This check prevents re-synthesizing if Data_State is already 2 (or higher due to spectrograms)
# and the expected number of files are found.

# First, check counts for existing files
train_mixed_present, count_train_mixed = count_files_in_dir(TRAIN_MIXED_DIR, NUM_TRAIN_MIX, desc="Train Mixed")
train_whistle_present, count_train_whistle = count_files_in_dir(TRAIN_WHISTLE_DIR, NUM_TRAIN_MIX, desc="Train Whistle")
val_mixed_present, count_val_mixed = count_files_in_dir(VAL_MIXED_DIR, NUM_VAL_MIX, desc="Validation Mixed")
val_whistle_present, count_val_whistle = count_files_in_dir(VAL_WHISTLE_DIR, NUM_VAL_MIX, desc="Validation Whistle")

# Determine if all mixtures are already synthesized
all_mixtures_exist = (
    train_mixed_present and train_whistle_present and
    val_mixed_present and val_whistle_present
)

# This 'if' block will determine if synthesis is skipped or performed
if all_mixtures_exist and Data_State >= 2:
    print(f"\nSkipping all mixture synthesis: All {NUM_TRAIN_MIX*2 + NUM_VAL_MIX*2} expected mixed/whistle files already exist and Data_State is {Data_State}.")
else:
    print("\nStarting mixture synthesis (or completing previous run)...")

    # --- Load Processed Raw Audio Files ---
    # These directories should contain the processed .wav files (e.g., from Cell 6)
    speech_files = [os.path.join(PROCESSED_SPEECH_DIR, f) for f in os.listdir(PROCESSED_SPEECH_DIR) if f.endswith('.wav')]
    noise_files = [os.path.join(PROCESSED_NOISE_DIR, f) for f in os.listdir(PROCESSED_NOISE_DIR) if f.endswith('.wav')]
    whistle_files = [os.path.join(PROCESSED_WHISTLE_DIR, f) for f in os.listdir(PROCESSED_WHISTLE_DIR) if f.endswith('.wav')]

    if not speech_files or not noise_files or not whistle_files:
        print("ERROR: Not enough processed raw audio files found for synthesis. Please ensure Cell 6 (Download/Prepare) ran successfully.")
        # You might want to exit or raise an error here if synthesis cannot proceed
    else:
        print(f"Found {len(speech_files)} speech, {len(noise_files)} noise, and {len(whistle_files)} whistle processed files.")

        # --- Synthesis Function ---
        def synthesize_mixtures(num_mixtures, mixed_output_dir, whistle_output_dir, prefix):
            print(f"\nSynthesizing {num_mixtures} {prefix} mixtures...")
            for i in tqdm(range(num_mixtures), desc=f"Generating {prefix} Mixtures"):
                mixed_filename = os.path.join(mixed_output_dir, f'{prefix}_mixed_{i:05d}.wav')
                whistle_filename = os.path.join(whistle_output_dir, f'{prefix}_whistle_{i:05d}.wav')

                # Skip if both files already exist (useful for resuming)
                if os.path.exists(mixed_filename) and os.path.exists(whistle_filename):
                    continue

                # Randomly pick audio segments
                speech_audio, _ = librosa.load(random.choice(speech_files), sr=TARGET_SAMPLE_RATE, duration=SEGMENT_DURATION_SECONDS, mono=True)
                noise_audio, _ = librosa.load(random.choice(noise_files), sr=TARGET_SAMPLE_RATE, duration=SEGMENT_DURATION_SECONDS, mono=True)
                whistle_audio, _ = librosa.load(random.choice(whistle_files), sr=TARGET_SAMPLE_RATE, duration=SEGMENT_DURATION_SECONDS, mono=True)

                # Ensure all segments are exactly SEGMENT_DURATION_SECONDS long (padding/trimming if needed)
                sample_length = int(TARGET_SAMPLE_RATE * SEGMENT_DURATION_SECONDS)
                speech_audio = librosa.util.fix_length(speech_audio, size=sample_length)
                noise_audio = librosa.util.fix_length(noise_audio, size=sample_length)
                whistle_audio = librosa.util.fix_length(whistle_audio, size=sample_length)

                # Normalize individual components to a reference level (e.g., RMS)
                whistle_audio = normalize_audio(whistle_audio)
                speech_audio = normalize_audio(speech_audio)
                noise_audio = normalize_audio(noise_audio)

                # Randomly select SNR
                current_snr_db = random.uniform(SNR_RANGE_DB[0], SNR_RANGE_DB[1])

                # Mix whistle, speech, and noise
                # Adjust speech and noise levels relative to whistle based on SNR
                # Convert SNR dB to linear ratio
                snr_linear = 10**(current_snr_db / 10)

                # Calculate RMS of whistle
                rms_whistle = np.sqrt(np.mean(whistle_audio**2))

                if rms_whistle > 1e-8: # Avoid division by zero
                    # Calculate required RMS of combined noise (speech + noise_component)
                    rms_combined_noise_target = rms_whistle / np.sqrt(snr_linear)

                    # Simple scaling: combine speech and noise, then scale their sum
                    # You might want to mix speech and noise with their own SNRs first
                    # For simplicity, let's scale noise_audio based on whistle RMS
                    rms_noise = np.sqrt(np.mean(noise_audio**2))
                    if rms_noise > 1e-8:
                        noise_audio_scaled = noise_audio * (rms_combined_noise_target / rms_noise)
                    else:
                        noise_audio_scaled = noise_audio # Noise is silent, no scaling needed

                    # Mix them: Whistle + Speech + Scaled Noise
                    mixed_audio = whistle_audio + speech_audio + noise_audio_scaled
                else: # If whistle is silent, mixed audio is just speech and noise
                    mixed_audio = speech_audio + noise_audio

                # Final normalization to prevent clipping for the mixed signal
                mixed_audio = normalize_audio(mixed_audio)
                whistle_audio = normalize_audio(whistle_audio) # Re-normalize whistle only

                # Save mixed and whistle-only files
                sf.write(mixed_filename, mixed_audio, TARGET_SAMPLE_RATE)
                sf.write(whistle_filename, whistle_audio, TARGET_SAMPLE_RATE)

        # --- Execute Synthesis for Training and Validation ---
        synthesize_mixtures(NUM_TRAIN_MIX, TRAIN_MIXED_DIR, TRAIN_WHISTLE_DIR, "train")
        synthesize_mixtures(NUM_VAL_MIX, VAL_MIXED_DIR, VAL_WHISTLE_DIR, "val")

        # --- Final Data State Update ---
        # Re-check file counts after synthesis attempt
        train_mixtures_final_check, _ = count_files_in_dir(TRAIN_MIXED_DIR, NUM_TRAIN_MIX, desc="Final Train Mixed Check")
        _, count_train_whistle_final = count_files_in_dir(TRAIN_WHISTLE_DIR, NUM_TRAIN_MIX, desc="Final Train Whistle Check")
        val_mixtures_final_check, _ = count_files_in_dir(VAL_MIXED_DIR, NUM_VAL_MIX, desc="Final Validation Mixed Check")
        _, count_val_whistle_final = count_files_in_dir(VAL_WHISTLE_DIR, NUM_VAL_MIX, desc="Final Validation Whistle Check")

        # Update Data_State only if all files are truly present
        if train_mixtures_final_check and val_mixtures_final_check and \
           count_train_whistle_final >= NUM_TRAIN_MIX and count_val_whistle_final >= NUM_VAL_MIX:
            Data_State = 2
            print(f"\nData_State updated to {Data_State} (All Mixture Synthesis Complete).")
        else:
            print("\nData synthesis might be incomplete. Data_State not updated to 2.")

print("\nAudio mixture synthesis check complete.")

--- Starting Audio Mixture Synthesis ---
  Train Mixed: Found 10000 .wav files (Expected: 10000) - All good.
  Train Whistle: Found 10000 .wav files (Expected: 10000) - All good.
  Validation Mixed: Found 1000 .wav files (Expected: 1000) - All good.
  Validation Whistle: Found 1000 .wav files (Expected: 1000) - All good.

Skipping all mixture synthesis: All 22000 expected mixed/whistle files already exist and Data_State is 3.

Audio mixture synthesis check complete.


In [10]:
# @title Cell 7.5: Pre-process Audio to Mel Spectrogram Tensors
# This cell converts all synthesized WAV files into pre-computed Mel spectrogram tensors.

print("--- Pre-processing WAV files to Mel Spectrogram Tensors ---")

# Ensure necessary imports for this cell
import librosa
import numpy as np
import torch
import os
from scipy.signal import butter, filtfilt
from tqdm import tqdm
import random # Needed for get_audio_files if defined here

# --- Helper function for DSP and Mel Spectrogram conversion ---
# Ensure these parameters (TARGET_SAMPLE_RATE, STFT_N_FFT, N_MELS, etc.) are defined in Cell 2/3
def audio_to_mel_spectrogram(audio, sr, lowcut_freq, highcut_freq, filter_order, stft_n_fft, stft_hop_length, n_mels):
    """
    Applies band-pass filter, performs STFT, and converts to Mel spectrogram.
    Returns normalized Mel spectrogram.
    """
    # 1. Band-pass filtering
    nyquist = 0.5 * sr
    low = lowcut_freq / nyquist
    high = highcut_freq / nyquist
    b, a = butter(filter_order, [low, high], btype='band')
    y_filtered = filtfilt(b, a, audio)

    # 2. STFT
    D = librosa.stft(y_filtered, n_fft=stft_n_fft, hop_length=stft_hop_length)

    # 3. Mel Spectrogram
    mel_spec = librosa.feature.melspectrogram(S=librosa.magphase(D)[0], sr=sr, n_fft=stft_n_fft, hop_length=stft_hop_length, n_mels=n_mels)
    mel_spec_db = librosa.amplitude_to_db(mel_spec, ref=np.max)

    # 4. Normalization (e.g., min-max to 0-1)
    # Using fixed min/max for consistent scaling across all spectrograms
    min_db = -100.0 # A common lower bound for dB spectrograms
    max_db = 0.0    # Assuming max amplitude is normalized to 0dB in librosa.amplitude_to_db(..., ref=np.max)

    normalized_spec = (mel_spec_db - min_db) / (max_db - min_db + 1e-8)
    normalized_spec = np.clip(normalized_spec, 0, 1) # Ensure values are strictly between 0 and 1

    # Add channel dimension (batch, channel, freq, time) before saving
    return torch.from_numpy(normalized_spec).float().unsqueeze(0)


# --- Function to process and save all files in a directory ---
def process_and_save_spectrograms(input_audio_dir, output_specs_dir, prefix):
    print(f"\nProcessing {prefix} audio files from '{input_audio_dir}' to '{output_specs_dir}'...")
    audio_files = sorted([f for f in os.listdir(input_audio_dir) if f.endswith('.wav')])

    # Check if all expected spectrogram files already exist
    all_specs_exist = True
    if len(audio_files) == 0: # If no audio files, then no specs to expect
        all_specs_exist = False # Force generation or report
    else:
        for audio_file in audio_files:
            spec_filename = os.path.join(output_specs_dir, os.path.splitext(audio_file)[0] + '.pt')
            if not os.path.exists(spec_filename):
                all_specs_exist = False
                break

        if all_specs_exist and len(os.listdir(output_specs_dir)) >= len(audio_files):
            print(f"Skipping {prefix} spectrogram generation: All {len(audio_files)} spectrograms already found in '{output_specs_dir}'.")
            return

    for audio_file in tqdm(audio_files, desc=f"Converting {prefix} WAV to Spectrograms"):
        audio_path = os.path.join(input_audio_dir, audio_file)
        spec_filename = os.path.join(output_specs_dir, os.path.splitext(audio_file)[0] + '.pt')

        if os.path.exists(spec_filename):
            continue # Skip if spectrogram already exists

        audio, sr = librosa.load(audio_path, sr=TARGET_SAMPLE_RATE, mono=True)

        # Ensure audio is correctly segmented to SEGMENT_DURATION_SECONDS.
        # If your model's input_time_bins (e.g., 2) is very small, the pre-processed
        # spectrogram should be the full segment. The slicing for 'input_time_bins'
        # will happen in the DataLoader.

        mel_spec_tensor = audio_to_mel_spectrogram(
            audio, sr, LOWCUT_FREQ, HIGHCUT_FREQ, FILTER_ORDER,
            STFT_N_FFT, STFT_HOP_LENGTH, N_MELS
        )

        # Save the tensor
        torch.save(mel_spec_tensor, spec_filename)
if Data_State < 3:
    # --- Execution ---
    # Assumes TARGET_SAMPLE_RATE, LOWCUT_FREQ, HIGHCUT_FREQ, FILTER_ORDER,
    # STFT_N_FFT, STFT_HOP_LENGTH, N_MELS are defined in Cell 2/3.

    # Process training data
    process_and_save_spectrograms(TRAIN_MIXED_DIR, TRAIN_MIXED_SPECS_DIR, "Training Mixed")
    process_and_save_spectrograms(TRAIN_WHISTLE_DIR, TRAIN_WHISTLE_SPECS_DIR, "Training Whistle")

    # Process validation data
    process_and_save_spectrograms(VAL_MIXED_DIR, VAL_MIXED_SPECS_DIR, "Validation Mixed")
    process_and_save_spectrograms(VAL_WHISTLE_DIR, VAL_WHISTLE_SPECS_DIR, "Validation Whistle")

    # Update Data_State after spectrograms are pre-processed
    # This is a global variable from Cell 0, assuming it's managed for persistence.
    # For robust persistence, you'd save/load this state to/from a file.
    Data_State = 3
    print(f"\nData_State updated to {Data_State} (Spectrograms Pre-processed).")
    print("\nAll audio files pre-processed to Mel Spectrogram Tensors.")

--- Pre-processing WAV files to Mel Spectrogram Tensors ---


In [11]:
# @title Cell 8: Dataset and DataLoader Initialization
# Initializes the datasets and data loaders for training and validation.

import torch
from torch.utils.data import Dataset, DataLoader
import os
import random # Needed for random slicing of time frames

print("\n--- Initializing PrecomputedSpectrogramDataset ---")

# Define the new dataset class to load pre-computed spectrograms
class PrecomputedSpectrogramDataset(Dataset):
    def __init__(self, mixed_specs_dir, whistle_specs_dir, input_time_bins=2):
        self.mixed_specs_dir = mixed_specs_dir
        self.whistle_specs_dir = whistle_specs_dir
        self.input_time_bins = input_time_bins
        self.mixed_files = sorted([f for f in os.listdir(mixed_specs_dir) if f.endswith('.pt')])
        self.whistle_files = sorted([f for f in os.listdir(whistle_specs_dir) if f.endswith('.pt')])


        # Ensure that mixed and whistle file lists match by count
        assert len(self.mixed_files) == len(self.whistle_files), \
            f"Mismatch between number of mixed ({len(self.mixed_files)}) and whistle ({len(self.whistle_files)}) spectrogram files."

        # Removed the problematic assertion that compared full basenames,
        # as files are correctly paired by sorting and consistent numbering
        # from Cell 7.5.

        self.length = len(self.mixed_files)
        print(f"Found {self.length} pre-computed spectrograms in {mixed_specs_dir} and {whistle_specs_dir}.")


    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        mixed_spec_path = os.path.join(self.mixed_specs_dir, self.mixed_files[idx])
        whistle_spec_path = os.path.join(self.whistle_specs_dir, self.whistle_files[idx])

        # Load the pre-computed spectrogram tensors
        mixed_spec_full = torch.load(mixed_spec_path) # Expected shape: (1, N_MELS, num_time_frames)
        whistle_spec_full = torch.load(whistle_spec_path) # Expected shape: (1, N_MELS, num_time_frames)

        # --- Handle time_bins slicing for model input ---
        # Your model expects input_time_bins (e.g., 2) as the last dimension.
        # The pre-processed spectrograms have `num_time_frames` based on SEGMENT_DURATION_SECONDS.

        num_full_time_frames = mixed_spec_full.shape[-1]

        if num_full_time_frames < self.input_time_bins:
             # This should ideally not happen if segment_duration and STFT params are set correctly
             # or if segments are padded. If it does, you might want to pad here or raise an error.
             raise ValueError(f"Spectrogram '{self.mixed_files[idx]}' has only {num_full_time_frames} time frames, but input_time_bins is {self.input_time_bins}. Ensure segment_duration or STFT parameters yield enough time frames.")

        # Pick a random starting point for the 'input_time_bins' window.
        # This ensures the model sees different temporal contexts from the same long segment.
        start_frame = random.randint(0, num_full_time_frames - self.input_time_bins)
        end_frame = start_frame + self.input_time_bins

        mixed_spec_window = mixed_spec_full[:, :, start_frame:end_frame]
        whistle_spec_window = whistle_spec_full[:, :, start_frame:end_frame]

        return mixed_spec_window, whistle_spec_window


try:
    train_dataset = PrecomputedSpectrogramDataset(
        mixed_specs_dir=TRAIN_MIXED_SPECS_DIR,
        whistle_specs_dir=TRAIN_WHISTLE_SPECS_DIR,
        input_time_bins=2 # This should match your model's input_time_bins
    )
    val_dataset = PrecomputedSpectrogramDataset(
        mixed_specs_dir=VAL_MIXED_SPECS_DIR,
        whistle_specs_dir=VAL_WHISTLE_SPECS_DIR,
        input_time_bins=2
    )

    # Increase num_workers: Now that data loading is much faster, you can utilize more CPU cores
    # for loading. A good starting point is 4 or 8, depending on Colab's available cores and RAM.
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    # Test retrieving a batch to confirm shapes
    print(f"\nAttempting to retrieve a batch from Train DataLoader with batch size {BATCH_SIZE}...")
    for mixed_batch, whistle_batch in train_loader:
        print(f"   Mixed Mel-Spectrograms Batch Shape: {mixed_batch.shape}")   # Expected: (BATCH_SIZE, 1, N_MELS, 2)
        print(f"   Whistle Mel-Spectrograms Batch Shape: {whistle_batch.shape}") # Expected: (BATCH_SIZE, 1, N_MELS, 2)
        break # Just get one batch

except Exception as e:
    print(f"Error initializing Dataset or DataLoader: {e}")
    # You might want to handle this more gracefully than letting the cell crash
    # For a notebook, a printed error is usually sufficient.

print("Dataset and DataLoader initialized.")


--- Initializing PrecomputedSpectrogramDataset ---
Found 10000 pre-computed spectrograms in /content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1/DeepLearning_Data/2_Spectrograms/train_mixed_specs and /content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1/DeepLearning_Data/2_Spectrograms/train_whistle_specs.
Found 1000 pre-computed spectrograms in /content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1/DeepLearning_Data/2_Spectrograms/val_mixed_specs and /content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1/DeepLearning_Data/2_Spectrograms/val_whistle_specs.

Attempting to retrieve a batch from Train DataLoader with batch size 64...




   Mixed Mel-Spectrograms Batch Shape: torch.Size([64, 1, 64, 2])
   Whistle Mel-Spectrograms Batch Shape: torch.Size([64, 1, 64, 2])
Dataset and DataLoader initialized.


In [12]:
# @title Cell 9: Model, Loss, and Optimizer Setup
# Initializes the model, defines the loss function, and sets up the optimizer.

model = LightweightWhistleEnhancer(input_freq_bins=N_MELS).to(device) # Pass N_MELS to model
criterion = nn.MSELoss() # Mean Squared Error, common for spectrogram prediction
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=scheduler_factor, patience=scheduler_patience, verbose=True)

print(f"\nModel architecture:\n{model}")
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}") # Should be much smaller than U-Net

print("Model, loss function, and optimizer set up.")


Model architecture:
LightweightWhistleEnhancer(
  (conv_block1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 1), stride=(1, 1), padding=(2, 0))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv_block2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (output_conv): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
  (sigmoid): Sigmoid()
)
Total trainable parameters: 1793
Model, loss function, and optimizer set up.




In [17]:
# @title Cell 10: Training Loop
# The main training loop for the deep learning model.

import torch
import os # Import os for path operations
from tqdm import tqdm # Ensure tqdm is imported for progress bars
from torch.optim.lr_scheduler import ReduceLROnPlateau # Import the scheduler

print("\n--- Starting Training ---")

# --- Early Stopping Configuration ---
patience_counter = 0

# --- Resume Training Logic ---
start_epoch = 0
best_val_loss = float('inf') # Initialize with infinity to ensure the first model is saved

# Try to load the last saved checkpoint to resume training
if os.path.exists(LAST_MODEL_PATH):
    print(f"Attempting to resume training from {LAST_MODEL_PATH}...")
    try:
        checkpoint = torch.load(LAST_MODEL_PATH, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1 # Start from the next epoch
        best_val_loss = checkpoint['loss'] # Set best_val_loss to the loaded loss

        # If scheduler state was also saved, load it
        if 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        print(f"Resumed training from Epoch {start_epoch}. Last recorded Val Loss: {best_val_loss:.4f}")
        print(f"Current learning rate: {optimizer.param_groups[0]['lr']:.6f}")

    except Exception as e:
        print(f"Error loading checkpoint: {e}. Starting training from scratch.")
        # Reset to defaults if loading fails
        start_epoch = 0
        best_val_loss = float('inf')
        patience_counter = 0 # Reset patience counter if starting from scratch

else:
    print(f"No checkpoint found at {LAST_MODEL_PATH}. Starting training from scratch.")

# Ensure NUM_EPOCHS is defined (e.g., in Cell 2)
# The training loop will now run from start_epoch up to NUM_EPOCHS
for epoch in range(start_epoch, NUM_EPOCHS):
    model.train() # Set model to training mode
    running_loss = 0.0

    for batch_idx, (mixed_specs, whistle_specs) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} (Train)")):
        mixed_specs = mixed_specs.to(device)
        whistle_specs = whistle_specs.to(device)

        optimizer.zero_grad() # Zero the gradients

        outputs = model(mixed_specs) # Forward pass
        loss = criterion(outputs, whistle_specs) # Calculate loss

        loss.backward() # Backward pass
        optimizer.step() # Update weights

        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)

    # --- Validation ---
    model.eval() # Set model to evaluation mode
    val_loss = 0.0
    with torch.no_grad(): # Disable gradient calculations for validation
        for mixed_specs_val, whistle_specs_val in tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} (Val)"):
            mixed_specs_val = mixed_specs_val.to(device)
            whistle_specs_val = whistle_specs_val.to(device)

            outputs_val = model(mixed_specs_val)
            loss_val = criterion(outputs_val, whistle_specs_val)
            val_loss += loss_val.item()

    avg_val_loss = val_loss / len(val_loader)

    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    # --- Learning Rate Scheduler Step ---
    scheduler.step(avg_val_loss) # Adjust learning rate based on validation loss
    print(f"Current Learning Rate: {optimizer.param_groups[0]['lr']:.6f}") # Print current LR

    # --- Checkpointing Logic ---
    # Always save the latest model (useful for resuming after timeout)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(), # Save scheduler state
        'loss': avg_val_loss,
    }, LAST_MODEL_PATH)
    # print(f"Saved latest model to {LAST_MODEL_PATH}") # Uncomment for verbose saving

    # Save the model if it's the best performing so far on validation loss
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0 # Reset patience if validation loss improved
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': best_val_loss,
        }, BEST_MODEL_PATH)
        print(f"Saved best model with validation loss {best_val_loss:.4f} to {BEST_MODEL_PATH}")
    else:
        patience_counter += 1 # Increment patience if validation loss did not improve
        print(f"Validation loss did not improve. Patience: {patience_counter}/{EARLY_STOP_PATIENCE}")

    # --- Early Stopping Check ---
    if patience_counter >= EARLY_STOP_PATIENCE:
        print(f"Early stopping triggered after {epoch+1} epochs due to no improvement in validation loss for {EARLY_STOP_PATIENCE} epochs.")
        break # Exit the training loop


print("\n--- Training Complete! ---")
print("The LightweightWhistleEnhancer model has been trained using Mel spectrograms.")
print("Next steps will involve evaluating its performance and preparing for deployment on a drone.")


--- Processing Audio for Playback ---
Using sample audio file for listening: /content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1/DeepLearning_Data/1_Synthesized_Mixtures/Train/Mixed/train_mixed_09000.wav
Loaded original mixed audio (Duration: 4.00s, SR: 16000Hz)
Designed Band-Pass Filter (500-5000 Hz, Order: 6)
Applied band-pass filter to the audio.
Saved original audio to: /content/audio_outputs/original_mixed_audio.wav
Saved filtered audio to: /content/audio_outputs/band_pass_filtered_audio.wav

--- Playback ---
1. Original Mixed Audio:



2. Band-Pass Filtered Audio (DSP Output):
  (Expected: Much less low-frequency rumble/speech, less high-frequency hiss, whistle should be clearer)



3. DSP + Model Enhanced Audio:
Loaded model from /content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1/checkpoints/best_whistle_enhancer_model.pth


TypeError: InverseMelScale.__init__() got an unexpected keyword argument 'n_fft'

In [14]:
# @title Cell 11: DSP Audio Playback Demo
# This cell allows you to listen to the original and band-pass filtered audio.
# It uses the DSP functions defined in Cell 3.

print("\n--- Processing Audio for Playback ---")

# Path to one of your synthesized mixed audio files for demonstration
# Uses the TRAIN_MIXED_DIR from Cell 2
sample_file_path = None
for f_name in os.listdir(TRAIN_MIXED_DIR):
    if f_name.endswith('.wav'):
        sample_file_path = os.path.join(TRAIN_MIXED_DIR, f_name)
        break

if sample_file_path is None:
    raise FileNotFoundError(f"No .wav files found in {TRAIN_MIXED_DIR}. Please ensure your synthesis completed successfully.")
print(f"Using sample audio file for listening: {sample_file_path}")

# Load the mixed audio file
y_mixed_full, sr = librosa.load(sample_file_path, sr=TARGET_SAMPLE_RATE)
print(f"Loaded original mixed audio (Duration: {y_mixed_full.shape[0]/sr:.2f}s, SR: {sr}Hz)")

# Design the band-pass filter using the function from Cell 3
b, a = butter_bandpass(LOWCUT_FREQ, HIGHCUT_FREQ, sr, order=FILTER_ORDER)
print(f"Designed Band-Pass Filter ({LOWCUT_FREQ}-{HIGHCUT_FREQ} Hz, Order: {FILTER_ORDER})")

# Apply the filter to the entire audio file using the function from Cell 3
y_filtered_full = apply_filter(y_mixed_full, b, a)
print("Applied band-pass filter to the audio.")

# Normalize the filtered audio to prevent clipping if it became louder
if np.max(np.abs(y_filtered_full)) > 0:
    y_filtered_full = y_filtered_full / np.max(np.abs(y_filtered_full)) * 0.9 # Scale to 90% of max amplitude

# --- Save Processed Audio for Playback ---
output_dir = '/content/audio_outputs'
os.makedirs(output_dir, exist_ok=True)
original_output_path = os.path.join(output_dir, 'original_mixed_audio.wav')
filtered_output_path = os.path.join(output_dir, 'band_pass_filtered_audio.wav')

sf.write(original_output_path, y_mixed_full, sr)
sf.write(filtered_output_path, y_filtered_full, sr)
print(f"Saved original audio to: {original_output_path}")
print(f"Saved filtered audio to: {filtered_output_path}")

# --- Play Audio in Colab ---
print("\n--- Playback ---")

print("1. Original Mixed Audio:")
display(Audio(data=y_mixed_full, rate=sr))

print("\n2. Band-Pass Filtered Audio (DSP Output):")
print(f"  (Expected: Much less low-frequency rumble/speech, less high-frequency hiss, whistle should be clearer)")
display(Audio(data=y_filtered_full, rate=sr))

print("\nPlayback complete. You can compare the two audio clips now.")


--- Processing Audio for Playback ---
Using sample audio file for listening: /content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1/DeepLearning_Data/1_Synthesized_Mixtures/Train/Mixed/train_mixed_09000.wav
Loaded original mixed audio (Duration: 4.00s, SR: 16000Hz)
Designed Band-Pass Filter (500-5000 Hz, Order: 6)
Applied band-pass filter to the audio.
Saved original audio to: /content/audio_outputs/original_mixed_audio.wav
Saved filtered audio to: /content/audio_outputs/band_pass_filtered_audio.wav

--- Playback ---
1. Original Mixed Audio:



2. Band-Pass Filtered Audio (DSP Output):
  (Expected: Much less low-frequency rumble/speech, less high-frequency hiss, whistle should be clearer)



Playback complete. You can compare the two audio clips now.


In [None]:
import shutil
import os

def copy_folder_tree(source_path, destination_path):
    """
    Copies the entire folder tree from source_path to destination_path.

    Args:
        source_path (str): The path to the root folder you want to copy.
        destination_path (str): The path to the destination folder where the
                                content will be copied.
    """
    if not os.path.exists(source_path):
        print(f"Error: Source path '{source_path}' does not exist.")
        return

    if os.path.exists(destination_path):
        print(f"Warning: Destination path '{destination_path}' already exists. "
              "Existing content might be overwritten or merged.")
        # You might want to add a prompt here to ask the user if they want to overwrite
        # For simplicity, we'll proceed with copying, which will merge/overwrite.
        # If you want to ensure a clean copy, you might delete the destination first:
        # shutil.rmtree(destination_path)
        # print(f"Removed existing destination folder: {destination_path}")

    try:
        shutil.copytree(source_path, destination_path)
        print(f"Successfully copied '{source_path}' to '{destination_path}'.")
    except shutil.Error as e:
        print(f"Error: Could not copy folder tree. Details: {e}")
    except OSError as e:
        print(f"Error: Operating system error during copy. Details: {e}")

# --- How to use it in Google Colab ---

# 1. Mount Google Drive (recommended for persistent storage)
#    This allows you to easily access files in your Google Drive.
# from google.colab import drive
# drive.mount('/content/drive')

# 2. Define your source and destination paths
#    Examples:

#    a) Using paths within Colab's temporary file system (will be lost after session)
#       Let's create some dummy folders/files for demonstration
#       os.makedirs('/content/source_folder/subfolder1', exist_ok=True)
#       os.makedirs('/content/source_folder/subfolder2', exist_ok=True)
#       with open('/content/source_folder/file1.txt', 'w') as f:
#           f.write('This is file 1.')
#       with open('/content/source_folder/subfolder1/file_in_sub1.txt', 'w') as f:
#           f.write('This is a file in subfolder1.')
#       source_folder = '/content/source_folder'
#       destination_folder = '/content/copied_folder'

#    b) Using paths in Google Drive (recommended for your own files)
#       Make sure these folders actually exist in your Google Drive or create them.
#       Example:
#       source_folder = '/content/drive/MyDrive/MyProject/OriginalData'
#       destination_folder = '/content/drive/MyDrive/MyProject/BackupData'

#    IMPORTANT: Replace these with your actual desired paths!
source_folder = '/content/drive/MyDrive/Colab Projects Data/Yondu Arrow 2' # Example source folder
destination_folder = '/content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1' # Example destination folder

# --- Create dummy data for demonstration (optional, remove for your actual use) ---
print("Creating dummy source data for demonstration...")
os.makedirs(os.path.join(source_folder, 'images'), exist_ok=True)
os.makedirs(os.path.join(source_folder, 'documents', 'reports'), exist_ok=True)
with open(os.path.join(source_folder, 'main.py'), 'w') as f:
    f.write('print("Hello from main.py")')
with open(os.path.join(source_folder, 'images', 'image1.jpg'), 'w') as f:
    f.write('dummy image content')
with open(os.path.join(source_folder, 'documents', 'notes.txt'), 'w') as f:
    f.write('Some important notes.')
with open(os.path.join(source_folder, 'documents', 'reports', 'report_final.pdf'), 'w') as f:
    f.write('dummy pdf content')
print("Dummy source data created.")
# --- End of dummy data creation ---

# Call the function to copy the folder tree
copy_folder_tree(source_folder, destination_folder)

# You can then verify the copy by listing the contents of the destination folder
print("\nContents of destination folder:")
if os.path.exists(destination_folder):
    for dirpath, dirnames, filenames in os.walk(destination_folder):
        level = dirpath.replace(destination_folder, '').count(os.sep)
        indent = ' ' * 4 * (level)
        print(f'{indent}{os.path.basename(dirpath)}/')
        subindent = ' ' * 4 * (level + 1)
        for f in filenames:
            print(f'{subindent}{f}')
else:
    print(f"Destination folder '{destination_folder}' does not exist (copy failed).")

In [None]:
import shutil
import os

def copy_all_files_flat(source_root_path, destination_flat_path):
    """
    Copies all files from a source root folder and its subdirectories
    into a single flat destination folder, without preserving the original structure.

    Args:
        source_root_path (str): The path to the root folder from which to copy files.
        destination_flat_path (str): The path to the existing destination folder
                                     where all files will be copied.
    """
    if not os.path.exists(source_root_path):
        print(f"Error: Source root path '{source_root_path}' does not exist.")
        return

    if not os.path.exists(destination_flat_path):
        print(f"Error: Destination flat path '{destination_flat_path}' does not exist. "
              "Please create it before running the copy.")
        # Optionally, you could create it here:
        # os.makedirs(destination_flat_path)
        # print(f"Created destination folder: {destination_flat_path}")
        return

    if not os.path.isdir(destination_flat_path):
        print(f"Error: Destination flat path '{destination_flat_path}' is not a directory.")
        return

    print(f"Starting to copy files from '{source_root_path}' to '{destination_flat_path}'...")
    copied_count = 0
    skipped_count = 0

    for dirpath, dirnames, filenames in os.walk(source_root_path):
        for filename in filenames:
            source_file_path = os.path.join(dirpath, filename)
            destination_file_path = os.path.join(destination_flat_path, filename)

            # Handle potential duplicate filenames:
            # If a file with the same name already exists in the destination,
            # you might want to rename the new file to avoid overwriting.
            # For simplicity, this code will overwrite if the file name is the same.
            # If you need to handle duplicates, consider adding a counter:
            # base, ext = os.path.splitext(filename)
            # counter = 1
            # while os.path.exists(destination_file_path):
            #     destination_file_path = os.path.join(destination_flat_path, f"{base}_{counter}{ext}")
            #     counter += 1

            try:
                shutil.copy2(source_file_path, destination_file_path)
                print(f"  Copied: '{source_file_path}' to '{destination_file_path}'")
                copied_count += 1
            except FileExistsError: # This might happen if using a different copy method or if the file changes between checks
                print(f"  Skipped (already exists): '{source_file_path}'")
                skipped_count += 1
            except Exception as e:
                print(f"  Error copying '{source_file_path}': {e}")
                skipped_count += 1

    print(f"\nCopying complete.")
    print(f"Total files copied: {copied_count}")
    print(f"Total files skipped/errored: {skipped_count}")


source_root_noise = '/content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1/DeepLearning_Data/0_Raw_Sources/CleanNoise/ESC-50-master/audio'
destination_flat_noise = '/content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1/DeepLearning_Data/0_Raw_Audio_Data/noise'
# Call the function to copy all files flat
copy_all_files_flat(source_root_noise, destination_flat_noise)


source_root_speak = '/content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1/DeepLearning_Data/0_Raw_Sources/CleanSpeech/LibriSpeech/dev-clean'
destination_flat_speak = '/content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1/DeepLearning_Data/0_Raw_Audio_Data/speech'
# Call the function to copy all files flat
copy_all_files_flat(source_root_speak, destination_flat_speak)


source_root_whistle = '/content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1/DeepLearning_Data/0_Raw_Sources/CleanWhistles'
destination_flat_whistle = '/content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1/DeepLearning_Data/0_Raw_Audio_Data/whistle'
# Call the function to copy all files flat
copy_all_files_flat(source_root_whistle, destination_flat_whistle)


In [None]:
import os

def get_folder_size(root_path):
    """
    Calculates the total size of all files within a given root directory
    and its subdirectories.

    Args:
        root_path (str): The path to the root folder.

    Returns:
        float: The total size in bytes. Returns 0 if the path does not exist
               or is not a directory.
    """
    total_size = 0
    if not os.path.exists(root_path):
        print(f"Error: The path '{root_path}' does not exist.")
        return 0
    if not os.path.isdir(root_path):
        print(f"Error: The path '{root_path}' is not a directory.")
        # If it's a file, return its size
        try:
            return os.path.getsize(root_path)
        except Exception as e:
            print(f"Error getting size of file {root_path}: {e}")
            return 0


    for dirpath, dirnames, filenames in os.walk(root_path):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            # Skip symbolic links that might point to files outside the tree
            if not os.path.islink(fp):
                try:
                    total_size += os.path.getsize(fp)
                except FileNotFoundError:
                    print(f"File not found during size calculation: {fp}")
                except PermissionError:
                    # Handle cases where you don't have permission to read a file
                    print(f"Permission denied for file: {fp}")
                except Exception as e:
                    print(f"Error getting size of {fp}: {e}")
    return total_size

# --- Example Usage ---
# Set the root path. For Colab, you might use:
# '/content' for the Colab VM's local storage
# '/content/drive/MyDrive' after mounting Google Drive
root_path = '/content/drive/MyDrive/Colab Projects Data/Yondu Arrow 1' # Adjust this to your actual root folder

print(f"Calculating total size under: {root_path}")

total_bytes = get_folder_size(root_path)
total_gb = total_bytes / (1024**3) # Convert bytes to GB

print(f"\nTotal size of '{root_path}' and its contents: {total_gb:.2f} GB")