## 1. *Band-pass Filter Parameters*
- `BAND_LO` and `BAND_HI`: Adjust frequency range (currently 12-18 kHz)
- Target different frequency bands for different battery states

    - This is essentially how we filter only the frequencies within our range of interest (the ultrasonic waves)

## 2. *Spectral Subtraction Parameters*
- `QUIET_PCT`: Percentage of quietest frames for noise estimation (currently 20%)
- `OVERSUB`: Over-subtraction factor (currently 1.2)
    - used to remove noise from our recordings

## 3. **STFT Parameters**
- `NFFT`: FFT size affects frequency resolution (currently 2048)
- `HOP`: Hop length affects time resolution (currently 512)

    - an STFT is basically just a time frequency domain representation of the audio waveform, instead of showing up in a magnitude format. 

## 4. **Mel-spectrogram Parameters**
- `N_MELS`: Number of mel bins (currently 64)
- `n_fft`: FFT size for mel transform (currently 1024)

    - Mel-spectrograms are a way to represent audio signals in a way that aligns with how we perceive sound, specifically focusing on freqs that we can hear more distinctly.

In [None]:
import os, glob, subprocess, math
import numpy as np
from PIL import Image

import torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models

import torchaudio
import torchaudio.transforms as T
import torchaudio.functional as AF

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from collections import Counter

# SETTINGS
RAW_VIDEOS   = "raw_videos"     # input videos
AUDIO_DIR    = "audio_segments"   # output 2s wavs
IMAGE_DIR    = "recorded_images"  # output frame jpgs

SAMPLE_RATE  = 48000       
CHUNK_SECONDS= 2
N_MELS       = 64
BATCH_SIZE   = 20
EPOCHS       = 15
DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"

label_map = {"full": 0, "half": 1, "empty": 2}
LABELS = {v: k.capitalize() for k, v in label_map.items()}

os.makedirs(AUDIO_DIR, exist_ok=True)
os.makedirs(IMAGE_DIR, exist_ok=True)

In [None]:
# EXTRACT 2s AUDIO + FRAMES FROM VIDEO
def extract_audio(video_path, wav_out, sample_rate=SAMPLE_RATE):
    # Use ffmpeg to extract audio from the video file
    cmd = [
        "ffmpeg", "-y", "-i", video_path,  # Input video file
        "-ar", str(sample_rate), "-ac", "1", wav_out  # Output audio file with specified sample rate and mono channel
    ]
    subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)  # Run the command silently

def extract_frame(video_path, image_out, timestamp):
    # Use ffmpeg to extract a single frame from the video at a specific timestamp
    cmd = ["ffmpeg", "-y", "-i", video_path, "-ss", f"{timestamp:.3f}", "-vframes", "1", image_out]
    subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)  # Run the command silently

def split_audio_and_frames(video_path, audio_dir=AUDIO_DIR, image_dir=IMAGE_DIR,
                           chunk_sec=CHUNK_SECONDS, sample_rate=SAMPLE_RATE):
    # Extract the base name of the video file (e.g., full_8)
    base = os.path.splitext(os.path.basename(video_path))[0]

    # 1) Extract full audio from the video
    full_audio = os.path.join(audio_dir, base + ".wav")  # Path for the extracted audio file
    extract_audio(video_path, full_audio, sample_rate=sample_rate)  # Extract audio

    # 2) Load the extracted audio and split it into 2-second chunks
    waveform, sr = torchaudio.load(full_audio)  # Load the audio waveform and sample rate
    if sr != sample_rate:
        waveform = torchaudio.functional.resample(waveform, sr, sample_rate)  # Resample if needed
        sr = sample_rate

    total_samples = waveform.shape[-1]  # Total number of samples in the audio
    chunk_samples = int(chunk_sec * sr)  # Number of samples per chunk
    n_chunks = total_samples // chunk_samples  # Total number of chunks

    for i in range(n_chunks):
        # Extract a chunk of audio
        start = i * chunk_samples  # Start sample index
        end   = start + chunk_samples  # End sample index
        seg   = waveform[:, start:end]  # Extract the segment
        seg_name = f"{base}_seg{i}.wav"  # Name for the audio segment file
        seg_path = os.path.join(audio_dir, seg_name)  # Path for the audio segment file
        torchaudio.save(seg_path, seg, sr)  # Save the audio segment

        # Extract the frame at the middle of the audio chunk (i*2s + 1.0s)
        ts = (i * chunk_sec) + (chunk_sec / 2.0)  # Timestamp for the frame
        frame_name = f"{base}_frame{i}.jpg"  # Name for the frame image file
        frame_path = os.path.join(image_dir, frame_name)  # Path for the frame image file
        extract_frame(video_path, frame_path, ts)  # Extract the frame

    print(f"[INFO] {base}: wrote {n_chunks} segments")  # Log the number of segments created

