# Audio Embedding Extraction

This notebook extracts embeddings from speech/audio data for Alzheimer's detection.

**Supported audio sources:**
- ADReSS / ADReSSo challenge datasets
- DementiaBank Pitt Corpus
- Any WAV files organized by class folders

**Pipeline:** WAV → Log-Mel Spectrogram → ResNet-18 → 256-D Embedding → Compressed .npz

---

**Instructions:**
1. Select GPU runtime: Runtime → Change runtime type → T4 GPU
2. Upload or download your audio dataset
3. Run all cells
4. Embeddings saved to Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
PROJECT_DIR = '/content/drive/MyDrive/alzheimer-research'
os.makedirs(PROJECT_DIR, exist_ok=True)

REPO_DIR = '/content/alzheimer-research'
if not os.path.exists(REPO_DIR):
    !git clone https://github.com/YOUR_USERNAME/alzheimer-research.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull

os.chdir(REPO_DIR)
!pip install -q -r requirements.txt

In [None]:
import sys
sys.path.insert(0, REPO_DIR)

import numpy as np
import torch
import torchaudio
from pathlib import Path
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

## Configure Data Source

Organize your audio files into class folders:
```
audio_data/
  NonDemented/
    file1.wav
    file2.wav
  VeryMildDemented/
    ...
  MildDemented/
    ...
  ModerateDemented/
    ...
```

If your dataset uses different class names (e.g., `cc` and `cd` for ADReSS),
update the `LABEL_MAP` below.

In [None]:
AUDIO_DATA_DIR = '/tmp/audio_data'  # Update this path

# Label mapping: folder name -> ordinal class index
# Default: 4-class CDR staging
LABEL_MAP = {
    'NonDemented': 0,
    'VeryMildDemented': 1,
    'MildDemented': 2,
    'ModerateDemented': 3,
}

# Alternative for ADReSS (binary classification → adapt for ordinal):
# LABEL_MAP = {
#     'cc': 0,  # Control (non-demented)
#     'cd': 2,  # Dementia (map to MildDemented)
# }

# Audio parameters
TARGET_SAMPLE_RATE = 16000
MAX_AUDIO_LENGTH = TARGET_SAMPLE_RATE * 10  # 10 seconds max
EMBED_DIM = 256

In [None]:
from models.audio_cnn import AudioCNN

model = AudioCNN(
    embed_dim=EMBED_DIM,
    pretrained=True,
    sample_rate=TARGET_SAMPLE_RATE,
    n_mels=128,
    n_fft=1024,
    hop_length=512,
    from_spectrogram=False,  # Input is raw waveform
).to(device)
model.eval()
print(f'AudioCNN loaded (embed_dim={EMBED_DIM})')

In [None]:
def load_and_preprocess_audio(filepath, target_sr=16000, max_length=None):
    """Load audio file, resample, and pad/truncate to fixed length."""
    waveform, sr = torchaudio.load(filepath)

    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # Resample if needed
    if sr != target_sr:
        resampler = torchaudio.transforms.Resample(sr, target_sr)
        waveform = resampler(waveform)

    # Pad or truncate
    if max_length is not None:
        if waveform.shape[1] > max_length:
            waveform = waveform[:, :max_length]
        elif waveform.shape[1] < max_length:
            padding = max_length - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, padding))

    return waveform  # (1, T)

In [None]:
# Extract audio embeddings
all_embeddings = []
all_labels = []
all_filenames = []

for class_name, label in LABEL_MAP.items():
    class_dir = os.path.join(AUDIO_DATA_DIR, class_name)
    if not os.path.exists(class_dir):
        print(f'Skipping {class_name}: directory not found')
        continue

    audio_files = [f for f in sorted(os.listdir(class_dir))
                   if f.endswith(('.wav', '.mp3', '.flac', '.ogg'))]
    print(f'Processing {class_name} ({len(audio_files)} files)...')

    for fname in tqdm(audio_files, desc=class_name):
        fpath = os.path.join(class_dir, fname)
        try:
            waveform = load_and_preprocess_audio(
                fpath, TARGET_SAMPLE_RATE, MAX_AUDIO_LENGTH
            )
            # Add batch dim: (1, 1, T)
            waveform = waveform.unsqueeze(0).to(device)

            with torch.no_grad():
                emb = model.extract_embedding(waveform)

            all_embeddings.append(emb.cpu().numpy())
            all_labels.append(label)
            all_filenames.append(fname)
        except Exception as e:
            print(f'Error processing {fname}: {e}')

if all_embeddings:
    embeddings = np.concatenate(all_embeddings, axis=0)
    labels = np.array(all_labels)
    print(f'\nExtracted {len(embeddings)} audio embeddings of dimension {embeddings.shape[1]}')
else:
    print('No audio files processed. Check your data directory and LABEL_MAP.')

In [None]:
# Save compressed embeddings to Google Drive
import pandas as pd

SAVE_DIR = os.path.join(PROJECT_DIR, 'data_embeddings')
os.makedirs(SAVE_DIR, exist_ok=True)

if all_embeddings:
    emb_path = os.path.join(SAVE_DIR, 'audio_embeddings.npz')
    np.savez_compressed(emb_path, embeddings=embeddings.astype(np.float16))
    print(f'Saved audio embeddings to {emb_path}')
    print(f'File size: {os.path.getsize(emb_path) / 1024 / 1024:.2f} MB')

    # Save labels (or append to existing)
    labels_path = os.path.join(SAVE_DIR, 'audio_labels.csv')
    df = pd.DataFrame({
        'filename': all_filenames,
        'label': labels,
        'class_name': [list(LABEL_MAP.keys())[l] for l in labels],
    })
    df.to_csv(labels_path, index=False)
    print(f'Saved labels to {labels_path}')

In [None]:
# Cleanup
import shutil
if os.path.exists(AUDIO_DATA_DIR):
    shutil.rmtree(AUDIO_DATA_DIR)
    print(f'Cleaned up {AUDIO_DATA_DIR}')

print('Done! Audio embeddings saved to Google Drive.')