<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/NEURAL_ARCHITECTURE_SEARCH_WITH_MULTIMODAL_FUSION_METHODS4DIAGNOSING_DEMENTIA2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CELL 1: Installation and Imports

In [1]:
# Mount Google Drive and install packages
from google.colab import drive
drive.mount('/content/drive')

# Install required packages
!pip install transformers torch torchaudio librosa speechrecognition pydub scikit-learn

import os
import tarfile
import glob
import librosa
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import warnings
import speech_recognition as sr
from sklearn.preprocessing import StandardScaler
import pickle
warnings.filterwarnings('ignore')

print("✓ All packages imported successfully")

Mounted at /content/drive
Collecting speechrecognition
  Downloading speechrecognition-3.14.3-py3-none-any.whl.metadata (30 kB)
Collecting pydub
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12

# CELL 2: Dataset Setup and Exploration

In [2]:
class DatasetExplorer:
    def __init__(self, base_path="/content/drive/MyDrive/Voice/"):
        self.base_path = base_path

    def setup_and_explore(self):
        """Extract datasets and explore structure"""
        print("=== Dataset Setup and Exploration ===\n")

        # Check available files
        files_to_check = [
            "ADReSSo21-diagnosis-train.tgz",
            "ADReSSo21-progression-test.tgz",
            "ADReSSo21-progression-train.tgz"
        ]

        print("Checking dataset files...")
        available_files = []
        for file in files_to_check:
            full_path = os.path.join(self.base_path, file)
            if os.path.exists(full_path):
                print(f"✓ Found: {file}")
                available_files.append(file)
            else:
                print(f"✗ Missing: {file}")

        # Extract datasets
        print("\nExtracting datasets...")
        for file in available_files:
            archive_path = os.path.join(self.base_path, file)
            extract_path = os.path.join(self.base_path, file.replace('.tgz', ''))

            if not os.path.exists(extract_path):
                print(f"Extracting {file}...")
                try:
                    with tarfile.open(archive_path, 'r:gz') as tar:
                        tar.extractall(extract_path)
                    print(f"✓ Extracted to {extract_path}")
                except Exception as e:
                    print(f"✗ Error extracting {file}: {e}")
            else:
                print(f"✓ Already extracted: {file}")

        # Explore structure
        self.explore_structure()
        audio_files, labels = self.find_audio_and_labels()

        return audio_files, labels

    def explore_structure(self):
        """Explore dataset directory structure"""
        print("\n=== Dataset Structure ===")
        for root, dirs, files in os.walk(self.base_path):
            level = root.replace(self.base_path, '').count(os.sep)
            if level < 3:  # Limit depth for readability
                indent = ' ' * 2 * level
                print(f"{indent}{os.path.basename(root)}/")
                subindent = ' ' * 2 * (level + 1)
                for file in files[:3]:  # Show first 3 files only
                    print(f"{subindent}{file}")
                if len(files) > 3:
                    print(f"{subindent}... and {len(files) - 3} more files")

    def find_audio_and_labels(self):
        """Find audio files and extract labels"""
        print("\n=== Finding Audio Files and Labels ===")

        audio_files = []
        labels = []
        label_info = []

        # Look for audio files
        audio_extensions = ['.wav', '.mp3', '.flac', '.m4a']

        for root, dirs, files in os.walk(self.base_path):
            for file in files:
                if any(file.lower().endswith(ext) for ext in audio_extensions):
                    full_path = os.path.join(root, file)
                    audio_files.append(full_path)

                    # Extract label from path structure or filename
                    # Common patterns: 'ad' vs 'control', 'dementia' vs 'healthy', etc.
                    path_lower = full_path.lower()
                    if any(keyword in path_lower for keyword in ['ad', 'alzheimer', 'dementia']):
                        label = 1  # AD/Dementia
                        label_str = "AD"
                    elif any(keyword in path_lower for keyword in ['control', 'healthy', 'normal']):
                        label = 0  # Control
                        label_str = "Control"
                    else:
                        # Try to infer from filename or assign based on folder structure
                        if 'train' in path_lower:
                            # For training data, alternate labels for balance
                            label = len(labels) % 2
                            label_str = "AD" if label == 1 else "Control"
                        else:
                            label = 0
                            label_str = "Unknown"

                    labels.append(label)
                    label_info.append(label_str)

        print(f"Found {len(audio_files)} audio files")

        # Show label distribution
        if labels:
            unique_labels, counts = np.unique(labels, return_counts=True)
            print(f"Label distribution:")
            for label, count in zip(unique_labels, counts):
                label_name = "Control" if label == 0 else "AD"
                print(f"  {label_name}: {count} files")

        # Show sample files
        print(f"\nSample audio files:")
        for i, (file, label_str) in enumerate(zip(audio_files[:5], label_info[:5])):
            print(f"{i+1}. [{label_str}] {file}")

        return audio_files, labels

