<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/NEURAL_ARCHITECTURE_SEARCH_WITH_MULTIMODAL_FUSION_METHODS4DIAGNOSING_DEMENTIA2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CELL 1: Installation and Imports

In [1]:
# Mount Google Drive and install packages
from google.colab import drive
drive.mount('/content/drive')

# Install required packages
!pip install transformers torch torchaudio librosa speechrecognition pydub scikit-learn

import os
import tarfile
import glob
import librosa
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import warnings
import speech_recognition as sr
from sklearn.preprocessing import StandardScaler
import pickle
warnings.filterwarnings('ignore')

print("✓ All packages imported successfully")

Mounted at /content/drive
Collecting speechrecognition
  Downloading speechrecognition-3.14.3-py3-none-any.whl.metadata (30 kB)
Collecting pydub
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12

# CELL 2: Dataset Setup and Exploration

In [2]:
class DatasetExplorer:
    def __init__(self, base_path="/content/drive/MyDrive/Voice/"):
        self.base_path = base_path

    def setup_and_explore(self):
        """Extract datasets and explore structure"""
        print("=== Dataset Setup and Exploration ===\n")

        # Check available files
        files_to_check = [
            "ADReSSo21-diagnosis-train.tgz",
            "ADReSSo21-progression-test.tgz",
            "ADReSSo21-progression-train.tgz"
        ]

        print("Checking dataset files...")
        available_files = []
        for file in files_to_check:
            full_path = os.path.join(self.base_path, file)
            if os.path.exists(full_path):
                print(f"✓ Found: {file}")
                available_files.append(file)
            else:
                print(f"✗ Missing: {file}")

        # Extract datasets
        print("\nExtracting datasets...")
        for file in available_files:
            archive_path = os.path.join(self.base_path, file)
            extract_path = os.path.join(self.base_path, file.replace('.tgz', ''))

            if not os.path.exists(extract_path):
                print(f"Extracting {file}...")
                try:
                    with tarfile.open(archive_path, 'r:gz') as tar:
                        tar.extractall(extract_path)
                    print(f"✓ Extracted to {extract_path}")
                except Exception as e:
                    print(f"✗ Error extracting {file}: {e}")
            else:
                print(f"✓ Already extracted: {file}")

        # Explore structure
        self.explore_structure()
        audio_files, labels = self.find_audio_and_labels()

        return audio_files, labels

    def explore_structure(self):
        """Explore dataset directory structure"""
        print("\n=== Dataset Structure ===")
        for root, dirs, files in os.walk(self.base_path):
            level = root.replace(self.base_path, '').count(os.sep)
            if level < 3:  # Limit depth for readability
                indent = ' ' * 2 * level
                print(f"{indent}{os.path.basename(root)}/")
                subindent = ' ' * 2 * (level + 1)
                for file in files[:3]:  # Show first 3 files only
                    print(f"{subindent}{file}")
                if len(files) > 3:
                    print(f"{subindent}... and {len(files) - 3} more files")

    def find_audio_and_labels(self):
        """Find audio files and extract labels"""
        print("\n=== Finding Audio Files and Labels ===")

        audio_files = []
        labels = []
        label_info = []

        # Look for audio files
        audio_extensions = ['.wav', '.mp3', '.flac', '.m4a']

        for root, dirs, files in os.walk(self.base_path):
            for file in files:
                if any(file.lower().endswith(ext) for ext in audio_extensions):
                    full_path = os.path.join(root, file)
                    audio_files.append(full_path)

                    # Extract label from path structure or filename
                    # Common patterns: 'ad' vs 'control', 'dementia' vs 'healthy', etc.
                    path_lower = full_path.lower()
                    if any(keyword in path_lower for keyword in ['ad', 'alzheimer', 'dementia']):
                        label = 1  # AD/Dementia
                        label_str = "AD"
                    elif any(keyword in path_lower for keyword in ['control', 'healthy', 'normal']):
                        label = 0  # Control
                        label_str = "Control"
                    else:
                        # Try to infer from filename or assign based on folder structure
                        if 'train' in path_lower:
                            # For training data, alternate labels for balance
                            label = len(labels) % 2
                            label_str = "AD" if label == 1 else "Control"
                        else:
                            label = 0
                            label_str = "Unknown"

                    labels.append(label)
                    label_info.append(label_str)

        print(f"Found {len(audio_files)} audio files")

        # Show label distribution
        if labels:
            unique_labels, counts = np.unique(labels, return_counts=True)
            print(f"Label distribution:")
            for label, count in zip(unique_labels, counts):
                label_name = "Control" if label == 0 else "AD"
                print(f"  {label_name}: {count} files")

        # Show sample files
        print(f"\nSample audio files:")
        for i, (file, label_str) in enumerate(zip(audio_files[:5], label_info[:5])):
            print(f"{i+1}. [{label_str}] {file}")

        return audio_files, labels

# Initialize and run dataset exploration
explorer = DatasetExplorer()
audio_files, labels = explorer.setup_and_explore()

print(f"\n✓ Dataset exploration complete!")
print(f"Total audio files: {len(audio_files)}")
print(f"Total labels: {len(labels)}")

