<a href="https://colab.research.google.com/github/chisomrutherford/HeAR_asthma_classification/blob/main/Generate_Audio_Embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# mount Google Drive

from google.colab import drive
drive.mount('/content/drive')

In [None]:
#Download SPRS repo to Google Drive
%cd /content/drive/MyDrive/
!git clone https://github.com/SJTU-YONGFU-RESEARCH-GRP/SPRSound.git

In [None]:
# import dependencies

#from transformers import AutoProcessor, TFAutoModel
import tensorflow as tf
import numpy as np
import os
import librosa
import soundfile as sf
import json
import pandas as pd

In [None]:
from huggingface_hub import login
login(new_session=False)

In [None]:
# Load Google's HeAR model for generating embeddings

from huggingface_hub import from_pretrained_keras

model = from_pretrained_keras("google/hear")

In [None]:
# import .wav and .json folders from Google Drive

TRAIN_WAV_DIR = "/content/drive/MyDrive/SPRSound/Classification/train_classification_wav"
TRAIN_JSON_DIR = "/content/drive/MyDrive/SPRSound/Classification/train_classification_json"

VALID_WAV_DIR = "/content/drive/MyDrive/SPRSound/Classification/valid_classification_wav"
VALID_JSON_DIR = "/content/drive/MyDrive/SPRSound/Classification/valid_classification_json"

TRAIN_OUT_DIR = "/content/asthma_clips/train"
VALID_OUT_DIR = "/content/asthma_clips/valid"

os.makedirs(TRAIN_OUT_DIR, exist_ok=True)
os.makedirs(VALID_OUT_DIR, exist_ok=True)


In [None]:
included_labels = ['Normal', 'Wheeze', 'Wheeze+Crackle', 'Rhonchi', 'Stridor']

def process_record(wav_path, json_path, output_dir):
    """
    Processes a respiratory sound recording into 2-second audio clips based on event-level annotations.

    Parameters:
    ----------
    wav_path : str
        Path to the input `.wav` file (original full-length respiratory sound).

    json_path : str
        Path to the corresponding `.json` file containing event-level annotations.

    output_dir : str
        Directory where the extracted and standardized audio clips will be saved.

    Returns:
    -------
    List[Tuple[str, int, str]]
        A list of tuples, each containing:
            - the saved clip filename,
            - the binary label (0 = Normal, 1 = Abnormal),
            - the original event label (e.g., "Wheeze", "Rhonchi").

    Description:
    -----------
    - The function loads the `.wav` file and reads its event-level annotations from the `.json` file.
    - Only medically meaningful labels are included (as defined in `included_labels`).
    - Recordings labeled as "Poor Quality" are skipped entirely.
    - For each valid event, the audio segment is:
        1. Extracted based on start and end timestamps (in milliseconds),
        2. Resampled to 16kHz mono,
        3. Standardized to 2 seconds via truncation or zero-padding,
        4. Saved as a new `.wav` clip in the output directory.
    - The filename format is: `<original_id>_<event_index>_<binary_label>.wav`
    """
    y, _ = librosa.load(wav_path, sr=sr, mono=True)

    with open(json_path) as f:
        meta = json.load(f)

    # Skip if recording is marked as "Poor Quality"
    if meta.get("recording_annotation") == "Poor Quality":
        return []

    events = meta.get("event_annotation", [])
    base = os.path.splitext(os.path.basename(wav_path))[0]
    saved = []

    for idx, event in enumerate(events):
        label = event['type']
        if label not in included_labels:
            continue

        # Convert start/end from ms to samples
        start_sample = int((float(event['start']) / 1000.0) * sr)
        end_sample = int((float(event['end']) / 1000.0) * sr)
        segment = y[start_sample:end_sample]

        # Standardize length to exactly 2 seconds
        if len(segment) < clip_len:
            segment = librosa.util.fix_length(segment, size=clip_len)
        else:
            segment = segment[:clip_len]

        # Binary classification: 0 = Normal, 1 = Abnormal
        binary_label = 0 if label == "Normal" else 1

        # Save the clip
        outname = f"{base}_{idx}_{binary_label}.wav"
        outpath = os.path.join(output_dir, outname)
        sf.write(outpath, segment, sr)

        saved.append((outname, binary_label, label))

    return saved

print('Function Called')


In [None]:

def batch_process(wav_dir, json_dir, out_dir, metadata_path):
    """
    Processes a batch of respiratory sound recordings and their corresponding JSON annotations.

    For each `.wav` file in `wav_dir`, this function:
      - Finds the corresponding `.json` annotation in `json_dir`
      - Extracts and saves 2-second labeled audio segments to `out_dir`
      - Collects metadata about each segment (filename, binary label, original label)

    Finally, it saves all metadata to a CSV at `metadata_path`.

    Parameters:
    -----------
    wav_dir : str
        Path to directory containing .wav audio files.

    json_dir : str
        Path to directory containing .json annotation files (same base names as .wav files).

    out_dir : str
        Directory where processed 2-second clips will be saved.

    metadata_path : str
        Path to the CSV file where metadata for all saved clips will be written.

    Returns:
    --------
    pd.DataFrame
        A DataFrame containing metadata for each saved audio segment, with columns:
        ["filename", "binary_label", "original_label"]
    """

    all_metadata = []  # List to store metadata for all processed clips

    # Loop through each .wav file in the audio directory
    for f in os.listdir(wav_dir):
        if f.endswith(".wav"):
            wav_path = os.path.join(wav_dir, f)
            json_path = os.path.join(json_dir, f.replace(".wav", ".json"))

            # Process only if a corresponding .json annotation file exists
            if os.path.exists(json_path):
                # Process the .wav and .json pair to extract valid segments
                segments = process_record(wav_path, json_path, out_dir)
                all_metadata.extend(segments)  # Add all returned segments to the metadata list

    # Create a DataFrame from the collected metadata
    df = pd.DataFrame(all_metadata, columns=["filename", "binary_label", "original_label"])

    # Save metadata as a CSV file
    df.to_csv(metadata_path, index=False)

    print(f"Done. Saved {len(df)} clips to {out_dir}")
    return df


In [None]:
# Set directory paths
train_wav_dir = "/content/drive/MyDrive/SPRSound/Classification/train_classification_wav"
train_json_dir = "/content/drive/MyDrive/SPRSound/Classification/train_classification_json"
train_out_dir = "/content/asthma_clips/train"
train_metadata_csv = "/content/asthma_clips/train_metadata.csv"

valid_wav_dir = "/content/drive/MyDrive/SPRSound/Classification/valid_classification_wav/2022"
valid_json_dir = "/content/drive/MyDrive/SPRSound/Classification/valid_classification_json/2022/intra_test_json"
valid_out_dir = "/content/asthma_clips/valid"
valid_metadata_csv = "/content/asthma_clips/valid_metadata.csv"

# Make sure output folders exist
os.makedirs(train_out_dir, exist_ok=True)
os.makedirs(valid_out_dir, exist_ok=True)

# Process both sets
#train_df = batch_process(train_wav_dir, train_json_dir, train_out_dir, train_metadata_csv)
#valid_df = batch_process(valid_wav_dir, valid_json_dir, valid_out_dir, valid_metadata_csv)


In [None]:
from tqdm import tqdm

def batch_process(wav_dir, json_dir, out_dir, metadata_path):
    """
    Batch processes respiratory sound recordings and their annotations.

    For each `.wav` file in `wav_dir`, the function:
      - Locates the corresponding `.json` annotation in `json_dir`
      - Uses `process_record()` to extract 2-second labeled audio segments
      - Saves each segment to `out_dir`
      - Records metadata (filename, binary label, and original label)

    A progress bar is displayed during processing, and a metadata CSV is written to `metadata_path`.

    Parameters:
    -----------
    wav_dir : str
        Directory containing .wav audio files.

    json_dir : str
        Directory containing corresponding .json annotation files.

    out_dir : str
        Directory to save the processed 2-second audio clips.

    metadata_path : str
        File path to save the resulting metadata CSV.

    Returns:
    --------
    pd.DataFrame
        DataFrame containing metadata with columns:
        ["filename", "binary_label", "original_label"]
    """
    all_metadata = []  # List to collect metadata for all audio segments

    # Get list of all .wav files (non-recursive)
    wav_files = [f for f in os.listdir(wav_dir) if f.endswith(".wav")]

    # Loop through each .wav file with a progress bar
    for f in tqdm(wav_files, desc=f"Processing {os.path.basename(out_dir)}", unit="file"):
        wav_path = os.path.join(wav_dir, f)
        json_path = os.path.join(json_dir, f.replace(".wav", ".json"))

        # Proceed only if the corresponding .json file exists
        if os.path.exists(json_path):
            # Extract segments using process_record
            segments = process_record(wav_path, json_path, out_dir)
            all_metadata.extend(segments)

    # Create a DataFrame from collected metadata
    df = pd.DataFrame(all_metadata, columns=["filename", "binary_label", "original_label"])

    # Save metadata to CSV
    df.to_csv(metadata_path, index=False)

    print(f"Done. Saved {len(df)} clips to {out_dir}")
    return df


In [None]:


def batch_process_recursive(wav_dir, json_dir, out_dir, metadata_path):
    """
    Recursively processes respiratory audio files and corresponding annotation files,
    extracting 2-second labeled clips using `process_record()` and saving them with metadata.

    This function:
      - Walks through all subdirectories in `wav_dir` to find `.wav` files.
      - For each `.wav`, computes the expected relative path to its `.json` annotation in `json_dir`.
      - Uses `process_record()` to extract and save labeled audio segments into `out_dir`.
      - Records metadata (filename, binary label, original label) for each saved clip.
      - Saves the metadata to a CSV at `metadata_path`.

    Parameters:
    -----------
    wav_dir : str
        Root directory containing .wav audio files (can include nested folders).

    json_dir : str
        Root directory containing corresponding .json annotation files (same folder structure as wav_dir).

    out_dir : str
        Directory to save the extracted 2-second audio clips.

    metadata_path : str
        Path to save the resulting metadata CSV file.

    Returns:
    --------
    pd.DataFrame
        A DataFrame containing metadata for all saved audio segments, with columns:
        ["filename", "binary_label", "original_label"]
    """

    all_metadata = []  # To collect metadata from all processed segments
    all_wav_paths = []  # To collect all .wav files recursively

    # Step 1: Walk through wav_dir recursively and collect all .wav file paths
    for root, _, files in os.walk(wav_dir):
        for f in files:
            if f.endswith(".wav"):
                all_wav_paths.append(os.path.join(root, f))

    # Step 2: Process each wav file using a progress bar
    for wav_path in tqdm(all_wav_paths, desc=f"Processing {os.path.basename(out_dir)}", unit="file"):
        # Calculate the relative path from wav_dir to the current .wav
        # and derive the corresponding json path
        rel_path = os.path.relpath(wav_path, wav_dir).replace(".wav", ".json")
        json_path = os.path.join(json_dir, rel_path)

        # Step 3: Process the file if the corresponding JSON exists
        if os.path.exists(json_path):
            segments = process_record(wav_path, json_path, out_dir)
            all_metadata.extend(segments)

    # Step 4: Save all metadata to CSV
    df = pd.DataFrame(all_metadata, columns=["filename", "binary_label", "original_label"])
    df.to_csv(metadata_path, index=False)

    print(f"Done. Saved {len(df)} clips to {out_dir}")
    return df


In [None]:
train_df = batch_process(
    wav_dir=train_wav_dir,
    json_dir=train_json_dir,
    out_dir=train_out_dir,
    metadata_path=train_metadata_csv
)

In [None]:
valid_df = batch_process_recursive(
    wav_dir="/content/drive/MyDrive/SPRSound/Classification/valid_classification_wav/2022",
    json_dir="/content/drive/MyDrive/SPRSound/Classification/valid_classification_json/2022/intra_test_json",
    out_dir="/content/asthma_clips/valid",
    metadata_path="/content/asthma_clips/valid_metadata.csv"
)


In [None]:
valid_df = pd.read_csv("/content/asthma_clips/valid_metadata.csv")
print("Loaded valid_metadata CSV")
print(valid_df.head())

train_df = pd.read_csv("/content/asthma_clips/train_metadata.csv")
print("Loaded train_metadata CSV")
print(valid_df.head())


In [None]:
valid_df['original_label'].value_counts()

In [None]:
import librosa

def extract_hear_embedding(wav_path, model):
    """
    Extracts a 512-dimensional audio embedding from a .wav file using a
    TensorFlow-based model compatible with the HEAR (Holistic Evaluation of Audio Representations) benchmark.

    This function:
      - Loads a mono waveform from `wav_path` at the target sample rate `SR`
      - Trims or pads the waveform to exactly `CLIP_LEN` samples (e.g., 2 seconds at 16kHz = 32000 samples)
      - Feeds it into the model via its 'serving_default' signature
      - Returns the resulting audio embedding as a 1D NumPy array

    Parameters:
    -----------
    wav_path : str
        Path to the audio file to process (.wav format).

    model : tf.saved_model
        A TensorFlow model loaded via `tf.saved_model.load(...)` that exposes
        a 'serving_default' signature and returns an 'audio_embedding' tensor.

    Returns:
    --------
    np.ndarray
        A 1D NumPy array of shape (512,) representing the audio embedding.
    """

    # Load audio file and resample to SR
    y, _ = librosa.load(wav_path, sr=SR)

    # Ensure the clip is exactly CLIP_LEN samples long
    y = y[:CLIP_LEN]
    if len(y) < CLIP_LEN:
        y = np.pad(y, (0, CLIP_LEN - len(y)))

    # Prepare as a batch of size 1 (shape: [1, CLIP_LEN])
    batch = np.expand_dims(y, 0).astype(np.float32)

    # Use the model's signature to get the embedding
    embedding_tensor = model.signatures['serving_default'](tf.constant(batch))['audio_embedding']

    # Convert to NumPy and return as 1D array
    return embedding_tensor.numpy().squeeze()  # shape: (512,)