def prepare_dataset_from_videos(raw_videos=RAW_VIDEOS):
    # Collect all video files from the raw_videos directory
    video_files = []
    video_files.extend(glob.glob(os.path.join(raw_videos, "*.MOV")))  # Add .MOV files
    video_files.extend(glob.glob(os.path.join(raw_videos, "*.mov")))  # Add .mov files
    video_files.extend(glob.glob(os.path.join(raw_videos, "*.mp4")))  # Add .mp4 files

    print(video_files)  # Print the list of video files found
    if not video_files:
        print(f"[WARN] No videos in {raw_videos}")  # Warn if no videos are found

    # Process each video file to extract audio and frames
    for vf in video_files:
        split_audio_and_frames(vf)

In [None]:
# IMAGE TRANSFORM
gray = transforms.Grayscale(num_output_channels=3)
transform_image = transforms.Compose([
    transforms.Resize((224,224)),
    gray,
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# Audio cleaning for a 2s chunk (band-pass + quiet-frame spectral subtraction)
SR        = SAMPLE_RATE
BAND_LO   = 12000
BAND_HI   = 18000
NFFT      = 2048
HOP       = 512
OVERSUB   = 1.2
QUIET_PCT = 0.20

def _stft(x: torch.Tensor) -> torch.Tensor:
    if x.dim() == 2: x = x.squeeze(0)
    return torch.stft(x, n_fft=NFFT, hop_length=HOP, win_length=NFFT,
                      window=torch.hann_window(NFFT), return_complex=True, center=True)

def _istft(S: torch.Tensor, length: int) -> torch.Tensor:
    """
    Inverse STFT that accepts complex tensors when supported,
    and falls back to view_as_real format otherwise.
    Always uses a real-valued Hann window (float32).
    """
    win = torch.hann_window(NFFT, device=S.device, dtype=torch.float32)
    try:
        y = torch.istft(
            S, n_fft=NFFT, hop_length=HOP, win_length=NFFT,
            window=win, length=length, center=True
        )
    except (TypeError, RuntimeError):
        y = torch.istft(
            torch.view_as_real(S), n_fft=NFFT, hop_length=HOP, win_length=NFFT,
            window=win, length=length, center=True
        )
    return y

def bandpass_chunk(waveform: torch.Tensor, sr: int) -> torch.Tensor:
    """
    Band-pass in STFT domain by zeroing bins outside [BAND_LO, BAND_HI].
    Version-agnostic; avoids torchaudio biquad/lfilter kernels.
    Input:  waveform [1, T] (float32)
    Output: [1, T]
    """
    if waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)
    x = waveform.squeeze(0).to(torch.float32)          # [T]

    # STFT
    S = torch.stft(x, n_fft=NFFT, hop_length=HOP, win_length=NFFT,
                   window=torch.hann_window(NFFT, dtype=torch.float32),
                   return_complex=True, center=True)   # [F, T], complex

    # Frequency mask
    freqs = np.fft.rfftfreq(NFFT, d=1.0/sr)
    lo = int(np.searchsorted(freqs, BAND_LO))
    hi = int(np.searchsorted(freqs, BAND_HI))
    lo = max(lo, 0); hi = min(hi, S.shape[0])

    mask = torch.zeros_like(S, dtype=torch.bool)       # [F, T]
    mask[lo:hi, :] = True
    S_bp = torch.where(mask, S, torch.zeros_like(S))

    # iSTFT back to time (compat wrapper)
    y = _istft(S_bp, length=x.numel())                 # [T]
    y = y - y.mean()
    y = torch.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)
    return y.unsqueeze(0)