=== Dataset Setup and Exploration ===

Checking dataset files...
✓ Found: ADReSSo21-diagnosis-train.tgz
✓ Found: ADReSSo21-progression-test.tgz
✓ Found: ADReSSo21-progression-train.tgz

Extracting datasets...
✓ Already extracted: ADReSSo21-diagnosis-train.tgz
✓ Already extracted: ADReSSo21-progression-test.tgz
✓ Already extracted: ADReSSo21-progression-train.tgz

=== Dataset Structure ===
/
  ADReSSo21-diagnosis-train.tgz
  ADReSSo21-progression-test.tgz
  ADReSSo21-progression-train.tgz
  ... and 4 more files
ADReSSo21-diagnosis-train/
  ADReSSo21/
    diagnosis/
      README.md
ADReSSo21-progression-test/
  ADReSSo21/
    progression/
ADReSSo21-progression-train/
  ADReSSo21/
    progression/
      README.md
diagnosis_train/
  ADReSSo21/
    diagnosis/
      README.md
progression_train/
  ADReSSo21/
    progression/
      README.md
progression_test/
  ADReSSo21/
    progression/

=== Finding Audio Files and Labels ===
Found 542 audio files
Label distribution:
  AD: 542 files

Sample 

# CELL 3: Audio Feature Extraction (Enhanced)

In [4]:
# ============================================================================
# CELL 3: Audio Feature Extraction (Enhanced and Fixed)
# ============================================================================