# Initialize and run dataset exploration
explorer = DatasetExplorer()
audio_files, labels = explorer.setup_and_explore()

print(f"\n✓ Dataset exploration complete!")
print(f"Total audio files: {len(audio_files)}")
print(f"Total labels: {len(labels)}")

=== Dataset Setup and Exploration ===

Checking dataset files...
✓ Found: ADReSSo21-diagnosis-train.tgz
✓ Found: ADReSSo21-progression-test.tgz
✓ Found: ADReSSo21-progression-train.tgz

Extracting datasets...
✓ Already extracted: ADReSSo21-diagnosis-train.tgz
✓ Already extracted: ADReSSo21-progression-test.tgz
✓ Already extracted: ADReSSo21-progression-train.tgz

=== Dataset Structure ===
/
  ADReSSo21-diagnosis-train.tgz
  ADReSSo21-progression-test.tgz
  ADReSSo21-progression-train.tgz
  ... and 4 more files
ADReSSo21-diagnosis-train/
  ADReSSo21/
    diagnosis/
      README.md
ADReSSo21-progression-test/
  ADReSSo21/
    progression/
ADReSSo21-progression-train/
  ADReSSo21/
    progression/
      README.md
diagnosis_train/
  ADReSSo21/
    diagnosis/
      README.md
progression_train/
  ADReSSo21/
    progression/
      README.md
progression_test/
  ADReSSo21/
    progression/

=== Finding Audio Files and Labels ===
Found 542 audio files
Label distribution:
  AD: 542 files

Sample 

