<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/ADReSSo21.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install SpeechRecognition pydub

Collecting SpeechRecognition
  Downloading speechrecognition-3.14.3-py3-none-any.whl.metadata (30 kB)
Collecting pydub
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Downloading speechrecognition-3.14.3-py3-none-any.whl (32.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m32.9/32.9 MB[0m [31m49.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)
Installing collected packages: pydub, SpeechRecognition
Successfully installed SpeechRecognition-3.14.3 pydub-0.25.1


In [2]:
# Step-by-Step Audio Transcript Extractor for ADReSSo21 Dataset
# This script will:
# 1. Mount Google Drive
# 2. Extract dataset files
# 3. Find all WAV files
# 4. Extract transcripts from audio using speech recognition
# 5. Save organized transcripts

import os
import tarfile
import pandas as pd
import numpy as np
from pathlib import Path
import librosa
import speech_recognition as sr
import soundfile as sf
from pydub import AudioSegment
import warnings
warnings.filterwarnings('ignore')

print("="*60)
print("ADReSSo21 AUDIO TRANSCRIPT EXTRACTOR")
print("="*60)

# STEP 1: MOUNT GOOGLE DRIVE
print("\nSTEP 1: Mounting Google Drive...")
try:
    from google.colab import drive
    drive.mount('/content/drive')
    print("✓ Google Drive mounted successfully!")
except:
    print("⚠ Not running in Colab or Drive already mounted")

# STEP 2: INSTALL REQUIRED PACKAGES
print("\nSTEP 2: Installing required packages...")
print("Installing speech recognition and audio processing libraries...")

# Install packages (run once)
!pip install SpeechRecognition
!pip install pydub
!pip install librosa
!pip install soundfile
!apt-get install -y ffmpeg

print("✓ Packages ready (make sure to install them first)")

# STEP 3: SET UP PATHS AND CONFIGURATION
print("\nSTEP 3: Setting up paths and configuration...")

BASE_PATH = "/content/drive/MyDrive/Voice/"
EXTRACT_PATH = "/content/drive/MyDrive/Voice/extracted/"
OUTPUT_PATH = "/content/drive/MyDrive/Voice/transcripts/"

# Create directories
os.makedirs(EXTRACT_PATH, exist_ok=True)
os.makedirs(OUTPUT_PATH, exist_ok=True)

datasets = {
    'progression_train': 'ADReSSo21-progression-train.tgz',
    'progression_test': 'ADReSSo21-progression-test.tgz',
    'diagnosis_train': 'ADReSSo21-diagnosis-train.tgz'
}

print(f"✓ Base path: {BASE_PATH}")
print(f"✓ Extract path: {EXTRACT_PATH}")
print(f"✓ Output path: {OUTPUT_PATH}")

# STEP 4: EXTRACT DATASET FILES
print("\nSTEP 4: Extracting dataset files...")

def extract_datasets():
    """Extract all tgz files"""
    for dataset_name, filename in datasets.items():
        file_path = os.path.join(BASE_PATH, filename)

        if os.path.exists(file_path):
            print(f"  Extracting {filename}...")
            try:
                with tarfile.open(file_path, 'r:gz') as tar:
                    tar.extractall(path=EXTRACT_PATH)
                print(f"  ✓ {filename} extracted successfully")
            except Exception as e:
                print(f"  ⚠ Error extracting {filename}: {e}")
        else:
            print(f"  ⚠ {filename} not found at {file_path}")

extract_datasets()

# STEP 5: FIND ALL WAV FILES
print("\nSTEP 5: Finding all WAV files...")

def find_wav_files():
    """Find all WAV files and organize by dataset and label"""
    wav_files = {
        'progression_train': {'decline': [], 'no_decline': []},
        'progression_test': [],
        'diagnosis_train': {'ad': [], 'cn': []}
    }

    # Progression training files
    prog_train_base = os.path.join(EXTRACT_PATH, "ADReSSo21/progression/train/audio/")

    # Decline cases
    decline_path = os.path.join(prog_train_base, "decline/")
    if os.path.exists(decline_path):
        decline_wavs = [f for f in os.listdir(decline_path) if f.endswith('.wav')]
        wav_files['progression_train']['decline'] = [os.path.join(decline_path, f) for f in decline_wavs]
        print(f"  Found {len(decline_wavs)} decline WAV files")

    # No decline cases
    no_decline_path = os.path.join(prog_train_base, "no_decline/")
    if os.path.exists(no_decline_path):
        no_decline_wavs = [f for f in os.listdir(no_decline_path) if f.endswith('.wav')]
        wav_files['progression_train']['no_decline'] = [os.path.join(no_decline_path, f) for f in no_decline_wavs]
        print(f"  Found {len(no_decline_wavs)} no_decline WAV files")

    # Progression test files
    prog_test_path = os.path.join(EXTRACT_PATH, "ADReSSo21/progression/test-dist/audio/")
    if os.path.exists(prog_test_path):
        test_wavs = [f for f in os.listdir(prog_test_path) if f.endswith('.wav')]
        wav_files['progression_test'] = [os.path.join(prog_test_path, f) for f in test_wavs]
        print(f"  Found {len(test_wavs)} test WAV files")

    # Diagnosis training files
    diag_train_base = os.path.join(EXTRACT_PATH, "ADReSSo21/diagnosis/train/audio/")

    # AD cases
    ad_path = os.path.join(diag_train_base, "ad/")
    if os.path.exists(ad_path):
        ad_wavs = [f for f in os.listdir(ad_path) if f.endswith('.wav')]
        wav_files['diagnosis_train']['ad'] = [os.path.join(ad_path, f) for f in ad_wavs]
        print(f"  Found {len(ad_wavs)} AD WAV files")

    # CN cases
    cn_path = os.path.join(diag_train_base, "cn/")
    if os.path.exists(cn_path):
        cn_wavs = [f for f in os.listdir(cn_path) if f.endswith('.wav')]
        wav_files['diagnosis_train']['cn'] = [os.path.join(cn_path, f) for f in cn_wavs]
        print(f"  Found {len(cn_wavs)} CN WAV files")

    return wav_files

wav_files = find_wav_files()

# STEP 6: AUDIO PREPROCESSING FUNCTIONS
print("\nSTEP 6: Setting up audio preprocessing...")

def preprocess_audio(audio_path, target_sr=16000):
    """Preprocess audio file for speech recognition"""
    try:
        # Load audio with librosa
        audio, sr = librosa.load(audio_path, sr=target_sr)

        # Normalize audio
        audio = librosa.util.normalize(audio)

        # Remove silence
        audio_trimmed, _ = librosa.effects.trim(audio, top_db=20)

        return audio_trimmed, target_sr
    except Exception as e:
        print(f"    Error preprocessing {audio_path}: {e}")
        return None, None

def convert_to_wav_if_needed(audio_path):
    """Convert audio to WAV format if needed"""
    try:
        if not audio_path.endswith('.wav'):
            # Convert using pydub
            audio = AudioSegment.from_file(audio_path)
            wav_path = audio_path.rsplit('.', 1)[0] + '_converted.wav'
            audio.export(wav_path, format="wav")
            return wav_path
        return audio_path
    except Exception as e:
        print(f"    Error converting {audio_path}: {e}")
        return audio_path

# STEP 7: SPEECH RECOGNITION FUNCTION
print("\nSTEP 7: Setting up speech recognition...")

def extract_transcript_from_audio(audio_path, method='google'):
    """Extract transcript from audio file using speech recognition"""
    recognizer = sr.Recognizer()

    try:
        # Convert to WAV if needed
        wav_path = convert_to_wav_if_needed(audio_path)

        # Preprocess audio
        audio_data, sr_rate = preprocess_audio(wav_path, target_sr=16000)

        if audio_data is None:
            return None, "Preprocessing failed"

        # Save preprocessed audio temporarily
        temp_wav = audio_path.replace('.wav', '_temp.wav')
        sf.write(temp_wav, audio_data, sr_rate)

        # Use speech recognition
        with sr.AudioFile(temp_wav) as source:
            # Adjust for ambient noise
            recognizer.adjust_for_ambient_noise(source, duration=0.5)
            audio = recognizer.listen(source)

        # Try different recognition methods
        transcript = None
        error_msg = ""

        if method == 'google':
            try:
                transcript = recognizer.recognize_google(audio)
            except sr.UnknownValueError:
                error_msg = "Google Speech Recognition could not understand audio"
            except sr.RequestError as e:
                error_msg = f"Google Speech Recognition error: {e}"

        # Fallback to other methods if Google fails
        if transcript is None:
            try:
                transcript = recognizer.recognize_sphinx(audio)
                method = 'sphinx'
            except sr.UnknownValueError:
                error_msg += "; Sphinx could not understand audio"
            except sr.RequestError as e:
                error_msg += f"; Sphinx error: {e}"

        # Clean up temporary file
        if os.path.exists(temp_wav):
            os.remove(temp_wav)

        if transcript:
            return transcript.strip(), method
        else:
            return None, error_msg

    except Exception as e:
        return None, f"Error processing audio: {str(e)}"

# STEP 8: PROCESS ALL AUDIO FILES AND EXTRACT TRANSCRIPTS
print("\nSTEP 8: Processing audio files and extracting transcripts...")
print("This may take a while depending on the number and length of audio files...")

def process_audio_files(wav_files):
    """Process all audio files and extract transcripts"""
    all_transcripts = []

    # Process progression training data
    print("\n  Processing progression training data...")
    for label in ['decline', 'no_decline']:
        files = wav_files['progression_train'][label]
        print(f"    Processing {len(files)} {label} files...")

        for i, audio_path in enumerate(files):
            print(f"      Processing {i+1}/{len(files)}: {os.path.basename(audio_path)}")

            transcript, method_or_error = extract_transcript_from_audio(audio_path)

            all_transcripts.append({
                'file_id': os.path.splitext(os.path.basename(audio_path))[0],
                'file_path': audio_path,
                'dataset': 'progression_train',
                'label': label,
                'transcript': transcript,
                'recognition_method': method_or_error if transcript else None,
                'error': None if transcript else method_or_error,
                'success': transcript is not None
            })

    # Process progression test data
    print("\n  Processing progression test data...")
    files = wav_files['progression_test']
    print(f"    Processing {len(files)} test files...")

    for i, audio_path in enumerate(files):
        print(f"      Processing {i+1}/{len(files)}: {os.path.basename(audio_path)}")

        transcript, method_or_error = extract_transcript_from_audio(audio_path)

        all_transcripts.append({
            'file_id': os.path.splitext(os.path.basename(audio_path))[0],
            'file_path': audio_path,
            'dataset': 'progression_test',
            'label': 'test',
            'transcript': transcript,
            'recognition_method': method_or_error if transcript else None,
            'error': None if transcript else method_or_error,
            'success': transcript is not None
        })

    # Process diagnosis training data
    print("\n  Processing diagnosis training data...")
    for label in ['ad', 'cn']:
        files = wav_files['diagnosis_train'][label]
        print(f"    Processing {len(files)} {label} files...")

        for i, audio_path in enumerate(files):
            print(f"      Processing {i+1}/{len(files)}: {os.path.basename(audio_path)}")

            transcript, method_or_error = extract_transcript_from_audio(audio_path)

            all_transcripts.append({
                'file_id': os.path.splitext(os.path.basename(audio_path))[0],
                'file_path': audio_path,
                'dataset': 'diagnosis_train',
                'label': label,
                'transcript': transcript,
                'recognition_method': method_or_error if transcript else None,
                'error': None if transcript else method_or_error,
                'success': transcript is not None
            })

    return all_transcripts

# Process all files
transcripts = process_audio_files(wav_files)

# STEP 9: SAVE RESULTS
print("\nSTEP 9: Saving transcription results...")

# Convert to DataFrame
df = pd.DataFrame(transcripts)

# Save complete results
complete_output = os.path.join(OUTPUT_PATH, "all_transcripts.csv")
df.to_csv(complete_output, index=False)
print(f"✓ Saved complete results to: {complete_output}")

# Save successful transcripts only
successful_df = df[df['success'] == True].copy()
success_output = os.path.join(OUTPUT_PATH, "successful_transcripts.csv")
successful_df.to_csv(success_output, index=False)
print(f"✓ Saved successful transcripts to: {success_output}")

# Save by dataset
datasets_to_save = df['dataset'].unique()
for dataset in datasets_to_save:
    dataset_df = df[df['dataset'] == dataset].copy()
    dataset_output = os.path.join(OUTPUT_PATH, f"{dataset}_transcripts.csv")
    dataset_df.to_csv(dataset_output, index=False)
    print(f"✓ Saved {dataset} transcripts to: {dataset_output}")

# STEP 10: DISPLAY SUMMARY STATISTICS
print("\nSTEP 10: Summary Statistics")
print("="*50)

total_files = len(df)
successful = len(successful_df)
failed = total_files - successful

print(f"Total audio files processed: {total_files}")
print(f"Successful transcriptions: {successful} ({successful/total_files*100:.1f}%)")
print(f"Failed transcriptions: {failed} ({failed/total_files*100:.1f}%)")

print(f"\nDataset breakdown:")
for dataset in df['dataset'].unique():
    dataset_total = len(df[df['dataset'] == dataset])
    dataset_success = len(df[(df['dataset'] == dataset) & (df['success'] == True)])
    print(f"  {dataset}: {dataset_success}/{dataset_total} successful ({dataset_success/dataset_total*100:.1f}%)")

print(f"\nLabel distribution (successful transcripts only):")
if not successful_df.empty:
    print(successful_df['label'].value_counts())

print(f"\nRecognition methods used:")
if not successful_df.empty:
    print(successful_df['recognition_method'].value_counts())

# Show sample transcripts
print(f"\nSample successful transcripts:")
sample_transcripts = successful_df['transcript'].dropna().head(3)
for i, transcript in enumerate(sample_transcripts):
    print(f"  Sample {i+1}: {transcript[:200]}...")

# Show common errors
print(f"\nMost common errors:")
error_df = df[df['success'] == False]
if not error_df.empty:
    error_counts = error_df['error'].value_counts().head(5)
    for error, count in error_counts.items():
        print(f"  {error}: {count} files")

print("\n" + "="*60)
print("TRANSCRIPT EXTRACTION COMPLETE!")
print(f"All results saved in: {OUTPUT_PATH}")
print("="*60)

ADReSSo21 AUDIO TRANSCRIPT EXTRACTOR

STEP 1: Mounting Google Drive...
Mounted at /content/drive
✓ Google Drive mounted successfully!

STEP 2: Installing required packages...
Installing speech recognition and audio processing libraries...
✓ Packages ready (make sure to install them first)

STEP 3: Setting up paths and configuration...
✓ Base path: /content/drive/MyDrive/Voice/
✓ Extract path: /content/drive/MyDrive/Voice/extracted/
✓ Output path: /content/drive/MyDrive/Voice/transcripts/

STEP 4: Extracting dataset files...
  Extracting ADReSSo21-progression-train.tgz...
  ✓ ADReSSo21-progression-train.tgz extracted successfully
  Extracting ADReSSo21-progression-test.tgz...
  ✓ ADReSSo21-progression-test.tgz extracted successfully
  Extracting ADReSSo21-diagnosis-train.tgz...
  ✓ ADReSSo21-diagnosis-train.tgz extracted successfully

STEP 5: Finding all WAV files...
  Found 15 decline WAV files
  Found 58 no_decline WAV files
  Found 32 test WAV files
  Found 87 AD WAV files
  Found 79

In [5]:
# Complete Audio-Text AD Classification Pipeline
# This script combines audio feature extraction, BERT processing, and DARTS classification

import os
import pandas as pd
import numpy as np
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# For BERT processing
from transformers import AutoTokenizer, AutoModel
import re

print("="*80)
print("COMPREHENSIVE AD CLASSIFICATION PIPELINE")
print("Audio Features + BERT + DARTS Architecture")
print("="*80)

COMPREHENSIVE AD CLASSIFICATION PIPELINE
Audio Features + BERT + DARTS Architecture


In [6]:
# ============================================================================
# PART 1: ADVANCED AUDIO FEATURE EXTRACTION
# ============================================================================

class AudioFeatureExtractor:
    def __init__(self, sr=16000, n_mfcc=13, n_fft=2048, hop_length=512):
        self.sr = sr
        self.n_mfcc = n_mfcc
        self.n_fft = n_fft
        self.hop_length = hop_length

    def extract_mfcc_features(self, audio_path):
        """Extract comprehensive MFCC features including deltas"""
        try:
            # Load audio
            y, sr = librosa.load(audio_path, sr=self.sr)

            # Extract MFCC features
            mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=self.n_mfcc,
                                       n_fft=self.n_fft, hop_length=self.hop_length)

            # Extract delta features (first derivative)
            delta_mfccs = librosa.feature.delta(mfccs)

            # Extract delta-delta features (second derivative)
            delta2_mfccs = librosa.feature.delta(mfccs, order=2)

            # Combine all MFCC features
            combined_mfccs = np.concatenate([mfccs, delta_mfccs, delta2_mfccs], axis=0)

            # Statistical features for each coefficient
            features = {}

            # Mean, std, min, max for each coefficient
            features['mfcc_mean'] = np.mean(combined_mfccs, axis=1)
            features['mfcc_std'] = np.std(combined_mfccs, axis=1)
            features['mfcc_min'] = np.min(combined_mfccs, axis=1)
            features['mfcc_max'] = np.max(combined_mfccs, axis=1)
            features['mfcc_median'] = np.median(combined_mfccs, axis=1)
            features['mfcc_skew'] = self._calculate_skewness(combined_mfccs)
            features['mfcc_kurtosis'] = self._calculate_kurtosis(combined_mfccs)

            return features, combined_mfccs

        except Exception as e:
            print(f"Error extracting MFCC from {audio_path}: {e}")
            return None, None

    def extract_spectral_features(self, audio_path):
        """Extract spectral features"""
        try:
            y, sr = librosa.load(audio_path, sr=self.sr)

            features = {}

            # Spectral centroid
            spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)[0]
            features['spectral_centroid_mean'] = np.mean(spectral_centroid)
            features['spectral_centroid_std'] = np.std(spectral_centroid)

            # Spectral bandwidth
            spectral_bandwidth = librosa.feature.spectral_bandwidth(y=y, sr=sr)[0]
            features['spectral_bandwidth_mean'] = np.mean(spectral_bandwidth)
            features['spectral_bandwidth_std'] = np.std(spectral_bandwidth)

            # Spectral rolloff
            spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)[0]
            features['spectral_rolloff_mean'] = np.mean(spectral_rolloff)
            features['spectral_rolloff_std'] = np.std(spectral_rolloff)

            # Zero crossing rate
            zcr = librosa.feature.zero_crossing_rate(y)[0]
            features['zcr_mean'] = np.mean(zcr)
            features['zcr_std'] = np.std(zcr)

            # Chroma features
            chroma = librosa.feature.chroma_stft(y=y, sr=sr)
            features['chroma_mean'] = np.mean(chroma, axis=1)
            features['chroma_std'] = np.std(chroma, axis=1)

            # Tempo
            tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
            features['tempo'] = tempo

            return features

        except Exception as e:
            print(f"Error extracting spectral features from {audio_path}: {e}")
            return {}

    def extract_prosodic_features(self, audio_path):
        """Extract prosodic features (pitch, energy, etc.)"""
        try:
            y, sr = librosa.load(audio_path, sr=self.sr)

            features = {}

            # Fundamental frequency (pitch)
            f0, voiced_flag, voiced_probs = librosa.pyin(y, fmin=librosa.note_to_hz('C2'),
                                                       fmax=librosa.note_to_hz('C7'))

            # Remove NaN values
            f0_clean = f0[~np.isnan(f0)]
            if len(f0_clean) > 0:
                features['f0_mean'] = np.mean(f0_clean)
                features['f0_std'] = np.std(f0_clean)
                features['f0_min'] = np.min(f0_clean)
                features['f0_max'] = np.max(f0_clean)
                features['f0_range'] = np.max(f0_clean) - np.min(f0_clean)
            else:
                features.update({
                    'f0_mean': 0, 'f0_std': 0, 'f0_min': 0,
                    'f0_max': 0, 'f0_range': 0
                })

            # RMS energy
            rms = librosa.feature.rms(y=y)[0]
            features['rms_mean'] = np.mean(rms)
            features['rms_std'] = np.std(rms)

            # Spectral contrast
            contrast = librosa.feature.spectral_contrast(y=y, sr=sr)
            features['contrast_mean'] = np.mean(contrast, axis=1)
            features['contrast_std'] = np.std(contrast, axis=1)

            return features

        except Exception as e:
            print(f"Error extracting prosodic features from {audio_path}: {e}")
            return {}

    def _calculate_skewness(self, data):
        """Calculate skewness for each row"""
        mean = np.mean(data, axis=1, keepdims=True)
        std = np.std(data, axis=1, keepdims=True)
        std[std == 0] = 1  # Avoid division by zero
        normalized = (data - mean) / std
        skewness = np.mean(normalized**3, axis=1)
        return skewness

    def _calculate_kurtosis(self, data):
        """Calculate kurtosis for each row"""
        mean = np.mean(data, axis=1, keepdims=True)
        std = np.std(data, axis=1, keepdims=True)
        std[std == 0] = 1  # Avoid division by zero
        normalized = (data - mean) / std
        kurtosis = np.mean(normalized**4, axis=1) - 3
        return kurtosis

    def extract_all_features(self, audio_path):
        """Extract all audio features"""
        all_features = {}

        # MFCC features
        mfcc_features, mfcc_matrix = self.extract_mfcc_features(audio_path)
        if mfcc_features:
            all_features.update(mfcc_features)

        # Spectral features
        spectral_features = self.extract_spectral_features(audio_path)
        all_features.update(spectral_features)

        # Prosodic features
        prosodic_features = self.extract_prosodic_features(audio_path)
        all_features.update(prosodic_features)

        # Flatten nested arrays
        flattened_features = {}
        for key, value in all_features.items():
            if isinstance(value, np.ndarray):
                if value.ndim == 1:
                    for i, v in enumerate(value):
                        flattened_features[f"{key}_{i}"] = v
                else:
                    flattened_features[key] = np.mean(value)
            else:
                flattened_features[key] = value

        return flattened_features, mfcc_matrix