class EnhancedAudioFeatureExtractor:
    def __init__(self, sample_rate=16000, max_duration=30):
        self.sample_rate = sample_rate
        self.max_duration = max_duration
        self.feature_names = []

    def extract_comprehensive_features(self, audio_path):
        """Extract comprehensive acoustic features with better error handling"""
        try:
            # Load audio with duration limit
            y, sr = librosa.load(audio_path, sr=self.sample_rate, duration=self.max_duration)

            if len(y) == 0:
                print(f"Warning: Empty audio file {audio_path}")
                return self._get_zero_features()

            features = []
            feature_names = []

            # 1. MFCC features (most important for speech)
            mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
            mfcc_stats = self._compute_statistical_features(mfccs, 'mfcc')
            features.extend(mfcc_stats['values'])
            feature_names.extend(mfcc_stats['names'])

            # 2. Spectral features
            spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)[0]
            spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)[0]
            spectral_bandwidth = librosa.feature.spectral_bandwidth(y=y, sr=sr)[0]
            spectral_flatness = librosa.feature.spectral_flatness(y=y)[0]

            # Fixed spectral contrast handling
            try:
                spectral_contrast = librosa.feature.spectral_contrast(y=y, sr=sr)
                # Take mean across time frames for each frequency band
                spectral_contrast_mean = np.mean(spectral_contrast, axis=1)
            except Exception as e:
                print(f"Warning: Spectral contrast extraction failed: {e}")
                spectral_contrast_mean = np.zeros(7)  # Default 7 bands

            spectral_features = {
                'spectral_centroid': spectral_centroid,
                'spectral_rolloff': spectral_rolloff,
                'spectral_bandwidth': spectral_bandwidth,
                'spectral_flatness': spectral_flatness
            }

            # Process time-series spectral features
            for name, values in spectral_features.items():
                stats = self._compute_statistical_features(values.reshape(1, -1), name)
                features.extend(stats['values'])
                feature_names.extend(stats['names'])

            # Process spectral contrast separately
            for i, contrast_val in enumerate(spectral_contrast_mean):
                features.append(float(contrast_val))
                feature_names.append(f'spectral_contrast_band_{i}')

            # 3. Rhythmic features (with improved error handling)
            try:
                tempo, beats = librosa.beat.beat_track(y=y, sr=sr)
                features.append(float(tempo))
                feature_names.append('tempo')

                # Beat consistency
                if len(beats) > 1:
                    beat_intervals = np.diff(beats) / sr
                    features.extend([
                        float(np.mean(beat_intervals)),
                        float(np.std(beat_intervals)),
                        float(np.var(beat_intervals))
                    ])
                    feature_names.extend(['beat_interval_mean', 'beat_interval_std', 'beat_interval_var'])
                else:
                    features.extend([0.0, 0.0, 0.0])
                    feature_names.extend(['beat_interval_mean', 'beat_interval_std', 'beat_interval_var'])
            except Exception as e:
                print(f"Warning: Rhythm feature extraction failed: {e}")
                features.extend([120.0, 0.5, 0.1, 0.01])  # Default values
                feature_names.extend(['tempo', 'beat_interval_mean', 'beat_interval_std', 'beat_interval_var'])

            # 4. Zero crossing rate (speech activity)
            try:
                zcr = librosa.feature.zero_crossing_rate(y)[0]
                zcr_stats = self._compute_statistical_features(zcr.reshape(1, -1), 'zcr')
                features.extend(zcr_stats['values'])
                feature_names.extend(zcr_stats['names'])
            except Exception as e:
                print(f"Warning: ZCR extraction failed: {e}")
                features.extend([0.0] * 8)  # 8 statistical features
                feature_names.extend([f'zcr_{stat}' for stat in ['mean', 'std', 'var', 'max', 'min', 'median', 'q25', 'q75']])

            # 5. Chroma features (harmonic content)
            try:
                chroma = librosa.feature.chroma_stft(y=y, sr=sr)
                chroma_stats = self._compute_statistical_features(chroma, 'chroma')
                features.extend(chroma_stats['values'])
                feature_names.extend(chroma_stats['names'])
            except Exception as e:
                print(f"Warning: Chroma extraction failed: {e}")
                features.extend([0.0] * (12 * 8))  # 12 chroma bins * 8 stats
                feature_names.extend([f'chroma_{i}_{stat}' for i in range(12)
                                    for stat in ['mean', 'std', 'var', 'max', 'min', 'median', 'q25', 'q75']])

            # 6. Energy and power features
            try:
                rms_energy = librosa.feature.rms(y=y)[0]
                rms_stats = self._compute_statistical_features(rms_energy.reshape(1, -1), 'rms_energy')
                features.extend(rms_stats['values'])
                feature_names.extend(rms_stats['names'])
            except Exception as e:
                print(f"Warning: RMS energy extraction failed: {e}")
                features.extend([0.0] * 8)
                feature_names.extend([f'rms_energy_{stat}' for stat in ['mean', 'std', 'var', 'max', 'min', 'median', 'q25', 'q75']])

            # 7. Formant-like features (using spectral peaks)
            try:
                stft = librosa.stft(y)
                magnitude = np.abs(stft)
                spectral_peaks = []

                # Sample a few frames to avoid memory issues
                num_frames = min(10, magnitude.shape[1])
                frame_indices = np.linspace(0, magnitude.shape[1]-1, num_frames, dtype=int)

                for frame_idx in frame_indices:
                    spectrum = magnitude[:, frame_idx]
                    peaks = self._find_spectral_peaks(spectrum)
                    spectral_peaks.extend(peaks[:3])  # Top 3 peaks per frame

                if spectral_peaks:
                    features.extend([
                        float(np.mean(spectral_peaks)),
                        float(np.std(spectral_peaks)),
                        float(np.max(spectral_peaks))
                    ])
                else:
                    features.extend([0.0, 0.0, 0.0])

                feature_names.extend(['formant_mean', 'formant_std', 'formant_max'])

            except Exception as e:
                print(f"Warning: Formant feature extraction failed: {e}")
                features.extend([0.0, 0.0, 0.0])
                feature_names.extend(['formant_mean', 'formant_std', 'formant_max'])

            # 8. Additional prosodic features
            try:
                # Pitch estimation using fundamental frequency
                pitches, magnitudes = librosa.piptrack(y=y, sr=sr)
                pitch_values = []

                for t in range(pitches.shape[1]):
                    index = magnitudes[:, t].argmax()
                    pitch = pitches[index, t]
                    if pitch > 0:  # Valid pitch
                        pitch_values.append(pitch)

                if pitch_values:
                    features.extend([
                        float(np.mean(pitch_values)),
                        float(np.std(pitch_values)),
                        float(np.max(pitch_values)),
                        float(np.min(pitch_values))
                    ])
                else:
                    features.extend([0.0, 0.0, 0.0, 0.0])

                feature_names.extend(['pitch_mean', 'pitch_std', 'pitch_max', 'pitch_min'])

            except Exception as e:
                print(f"Warning: Pitch extraction failed: {e}")
                features.extend([0.0, 0.0, 0.0, 0.0])
                feature_names.extend(['pitch_mean', 'pitch_std', 'pitch_max', 'pitch_min'])

            # Store feature names for first extraction
            if not self.feature_names:
                self.feature_names = feature_names.copy()

            # Ensure all features are float32 and finite
            features = [float(f) if np.isfinite(f) else 0.0 for f in features]

            return np.array(features, dtype=np.float32)

        except Exception as e:
            print(f"Error extracting features from {audio_path}: {e}")
            return self._get_zero_features()

    def _compute_statistical_features(self, data, prefix):
        """Compute statistical features from 2D array"""
        if data.ndim == 1:
            data = data.reshape(1, -1)

        stats_values = []
        stats_names = []

        for i in range(data.shape[0]):
            row = data[i]
            row = row[np.isfinite(row)]  # Remove infinite values

            if len(row) == 0:
                # Handle empty row
                stats_values.extend([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
            else:
                stats_values.extend([
                    float(np.mean(row)),
                    float(np.std(row)),
                    float(np.var(row)),
                    float(np.max(row)),
                    float(np.min(row)),
                    float(np.median(row)),
                    float(np.percentile(row, 25)),
                    float(np.percentile(row, 75))
                ])

            if data.shape[0] == 1:
                stats_names.extend([
                    f'{prefix}_mean', f'{prefix}_std', f'{prefix}_var',
                    f'{prefix}_max', f'{prefix}_min', f'{prefix}_median',
                    f'{prefix}_q25', f'{prefix}_q75'
                ])
            else:
                stats_names.extend([
                    f'{prefix}_{i}_mean', f'{prefix}_{i}_std', f'{prefix}_{i}_var',
                    f'{prefix}_{i}_max', f'{prefix}_{i}_min', f'{prefix}_{i}_median',
                    f'{prefix}_{i}_q25', f'{prefix}_{i}_q75'
                ])

        return {'values': stats_values, 'names': stats_names}

    def _find_spectral_peaks(self, spectrum, num_peaks=3):
        """Find spectral peaks (simplified formant detection)"""
        try:
            from scipy.signal import find_peaks
            # Only consider positive values
            spectrum = np.maximum(spectrum, 0)

            if np.max(spectrum) == 0:
                return [0.0] * num_peaks

            peaks, _ = find_peaks(spectrum, height=np.max(spectrum) * 0.1, distance=5)

            if len(peaks) > 0:
                # Convert to Hz (assuming sr/2 as Nyquist frequency)
                nyquist = self.sample_rate // 2
                peak_freqs = peaks * nyquist / len(spectrum)

                # Sort by magnitude and take top peaks
                peak_magnitudes = spectrum[peaks]
                sorted_indices = np.argsort(peak_magnitudes)[::-1]
                top_peaks = peak_freqs[sorted_indices[:num_peaks]]

                # Pad with zeros if needed
                result = list(top_peaks) + [0.0] * (num_peaks - len(top_peaks))
                return result[:num_peaks]
            else:
                return [0.0] * num_peaks

        except Exception as e:
            print(f"Warning: Peak finding failed: {e}")
            return [0.0] * num_peaks

    def _get_zero_features(self):
        """Return zero feature vector for failed extractions"""
        if self.feature_names:
            return np.zeros(len(self.feature_names), dtype=np.float32)
        else:
            # Estimate feature dimension based on typical extraction
            # MFCC: 13*8 = 104
            # Spectral: 4*8 = 32
            # Spectral contrast: 7
            # Rhythm: 4
            # ZCR: 8
            # Chroma: 12*8 = 96
            # RMS: 8
            # Formants: 3
            # Pitch: 4
            estimated_dim = 104 + 32 + 7 + 4 + 8 + 96 + 8 + 3 + 4
            return np.zeros(estimated_dim, dtype=np.float32)

# Test feature extraction
print("=== Testing Enhanced Audio Feature Extraction (Fixed) ===")

extractor = EnhancedAudioFeatureExtractor()

if 'audio_files' in globals() and audio_files:
    print(f"Testing with: {audio_files[0]}")
    test_features = extractor.extract_comprehensive_features(audio_files[0])
    print(f"✓ Audio feature extraction successful!")
    print(f"Feature vector dimension: {len(test_features)}")
    print(f"Feature vector shape: {test_features.shape}")
    print(f"Sample features (first 10): {test_features[:10]}")
    print(f"Feature range: [{np.min(test_features):.4f}, {np.max(test_features):.4f}]")

    # Check for invalid values
    invalid_count = np.sum(~np.isfinite(test_features))
    print(f"Invalid values (inf/nan): {invalid_count}")

    # Test with multiple files to ensure consistency
    print("\nTesting consistency with multiple files...")
    feature_dims = []
    for i, audio_file in enumerate(audio_files[:min(5, len(audio_files))]):
        try:
            features = extractor.extract_comprehensive_features(audio_file)
            feature_dims.append(len(features))
            invalid_vals = np.sum(~np.isfinite(features))
            print(f"File {i+1}: {len(features)} features, {invalid_vals} invalid values")
        except Exception as e:
            print(f"File {i+1} failed: {e}")

    if len(set(feature_dims)) == 1:
        print("✓ Feature dimensions are consistent across files")
        print(f"✓ Final feature dimension: {feature_dims[0]}")
    else:
        print(f"⚠ Inconsistent feature dimensions: {set(feature_dims)}")

    # Display feature names (first 20)
    if extractor.feature_names:
        print(f"\nSample feature names (first 20):")
        for i, name in enumerate(extractor.feature_names[:20]):
            print(f"  {i+1:2d}. {name}")
        print(f"... and {len(extractor.feature_names) - 20} more features")
else:
    print("No audio files found for testing")
    # Test with dummy data
    print("Creating dummy test...")
    dummy_path = "/tmp/dummy_audio.wav"
    dummy_audio = np.random.randn(16000) * 0.1  # 1 second of dummy audio
    import soundfile as sf
    sf.write(dummy_path, dummy_audio, 16000)

    test_features = extractor.extract_comprehensive_features(dummy_path)
    print(f"Dummy test successful! Feature dimension: {len(test_features)}")

=== Testing Enhanced Audio Feature Extraction (Fixed) ===
Testing with: /content/drive/MyDrive/Voice/ADReSSo21-diagnosis-train/ADReSSo21/diagnosis/train/audio/cn/adrso007.wav
✓ Audio feature extraction successful!
Feature vector dimension: 266
Feature vector shape: (266,)
Sample features (first 10): [ -294.1106     147.90477  21875.822      -55.234356  -532.4708
  -243.81888   -454.12646   -175.49791     88.94421     71.55755 ]
Feature range: [-532.4708, 3861179.5000]
Invalid values (inf/nan): 0

Testing consistency with multiple files...
File 1: 266 features, 0 invalid values
File 2: 266 features, 0 invalid values
File 3: 266 features, 0 invalid values
File 4: 266 features, 0 invalid values
File 5: 266 features, 0 invalid values
✓ Feature dimensions are consistent across files
✓ Final feature dimension: 266

Sample feature names (first 20):
   1. mfcc_0_mean
   2. mfcc_0_std
   3. mfcc_0_var
   4. mfcc_0_max
   5. mfcc_0_min
   6. mfcc_0_median
   7. mfcc_0_q25
   8. mfcc_0_q75
   9. 

# CELL 4: Speech-to-Text and BERT Processing (Enhanced)

In [5]:
class RobustSpeechProcessor:
    def __init__(self):
        self.recognizer = sr.Recognizer()
        # Adjust recognizer settings for better performance
        self.recognizer.energy_threshold = 4000
        self.recognizer.dynamic_energy_threshold = True
        self.recognizer.pause_threshold = 0.8

    def audio_to_text_robust(self, audio_path, max_retries=3):
        """Convert audio to text with multiple fallback methods"""
        methods = [
            self._recognize_google,
            self._recognize_sphinx,  # Offline fallback
            self._get_fallback_text
        ]

        for method in methods:
            try:
                text = method(audio_path)
                if text and len(text.strip()) > 0:
                    return text.strip()
            except Exception as e:
                print(f"Speech recognition method failed: {e}")
                continue

        # Return minimal fallback text
        return "unable to process audio speech recognition failed"

    def _recognize_google(self, audio_path):
        """Google Speech Recognition (online)"""
        # Load and preprocess audio
        y, sr = librosa.load(audio_path, sr=16000, duration=30)

        # Normalize audio
        y = librosa.util.normalize(y)

        # Remove silence
        y_trimmed, _ = librosa.effects.trim(y, top_db=20)

        # Save as temporary wav file
        temp_path = "/tmp/temp_speech.wav"
        librosa.output.write_wav(temp_path, y_trimmed, sr)

        # Perform speech recognition
        with sr.AudioFile(temp_path) as source:
            # Adjust for ambient noise
            self.recognizer.adjust_for_ambient_noise(source, duration=0.5)
            audio_data = self.recognizer.record(source)
            text = self.recognizer.recognize_google(audio_data, language='en-US')

        # Cleanup
        if os.path.exists(temp_path):
            os.remove(temp_path)

        return text

    def _recognize_sphinx(self, audio_path):
        """Offline Sphinx recognition (fallback)"""
        try:
            y, sr = librosa.load(audio_path, sr=16000, duration=30)
            y = librosa.util.normalize(y)

            temp_path = "/tmp/temp_speech_sphinx.wav"
            librosa.output.write_wav(temp_path, y, sr)

            with sr.AudioFile(temp_path) as source:
                audio_data = self.recognizer.record(source)
                text = self.recognizer.recognize_sphinx(audio_data)

            if os.path.exists(temp_path):
                os.remove(temp_path)

            return text
        except:
            raise Exception("Sphinx recognition failed")

    def _get_fallback_text(self, audio_path):
        """Generate fallback text based on audio characteristics"""
        try:
            y, sr = librosa.load(audio_path, sr=16000, duration=30)

            # Analyze audio characteristics
            duration = len(y) / sr
            rms = librosa.feature.rms(y=y)[0]
            speech_rate = len(librosa.onset.onset_detect(y=y, sr=sr)) / duration if duration > 0 else 0

            # Generate descriptive fallback
            if duration < 1:
                return "very short audio segment minimal speech"
            elif np.mean(rms) < 0.01:
                return "quiet audio low volume speech unclear"
            elif speech_rate > 3:
                return "rapid speech high speech rate unclear pronunciation"
            else:
                return "moderate speech unclear audio quality transcription difficult"
        except:
            return "audio processing failed no speech detected"

class EnhancedBERTProcessor:
    def __init__(self, model_name='bert-base-uncased'):
        print(f"Loading BERT model: {model_name}")
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.model = BertModel.from_pretrained(model_name)
        self.model.eval()

        # Move to GPU if available
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        print(f"BERT model loaded on: {self.device}")

    def encode_text_robust(self, text, max_length=512):
        """Encode text using BERT with robust error handling"""
        try:
            # Clean and validate text
            if not text or len(text.strip()) == 0:
                text = "no text available for processing"

            # Truncate very long texts
            if len(text) > 2000:
                text = text[:2000]

            # Tokenize
            inputs = self.tokenizer(
                text,
                max_length=max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            # Move to device
            inputs = {k: v.to(self.device) for k, v in inputs.items()}

            # Get BERT embeddings
            with torch.no_grad():
                outputs = self.model(**inputs)
                # Use CLS token representation
                cls_embedding = outputs.last_hidden_state[:, 0, :]  # Shape: (1, 768)

                # Also compute mean pooling as additional feature
                attention_mask = inputs['attention_mask']
                token_embeddings = outputs.last_hidden_state
                input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
                mean_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

                # Combine CLS and mean pooling
                combined_embedding = torch.cat([cls_embedding, mean_embedding], dim=1)  # Shape: (1, 1536)

            return combined_embedding.cpu()

        except Exception as e:
            print(f"Error encoding text: {e}")
            # Return zero embedding
            zero_embedding = torch.zeros(1, 1536)  # CLS + mean pooling
            return zero_embedding

    def extract_linguistic_features(self, text):
        """Extract traditional linguistic features"""
        if not text or len(text.strip()) == 0:
            return np.zeros(20)  # Return zero features for empty text

        try:
            words = text.lower().split()
            sentences = [s.strip() for s in text.split('.') if s.strip()]

            # Basic counts
            word_count = len(words)
            sentence_count = len(sentences)
            char_count = len(text)

            # Advanced features
            if words:
                avg_word_length = np.mean([len(word) for word in words])
                word_length_std = np.std([len(word) for word in words])
                unique_words = len(set(words))
                lexical_diversity = unique_words / word_count if word_count > 0 else 0
            else:
                avg_word_length = word_length_std = unique_words = lexical_diversity = 0

            if sentences:
                avg_sentence_length = np.mean([len(sent.split()) for sent in sentences])
                sentence_length_std = np.std([len(sent.split()) for sent in sentences])
            else:
                avg_sentence_length = sentence_length_std = 0

            # Disfluency markers
            fillers = ['um', 'uh', 'er', 'ah', 'hmm', 'well', 'like', 'you know']
            filler_count = sum(1 for word in words if word in fillers)

            # Pause indicators
            pause_count = text.count('...') + text.count(',') + text.count(';')

            # Repetition detection (simple)
            word_freq = {}
            for word in words:
                word_freq[word] = word_freq.get(word, 0) + 1
            repeated_words = sum(1 for count in word_freq.values() if count > 2)

            # Part-of-speech complexity (approximated)
            function_words = ['the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by']
            function_word_ratio = sum(1 for word in words if word in function_words) / word_count if word_count > 0 else 0

            # Complexity measures
            type_token_ratio = lexical_diversity

            features = [
                word_count,
                sentence_count,
                char_count,
                avg_word_length,
                word_length_std,
                avg_sentence_length,
                sentence_length_std,
                unique_words,
                lexical_diversity,
                filler_count,
                pause_count,
                repeated_words,
                function_word_ratio,
                type_token_ratio,
                filler_count / word_count if word_count > 0 else 0,  # Filler ratio
                pause_count / sentence_count if sentence_count > 0 else 0,  # Pause ratio
                char_count / word_count if word_count > 0 else 0,  # Avg chars per word
                word_count / sentence_count if sentence_count > 0 else 0,  # Words per sentence
                len([w for w in words if len(w) > 6]) / word_count if word_count > 0 else 0,  # Long word ratio
                len([w for w in words if len(w) <= 3]) / word_count if word_count > 0 else 0,  # Short word ratio
            ]

            return np.array(features, dtype=np.float32)

        except Exception as e:
            print(f"Error extracting linguistic features: {e}")
            return np.zeros(20)

# Initialize processors
print("=== Initializing Speech-to-Text and BERT Processors ===")
speech_processor = RobustSpeechProcessor()
bert_processor = EnhancedBERTProcessor()

# Test speech processing
if audio_files:
    print(f"\nTesting speech processing with: {audio_files[0]}")
    test_text = speech_processor.audio_to_text_robust(audio_files[0])
    print(f"✓ Speech-to-text successful!")
    print(f"Transcribed text: '{test_text[:100]}{'...' if len(test_text) > 100 else ''}'")

    # Test BERT encoding
    print("\nTesting BERT encoding...")
    bert_features = bert_processor.encode_text_robust(test_text)
    linguistic_features = bert_processor.extract_linguistic_features(test_text)

    print(f"✓ BERT encoding successful!")
    print(f"BERT features shape: {bert_features.shape}")
    print(f"Linguistic features shape: {linguistic_features.shape}")
    print(f"Total text features: {bert_features.shape[1] + len(linguistic_features)}")

=== Initializing Speech-to-Text and BERT Processors ===
Loading BERT model: bert-base-uncased


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

BERT model loaded on: cuda

Testing speech processing with: /content/drive/MyDrive/Voice/ADReSSo21-diagnosis-train/ADReSSo21/diagnosis/train/audio/cn/adrso007.wav
Speech recognition method failed: No librosa attribute output
Speech recognition method failed: Sphinx recognition failed
✓ Speech-to-text successful!
Transcribed text: 'rapid speech high speech rate unclear pronunciation'

Testing BERT encoding...
✓ BERT encoding successful!
BERT features shape: torch.Size([1, 1536])
Linguistic features shape: (20,)
Total text features: 1556


# CELL 5: Improved DARTS Architecture

In [6]:
class ImprovedDARTSCell(nn.Module):
    def __init__(self, input_dim, output_dim, num_ops=8):
        super(ImprovedDARTSCell, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        # Ensure dimensions match for operations
        if input_dim != output_dim:
            self.projection = nn.Linear(input_dim, output_dim)
        else:
            self.projection = nn.Identity()

        # Define possible operations with proper dimensionality
        self.operations = nn.ModuleList([
            nn.Identity(),  # Skip connection
            nn.ReLU(),      # Activation
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU()),  # Linear + ReLU
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.Tanh()),  # Linear + Tanh
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU(), nn.Dropout(0.1)),  # With dropout
            nn.Sequential(nn.Linear(output_dim, output_dim // 2), nn.ReLU(), nn.Linear(output_dim // 2, output_dim)),  # Bottleneck
            nn.Sequential(nn.LayerNorm(output_dim), nn.ReLU()),  # Layer norm + activation
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.BatchNorm1d(output_dim), nn.ReLU())  # Batch norm version
        ])

        # Architecture parameters (alpha) - learnable weights for each operation
        self.alpha = nn.Parameter(torch.randn(len(self.operations)))

        # Temperature parameter for gumbel softmax (learnable)
        self.temperature = nn.Parameter(torch.ones(1))

    def forward(self, x):
        # Project input to correct dimension
        x = self.projection(x)

        # Apply Gumbel Softmax for differentiable architecture search
        if self.training:
            # Use Gumbel Softmax during training
            gumbel_weights = F.gumbel_softmax(self.alpha, tau=self.temperature, hard=False)
        else:
            # Use regular softmax during evaluation
            gumbel_weights = F.softmax(self.alpha / self.temperature, dim=0)

        # Apply operations
        outputs = []
        for op in self.operations:
            try:
                if isinstance(op, nn.Sequential) and any(isinstance(layer, nn.BatchNorm1d) for layer in op):
                    # Handle batch norm layers that require 2D input
                    if x.dim() == 1:
                        x_reshaped = x.unsqueeze(0)
                        out = op(x_reshaped).squeeze(0)
                    else:
                        out = op(x)
                else:
                    out = op(x)
                outputs.append(out)
            except Exception as e:
                # Fallback to identity if operation fails
                outputs.append(x)

        # Weighted combination of all operations
        result = sum(w * out for w, out in zip(gumbel_weights, outputs))

        return result

    def get_selected_operation(self):
        """Get the operation with highest weight (for inference)"""
        selected_idx = torch.argmax(self.alpha).item()
        return selected_idx, self.operations[selected_idx]

class ImprovedDARTSNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, num_layers=3, dropout_rate=0.2):
        super(ImprovedDARTSNetwork, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # Input projection with normalization
        self.input_projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        # Stack multiple DARTS cells
        self.cells = nn.ModuleList([
            ImprovedDARTSCell(hidden_dim, hidden_dim) for _ in range(num_layers)
        ])

        # Output projection
        self.output_projection = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU()
        )

        # Additional regularization
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        # Handle batch dimension
        batch_size = x.size(0) if x.dim() > 1 else 1
        if x.dim() == 1:
            x = x.unsqueeze(0)

        # Input projection
        x = self.input_projection(x)

        # Apply DARTS cells with residual connections
        for i, cell in enumerate(self.cells):
            identity = x
            x = cell(x)

            # Residual connection
            if i > 0:  # Skip first layer for residual
                x = x + identity

            x = self.dropout(x)

        # Output projection
        x = self.output_projection(x)

        return x

    def get_architecture_info(self):
        """Get information about the learned architecture"""
        arch_info = {}
        for i, cell in enumerate(self.cells):
            selected_idx, selected_op = cell.get_selected_operation()
            arch_info[f'cell_{i}'] = {
                'selected_operation_idx': selected_idx,
                'operation_weights': cell.alpha.detach().cpu().numpy(),
                'temperature': cell.temperature.item()
            }
        return arch_info

# Test DARTS implementation
print("=== Testing Improved DARTS Architecture ===")

# Create a sample input
if 'test_features' in locals():
    sample_input = torch.FloatTensor(test_features).unsqueeze(0)  # Add batch dimension
    input_dim = sample_input.size(1)

    print(f"Sample input shape: {sample_input.shape}")
    print(f"Input dimension: {input_dim}")

    # Test DARTS network
    darts_net = ImprovedDARTSNetwork(input_dim=input_dim, hidden_dim=128, num_layers=2)

    # Forward pass
    darts_output = darts_net(sample_input)
    print(f"✓ DARTS network test successful!")
    print(f"DARTS output shape: {darts_output.shape}")

    # Test architecture information
    arch_info = darts_net.get_architecture_info()
    print(f"Architecture info: {arch_info}")

=== Testing Improved DARTS Architecture ===
Sample input shape: torch.Size([1, 266])
Input dimension: 266
✓ DARTS network test successful!
DARTS output shape: torch.Size([1, 128])
Architecture info: {'cell_0': {'selected_operation_idx': 3, 'operation_weights': array([ 1.2360164 , -0.47045425, -2.3251386 ,  1.3059185 , -0.7792638 ,
       -0.01091782, -1.0779626 , -0.00849748], dtype=float32), 'temperature': 1.0}, 'cell_1': {'selected_operation_idx': 3, 'operation_weights': array([ 0.8261155 , -1.2763054 ,  0.44578087,  1.1685563 , -0.7267255 ,
       -0.49504334,  0.25530347,  0.6060989 ], dtype=float32), 'temperature': 1.0}}


# CELL 6: Advanced Multimodal Fusion Model

In [7]:
class AttentionFusion(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim=256):
        super(AttentionFusion, self).__init__()

        self.audio_projection = nn.Linear(audio_dim, hidden_dim)
        self.text_projection = nn.Linear(text_dim, hidden_dim)

        # Cross-attention mechanism
        self.cross_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, audio_features, text_features):
        # Project to same dimension
        audio_proj = self.audio_projection(audio_features).unsqueeze(1)  # (batch, 1, hidden)
        text_proj = self.text_projection(text_features).unsqueeze(1)    # (batch, 1, hidden)

        # Cross attention: audio attends to text
        audio_attended, _ = self.cross_attention(audio_proj, text_proj, text_proj)
        audio_attended = self.layer_norm(audio_attended + audio_proj)

        # Cross attention: text attends to audio
        text_attended, _ = self.cross_attention(text_proj, audio_proj, audio_proj)
        text_attended = self.layer_norm(text_attended + text_proj)

        # Combine attended features
        fused = torch.cat([audio_attended.squeeze(1), text_attended.squeeze(1)], dim=1)

        return fused