In [None]:
#Generate embeddings

from tqdm.notebook import tqdm

SR = 16000
CLIP_LEN = 2 * SR
BATCH_SIZE = 8  # tune based on GPU memory

# Re‑load model under GPU
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras("google/hear")
infer = model.signatures['serving_default']

def load_clip(path):
    """
    Loads an audio clip from the specified file path, resamples it to the target sample rate (SR),
    and ensures it is exactly CLIP_LEN samples long by trimming or zero-padding as needed.

    Args:
        path (str): Path to the .wav audio file.

    Returns:
        np.ndarray: A 1D NumPy array of length CLIP_LEN representing the audio signal.
    """
    y, _ = librosa.load(path, sr=SR)         # Load audio and resample to SR
    y = y[:CLIP_LEN]                         # Trim to maximum allowed length
    if len(y) < CLIP_LEN:
        y = np.pad(y, (0, CLIP_LEN - len(y)))  # Pad with zeros if shorter than CLIP_LEN
    return y


def batched_generate(clip_dir, metadata_csv, out_prefix):
    """
    Generates audio embeddings in batches from a directory of 2-second clips
    and a corresponding metadata CSV.

    This function:
    - Loads audio clip paths and binary labels from the CSV
    - Processes clips in batches using the preloaded `infer` model
    - Extracts 512-dimensional embeddings for each clip
    - Saves the embeddings, labels, and filenames to disk

    Args:
        clip_dir (str): Directory containing preprocessed 2-second .wav files.
        metadata_csv (str): Path to a CSV file with columns ['filename', 'binary_label'].
        out_prefix (str): Prefix for output files (without extension).

    Outputs:
        - {out_prefix}_embeddings.npy: Numpy array of shape (N, 512)
        - {out_prefix}_labels.npy: Numpy array of shape (N,)
        - {out_prefix}_filenames.txt: Text file with one filename per line
    """
    df = pd.read_csv(metadata_csv)  # Load metadata CSV
    paths = [os.path.join(clip_dir, fn) for fn in df.filename]  # Full paths to each clip
    labels = df.binary_label.values  # Corresponding binary labels

    all_embs, all_labels, all_fns = [], [], []  # Lists to collect outputs

    # Process clips in batches
    for i in tqdm(range(0, len(paths), BATCH_SIZE), desc=out_prefix):
        batch_paths = paths[i : i + BATCH_SIZE]  # Clip paths for current batch
        batch_fns   = df.filename.values[i : i + BATCH_SIZE]  # Filenames for batch
        batch_lbls  = labels[i : i + BATCH_SIZE]  # Labels for batch

        # Load and stack waveforms
        clips = np.stack([load_clip(p) for p in batch_paths]).astype(np.float32)

        # Inference on GPU using HEAR model
        out = infer(x=tf.constant(clips))['output_0'].numpy()

        # Collect embeddings, labels, filenames
        all_embs.append(out)
        all_labels.append(batch_lbls)
        all_fns.extend(batch_fns)

    # Concatenate all batches
    X = np.vstack(all_embs)
    y = np.concatenate(all_labels)

    # Save to disk
    np.save(f"{out_prefix}_embeddings.npy", X)
    np.save(f"{out_prefix}_labels.npy", y)
    with open(f"{out_prefix}_filenames.txt", "w") as f:
        f.writelines(fn + "\n" for fn in all_fns)

    print(f"Saved {X.shape[0]} embeddings as {out_prefix}_embeddings.npy")


# Run for train & valid
batched_generate("/content/asthma_clips/train", "/content/asthma_clips/train_metadata.csv", "/content/train")
batched_generate("/content/asthma_clips/valid", "/content/asthma_clips/valid_metadata.csv", "/content/valid")


In [None]:
# Save embeddings to Google Drive

!mkdir -p "/content/drive/MyDrive/HeAR_Embeddings"

# Train files
!cp /content/train_embeddings.npy /content/drive/MyDrive/HeAR_Embeddings/
!cp /content/train_labels.npy /content/drive/MyDrive/HeAR_Embeddings/
!cp /content/train_filenames.txt /content/drive/MyDrive/HeAR_Embeddings/

# Validation files
!cp /content/valid_embeddings.npy /content/drive/MyDrive/HeAR_Embeddings/
!cp /content/valid_labels.npy /content/drive/MyDrive/HeAR_Embeddings/
!cp /content/valid_filenames.txt /content/drive/MyDrive/HeAR_Embeddings/