In [7]:

# ============================================================================
# PART 2: BERT TEXT PROCESSING
# ============================================================================

class BERTTextProcessor:
    def __init__(self, model_name='bert-base-uncased', max_length=512):
        self.model_name = model_name
        self.max_length = max_length
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.model.eval()

    def preprocess_text(self, text):
        """Clean and preprocess text"""
        if pd.isna(text) or text is None:
            return ""

        # Convert to string and clean
        text = str(text).lower()

        # Remove special characters but keep spaces and basic punctuation
        text = re.sub(r'[^a-zA-Z0-9\s\.\,\!\?]', '', text)

        # Remove extra whitespace
        text = ' '.join(text.split())

        return text

    def extract_bert_features(self, text):
        """Extract BERT embeddings from text"""
        try:
            # Preprocess text
            clean_text = self.preprocess_text(text)

            if not clean_text:
                # Return zero vector for empty text
                return np.zeros(768)

            # Tokenize
            inputs = self.tokenizer(
                clean_text,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            # Extract features
            with torch.no_grad():
                outputs = self.model(**inputs)

                # Use [CLS] token embedding (first token)
                cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze()

                # Also compute mean pooling of all tokens
                attention_mask = inputs['attention_mask']
                token_embeddings = outputs.last_hidden_state
                input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
                sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
                sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
                mean_embedding = sum_embeddings / sum_mask
                mean_embedding = mean_embedding.squeeze()

                # Combine CLS and mean pooling
                combined_embedding = (cls_embedding + mean_embedding) / 2

                return combined_embedding.numpy()

        except Exception as e:
            print(f"Error extracting BERT features: {e}")
            return np.zeros(768)

    def extract_linguistic_features(self, text):
        """Extract basic linguistic features"""
        try:
            clean_text = self.preprocess_text(text)

            if not clean_text:
                return {
                    'word_count': 0, 'char_count': 0, 'avg_word_length': 0,
                    'sentence_count': 0, 'question_count': 0, 'exclamation_count': 0
                }

            words = clean_text.split()
            sentences = clean_text.split('.')

            features = {
                'word_count': len(words),
                'char_count': len(clean_text),
                'avg_word_length': np.mean([len(word) for word in words]) if words else 0,
                'sentence_count': len([s for s in sentences if s.strip()]),
                'question_count': clean_text.count('?'),
                'exclamation_count': clean_text.count('!')
            }

            return features

        except Exception as e:
            print(f"Error extracting linguistic features: {e}")
            return {'word_count': 0, 'char_count': 0, 'avg_word_length': 0,
                   'sentence_count': 0, 'question_count': 0, 'exclamation_count': 0}



In [9]:
# ============================================================================
# PART 3: IMPROVED DARTS ARCHITECTURE
# ============================================================================

class ImprovedDARTSCell(nn.Module):
    def __init__(self, input_dim, output_dim, num_ops=8):
        super(ImprovedDARTSCell, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        # Ensure dimensions match for operations
        if input_dim != output_dim:
            self.projection = nn.Linear(input_dim, output_dim)
        else:
            self.projection = nn.Identity()

        # Define possible operations with proper dimensionality
        self.operations = nn.ModuleList([
            nn.Identity(),  # Skip connection
            nn.ReLU(),      # Activation
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU()),  # Linear + ReLU
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.Tanh()),  # Linear + Tanh
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU(), nn.Dropout(0.1)),  # With dropout
            nn.Sequential(nn.Linear(output_dim, output_dim // 2), nn.ReLU(), nn.Linear(output_dim // 2, output_dim)),  # Bottleneck
            nn.Sequential(nn.LayerNorm(output_dim), nn.ReLU()),  # Layer norm + activation
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU(), nn.Linear(output_dim, output_dim))  # Deep linear
        ])

        # Architecture parameters (alpha) - learnable weights for each operation
        self.alpha = nn.Parameter(torch.randn(len(self.operations)))

        # Temperature parameter for gumbel softmax (learnable)
        self.temperature = nn.Parameter(torch.ones(1))

    def forward(self, x):
        # Project input to correct dimension
        x = self.projection(x)

        # Apply Gumbel Softmax for differentiable architecture search
        if self.training:
            # Use Gumbel Softmax during training
            gumbel_weights = F.gumbel_softmax(self.alpha, tau=self.temperature, hard=False)
        else:
            # Use regular softmax during evaluation
            gumbel_weights = F.softmax(self.alpha / self.temperature, dim=0)

        # Apply operations
        outputs = []
        for op in self.operations:
            try:
                out = op(x)
                outputs.append(out)
            except Exception as e:
                # Fallback to identity if operation fails
                outputs.append(x)

        # Weighted combination of all operations
        result = sum(w * out for w, out in zip(gumbel_weights, outputs))

        return result

    def get_selected_operation(self):
        """Get the operation with highest weight (for inference)"""
        selected_idx = torch.argmax(self.alpha).item()
        return selected_idx, self.operations[selected_idx]