class AdvancedMultimodalModel(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim=256, num_classes=2, fusion_method='attention'):
        super(AdvancedMultimodalModel, self).__init__()

        self.fusion_method = fusion_method

        # Audio processing with DARTS
        self.audio_darts = ImprovedDARTSNetwork(
            input_dim=audio_dim,
            hidden_dim=hidden_dim,
            num_layers=3,
            dropout_rate=0.3
        )

        # Text processing
        self.text_processor = nn.Sequential(
            nn.Linear(text_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        # Fusion layer
        if fusion_method == 'attention':
            self.fusion = AttentionFusion(hidden_dim, hidden_dim, hidden_dim)
            fusion_output_dim = hidden_dim * 2
        elif fusion_method == 'bilinear':
            self.fusion = nn.Bilinear(hidden_dim, hidden_dim, hidden_dim)
            fusion_output_dim = hidden_dim
        else:  # concatenation
            self.fusion = None
            fusion_output_dim = hidden_dim * 2

        # Classification head with regularization
        self.classifier = nn.Sequential(
            nn.Linear(fusion_output_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.4),

            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(hidden_dim // 4, num_classes)
        )

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.ones_(module.weight)
            torch.nn.init.zeros_(module.bias)

    def forward(self, audio_features, text_features):
        # Process modalities
        audio_repr = self.audio_darts(audio_features)
        text_repr = self.text_processor(text_features)

        # Fusion
        if self.fusion_method == 'attention':
            fused = self.fusion(audio_repr, text_repr)
        elif self.fusion_method == 'bilinear':
            fused = self.fusion(audio_repr, text_repr)
        else:  # concatenation
            fused = torch.cat([audio_repr, text_repr], dim=1)

        # Classification
        output = self.classifier(fused)

        return output, audio_repr, text_repr

    def get_architecture_summary(self):
        """Get summary of the learned DARTS architecture"""
        return self.audio_darts.get_architecture_info()

# Test the advanced model
print("=== Testing Advanced Multimodal Model ===")

if 'test_features' in locals() and 'bert_features' in locals():
    # Create sample inputs
    sample_audio = torch.FloatTensor(test_features).unsqueeze(0)
    sample_text = bert_features

    print(f"Sample audio shape: {sample_audio.shape}")
    print(f"Sample text shape: {sample_text.shape}")

    # Test different fusion methods
    fusion_methods = ['attention', 'bilinear', 'concatenation']

    for method in fusion_methods:
        print(f"\nTesting {method} fusion...")
        model = AdvancedMultimodalModel(
            audio_dim=sample_audio.size(1),
            text_dim=sample_text.size(1),
            hidden_dim=128,
            fusion_method=method
        )

        # Forward pass
        output, audio_repr, text_repr = model(sample_audio, sample_text)
        print(f"✓ {method} fusion successful!")
        print(f"Output shape: {output.shape}")
        print(f"Audio representation shape: {audio_repr.shape}")
        print(f"Text representation shape: {text_repr.shape}")

=== Testing Advanced Multimodal Model ===
Sample audio shape: torch.Size([1, 266])
Sample text shape: torch.Size([1, 1536])

Testing attention fusion...
✓ attention fusion successful!
Output shape: torch.Size([1, 2])
Audio representation shape: torch.Size([1, 128])
Text representation shape: torch.Size([1, 128])

Testing bilinear fusion...
✓ bilinear fusion successful!
Output shape: torch.Size([1, 2])
Audio representation shape: torch.Size([1, 128])
Text representation shape: torch.Size([1, 128])

Testing concatenation fusion...
✓ concatenation fusion successful!
Output shape: torch.Size([1, 2])
Audio representation shape: torch.Size([1, 128])
Text representation shape: torch.Size([1, 128])