In [None]:
def spectral_subtract_quiet_frames(waveform: torch.Tensor, sr: int) -> torch.Tensor:
    # Squeeze the waveform to remove unnecessary dimensions, resulting in a 1D tensor [T]
    x = waveform.squeeze(0)
    Tlen = x.shape[-1]  # Get the total length of the waveform

    # Perform Short-Time Fourier Transform (STFT) to convert the waveform to the frequency domain
    S = _stft(x)  # [F, T] complex tensor where F is frequency bins and T is time frames
    Mag = S.abs()  # Magnitude of the STFT
    Pow = Mag**2  # Power spectrum

    # Define the frequency range of interest (band-pass filter)
    freqs = np.fft.rfftfreq(NFFT, d=1.0/sr)  # Generate frequency bins
    lo = int(np.searchsorted(freqs, BAND_LO))  # Find the lower bound index for the band
    hi = int(np.searchsorted(freqs, BAND_HI))  # Find the upper bound index for the band
    lo = max(lo, 0)  # Ensure the lower bound is within range
    hi = min(hi, Mag.shape[0])  # Ensure the upper bound is within range

    # Calculate the average power within the band for each time frame
    band_pow_per_frame = Pow[lo:hi].mean(dim=0)  # [T_frames]
    T_frames = band_pow_per_frame.numel()  # Total number of time frames

    # Identify the quietest frames based on the specified percentage (QUIET_PCT)
    k = max(1, int(round(QUIET_PCT * T_frames)))  # Number of quietest frames to select
    vals, idxs = torch.topk(-band_pow_per_frame, k)  # Get indices of the quietest frames
    quiet_mask = torch.zeros_like(band_pow_per_frame, dtype=torch.bool)  # Initialize mask
    quiet_mask[idxs] = True  # Mark the quietest frames in the mask

    # Estimate the noise power spectral density (PSD) from the quiet frames
    Npsd = Pow[:, quiet_mask].mean(dim=1, keepdim=True)  # Noise PSD

    # Perform spectral subtraction to remove noise
    Pclean = torch.clamp(Pow - OVERSUB * Npsd, min=0.0)  # Cleaned power spectrum
    Mag_clean = torch.sqrt(Pclean + 1e-12)  # Cleaned magnitude spectrum

    # Reconstruct the cleaned STFT using the original phase information
    S_clean = Mag_clean * torch.exp(1j * S.angle())

    # Perform inverse STFT to convert back to the time domain
    y_clean = _istft(S_clean, length=Tlen)  # [T]

    # Remove DC offset and handle any NaN or infinite values
    y_clean = y_clean - y_clean.mean()
    y_clean = torch.nan_to_num(y_clean, nan=0.0, posinf=0.0, neginf=0.0)

    # Return the cleaned waveform as a 2D tensor [1, T]
    return y_clean.unsqueeze(0)

def clean_chunk(waveform: torch.Tensor, sr: int) -> torch.Tensor:
    # Ensure the waveform is a 2D tensor [1, T] by averaging or unsqueezing if necessary
    if waveform.dim() == 2 and waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)  # Average across channels
    elif waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)  # Add a batch dimension

    # Resample the waveform to the target sample rate (SR) if needed
    if sr != SR:
        waveform = torchaudio.functional.resample(waveform, sr, SR)
        sr = SR

    # Apply band-pass filtering to isolate the frequency range of interest
    y_bp = bandpass_chunk(waveform, sr)

    # Perform spectral subtraction to remove noise from the filtered waveform
    y_cl = spectral_subtract_quiet_frames(y_bp, sr)

    # Return the cleaned waveform
    return y_cl