class MultimodalDARTSClassifier(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim=256, num_classes=2):
        super(MultimodalDARTSClassifier, self).__init__()

        # Audio processing branch with DARTS
        self.audio_projection = nn.Sequential(
            nn.Linear(audio_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        self.audio_darts_cells = nn.ModuleList([
            ImprovedDARTSCell(hidden_dim, hidden_dim) for _ in range(3)
        ])

        # Text processing branch with DARTS
        self.text_projection = nn.Sequential(
            nn.Linear(text_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        self.text_darts_cells = nn.ModuleList([
            ImprovedDARTSCell(hidden_dim, hidden_dim) for _ in range(2)
        ])

        # Cross-modal attention
        self.cross_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)

        # Fusion and classification
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.4)
        )

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_classes)
        )

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, audio_features, text_features):
        # Process audio through DARTS
        audio_x = self.audio_projection(audio_features)
        for cell in self.audio_darts_cells:
            residual = audio_x
            audio_x = cell(audio_x)
            audio_x = audio_x + residual  # Residual connection

        # Process text through DARTS
        text_x = self.text_projection(text_features)
        for cell in self.text_darts_cells:
            residual = text_x
            text_x = cell(text_x)
            text_x = text_x + residual  # Residual connection

        # Cross-modal attention
        audio_attended, _ = self.cross_attention(
            audio_x.unsqueeze(1), text_x.unsqueeze(1), text_x.unsqueeze(1)
        )
        audio_attended = audio_attended.squeeze(1)

        # Fusion
        fused = torch.cat([audio_attended, text_x], dim=1)
        fused = self.fusion(fused)

        # Classification
        output = self.classifier(fused)

        return output

    def get_architecture_info(self):
        """Get information about the learned architecture"""
        arch_info = {}

        # Audio DARTS info
        for i, cell in enumerate(self.audio_darts_cells):
            selected_idx, _ = cell.get_selected_operation()
            arch_info[f'audio_cell_{i}'] = {
                'selected_operation_idx': selected_idx,
                'operation_weights': cell.alpha.detach().cpu().numpy(),
                'temperature': cell.temperature.item()
            }

        # Text DARTS info
        for i, cell in enumerate(self.text_darts_cells):
            selected_idx, _ = cell.get_selected_operation()
            arch_info[f'text_cell_{i}'] = {
                'selected_operation_idx': selected_idx,
                'operation_weights': cell.alpha.detach().cpu().numpy(),
                'temperature': cell.temperature.item()
            }

        return arch_info