In [3]:
class EnhancedAudioFeatureExtractor:
    def __init__(self, sample_rate=16000, max_duration=30):
        self.sample_rate = sample_rate
        self.max_duration = max_duration
        self.feature_names = []

    def extract_comprehensive_features(self, audio_path):
        """Extract comprehensive acoustic features with better error handling"""
        try:
            # Load audio with duration limit
            y, sr = librosa.load(audio_path, sr=self.sample_rate, duration=self.max_duration)

            if len(y) == 0:
                print(f"Warning: Empty audio file {audio_path}")
                return self._get_zero_features()

            features = []
            feature_names = []

            # 1. MFCC features (most important for speech)
            mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
            mfcc_stats = self._compute_statistical_features(mfccs, 'mfcc')
            features.extend(mfcc_stats['values'])
            feature_names.extend(mfcc_stats['names'])

            # 2. Spectral features
            spectral_features = {
                'spectral_centroid': librosa.feature.spectral_centroid(y=y, sr=sr)[0],
                'spectral_rolloff': librosa.feature.spectral_rolloff(y=y, sr=sr)[0],
                'spectral_bandwidth': librosa.feature.spectral_bandwidth(y=y, sr=sr)[0],
                'spectral_contrast': librosa.feature.spectral_contrast(y=y, sr=sr).mean(axis=0),
                'spectral_flatness': librosa.feature.spectral_flatness(y=y)[0]
            }

            for name, values in spectral_features.items():
                if values.ndim > 0:
                    stats = self._compute_statistical_features(values.reshape(1, -1), name)
                    features.extend(stats['values'])
                    feature_names.extend(stats['names'])
                else:
                    features.append(values)
                    feature_names.append(name)

            # 3. Rhythmic features
            tempo, beats = librosa.beat.beat_track(y=y, sr=sr)
            features.append(tempo)
            feature_names.append('tempo')

            # Beat consistency
            if len(beats) > 1:
                beat_intervals = np.diff(beats) / sr
                features.extend([
                    np.mean(beat_intervals),
                    np.std(beat_intervals),
                    np.var(beat_intervals)
                ])
                feature_names.extend(['beat_interval_mean', 'beat_interval_std', 'beat_interval_var'])
            else:
                features.extend([0, 0, 0])
                feature_names.extend(['beat_interval_mean', 'beat_interval_std', 'beat_interval_var'])

            # 4. Zero crossing rate (speech activity)
            zcr = librosa.feature.zero_crossing_rate(y)[0]
            zcr_stats = self._compute_statistical_features(zcr.reshape(1, -1), 'zcr')
            features.extend(zcr_stats['values'])
            feature_names.extend(zcr_stats['names'])

            # 5. Chroma features (harmonic content)
            chroma = librosa.feature.chroma_stft(y=y, sr=sr)
            chroma_stats = self._compute_statistical_features(chroma, 'chroma')
            features.extend(chroma_stats['values'])
            feature_names.extend(chroma_stats['names'])

            # 6. Energy and power features
            rms_energy = librosa.feature.rms(y=y)[0]
            rms_stats = self._compute_statistical_features(rms_energy.reshape(1, -1), 'rms_energy')
            features.extend(rms_stats['values'])
            feature_names.extend(rms_stats['names'])

            # 7. Formant-like features (using spectral peaks)
            stft = librosa.stft(y)
            magnitude = np.abs(stft)
            spectral_peaks = []
            for frame in range(min(10, magnitude.shape[1])):  # Sample few frames
                spectrum = magnitude[:, frame]
                peaks = self._find_spectral_peaks(spectrum)
                spectral_peaks.extend(peaks[:3])  # Top 3 peaks

            if spectral_peaks:
                features.extend([
                    np.mean(spectral_peaks),
                    np.std(spectral_peaks),
                    np.max(spectral_peaks) if spectral_peaks else 0
                ])
            else:
                features.extend([0, 0, 0])
            feature_names.extend(['formant_mean', 'formant_std', 'formant_max'])

            # Store feature names for first extraction
            if not self.feature_names:
                self.feature_names = feature_names.copy()

            return np.array(features, dtype=np.float32)

        except Exception as e:
            print(f"Error extracting features from {audio_path}: {e}")
            return self._get_zero_features()

    def _compute_statistical_features(self, data, prefix):
        """Compute statistical features from 2D array"""
        if data.ndim == 1:
            data = data.reshape(1, -1)

        stats_values = []
        stats_names = []

        for i in range(data.shape[0]):
            row = data[i]
            stats_values.extend([
                np.mean(row),
                np.std(row),
                np.var(row),
                np.max(row),
                np.min(row),
                np.median(row),
                np.percentile(row, 25),
                np.percentile(row, 75)
            ])

            if data.shape[0] == 1:
                stats_names.extend([
                    f'{prefix}_mean', f'{prefix}_std', f'{prefix}_var',
                    f'{prefix}_max', f'{prefix}_min', f'{prefix}_median',
                    f'{prefix}_q25', f'{prefix}_q75'
                ])
            else:
                stats_names.extend([
                    f'{prefix}_{i}_mean', f'{prefix}_{i}_std', f'{prefix}_{i}_var',
                    f'{prefix}_{i}_max', f'{prefix}_{i}_min', f'{prefix}_{i}_median',
                    f'{prefix}_{i}_q25', f'{prefix}_{i}_q75'
                ])

        return {'values': stats_values, 'names': stats_names}

    def _find_spectral_peaks(self, spectrum, num_peaks=3):
        """Find spectral peaks (simplified formant detection)"""
        try:
            from scipy.signal import find_peaks
            peaks, _ = find_peaks(spectrum, height=np.max(spectrum) * 0.1)
            if len(peaks) > 0:
                # Convert to Hz (assuming 22050 Hz max freq for simplicity)
                peak_freqs = peaks * (self.sample_rate // 2) / len(spectrum)
                return sorted(peak_freqs, reverse=True)[:num_peaks]
            else:
                return [0] * num_peaks
        except:
            return [0] * num_peaks

    def _get_zero_features(self):
        """Return zero feature vector for failed extractions"""
        # Estimate feature dimension based on typical extraction
        estimated_dim = 13*8 + 5*8 + 3 + 1*8 + 12*8 + 1*8 + 3  # Rough estimate
        return np.zeros(estimated_dim, dtype=np.float32)

# Test feature extraction
print("=== Testing Enhanced Audio Feature Extraction ===")

extractor = EnhancedAudioFeatureExtractor()

if audio_files:
    print(f"Testing with: {audio_files[0]}")
    test_features = extractor.extract_comprehensive_features(audio_files[0])
    print(f"✓ Audio feature extraction successful!")
    print(f"Feature vector dimension: {len(test_features)}")
    print(f"Feature vector shape: {test_features.shape}")
    print(f"Sample features: {test_features[:10]}")

    # Test with multiple files to ensure consistency
    print("\nTesting consistency with multiple files...")
    feature_dims = []
    for i, audio_file in enumerate(audio_files[:min(5, len(audio_files))]):
        try:
            features = extractor.extract_comprehensive_features(audio_file)
            feature_dims.append(len(features))
            print(f"File {i+1}: {len(features)} features")
        except Exception as e:
            print(f"File {i+1} failed: {e}")

    if len(set(feature_dims)) == 1:
        print("✓ Feature dimensions are consistent across files")
    else:
        print(f"⚠ Inconsistent feature dimensions: {set(feature_dims)}")
else:
    print("No audio files found for testing")

=== Testing Enhanced Audio Feature Extraction ===
Testing with: /content/drive/MyDrive/Voice/ADReSSo21-diagnosis-train/ADReSSo21/diagnosis/train/audio/cn/adrso007.wav
Error extracting features from /content/drive/MyDrive/Voice/ADReSSo21-diagnosis-train/ADReSSo21/diagnosis/train/audio/cn/adrso007.wav: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (263,) + inhomogeneous part.
✓ Audio feature extraction successful!
Feature vector dimension: 262
Feature vector shape: (262,)
Sample features: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

Testing consistency with multiple files...
Error extracting features from /content/drive/MyDrive/Voice/ADReSSo21-diagnosis-train/ADReSSo21/diagnosis/train/audio/cn/adrso007.wav: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (263,) + inhomogeneous part.
File 1: 262 features
Error extracting features fro