In [None]:
# DATASET
class AudioImageDataset(Dataset):
    def __init__(self, audio_dir, image_dir, label_map, transform_image=None,
                 sample_rate=SAMPLE_RATE, n_mels=N_MELS, use_filters=True):
        # Initialize dataset parameters
        self.audio_dir = audio_dir  # Directory containing audio files
        self.image_dir = image_dir  # Directory containing image files
        self.label_map = label_map  # Mapping of labels to integers
        self.transform_image = transform_image  # Image transformation pipeline
        self.sample_rate = sample_rate  # Target sample rate for audio
        self.n_mels = n_mels  # Number of mel bins for spectrogram
        self.use_filters = use_filters  # Whether to apply audio cleaning filters

        # Initialize lists to store dataset information
        self.audio_files, self.labels, self.video_ids = [], [], []
        for file in os.listdir(audio_dir):  # Iterate through audio directory
            if file.endswith(".wav") and "_seg" in file:  # Check for segmented audio files
                frame_file = file.replace("seg", "frame").replace(".wav", ".jpg")  # Corresponding image file
                if not os.path.exists(os.path.join(image_dir, frame_file)):
                    continue  # Skip if corresponding image file does not exist
                base_name = file.split("_seg")[0]  # Extract base name of the file
                label_str = base_name.split("_")[0].lower()  # Extract label from file name
                if label_str in label_map:  # Check if label is valid
                    self.audio_files.append(file)  # Add audio file to list
                    self.labels.append(label_map[label_str])  # Add label to list
                    self.video_ids.append(file)  # Add video ID to list

        # Define audio transformations: Mel-spectrogram and amplitude-to-decibel conversion
        self.mel = T.MelSpectrogram(sample_rate=self.sample_rate, n_fft=1024,
                                    hop_length=512, n_mels=self.n_mels)
        self.db  = T.AmplitudeToDB()  # Convert amplitude to decibels

    def __len__(self):
        # Return the total number of samples in the dataset
        return len(self.audio_files)

    def __getitem__(self, idx):
        # Retrieve the audio file, label, and video ID for the given index
        audio_file = self.audio_files[idx]  # Audio file name
        label      = self.labels[idx]  # Corresponding label
        seg_id     = self.video_ids[idx]  # Segment ID

        # ---- AUDIO ----
        a_path = os.path.join(self.audio_dir, audio_file)  # Full path to audio file
        waveform, sr = torchaudio.load(a_path)  # Load audio waveform and sample rate
        if self.use_filters:
            waveform = clean_chunk(waveform, sr)  # Apply audio cleaning filters
        else:
            if sr != self.sample_rate:
                waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)  # Resample audio

        # Convert waveform to mel-spectrogram
        spec = self.mel(waveform)  # [1, n_mels, time]
        spec = self.db(spec)  # Convert to decibel scale
        spec = (spec - spec.mean()) / (spec.std() + 1e-6)  # Normalize spectrogram
        spec = F.interpolate(spec.unsqueeze(0), size=(224,224), mode="bilinear", align_corners=False)  # Resize to 224x224
        spec = spec.mean(dim=1)  # Reduce to single channel [1,224,224]

        # ---- IMAGE ----
        frame_file = audio_file.replace("seg", "frame").replace(".wav", ".jpg")  # Corresponding image file name
        img_path = os.path.join(self.image_dir, frame_file)  # Full path to image file
        img = Image.open(img_path).convert("RGB")  # Load image and convert to RGB
        if self.transform_image:
            img = self.transform_image(img)  # Apply image transformations

        # Return the processed spectrogram, image, label, and segment ID
        return spec, img, label, seg_id