In [10]:
# ============================================================================
# PART 4: DATASET AND TRAINING UTILITIES
# ============================================================================

class ADDataset(Dataset):
    def __init__(self, audio_features, text_features, labels):
        self.audio_features = torch.FloatTensor(audio_features)
        self.text_features = torch.FloatTensor(text_features)
        self.labels = torch.LongTensor(labels)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.audio_features[idx], self.text_features[idx], self.labels[idx]

def visualize_features(audio_features, text_features, labels, feature_names=None):
    """Visualize feature distributions"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Audio feature distribution
    axes[0,0].hist(audio_features[labels==0][:, 0], alpha=0.7, label='Control', bins=30)
    axes[0,0].hist(audio_features[labels==1][:, 0], alpha=0.7, label='AD', bins=30)
    axes[0,0].set_title('Audio Feature Distribution (First Feature)')
    axes[0,0].legend()

    # Text feature distribution
    axes[0,1].hist(text_features[labels==0][:, 0], alpha=0.7, label='Control', bins=30)
    axes[0,1].hist(text_features[labels==1][:, 0], alpha=0.7, label='AD', bins=30)
    axes[0,1].set_title('Text Feature Distribution (First Feature)')
    axes[0,1].legend()

    # Feature correlation heatmap (subset)
    subset_audio = audio_features[:, :min(10, audio_features.shape[1])]
    corr_matrix = np.corrcoef(subset_audio.T)
    sns.heatmap(corr_matrix, ax=axes[1,0], cmap='coolwarm', center=0)
    axes[1,0].set_title('Audio Feature Correlation (Subset)')

    # Label distribution
    unique, counts = np.unique(labels, return_counts=True)
    axes[1,1].bar(['Control', 'AD'], counts)
    axes[1,1].set_title('Label Distribution')

    plt.tight_layout()
    plt.show()

def plot_training_history(train_losses, val_losses, val_accuracies):
    """Plot training history"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # Training and validation loss
    axes[0].plot(train_losses, label='Training Loss')
    axes[0].plot(val_losses, label='Validation Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].grid(True)

    # Validation accuracy
    axes[1].plot(val_accuracies, label='Validation Accuracy', color='green')
    axes[1].set_title('Validation Accuracy')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].legend()
    axes[1].grid(True)

    # Loss difference
    loss_diff = np.array(val_losses) - np.array(train_losses)
    axes[2].plot(loss_diff, label='Val - Train Loss', color='red')
    axes[2].set_title('Overfitting Monitor')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Loss Difference')
    axes[2].legend()
    axes[2].grid(True)

    plt.tight_layout()
    plt.show()

def train_model(model, train_loader, val_loader, num_epochs=50, lr=0.001):
    """Train the DARTS model"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
    criterion = nn.CrossEntropyLoss()

    train_losses = []
    val_losses = []
    val_accuracies = []

    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0

    print(f"Training on device: {device}")
    print(f"Training batches: {len(train_loader)}, Validation batches: {len(val_loader)}")

    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        for batch_idx, (audio_batch, text_batch, labels_batch) in enumerate(train_loader):
            audio_batch, text_batch, labels_batch = audio_batch.to(device), text_batch.to(device), labels_batch.to(device)

            optimizer.zero_grad()
            outputs = model(audio_batch, text_batch)
            loss = criterion(outputs, labels_batch)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            train_loss += loss.item()

            # Calculate training accuracy
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels_batch.size(0)
            train_correct += (predicted == labels_batch).sum().item()

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for audio_batch, text_batch, labels_batch in val_loader:
                audio_batch, text_batch, labels_batch = audio_batch.to(device), text_batch.to(device), labels_batch.to(device)

                outputs = model(audio_batch, text_batch)
                loss = criterion(outputs, labels_batch)
                val_loss += loss.item()

                _, predicted = torch.max(outputs.data, 1)
                val_total += labels_batch.size(0)
                val_correct += (predicted == labels_batch).sum().item()

        # Calculate metrics
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        val_accuracy = 100 * val_correct / val_total
        train_accuracy = 100 * train_correct / train_total

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

        # Learning rate scheduling
        scheduler.step(val_loss)

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1

        # Print progress
        if epoch % 5 == 0 or epoch == num_epochs - 1:
            current_lr = optimizer.param_groups[0]['lr']
            print(f'Epoch [{epoch+1}/{num_epochs}]')
            print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%')
            print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%')
            print(f'  LR: {current_lr:.6f}, Patience: {patience_counter}/15')
            print()

        # Early stopping condition
        if patience_counter >= 15:
            print(f"Early stopping at epoch {epoch+1}")
            break

    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model with validation loss: {best_val_loss:.4f}")

    return train_losses, val_losses, val_accuracies

def evaluate_model(model, test_loader, class_names=['Control', 'AD']):
    """Evaluate the trained model and return comprehensive metrics"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()

    all_predictions = []
    all_labels = []
    all_probabilities = []

    with torch.no_grad():
        for audio_batch, text_batch, labels_batch in test_loader:
            audio_batch, text_batch, labels_batch = audio_batch.to(device), text_batch.to(device), labels_batch.to(device)

            outputs = model(audio_batch, text_batch)
            probabilities = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels_batch.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_predictions)

    # AUC score
    all_probabilities = np.array(all_probabilities)
    if all_probabilities.shape[1] == 2:
        auc_score = roc_auc_score(all_labels, all_probabilities[:, 1])
    else:
        auc_score = 0

    # Classification report
    report = classification_report(all_labels, all_predictions, target_names=class_names)

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)

    # Print results
    print("="*50)
    print("MODEL EVALUATION RESULTS")
    print("="*50)
    print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"AUC Score: {auc_score:.4f}")
    print("\nClassification Report:")
    print(report)
    print("\nConfusion Matrix:")
    print(cm)

    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.show()

    # Return metrics for further analysis
    return {
        'accuracy': accuracy,
        'auc_score': auc_score,
        'predictions': all_predictions,
        'labels': all_labels,
        'probabilities': all_probabilities,
        'confusion_matrix': cm,
        'classification_report': report
    }