<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/ADReSSo21.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install SpeechRecognition pydub

Collecting SpeechRecognition
  Downloading speechrecognition-3.14.3-py3-none-any.whl.metadata (30 kB)
Collecting pydub
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Downloading speechrecognition-3.14.3-py3-none-any.whl (32.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m32.9/32.9 MB[0m [31m49.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)
Installing collected packages: pydub, SpeechRecognition
Successfully installed SpeechRecognition-3.14.3 pydub-0.25.1


In [2]:
# Step-by-Step Audio Transcript Extractor for ADReSSo21 Dataset
# This script will:
# 1. Mount Google Drive
# 2. Extract dataset files
# 3. Find all WAV files
# 4. Extract transcripts from audio using speech recognition
# 5. Save organized transcripts

import os
import tarfile
import pandas as pd
import numpy as np
from pathlib import Path
import librosa
import speech_recognition as sr
import soundfile as sf
from pydub import AudioSegment
import warnings
warnings.filterwarnings('ignore')

print("="*60)
print("ADReSSo21 AUDIO TRANSCRIPT EXTRACTOR")
print("="*60)

# STEP 1: MOUNT GOOGLE DRIVE
print("\nSTEP 1: Mounting Google Drive...")
try:
    from google.colab import drive
    drive.mount('/content/drive')
    print("✓ Google Drive mounted successfully!")
except:
    print("⚠ Not running in Colab or Drive already mounted")

# STEP 2: INSTALL REQUIRED PACKAGES
print("\nSTEP 2: Installing required packages...")
print("Installing speech recognition and audio processing libraries...")

# Install packages (run once)
!pip install SpeechRecognition
!pip install pydub
!pip install librosa
!pip install soundfile
!apt-get install -y ffmpeg

print("✓ Packages ready (make sure to install them first)")

# STEP 3: SET UP PATHS AND CONFIGURATION
print("\nSTEP 3: Setting up paths and configuration...")

BASE_PATH = "/content/drive/MyDrive/Voice/"
EXTRACT_PATH = "/content/drive/MyDrive/Voice/extracted/"
OUTPUT_PATH = "/content/drive/MyDrive/Voice/transcripts/"

# Create directories
os.makedirs(EXTRACT_PATH, exist_ok=True)
os.makedirs(OUTPUT_PATH, exist_ok=True)

datasets = {
    'progression_train': 'ADReSSo21-progression-train.tgz',
    'progression_test': 'ADReSSo21-progression-test.tgz',
    'diagnosis_train': 'ADReSSo21-diagnosis-train.tgz'
}

print(f"✓ Base path: {BASE_PATH}")
print(f"✓ Extract path: {EXTRACT_PATH}")
print(f"✓ Output path: {OUTPUT_PATH}")

# STEP 4: EXTRACT DATASET FILES
print("\nSTEP 4: Extracting dataset files...")

def extract_datasets():
    """Extract all tgz files"""
    for dataset_name, filename in datasets.items():
        file_path = os.path.join(BASE_PATH, filename)

        if os.path.exists(file_path):
            print(f"  Extracting {filename}...")
            try:
                with tarfile.open(file_path, 'r:gz') as tar:
                    tar.extractall(path=EXTRACT_PATH)
                print(f"  ✓ {filename} extracted successfully")
            except Exception as e:
                print(f"  ⚠ Error extracting {filename}: {e}")
        else:
            print(f"  ⚠ {filename} not found at {file_path}")

extract_datasets()

# STEP 5: FIND ALL WAV FILES
print("\nSTEP 5: Finding all WAV files...")

def find_wav_files():
    """Find all WAV files and organize by dataset and label"""
    wav_files = {
        'progression_train': {'decline': [], 'no_decline': []},
        'progression_test': [],
        'diagnosis_train': {'ad': [], 'cn': []}
    }

    # Progression training files
    prog_train_base = os.path.join(EXTRACT_PATH, "ADReSSo21/progression/train/audio/")

    # Decline cases
    decline_path = os.path.join(prog_train_base, "decline/")
    if os.path.exists(decline_path):
        decline_wavs = [f for f in os.listdir(decline_path) if f.endswith('.wav')]
        wav_files['progression_train']['decline'] = [os.path.join(decline_path, f) for f in decline_wavs]
        print(f"  Found {len(decline_wavs)} decline WAV files")

    # No decline cases
    no_decline_path = os.path.join(prog_train_base, "no_decline/")
    if os.path.exists(no_decline_path):
        no_decline_wavs = [f for f in os.listdir(no_decline_path) if f.endswith('.wav')]
        wav_files['progression_train']['no_decline'] = [os.path.join(no_decline_path, f) for f in no_decline_wavs]
        print(f"  Found {len(no_decline_wavs)} no_decline WAV files")

    # Progression test files
    prog_test_path = os.path.join(EXTRACT_PATH, "ADReSSo21/progression/test-dist/audio/")
    if os.path.exists(prog_test_path):
        test_wavs = [f for f in os.listdir(prog_test_path) if f.endswith('.wav')]
        wav_files['progression_test'] = [os.path.join(prog_test_path, f) for f in test_wavs]
        print(f"  Found {len(test_wavs)} test WAV files")

    # Diagnosis training files
    diag_train_base = os.path.join(EXTRACT_PATH, "ADReSSo21/diagnosis/train/audio/")

    # AD cases
    ad_path = os.path.join(diag_train_base, "ad/")
    if os.path.exists(ad_path):
        ad_wavs = [f for f in os.listdir(ad_path) if f.endswith('.wav')]
        wav_files['diagnosis_train']['ad'] = [os.path.join(ad_path, f) for f in ad_wavs]
        print(f"  Found {len(ad_wavs)} AD WAV files")

    # CN cases
    cn_path = os.path.join(diag_train_base, "cn/")
    if os.path.exists(cn_path):
        cn_wavs = [f for f in os.listdir(cn_path) if f.endswith('.wav')]
        wav_files['diagnosis_train']['cn'] = [os.path.join(cn_path, f) for f in cn_wavs]
        print(f"  Found {len(cn_wavs)} CN WAV files")

    return wav_files

wav_files = find_wav_files()

# STEP 6: AUDIO PREPROCESSING FUNCTIONS
print("\nSTEP 6: Setting up audio preprocessing...")

def preprocess_audio(audio_path, target_sr=16000):
    """Preprocess audio file for speech recognition"""
    try:
        # Load audio with librosa
        audio, sr = librosa.load(audio_path, sr=target_sr)

        # Normalize audio
        audio = librosa.util.normalize(audio)

        # Remove silence
        audio_trimmed, _ = librosa.effects.trim(audio, top_db=20)

        return audio_trimmed, target_sr
    except Exception as e:
        print(f"    Error preprocessing {audio_path}: {e}")
        return None, None

def convert_to_wav_if_needed(audio_path):
    """Convert audio to WAV format if needed"""
    try:
        if not audio_path.endswith('.wav'):
            # Convert using pydub
            audio = AudioSegment.from_file(audio_path)
            wav_path = audio_path.rsplit('.', 1)[0] + '_converted.wav'
            audio.export(wav_path, format="wav")
            return wav_path
        return audio_path
    except Exception as e:
        print(f"    Error converting {audio_path}: {e}")
        return audio_path

# STEP 7: SPEECH RECOGNITION FUNCTION
print("\nSTEP 7: Setting up speech recognition...")

def extract_transcript_from_audio(audio_path, method='google'):
    """Extract transcript from audio file using speech recognition"""
    recognizer = sr.Recognizer()

    try:
        # Convert to WAV if needed
        wav_path = convert_to_wav_if_needed(audio_path)

        # Preprocess audio
        audio_data, sr_rate = preprocess_audio(wav_path, target_sr=16000)

        if audio_data is None:
            return None, "Preprocessing failed"

        # Save preprocessed audio temporarily
        temp_wav = audio_path.replace('.wav', '_temp.wav')
        sf.write(temp_wav, audio_data, sr_rate)

        # Use speech recognition
        with sr.AudioFile(temp_wav) as source:
            # Adjust for ambient noise
            recognizer.adjust_for_ambient_noise(source, duration=0.5)
            audio = recognizer.listen(source)

        # Try different recognition methods
        transcript = None
        error_msg = ""

        if method == 'google':
            try:
                transcript = recognizer.recognize_google(audio)
            except sr.UnknownValueError:
                error_msg = "Google Speech Recognition could not understand audio"
            except sr.RequestError as e:
                error_msg = f"Google Speech Recognition error: {e}"

        # Fallback to other methods if Google fails
        if transcript is None:
            try:
                transcript = recognizer.recognize_sphinx(audio)
                method = 'sphinx'
            except sr.UnknownValueError:
                error_msg += "; Sphinx could not understand audio"
            except sr.RequestError as e:
                error_msg += f"; Sphinx error: {e}"

        # Clean up temporary file
        if os.path.exists(temp_wav):
            os.remove(temp_wav)

        if transcript:
            return transcript.strip(), method
        else:
            return None, error_msg

    except Exception as e:
        return None, f"Error processing audio: {str(e)}"

# STEP 8: PROCESS ALL AUDIO FILES AND EXTRACT TRANSCRIPTS
print("\nSTEP 8: Processing audio files and extracting transcripts...")
print("This may take a while depending on the number and length of audio files...")

def process_audio_files(wav_files):
    """Process all audio files and extract transcripts"""
    all_transcripts = []

    # Process progression training data
    print("\n  Processing progression training data...")
    for label in ['decline', 'no_decline']:
        files = wav_files['progression_train'][label]
        print(f"    Processing {len(files)} {label} files...")

        for i, audio_path in enumerate(files):
            print(f"      Processing {i+1}/{len(files)}: {os.path.basename(audio_path)}")

            transcript, method_or_error = extract_transcript_from_audio(audio_path)

            all_transcripts.append({
                'file_id': os.path.splitext(os.path.basename(audio_path))[0],
                'file_path': audio_path,
                'dataset': 'progression_train',
                'label': label,
                'transcript': transcript,
                'recognition_method': method_or_error if transcript else None,
                'error': None if transcript else method_or_error,
                'success': transcript is not None
            })

    # Process progression test data
    print("\n  Processing progression test data...")
    files = wav_files['progression_test']
    print(f"    Processing {len(files)} test files...")

    for i, audio_path in enumerate(files):
        print(f"      Processing {i+1}/{len(files)}: {os.path.basename(audio_path)}")

        transcript, method_or_error = extract_transcript_from_audio(audio_path)

        all_transcripts.append({
            'file_id': os.path.splitext(os.path.basename(audio_path))[0],
            'file_path': audio_path,
            'dataset': 'progression_test',
            'label': 'test',
            'transcript': transcript,
            'recognition_method': method_or_error if transcript else None,
            'error': None if transcript else method_or_error,
            'success': transcript is not None
        })

    # Process diagnosis training data
    print("\n  Processing diagnosis training data...")
    for label in ['ad', 'cn']:
        files = wav_files['diagnosis_train'][label]
        print(f"    Processing {len(files)} {label} files...")

        for i, audio_path in enumerate(files):
            print(f"      Processing {i+1}/{len(files)}: {os.path.basename(audio_path)}")

            transcript, method_or_error = extract_transcript_from_audio(audio_path)

            all_transcripts.append({
                'file_id': os.path.splitext(os.path.basename(audio_path))[0],
                'file_path': audio_path,
                'dataset': 'diagnosis_train',
                'label': label,
                'transcript': transcript,
                'recognition_method': method_or_error if transcript else None,
                'error': None if transcript else method_or_error,
                'success': transcript is not None
            })

    return all_transcripts

# Process all files
transcripts = process_audio_files(wav_files)

# STEP 9: SAVE RESULTS
print("\nSTEP 9: Saving transcription results...")

# Convert to DataFrame
df = pd.DataFrame(transcripts)

# Save complete results
complete_output = os.path.join(OUTPUT_PATH, "all_transcripts.csv")
df.to_csv(complete_output, index=False)
print(f"✓ Saved complete results to: {complete_output}")

# Save successful transcripts only
successful_df = df[df['success'] == True].copy()
success_output = os.path.join(OUTPUT_PATH, "successful_transcripts.csv")
successful_df.to_csv(success_output, index=False)
print(f"✓ Saved successful transcripts to: {success_output}")

# Save by dataset
datasets_to_save = df['dataset'].unique()
for dataset in datasets_to_save:
    dataset_df = df[df['dataset'] == dataset].copy()
    dataset_output = os.path.join(OUTPUT_PATH, f"{dataset}_transcripts.csv")
    dataset_df.to_csv(dataset_output, index=False)
    print(f"✓ Saved {dataset} transcripts to: {dataset_output}")

# STEP 10: DISPLAY SUMMARY STATISTICS
print("\nSTEP 10: Summary Statistics")
print("="*50)

total_files = len(df)
successful = len(successful_df)
failed = total_files - successful

print(f"Total audio files processed: {total_files}")
print(f"Successful transcriptions: {successful} ({successful/total_files*100:.1f}%)")
print(f"Failed transcriptions: {failed} ({failed/total_files*100:.1f}%)")

print(f"\nDataset breakdown:")
for dataset in df['dataset'].unique():
    dataset_total = len(df[df['dataset'] == dataset])
    dataset_success = len(df[(df['dataset'] == dataset) & (df['success'] == True)])
    print(f"  {dataset}: {dataset_success}/{dataset_total} successful ({dataset_success/dataset_total*100:.1f}%)")

print(f"\nLabel distribution (successful transcripts only):")
if not successful_df.empty:
    print(successful_df['label'].value_counts())

print(f"\nRecognition methods used:")
if not successful_df.empty:
    print(successful_df['recognition_method'].value_counts())

# Show sample transcripts
print(f"\nSample successful transcripts:")
sample_transcripts = successful_df['transcript'].dropna().head(3)
for i, transcript in enumerate(sample_transcripts):
    print(f"  Sample {i+1}: {transcript[:200]}...")

# Show common errors
print(f"\nMost common errors:")
error_df = df[df['success'] == False]
if not error_df.empty:
    error_counts = error_df['error'].value_counts().head(5)
    for error, count in error_counts.items():
        print(f"  {error}: {count} files")

print("\n" + "="*60)
print("TRANSCRIPT EXTRACTION COMPLETE!")
print(f"All results saved in: {OUTPUT_PATH}")
print("="*60)

ADReSSo21 AUDIO TRANSCRIPT EXTRACTOR

STEP 1: Mounting Google Drive...
Mounted at /content/drive
✓ Google Drive mounted successfully!

STEP 2: Installing required packages...
Installing speech recognition and audio processing libraries...
✓ Packages ready (make sure to install them first)

STEP 3: Setting up paths and configuration...
✓ Base path: /content/drive/MyDrive/Voice/
✓ Extract path: /content/drive/MyDrive/Voice/extracted/
✓ Output path: /content/drive/MyDrive/Voice/transcripts/

STEP 4: Extracting dataset files...
  Extracting ADReSSo21-progression-train.tgz...
  ✓ ADReSSo21-progression-train.tgz extracted successfully
  Extracting ADReSSo21-progression-test.tgz...
  ✓ ADReSSo21-progression-test.tgz extracted successfully
  Extracting ADReSSo21-diagnosis-train.tgz...
  ✓ ADReSSo21-diagnosis-train.tgz extracted successfully

STEP 5: Finding all WAV files...
  Found 15 decline WAV files
  Found 58 no_decline WAV files
  Found 32 test WAV files
  Found 87 AD WAV files
  Found 79

In [5]:
# Complete Audio-Text AD Classification Pipeline
# This script combines audio feature extraction, BERT processing, and DARTS classification

import os
import pandas as pd
import numpy as np
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# For BERT processing
from transformers import AutoTokenizer, AutoModel
import re

print("="*80)
print("COMPREHENSIVE AD CLASSIFICATION PIPELINE")
print("Audio Features + BERT + DARTS Architecture")
print("="*80)

COMPREHENSIVE AD CLASSIFICATION PIPELINE
Audio Features + BERT + DARTS Architecture


In [6]:
# ============================================================================
# PART 1: ADVANCED AUDIO FEATURE EXTRACTION
# ============================================================================

class AudioFeatureExtractor:
    def __init__(self, sr=16000, n_mfcc=13, n_fft=2048, hop_length=512):
        self.sr = sr
        self.n_mfcc = n_mfcc
        self.n_fft = n_fft
        self.hop_length = hop_length

    def extract_mfcc_features(self, audio_path):
        """Extract comprehensive MFCC features including deltas"""
        try:
            # Load audio
            y, sr = librosa.load(audio_path, sr=self.sr)

            # Extract MFCC features
            mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=self.n_mfcc,
                                       n_fft=self.n_fft, hop_length=self.hop_length)

            # Extract delta features (first derivative)
            delta_mfccs = librosa.feature.delta(mfccs)

            # Extract delta-delta features (second derivative)
            delta2_mfccs = librosa.feature.delta(mfccs, order=2)

            # Combine all MFCC features
            combined_mfccs = np.concatenate([mfccs, delta_mfccs, delta2_mfccs], axis=0)

            # Statistical features for each coefficient
            features = {}

            # Mean, std, min, max for each coefficient
            features['mfcc_mean'] = np.mean(combined_mfccs, axis=1)
            features['mfcc_std'] = np.std(combined_mfccs, axis=1)
            features['mfcc_min'] = np.min(combined_mfccs, axis=1)
            features['mfcc_max'] = np.max(combined_mfccs, axis=1)
            features['mfcc_median'] = np.median(combined_mfccs, axis=1)
            features['mfcc_skew'] = self._calculate_skewness(combined_mfccs)
            features['mfcc_kurtosis'] = self._calculate_kurtosis(combined_mfccs)

            return features, combined_mfccs

        except Exception as e:
            print(f"Error extracting MFCC from {audio_path}: {e}")
            return None, None

    def extract_spectral_features(self, audio_path):
        """Extract spectral features"""
        try:
            y, sr = librosa.load(audio_path, sr=self.sr)

            features = {}

            # Spectral centroid
            spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)[0]
            features['spectral_centroid_mean'] = np.mean(spectral_centroid)
            features['spectral_centroid_std'] = np.std(spectral_centroid)

            # Spectral bandwidth
            spectral_bandwidth = librosa.feature.spectral_bandwidth(y=y, sr=sr)[0]
            features['spectral_bandwidth_mean'] = np.mean(spectral_bandwidth)
            features['spectral_bandwidth_std'] = np.std(spectral_bandwidth)

            # Spectral rolloff
            spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)[0]
            features['spectral_rolloff_mean'] = np.mean(spectral_rolloff)
            features['spectral_rolloff_std'] = np.std(spectral_rolloff)

            # Zero crossing rate
            zcr = librosa.feature.zero_crossing_rate(y)[0]
            features['zcr_mean'] = np.mean(zcr)
            features['zcr_std'] = np.std(zcr)

            # Chroma features
            chroma = librosa.feature.chroma_stft(y=y, sr=sr)
            features['chroma_mean'] = np.mean(chroma, axis=1)
            features['chroma_std'] = np.std(chroma, axis=1)

            # Tempo
            tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
            features['tempo'] = tempo

            return features

        except Exception as e:
            print(f"Error extracting spectral features from {audio_path}: {e}")
            return {}

    def extract_prosodic_features(self, audio_path):
        """Extract prosodic features (pitch, energy, etc.)"""
        try:
            y, sr = librosa.load(audio_path, sr=self.sr)

            features = {}

            # Fundamental frequency (pitch)
            f0, voiced_flag, voiced_probs = librosa.pyin(y, fmin=librosa.note_to_hz('C2'),
                                                       fmax=librosa.note_to_hz('C7'))

            # Remove NaN values
            f0_clean = f0[~np.isnan(f0)]
            if len(f0_clean) > 0:
                features['f0_mean'] = np.mean(f0_clean)
                features['f0_std'] = np.std(f0_clean)
                features['f0_min'] = np.min(f0_clean)
                features['f0_max'] = np.max(f0_clean)
                features['f0_range'] = np.max(f0_clean) - np.min(f0_clean)
            else:
                features.update({
                    'f0_mean': 0, 'f0_std': 0, 'f0_min': 0,
                    'f0_max': 0, 'f0_range': 0
                })

            # RMS energy
            rms = librosa.feature.rms(y=y)[0]
            features['rms_mean'] = np.mean(rms)
            features['rms_std'] = np.std(rms)

            # Spectral contrast
            contrast = librosa.feature.spectral_contrast(y=y, sr=sr)
            features['contrast_mean'] = np.mean(contrast, axis=1)
            features['contrast_std'] = np.std(contrast, axis=1)

            return features

        except Exception as e:
            print(f"Error extracting prosodic features from {audio_path}: {e}")
            return {}

    def _calculate_skewness(self, data):
        """Calculate skewness for each row"""
        mean = np.mean(data, axis=1, keepdims=True)
        std = np.std(data, axis=1, keepdims=True)
        std[std == 0] = 1  # Avoid division by zero
        normalized = (data - mean) / std
        skewness = np.mean(normalized**3, axis=1)
        return skewness

    def _calculate_kurtosis(self, data):
        """Calculate kurtosis for each row"""
        mean = np.mean(data, axis=1, keepdims=True)
        std = np.std(data, axis=1, keepdims=True)
        std[std == 0] = 1  # Avoid division by zero
        normalized = (data - mean) / std
        kurtosis = np.mean(normalized**4, axis=1) - 3
        return kurtosis

    def extract_all_features(self, audio_path):
        """Extract all audio features"""
        all_features = {}

        # MFCC features
        mfcc_features, mfcc_matrix = self.extract_mfcc_features(audio_path)
        if mfcc_features:
            all_features.update(mfcc_features)

        # Spectral features
        spectral_features = self.extract_spectral_features(audio_path)
        all_features.update(spectral_features)

        # Prosodic features
        prosodic_features = self.extract_prosodic_features(audio_path)
        all_features.update(prosodic_features)

        # Flatten nested arrays
        flattened_features = {}
        for key, value in all_features.items():
            if isinstance(value, np.ndarray):
                if value.ndim == 1:
                    for i, v in enumerate(value):
                        flattened_features[f"{key}_{i}"] = v
                else:
                    flattened_features[key] = np.mean(value)
            else:
                flattened_features[key] = value

        return flattened_features, mfcc_matrix


In [7]:

# ============================================================================
# PART 2: BERT TEXT PROCESSING
# ============================================================================

class BERTTextProcessor:
    def __init__(self, model_name='bert-base-uncased', max_length=512):
        self.model_name = model_name
        self.max_length = max_length
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.model.eval()

    def preprocess_text(self, text):
        """Clean and preprocess text"""
        if pd.isna(text) or text is None:
            return ""

        # Convert to string and clean
        text = str(text).lower()

        # Remove special characters but keep spaces and basic punctuation
        text = re.sub(r'[^a-zA-Z0-9\s\.\,\!\?]', '', text)

        # Remove extra whitespace
        text = ' '.join(text.split())

        return text

    def extract_bert_features(self, text):
        """Extract BERT embeddings from text"""
        try:
            # Preprocess text
            clean_text = self.preprocess_text(text)

            if not clean_text:
                # Return zero vector for empty text
                return np.zeros(768)

            # Tokenize
            inputs = self.tokenizer(
                clean_text,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            # Extract features
            with torch.no_grad():
                outputs = self.model(**inputs)

                # Use [CLS] token embedding (first token)
                cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze()

                # Also compute mean pooling of all tokens
                attention_mask = inputs['attention_mask']
                token_embeddings = outputs.last_hidden_state
                input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
                sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
                sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
                mean_embedding = sum_embeddings / sum_mask
                mean_embedding = mean_embedding.squeeze()

                # Combine CLS and mean pooling
                combined_embedding = (cls_embedding + mean_embedding) / 2

                return combined_embedding.numpy()

        except Exception as e:
            print(f"Error extracting BERT features: {e}")
            return np.zeros(768)

    def extract_linguistic_features(self, text):
        """Extract basic linguistic features"""
        try:
            clean_text = self.preprocess_text(text)

            if not clean_text:
                return {
                    'word_count': 0, 'char_count': 0, 'avg_word_length': 0,
                    'sentence_count': 0, 'question_count': 0, 'exclamation_count': 0
                }

            words = clean_text.split()
            sentences = clean_text.split('.')

            features = {
                'word_count': len(words),
                'char_count': len(clean_text),
                'avg_word_length': np.mean([len(word) for word in words]) if words else 0,
                'sentence_count': len([s for s in sentences if s.strip()]),
                'question_count': clean_text.count('?'),
                'exclamation_count': clean_text.count('!')
            }

            return features

        except Exception as e:
            print(f"Error extracting linguistic features: {e}")
            return {'word_count': 0, 'char_count': 0, 'avg_word_length': 0,
                   'sentence_count': 0, 'question_count': 0, 'exclamation_count': 0}



In [9]:
# ============================================================================
# PART 3: IMPROVED DARTS ARCHITECTURE
# ============================================================================

class ImprovedDARTSCell(nn.Module):
    def __init__(self, input_dim, output_dim, num_ops=8):
        super(ImprovedDARTSCell, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        # Ensure dimensions match for operations
        if input_dim != output_dim:
            self.projection = nn.Linear(input_dim, output_dim)
        else:
            self.projection = nn.Identity()

        # Define possible operations with proper dimensionality
        self.operations = nn.ModuleList([
            nn.Identity(),  # Skip connection
            nn.ReLU(),      # Activation
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU()),  # Linear + ReLU
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.Tanh()),  # Linear + Tanh
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU(), nn.Dropout(0.1)),  # With dropout
            nn.Sequential(nn.Linear(output_dim, output_dim // 2), nn.ReLU(), nn.Linear(output_dim // 2, output_dim)),  # Bottleneck
            nn.Sequential(nn.LayerNorm(output_dim), nn.ReLU()),  # Layer norm + activation
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU(), nn.Linear(output_dim, output_dim))  # Deep linear
        ])

        # Architecture parameters (alpha) - learnable weights for each operation
        self.alpha = nn.Parameter(torch.randn(len(self.operations)))

        # Temperature parameter for gumbel softmax (learnable)
        self.temperature = nn.Parameter(torch.ones(1))

    def forward(self, x):
        # Project input to correct dimension
        x = self.projection(x)

        # Apply Gumbel Softmax for differentiable architecture search
        if self.training:
            # Use Gumbel Softmax during training
            gumbel_weights = F.gumbel_softmax(self.alpha, tau=self.temperature, hard=False)
        else:
            # Use regular softmax during evaluation
            gumbel_weights = F.softmax(self.alpha / self.temperature, dim=0)

        # Apply operations
        outputs = []
        for op in self.operations:
            try:
                out = op(x)
                outputs.append(out)
            except Exception as e:
                # Fallback to identity if operation fails
                outputs.append(x)

        # Weighted combination of all operations
        result = sum(w * out for w, out in zip(gumbel_weights, outputs))

        return result

    def get_selected_operation(self):
        """Get the operation with highest weight (for inference)"""
        selected_idx = torch.argmax(self.alpha).item()
        return selected_idx, self.operations[selected_idx]

class MultimodalDARTSClassifier(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim=256, num_classes=2):
        super(MultimodalDARTSClassifier, self).__init__()

        # Audio processing branch with DARTS
        self.audio_projection = nn.Sequential(
            nn.Linear(audio_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        self.audio_darts_cells = nn.ModuleList([
            ImprovedDARTSCell(hidden_dim, hidden_dim) for _ in range(3)
        ])

        # Text processing branch with DARTS
        self.text_projection = nn.Sequential(
            nn.Linear(text_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        self.text_darts_cells = nn.ModuleList([
            ImprovedDARTSCell(hidden_dim, hidden_dim) for _ in range(2)
        ])

        # Cross-modal attention
        self.cross_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)

        # Fusion and classification
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.4)
        )

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_classes)
        )

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, audio_features, text_features):
        # Process audio through DARTS
        audio_x = self.audio_projection(audio_features)
        for cell in self.audio_darts_cells:
            residual = audio_x
            audio_x = cell(audio_x)
            audio_x = audio_x + residual  # Residual connection

        # Process text through DARTS
        text_x = self.text_projection(text_features)
        for cell in self.text_darts_cells:
            residual = text_x
            text_x = cell(text_x)
            text_x = text_x + residual  # Residual connection

        # Cross-modal attention
        audio_attended, _ = self.cross_attention(
            audio_x.unsqueeze(1), text_x.unsqueeze(1), text_x.unsqueeze(1)
        )
        audio_attended = audio_attended.squeeze(1)

        # Fusion
        fused = torch.cat([audio_attended, text_x], dim=1)
        fused = self.fusion(fused)

        # Classification
        output = self.classifier(fused)

        return output

    def get_architecture_info(self):
        """Get information about the learned architecture"""
        arch_info = {}

        # Audio DARTS info
        for i, cell in enumerate(self.audio_darts_cells):
            selected_idx, _ = cell.get_selected_operation()
            arch_info[f'audio_cell_{i}'] = {
                'selected_operation_idx': selected_idx,
                'operation_weights': cell.alpha.detach().cpu().numpy(),
                'temperature': cell.temperature.item()
            }

        # Text DARTS info
        for i, cell in enumerate(self.text_darts_cells):
            selected_idx, _ = cell.get_selected_operation()
            arch_info[f'text_cell_{i}'] = {
                'selected_operation_idx': selected_idx,
                'operation_weights': cell.alpha.detach().cpu().numpy(),
                'temperature': cell.temperature.item()
            }

        return arch_info


In [11]:
# ============================================================================
# PART 4: DATASET AND TRAINING UTILITIES
# ============================================================================

import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset
from sklearn.preprocessing import StandardScaler, LabelEncoder
import warnings
warnings.filterwarnings('ignore')

class ADDataset(Dataset):
    def __init__(self, csv_path=None, audio_features=None, text_features=None, labels=None,
                 dataset_type='diagnosis', normalize_audio=True, normalize_text=True):
        """
        Enhanced AD Dataset class for multimodal classification

        Args:
            csv_path: Path to CSV file with transcripts (from transcript extraction)
            audio_features: Pre-extracted audio features array
            text_features: Pre-extracted text features array
            labels: Labels array
            dataset_type: 'diagnosis' (AD/CN) or 'progression' (decline/no_decline)
            normalize_audio: Whether to normalize audio features
            normalize_text: Whether to normalize text features
        """
        self.dataset_type = dataset_type
        self.normalize_audio = normalize_audio
        self.normalize_text = normalize_text

        # Initialize scalers
        self.audio_scaler = StandardScaler() if normalize_audio else None
        self.text_scaler = StandardScaler() if normalize_text else None
        self.label_encoder = LabelEncoder()

        if csv_path is not None:
            # Load from CSV file (transcript extraction results)
            self._load_from_csv(csv_path)
        else:
            # Load from pre-processed arrays
            self._load_from_arrays(audio_features, text_features, labels)

    def _load_from_csv(self, csv_path):
        """Load dataset from CSV file with transcripts"""
        print(f"Loading dataset from {csv_path}...")

        # Load CSV
        df = pd.read_csv(csv_path)

        # Filter successful transcripts only
        df = df[df['success'] == True].copy()
        print(f"Found {len(df)} successful transcripts")

        # Filter by dataset type if specified
        if self.dataset_type == 'diagnosis':
            df = df[df['dataset'] == 'diagnosis_train'].copy()
            valid_labels = ['ad', 'cn']
        elif self.dataset_type == 'progression':
            df = df[df['dataset'] == 'progression_train'].copy()
            valid_labels = ['decline', 'no_decline']
        else:
            # Keep all data
            valid_labels = df['label'].unique()

        # Filter valid labels
        df = df[df['label'].isin(valid_labels)].copy()
        print(f"After filtering: {len(df)} samples")

        # Extract information
        self.file_ids = df['file_id'].tolist()
        self.file_paths = df['file_path'].tolist()
        self.transcripts = df['transcript'].tolist()
        self.raw_labels = df['label'].tolist()

        # Encode labels
        self.labels = torch.LongTensor(self.label_encoder.fit_transform(self.raw_labels))
        self.label_mapping = dict(zip(self.label_encoder.classes_, range(len(self.label_encoder.classes_))))

        print(f"Label mapping: {self.label_mapping}")
        print(f"Label distribution: {pd.Series(self.raw_labels).value_counts().to_dict()}")

        # Initialize feature placeholders (will be filled by feature extractors)
        self.audio_features = None
        self.text_features = None

    def _load_from_arrays(self, audio_features, text_features, labels):
        """Load dataset from pre-processed feature arrays"""
        print("Loading dataset from pre-processed arrays...")

        if audio_features is None or text_features is None or labels is None:
            raise ValueError("All feature arrays must be provided")

        # Convert to numpy arrays if needed
        audio_features = np.array(audio_features)
        text_features = np.array(text_features)
        labels = np.array(labels)

        # Ensure same number of samples
        assert len(audio_features) == len(text_features) == len(labels), \
            "All arrays must have the same number of samples"

        # Normalize features if requested
        if self.normalize_audio and self.audio_scaler:
            audio_features = self.audio_scaler.fit_transform(audio_features)

        if self.normalize_text and self.text_scaler:
            text_features = self.text_scaler.fit_transform(text_features)

        # Convert to tensors
        self.audio_features = torch.FloatTensor(audio_features)
        self.text_features = torch.FloatTensor(text_features)

        # Handle labels
        if labels.dtype == 'object' or isinstance(labels[0], str):
            # String labels - encode them
            self.raw_labels = labels.tolist()
            self.labels = torch.LongTensor(self.label_encoder.fit_transform(labels))
            self.label_mapping = dict(zip(self.label_encoder.classes_, range(len(self.label_encoder.classes_))))
        else:
            # Numeric labels
            self.labels = torch.LongTensor(labels)
            self.raw_labels = labels.tolist()
            self.label_mapping = {i: i for i in range(len(np.unique(labels)))}

        print(f"Dataset loaded: {len(self.labels)} samples")
        print(f"Audio features shape: {self.audio_features.shape}")
        print(f"Text features shape: {self.text_features.shape}")
        print(f"Label mapping: {self.label_mapping}")

    def set_audio_features(self, audio_features):
        """Set audio features after extraction"""
        audio_features = np.array(audio_features)

        if self.normalize_audio and self.audio_scaler:
            audio_features = self.audio_scaler.fit_transform(audio_features)

        self.audio_features = torch.FloatTensor(audio_features)
        print(f"Audio features set: {self.audio_features.shape}")

    def set_text_features(self, text_features):
        """Set text features after extraction"""
        text_features = np.array(text_features)

        if self.normalize_text and self.text_scaler:
            text_features = self.text_scaler.fit_transform(text_features)

        self.text_features = torch.FloatTensor(text_features)
        print(f"Text features set: {self.text_features.shape}")

    def extract_features_from_transcripts(self, audio_extractor, text_processor):
        """
        Extract features from audio files and transcripts

        Args:
            audio_extractor: AudioFeatureExtractor instance
            text_processor: BERTTextProcessor instance
        """
        if not hasattr(self, 'file_paths') or not hasattr(self, 'transcripts'):
            raise ValueError("Dataset must be loaded from CSV to extract features")

        print("Extracting audio and text features...")

        # Extract audio features
        print("Extracting audio features...")
        audio_features_list = []
        failed_audio = 0

        for i, audio_path in enumerate(self.file_paths):
            if i % 10 == 0:
                print(f"  Processing audio {i+1}/{len(self.file_paths)}")

            features, _ = audio_extractor.extract_all_features(audio_path)
            if features:
                # Convert to list of values in consistent order
                feature_vector = [features.get(key, 0) for key in sorted(features.keys())]
                audio_features_list.append(feature_vector)
            else:
                # Use zero vector for failed extractions
                failed_audio += 1
                if audio_features_list:
                    audio_features_list.append([0] * len(audio_features_list[0]))
                else:
                    audio_features_list.append([0] * 100)  # Default size

        if failed_audio > 0:
            print(f"  Warning: {failed_audio} audio files failed feature extraction")

        # Extract text features
        print("Extracting text features...")
        text_features_list = []

        for i, transcript in enumerate(self.transcripts):
            if i % 20 == 0:
                print(f"  Processing text {i+1}/{len(self.transcripts)}")

            # Extract BERT features
            bert_features = text_processor.extract_bert_features(transcript)

            # Extract linguistic features
            ling_features = text_processor.extract_linguistic_features(transcript)

            # Combine features
            combined_features = np.concatenate([
                bert_features,
                [ling_features[key] for key in sorted(ling_features.keys())]
            ])

            text_features_list.append(combined_features)

        # Set features
        self.set_audio_features(audio_features_list)
        self.set_text_features(text_features_list)

        print(f"Feature extraction complete!")
        print(f"  Audio features: {self.audio_features.shape}")
        print(f"  Text features: {self.text_features.shape}")

    def get_feature_info(self):
        """Get information about the features"""
        info = {
            'num_samples': len(self.labels),
            'num_classes': len(self.label_mapping),
            'label_mapping': self.label_mapping
        }

        if self.audio_features is not None:
            info['audio_feature_dim'] = self.audio_features.shape[1]

        if self.text_features is not None:
            info['text_feature_dim'] = self.text_features.shape[1]

        return info

    def get_class_weights(self):
        """Calculate class weights for imbalanced datasets"""
        if hasattr(self, 'raw_labels'):
            label_counts = pd.Series(self.raw_labels).value_counts()
            total_samples = len(self.raw_labels)

            # Calculate inverse frequency weights
            weights = {}
            for label, count in label_counts.items():
                weights[self.label_mapping[label]] = total_samples / (len(label_counts) * count)

            # Convert to tensor
            weight_tensor = torch.zeros(len(self.label_mapping))
            for class_idx, weight in weights.items():
                weight_tensor[class_idx] = weight

            return weight_tensor

        return None

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if self.audio_features is None or self.text_features is None:
            raise ValueError("Features not set. Call set_audio_features() and set_text_features() first, or extract_features_from_transcripts()")

        return self.audio_features[idx], self.text_features[idx], self.labels[idx]

    def get_sample_info(self, idx):
        """Get detailed information about a specific sample"""
        info = {
            'index': idx,
            'label': self.labels[idx].item(),
            'raw_label': self.raw_labels[idx] if hasattr(self, 'raw_labels') else None
        }

        if hasattr(self, 'file_ids'):
            info['file_id'] = self.file_ids[idx]

        if hasattr(self, 'transcripts'):
            info['transcript'] = self.transcripts[idx]

        if hasattr(self, 'file_paths'):
            info['file_path'] = self.file_paths[idx]

        return info

# Example usage function
def create_dataset_from_transcripts(transcript_csv_path, audio_extractor, text_processor,
                                  dataset_type='diagnosis', test_size=0.2, val_size=0.1):
    """
    Create train/val/test datasets from transcript CSV

    Args:
        transcript_csv_path: Path to successful_transcripts.csv
        audio_extractor: AudioFeatureExtractor instance
        text_processor: BERTTextProcessor instance
        dataset_type: 'diagnosis' or 'progression'
        test_size: Proportion for test set
        val_size: Proportion for validation set

    Returns:
        train_dataset, val_dataset, test_dataset
    """
    from sklearn.model_selection import train_test_split

    # Load full dataset
    full_dataset = ADDataset(csv_path=transcript_csv_path, dataset_type=dataset_type)

    # Extract features
    full_dataset.extract_features_from_transcripts(audio_extractor, text_processor)

    # Get indices for splitting
    indices = list(range(len(full_dataset)))
    labels = [full_dataset.raw_labels[i] for i in indices]

    # First split: separate test set
    train_val_idx, test_idx = train_test_split(
        indices, test_size=test_size, stratify=labels, random_state=42
    )

    # Second split: separate train and validation
    train_labels = [labels[i] for i in train_val_idx]
    train_idx, val_idx = train_test_split(
        train_val_idx, test_size=val_size/(1-test_size), stratify=train_labels, random_state=42
    )

    # Create datasets
    def create_subset(indices):
        audio_subset = full_dataset.audio_features[indices]
        text_subset = full_dataset.text_features[indices]
        label_subset = full_dataset.labels[indices]

        return ADDataset(
            audio_features=audio_subset,
            text_features=text_subset,
            labels=label_subset,
            dataset_type=dataset_type,
            normalize_audio=False,  # Already normalized
            normalize_text=False
        )

    train_dataset = create_subset(train_idx)
    val_dataset = create_subset(val_idx)
    test_dataset = create_subset(test_idx)

    print(f"\nDataset split completed:")
    print(f"  Training: {len(train_dataset)} samples")
    print(f"  Validation: {len(val_dataset)} samples")
    print(f"  Test: {len(test_dataset)} samples")

    return train_dataset, val_dataset, test_dataset

# Gr

In [2]:
!pip install SpeechRecognition
!pip install pydub
!pip install librosa
!pip install soundfile
!apt-get install -y ffmpeg

!echo "✓ Packages ready (make sure to install them first)"

Collecting SpeechRecognition
  Downloading speechrecognition-3.14.3-py3-none-any.whl.metadata (30 kB)
Downloading speechrecognition-3.14.3-py3-none-any.whl (32.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m32.9/32.9 MB[0m [31m38.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: SpeechRecognition
Successfully installed SpeechRecognition-3.14.3
Collecting pydub
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)
Installing collected packages: pydub
Successfully installed pydub-0.25.1
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).
0 upgraded, 0 newly installed, 0 to remove and 35 not upgraded.
✓ Packages ready (make sure to install them first)


In [3]:
import os
import tarfile
import pandas as pd
import numpy as np
from pathlib import Path
import librosa
import speech_recognition as sr
import soundfile as sf
from pydub import AudioSegment
import warnings
warnings.filterwarnings('ignore')

print("="*60)
print("ADReSSo21 AUDIO TRANSCRIPT EXTRACTOR")
print("="*60)

ADReSSo21 AUDIO TRANSCRIPT EXTRACTOR


In [6]:
print("\nSTEP 1: Mounting Google Drive...")
try:
    from google.colab import drive
    drive.mount('/content/drive')
    print("✓ Google Drive mounted successfully!")
except:
    print("⚠ Not running in Colab or Drive already mounted")


STEP 1: Mounting Google Drive...
Mounted at /content/drive
✓ Google Drive mounted successfully!


In [7]:
print("\nSTEP 3: Setting up paths and configuration...")

BASE_PATH = "/content/drive/MyDrive/Voice/"
EXTRACT_PATH = "/content/drive/MyDrive/Voice/extracted/"
OUTPUT_PATH = "/content/drive/MyDrive/Voice/transcripts/"

# Create directories
os.makedirs(EXTRACT_PATH, exist_ok=True)
os.makedirs(OUTPUT_PATH, exist_ok=True)

datasets = {
    'progression_train': 'ADReSSo21-progression-train.tgz',
    'progression_test': 'ADReSSo21-progression-test.tgz',
    'diagnosis_train': 'ADReSSo21-diagnosis-train.tgz'
}

print(f"✓ Base path: {BASE_PATH}")
print(f"✓ Extract path: {EXTRACT_PATH}")
print(f"✓ Output path: {OUTPUT_PATH}")


STEP 3: Setting up paths and configuration...
✓ Base path: /content/drive/MyDrive/Voice/
✓ Extract path: /content/drive/MyDrive/Voice/extracted/
✓ Output path: /content/drive/MyDrive/Voice/transcripts/


In [8]:
print("\nSTEP 4: Extracting dataset files...")

def extract_datasets():
    """Extract all tgz files"""
    for dataset_name, filename in datasets.items():
        file_path = os.path.join(BASE_PATH, filename)

        if os.path.exists(file_path):
            print(f"  Extracting {filename}...")
            try:
                with tarfile.open(file_path, 'r:gz') as tar:
                    tar.extractall(path=EXTRACT_PATH)
                print(f"  ✓ {filename} extracted successfully")
            except Exception as e:
                print(f"  ⚠ Error extracting {filename}: {e}")
        else:
            print(f"  ⚠ {filename} not found at {file_path}")

extract_datasets()


STEP 4: Extracting dataset files...
  Extracting ADReSSo21-progression-train.tgz...
  ✓ ADReSSo21-progression-train.tgz extracted successfully
  Extracting ADReSSo21-progression-test.tgz...
  ✓ ADReSSo21-progression-test.tgz extracted successfully
  Extracting ADReSSo21-diagnosis-train.tgz...
  ✓ ADReSSo21-diagnosis-train.tgz extracted successfully


In [9]:
print("\nSTEP 5: Finding all WAV files...")

def find_wav_files():
    """Find all WAV files and organize by dataset and label"""
    wav_files = {
        'progression_train': {'decline': [], 'no_decline': []},
        'progression_test': [],
        'diagnosis_train': {'ad': [], 'cn': []}
    }

    # Progression training files
    prog_train_base = os.path.join(EXTRACT_PATH, "ADReSSo21/progression/train/audio/")

    # Decline cases
    decline_path = os.path.join(prog_train_base, "decline/")
    if os.path.exists(decline_path):
        decline_wavs = [f for f in os.listdir(decline_path) if f.endswith('.wav')]
        wav_files['progression_train']['decline'] = [os.path.join(decline_path, f) for f in decline_wavs]
        print(f"  Found {len(decline_wavs)} decline WAV files")

    # No decline cases
    no_decline_path = os.path.join(prog_train_base, "no_decline/")
    if os.path.exists(no_decline_path):
        no_decline_wavs = [f for f in os.listdir(no_decline_path) if f.endswith('.wav')]
        wav_files['progression_train']['no_decline'] = [os.path.join(no_decline_path, f) for f in no_decline_wavs]
        print(f"  Found {len(no_decline_wavs)} no_decline WAV files")

    # Progression test files
    prog_test_path = os.path.join(EXTRACT_PATH, "ADReSSo21/progression/test-dist/audio/")
    if os.path.exists(prog_test_path):
        test_wavs = [f for f in os.listdir(prog_test_path) if f.endswith('.wav')]
        wav_files['progression_test'] = [os.path.join(prog_test_path, f) for f in test_wavs]
        print(f"  Found {len(test_wavs)} test WAV files")

    # Diagnosis training files
    diag_train_base = os.path.join(EXTRACT_PATH, "ADReSSo21/diagnosis/train/audio/")

    # AD cases
    ad_path = os.path.join(diag_train_base, "ad/")
    if os.path.exists(ad_path):
        ad_wavs = [f for f in os.listdir(ad_path) if f.endswith('.wav')]
        wav_files['diagnosis_train']['ad'] = [os.path.join(ad_path, f) for f in ad_wavs]
        print(f"  Found {len(ad_wavs)} AD WAV files")

    # CN cases
    cn_path = os.path.join(diag_train_base, "cn/")
    if os.path.exists(cn_path):
        cn_wavs = [f for f in os.listdir(cn_path) if f.endswith('.wav')]
        wav_files['diagnosis_train']['cn'] = [os.path.join(cn_path, f) for f in cn_wavs]
        print(f"  Found {len(cn_wavs)} CN WAV files")

    return wav_files

wav_files = find_wav_files()


STEP 5: Finding all WAV files...
  Found 15 decline WAV files
  Found 58 no_decline WAV files
  Found 32 test WAV files
  Found 87 AD WAV files
  Found 79 CN WAV files


In [10]:
print("\nSTEP 6: Setting up audio preprocessing...")

def preprocess_audio(audio_path, target_sr=16000):
    """Preprocess audio file for speech recognition"""
    try:
        # Load audio with librosa
        audio, sr = librosa.load(audio_path, sr=target_sr)

        # Normalize audio
        audio = librosa.util.normalize(audio)

        # Remove silence
        audio_trimmed, _ = librosa.effects.trim(audio, top_db=20)

        return audio_trimmed, target_sr
    except Exception as e:
        print(f"    Error preprocessing {audio_path}: {e}")
        return None, None

def convert_to_wav_if_needed(audio_path):
    """Convert audio to WAV format if needed"""
    try:
        if not audio_path.endswith('.wav'):
            # Convert using pydub
            audio = AudioSegment.from_file(audio_path)
            wav_path = audio_path.rsplit('.', 1)[0] + '_converted.wav'
            audio.export(wav_path, format="wav")
            return wav_path
        return audio_path
    except Exception as e:
        print(f"    Error converting {audio_path}: {e}")
        return audio_path


STEP 6: Setting up audio preprocessing...


In [11]:
print("\nSTEP 7: Setting up speech recognition...")

def extract_transcript_from_audio(audio_path, method='google'):
    """Extract transcript from audio file using speech recognition"""
    recognizer = sr.Recognizer()

    try:
        # Convert to WAV if needed
        wav_path = convert_to_wav_if_needed(audio_path)

        # Preprocess audio
        audio_data, sr_rate = preprocess_audio(wav_path, target_sr=16000)

        if audio_data is None:
            return None, "Preprocessing failed"

        # Save preprocessed audio temporarily
        temp_wav = audio_path.replace('.wav', '_temp.wav')
        sf.write(temp_wav, audio_data, sr_rate)

        # Use speech recognition
        with sr.AudioFile(temp_wav) as source:
            # Adjust for ambient noise
            recognizer.adjust_for_ambient_noise(source, duration=0.5)
            audio = recognizer.listen(source)

        # Try different recognition methods
        transcript = None
        error_msg = ""

        if method == 'google':
            try:
                transcript = recognizer.recognize_google(audio)
            except sr.UnknownValueError:
                error_msg = "Google Speech Recognition could not understand audio"
            except sr.RequestError as e:
                error_msg = f"Google Speech Recognition error: {e}"

        # Fallback to other methods if Google fails
        if transcript is None:
            try:
                transcript = recognizer.recognize_sphinx(audio)
                method = 'sphinx'
            except sr.UnknownValueError:
                error_msg += "; Sphinx could not understand audio"
            except sr.RequestError as e:
                error_msg += f"; Sphinx error: {e}"

        # Clean up temporary file
        if os.path.exists(temp_wav):
            os.remove(temp_wav)

        if transcript:
            return transcript.strip(), method
        else:
            return None, error_msg

    except Exception as e:
        return None, f"Error processing audio: {str(e)}"


STEP 7: Setting up speech recognition...


In [12]:
print("\nSTEP 8: Processing audio files and extracting transcripts...")
print("This may take a while depending on the number and length of audio files...")

def process_audio_files(wav_files):
    """Process all audio files and extract transcripts"""
    all_transcripts = []

    # Process progression training data
    print("\n  Processing progression training data...")
    for label in ['decline', 'no_decline']:
        files = wav_files['progression_train'][label]
        print(f"    Processing {len(files)} {label} files...")

        for i, audio_path in enumerate(files):
            print(f"      Processing {i+1}/{len(files)}: {os.path.basename(audio_path)}")

            transcript, method_or_error = extract_transcript_from_audio(audio_path)

            all_transcripts.append({
                'file_id': os.path.splitext(os.path.basename(audio_path))[0],
                'file_path': audio_path,
                'dataset': 'progression_train',
                'label': label,
                'transcript': transcript,
                'recognition_method': method_or_error if transcript else None,
                'error': None if transcript else method_or_error,
                'success': transcript is not None
            })

    # Process progression test data
    print("\n  Processing progression test data...")
    files = wav_files['progression_test']
    print(f"    Processing {len(files)} test files...")

    for i, audio_path in enumerate(files):
        print(f"      Processing {i+1}/{len(files)}: {os.path.basename(audio_path)}")

        transcript, method_or_error = extract_transcript_from_audio(audio_path)

        all_transcripts.append({
            'file_id': os.path.splitext(os.path.basename(audio_path))[0],
            'file_path': audio_path,
            'dataset': 'progression_test',
            'label': 'test',
            'transcript': transcript,
            'recognition_method': method_or_error if transcript else None,
            'error': None if transcript else method_or_error,
            'success': transcript is not None
        })

    # Process diagnosis training data
    print("\n  Processing diagnosis training data...")
    for label in ['ad', 'cn']:
        files = wav_files['diagnosis_train'][label]
        print(f"    Processing {len(files)} {label} files...")

        for i, audio_path in enumerate(files):
            print(f"      Processing {i+1}/{len(files)}: {os.path.basename(audio_path)}")

            transcript, method_or_error = extract_transcript_from_audio(audio_path)

            all_transcripts.append({
                'file_id': os.path.splitext(os.path.basename(audio_path))[0],
                'file_path': audio_path,
                'dataset': 'diagnosis_train',
                'label': label,
                'transcript': transcript,
                'recognition_method': method_or_error if transcript else None,
                'error': None if transcript else method_or_error,
                'success': transcript is not None
            })

    return all_transcripts

transcripts = process_audio_files(wav_files)


STEP 8: Processing audio files and extracting transcripts...
This may take a while depending on the number and length of audio files...

  Processing progression training data...
    Processing 15 decline files...
      Processing 1/15: adrsp055.wav
      Processing 2/15: adrsp003.wav
      Processing 3/15: adrsp266.wav
      Processing 4/15: adrsp300.wav
      Processing 5/15: adrsp320.wav
      Processing 6/15: adrsp313.wav
      Processing 7/15: adrsp179.wav
      Processing 8/15: adrsp357.wav
      Processing 9/15: adrsp051.wav
      Processing 10/15: adrsp101.wav
      Processing 11/15: adrsp326.wav
      Processing 12/15: adrsp127.wav
      Processing 13/15: adrsp276.wav
      Processing 14/15: adrsp209.wav
      Processing 15/15: adrsp318.wav
    Processing 58 no_decline files...
      Processing 1/58: adrsp196.wav
      Processing 2/58: adrsp137.wav
      Processing 3/58: adrsp130.wav
      Processing 4/58: adrsp349.wav
      Processing 5/58: adrsp198.wav
      Processing 6/58

In [13]:
print("\nSTEP 9: Saving transcription results...")

# Convert to DataFrame
df = pd.DataFrame(transcripts)

# Save complete results
complete_output = os.path.join(OUTPUT_PATH, "all_transcripts.csv")
df.to_csv(complete_output, index=False)
print(f"✓ Saved complete results to: {complete_output}")

# Save successful transcripts only
successful_df = df[df['success'] == True].copy()
success_output = os.path.join(OUTPUT_PATH, "successful_transcripts.csv")
successful_df.to_csv(success_output, index=False)
print(f"✓ Saved successful transcripts to: {success_output}")

# Save by dataset
datasets_to_save = df['dataset'].unique()
for dataset in datasets_to_save:
    dataset_df = df[df['dataset'] == dataset].copy()
    dataset_output = os.path.join(OUTPUT_PATH, f"{dataset}_transcripts.csv")
    dataset_df.to_csv(dataset_output, index=False)
    print(f"✓ Saved {dataset} transcripts to: {dataset_output}")


STEP 9: Saving transcription results...
✓ Saved complete results to: /content/drive/MyDrive/Voice/transcripts/all_transcripts.csv
✓ Saved successful transcripts to: /content/drive/MyDrive/Voice/transcripts/successful_transcripts.csv
✓ Saved progression_train transcripts to: /content/drive/MyDrive/Voice/transcripts/progression_train_transcripts.csv
✓ Saved progression_test transcripts to: /content/drive/MyDrive/Voice/transcripts/progression_test_transcripts.csv
✓ Saved diagnosis_train transcripts to: /content/drive/MyDrive/Voice/transcripts/diagnosis_train_transcripts.csv


In [14]:
print("\nSTEP 10: Summary Statistics")
print("="*50)

total_files = len(df)
successful = len(successful_df)
failed = total_files - successful

print(f"Total audio files processed: {total_files}")
print(f"Successful transcriptions: {successful} ({successful/total_files*100:.1f}%)")
print(f"Failed transcriptions: {failed} ({failed/total_files*100:.1f}%)")

print(f"\nDataset breakdown:")
for dataset in df['dataset'].unique():
    dataset_total = len(df[df['dataset'] == dataset])
    dataset_success = len(df[(df['dataset'] == dataset) & (df['success'] == True)])
    print(f"  {dataset}: {dataset_success}/{dataset_total} successful ({dataset_success/dataset_total*100:.1f}%)")

print(f"\nLabel distribution (successful transcripts only):")
if not successful_df.empty:
    print(successful_df['label'].value_counts())

print(f"\nRecognition methods used:")
if not successful_df.empty:
    print(successful_df['recognition_method'].value_counts())

print(f"\nSample successful transcripts:")
sample_transcripts = successful_df['transcript'].dropna().head(3)
for i, transcript in enumerate(sample_transcripts):
    print(f"  Sample {i+1}: {transcript[:200]}...")

print(f"\nMost common errors:")
error_df = df[df['success'] == False]
if not error_df.empty:
    error_counts = error_df['error'].value_counts().head(5)
    for error, count in error_counts.items():
        print(f"  {error}: {count} files")

print("\n" + "="*60)
print("TRANSCRIPT EXTRACTION COMPLETE!")
print(f"All results saved in: {OUTPUT_PATH}")
print("="*60)


STEP 10: Summary Statistics
Total audio files processed: 271
Successful transcriptions: 156 (57.6%)
Failed transcriptions: 115 (42.4%)

Dataset breakdown:
  progression_train: 42/73 successful (57.5%)
  progression_test: 16/32 successful (50.0%)
  diagnosis_train: 98/166 successful (59.0%)

Label distribution (successful transcripts only):
label
ad            52
cn            46
no_decline    34
test          16
decline        8
Name: count, dtype: int64

Recognition methods used:
recognition_method
google    156
Name: count, dtype: int64

Sample successful transcripts:
  Sample 1: you can start now...
  Sample 2: is in 1 minute time I want you to name as many...
  Sample 3: cat dog giraffe...

Most common errors:
  Google Speech Recognition could not understand audio; Sphinx error: missing PocketSphinx module: ensure that PocketSphinx is set up correctly.: 115 files

TRANSCRIPT EXTRACTION COMPLETE!
All results saved in: /content/drive/MyDrive/Voice/transcripts/


In [15]:
import os
import pandas as pd
import numpy as np
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# For BERT processing
from transformers import AutoTokenizer, AutoModel
import re

print("="*80)
print("COMPREHENSIVE AD CLASSIFICATION PIPELINE")
print("Audio Features + BERT + DARTS Architecture")
print("="*80)

COMPREHENSIVE AD CLASSIFICATION PIPELINE
Audio Features + BERT + DARTS Architecture


In [16]:
class AudioFeatureExtractor:
    def __init__(self, sr=16000, n_mfcc=13, n_fft=2048, hop_length=512):
        self.sr = sr
        self.n_mfcc = n_mfcc
        self.n_fft = n_fft
        self.hop_length = hop_length

    def extract_mfcc_features(self, audio_path):
        """Extract comprehensive MFCC features including deltas"""
        try:
            # Load audio
            y, sr = librosa.load(audio_path, sr=self.sr)

            # Extract MFCC features
            mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=self.n_mfcc,
                                       n_fft=self.n_fft, hop_length=self.hop_length)

            # Extract delta features (first derivative)
            delta_mfccs = librosa.feature.delta(mfccs)

            # Extract delta-delta features (second derivative)
            delta2_mfccs = librosa.feature.delta(mfccs, order=2)

            # Combine all MFCC features
            combined_mfccs = np.concatenate([mfccs, delta_mfccs, delta2_mfccs], axis=0)

            # Statistical features for each coefficient
            features = {}

            # Mean, std, min, max for each coefficient
            features['mfcc_mean'] = np.mean(combined_mfccs, axis=1)
            features['mfcc_std'] = np.std(combined_mfccs, axis=1)
            features['mfcc_min'] = np.min(combined_mfccs, axis=1)
            features['mfcc_max'] = np.max(combined_mfccs, axis=1)
            features['mfcc_median'] = np.median(combined_mfccs, axis=1)
            features['mfcc_skew'] = self._calculate_skewness(combined_mfccs)
            features['mfcc_kurtosis'] = self._calculate_kurtosis(combined_mfccs)

            return features, combined_mfccs

        except Exception as e:
            print(f"Error extracting MFCC from {audio_path}: {e}")
            return None, None

    def extract_spectral_features(self, audio_path):
        """Extract spectral features"""
        try:
            y, sr = librosa.load(audio_path, sr=self.sr)

            features = {}

            # Spectral centroid
            spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)[0]
            features['spectral_centroid_mean'] = np.mean(spectral_centroid)
            features['spectral_centroid_std'] = np.std(spectral_centroid)

            # Spectral bandwidth
            spectral_bandwidth = librosa.feature.spectral_bandwidth(y=y, sr=sr)[0]
            features['spectral_bandwidth_mean'] = np.mean(spectral_bandwidth)
            features['spectral_bandwidth_std'] = np.std(spectral_bandwidth)

            # Spectral rolloff
            spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)[0]
            features['spectral_rolloff_mean'] = np.mean(spectral_rolloff)
            features['spectral_rolloff_std'] = np.std(spectral_rolloff)

            # Zero crossing rate
            zcr = librosa.feature.zero_crossing_rate(y)[0]
            features['zcr_mean'] = np.mean(zcr)
            features['zcr_std'] = np.std(zcr)

            # Chroma features
            chroma = librosa.feature.chroma_stft(y=y, sr=sr)
            features['chroma_mean'] = np.mean(chroma, axis=1)
            features['chroma_std'] = np.std(chroma, axis=1)

            # Tempo
            tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
            features['tempo'] = tempo

            return features

        except Exception as e:
            print(f"Error extracting spectral features from {audio_path}: {e}")
            return {}

    def extract_prosodic_features(self, audio_path):
        """Extract prosodic features (pitch, energy, etc.)"""
        try:
            y, sr = librosa.load(audio_path, sr=self.sr)

            features = {}

            # Fundamental frequency (pitch)
            f0, voiced_flag, voiced_probs = librosa.pyin(y, fmin=librosa.note_to_hz('C2'),
                                                       fmax=librosa.note_to_hz('C7'))

            # Remove NaN values
            f0_clean = f0[~np.isnan(f0)]
            if len(f0_clean) > 0:
                features['f0_mean'] = np.mean(f0_clean)
                features['f0_std'] = np.std(f0_clean)
                features['f0_min'] = np.min(f0_clean)
                features['f0_max'] = np.max(f0_clean)
                features['f0_range'] = np.max(f0_clean) - np.min(f0_clean)
            else:
                features.update({
                    'f0_mean': 0, 'f0_std': 0, 'f0_min': 0,
                    'f0_max': 0, 'f0_range': 0
                })

            # RMS energy
            rms = librosa.feature.rms(y=y)[0]
            features['rms_mean'] = np.mean(rms)
            features['rms_std'] = np.std(rms)

            # Spectral contrast
            contrast = librosa.feature.spectral_contrast(y=y, sr=sr)
            features['contrast_mean'] = np.mean(contrast, axis=1)
            features['contrast_std'] = np.std(contrast, axis=1)

            return features

        except Exception as e:
            print(f"Error extracting prosodic features from {audio_path}: {e}")
            return {}

    def _calculate_skewness(self, data):
        """Calculate skewness for each row"""
        mean = np.mean(data, axis=1, keepdims=True)
        std = np.std(data, axis=1, keepdims=True)
        std[std == 0] = 1  # Avoid division by zero
        normalized = (data - mean) / std
        skewness = np.mean(normalized**3, axis=1)
        return skewness

    def _calculate_kurtosis(self, data):
        """Calculate kurtosis for each row"""
        mean = np.mean(data, axis=1, keepdims=True)
        std = np.std(data, axis=1, keepdims=True)
        std[std == 0] = 1  # Avoid division by zero
        normalized = (data - mean) / std
        kurtosis = np.mean(normalized**4, axis=1) - 3
        return kurtosis

    def extract_all_features(self, audio_path):
        """Extract all audio features"""
        all_features = {}

        # MFCC features
        mfcc_features, mfcc_matrix = self.extract_mfcc_features(audio_path)
        if mfcc_features:
            all_features.update(mfcc_features)

        # Spectral features
        spectral_features = self.extract_spectral_features(audio_path)
        all_features.update(spectral_features)

        # Prosodic features
        prosodic_features = self.extract_prosodic_features(audio_path)
        all_features.update(prosodic_features)

        # Flatten nested arrays
        flattened_features = {}
        for key, value in all_features.items():
            if isinstance(value, np.ndarray):
                if value.ndim == 1:
                    for i, v in enumerate(value):
                        flattened_features[f"{key}_{i}"] = v
                else:
                    flattened_features[key] = np.mean(value)
            else:
                flattened_features[key] = value

        return flattened_features, mfcc_matrix

In [17]:
class BERTTextProcessor:
    def __init__(self, model_name='bert-base-uncased', max_length=512):
        self.model_name = model_name
        self.max_length = max_length
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.model.eval()

    def preprocess_text(self, text):
        """Clean and preprocess text"""
        if pd.isna(text) or text is None:
            return ""

        # Convert to string and clean
        text = str(text).lower()

        # Remove special characters but keep spaces and basic punctuation
        text = re.sub(r'[^a-zA-Z0-9\s\.\,\!\?]', '', text)

        # Remove extra whitespace
        text = ' '.join(text.split())

        return text

    def extract_bert_features(self, text):
        """Extract BERT embeddings from text"""
        try:
            # Preprocess text
            clean_text = self.preprocess_text(text)

            if not clean_text:
                # Return zero vector for empty text
                return np.zeros(768)

            # Tokenize
            inputs = self.tokenizer(
                clean_text,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            # Extract features
            with torch.no_grad():
                outputs = self.model(**inputs)

                # Use [CLS] token embedding (first token)
                cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze()

                # Also compute mean pooling of all tokens
                attention_mask = inputs['attention_mask']
                token_embeddings = outputs.last_hidden_state
                input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
                sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
                sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
                mean_embedding = sum_embeddings / sum_mask
                mean_embedding = mean_embedding.squeeze()

                # Combine CLS and mean pooling
                combined_embedding = (cls_embedding + mean_embedding) / 2

                return combined_embedding.numpy()

        except Exception as e:
            print(f"Error extracting BERT features: {e}")
            return np.zeros(768)

    def extract_linguistic_features(self, text):
        """Extract basic linguistic features"""
        try:
            clean_text = self.preprocess_text(text)

            if not clean_text:
                return {
                    'word_count': 0, 'char_count': 0, 'avg_word_length': 0,
                    'sentence_count': 0, 'question_count': 0, 'exclamation_count': 0
                }

            words = clean_text.split()
            sentences = clean_text.split('.')

            features = {
                'word_count': len(words),
                'char_count': len(clean_text),
                'avg_word_length': np.mean([len(word) for word in words]) if words else 0,
                'sentence_count': len([s for s in sentences if s.strip()]),
                'question_count': clean_text.count('?'),
                'exclamation_count': clean_text.count('!')
            }

            return features

        except Exception as e:
            print(f"Error extracting linguistic features: {e}")
            return {'word_count': 0, 'char_count': 0, 'avg_word_length': 0,
                   'sentence_count': 0, 'question_count': 0, 'exclamation_count': 0}

In [18]:
class ImprovedDARTSCell(nn.Module):
    def __init__(self, input_dim, output_dim, num_ops=8):
        super(ImprovedDARTSCell, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        # Ensure dimensions match for operations
        if input_dim != output_dim:
            self.projection = nn.Linear(input_dim, output_dim)
        else:
            self.projection = nn.Identity()

        # Define possible operations with proper dimensionality
        self.operations = nn.ModuleList([
            nn.Identity(),  # Skip connection
            nn.ReLU(),      # Activation
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU()),  # Linear + ReLU
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.Tanh()),  # Linear + Tanh
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU(), nn.Dropout(0.1)),  # With dropout
            nn.Sequential(nn.Linear(output_dim, output_dim // 2), nn.ReLU(), nn.Linear(output_dim // 2, output_dim)),  # Bottleneck
            nn.Sequential(nn.LayerNorm(output_dim), nn.ReLU()),  # Layer norm + activation
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU(), nn.Linear(output_dim, output_dim))  # Deep linear
        ])

        # Architecture parameters (alpha) - learnable weights for each operation
        self.alpha = nn.Parameter(torch.randn(len(self.operations)))

        # Temperature parameter for gumbel softmax (learnable)
        self.temperature = nn.Parameter(torch.ones(1))

    def forward(self, x):
        # Project input to correct dimension
        x = self.projection(x)

        # Apply Gumbel Softmax for differentiable architecture search
        if self.training:
            # Use Gumbel Softmax during training
            gumbel_weights = F.gumbel_softmax(self.alpha, tau=self.temperature, hard=False)
        else:
            # Use regular softmax during evaluation
            gumbel_weights = F.softmax(self.alpha / self.temperature, dim=0)

        # Apply operations
        outputs = []
        for op in self.operations:
            try:
                out = op(x)
                outputs.append(out)
            except Exception as e:
                # Fallback to identity if operation fails
                outputs.append(x)

        # Weighted combination of all operations
        result = sum(w * out for w, out in zip(gumbel_weights, outputs))

        return result

    def get_selected_operation(self):
        """Get the operation with highest weight (for inference)"""
        selected_idx = torch.argmax(self.alpha).item()
        return selected_idx, self.operations[selected_idx]

class MultimodalDARTSClassifier(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim=256, num_classes=2):
        super(MultimodalDARTSClassifier, self).__init__()

        # Audio processing branch with DARTS
        self.audio_projection = nn.Sequential(
            nn.Linear(audio_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        self.audio_darts_cells = nn.ModuleList([
            ImprovedDARTSCell(hidden_dim, hidden_dim) for _ in range(3)
        ])

        # Text processing branch with DARTS
        self.text_projection = nn.Sequential(
            nn.Linear(text_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        self.text_darts_cells = nn.ModuleList([
            ImprovedDARTSCell(hidden_dim, hidden_dim) for _ in range(2)
        ])

        # Cross-modal attention
        self.cross_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)

        # Fusion and classification
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.4)
        )

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_classes)
        )

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, audio_features, text_features):
        # Process audio through DARTS
        audio_x = self.audio_projection(audio_features)
        for cell in self.audio_darts_cells:
            residual = audio_x
            audio_x = cell(audio_x)
            audio_x = audio_x + residual  # Residual connection

        # Process text through DARTS
        text_x = self.text_projection(text_features)
        for cell in self.text_darts_cells:
            residual = text_x
            text_x = cell(text_x)
            text_x = text_x + residual  # Residual connection

        # Cross-modal attention
        audio_attended, _ = self.cross_attention(
            audio_x.unsqueeze(1), text_x.unsqueeze(1), text_x.unsqueeze(1)
        )
        audio_attended = audio_attended.squeeze(1)

        # Fusion
        fused = torch.cat([audio_attended, text_x], dim=1)
        fused = self.fusion(fused)

        # Classification
        output = self.classifier(fused)

        return output

    def get_architecture_info(self):
        """Get information about the learned architecture"""
        arch_info = {}

        # Audio DARTS info
        for i, cell in enumerate(self.audio_darts_cells):
            selected_idx, _ = cell.get_selected_operation()
            arch_info[f'audio_cell_{i}'] = {
                'selected_operation_idx': selected_idx,
                'operation_weights': cell.alpha.detach().cpu().numpy(),
                'temperature': cell.temperature.item()
            }

        # Text DARTS info
        for i, cell in enumerate(self.text_darts_cells):
            selected_idx, _ = cell.get_selected_operation()
            arch_info[f'text_cell_{i}'] = {
                'selected_operation_idx': selected_idx,
                'operation_weights': cell.alpha.detach().cpu().numpy(),
                'temperature': cell.temperature.item()
            }

        return arch_info

In [19]:
class ADDataset(Dataset):
    def __init__(self, csv_path=None, audio_features=None, text_features=None, labels=None,
                 dataset_type='diagnosis', normalize_audio=True, normalize_text=True):
        """
        Enhanced AD Dataset class for multimodal classification

        Args:
            csv_path: Path to CSV file with transcripts (from transcript extraction)
            audio_features: Pre-extracted audio features array
            text_features: Pre-extracted text features array
            labels: Labels array
            dataset_type: 'diagnosis' (AD/CN) or 'progression' (decline/no_decline)
            normalize_audio: Whether to normalize audio features
            normalize_text: Whether to normalize text features
        """
        self.dataset_type = dataset_type
        self.normalize_audio = normalize_audio
        self.normalize_text = normalize_text

        # Initialize scalers
        self.audio_scaler = StandardScaler() if normalize_audio else None
        self.text_scaler = StandardScaler() if normalize_text else None
        self.label_encoder = LabelEncoder()

        if csv_path is not None:
            # Load from CSV file (transcript extraction results)
            self._load_from_csv(csv_path)
        else:
            # Load from pre-processed arrays
            self._load_from_arrays(audio_features, text_features, labels)

    def _load_from_csv(self, csv_path):
        """Load dataset from CSV file with transcripts"""
        print(f"Loading dataset from {csv_path}...")

        # Load CSV
        df = pd.read_csv(csv_path)

        # Filter successful transcripts only
        df = df[df['success'] == True].copy()
        print(f"Found {len(df)} successful transcripts")

        # Filter by dataset type if specified
        if self.dataset_type == 'diagnosis':
            df = df[df['dataset'] == 'diagnosis_train'].copy()
            valid_labels = ['ad', 'cn']
        elif self.dataset_type == 'progression':
            df = df[df['dataset'] == 'progression_train'].copy()
            valid_labels = ['decline', 'no_decline']
        else:
            # Keep all data
            valid_labels = df['label'].unique()

        # Filter valid labels
        df = df[df['label'].isin(valid_labels)].copy()
        print(f"After filtering: {len(df)} samples")

        # Extract information
        self.file_ids = df['file_id'].tolist()
        self.file_paths = df['file_path'].tolist()
        self.transcripts = df['transcript'].tolist()
        self.raw_labels = df['label'].tolist()

        # Encode labels
        self.labels = torch.LongTensor(self.label_encoder.fit_transform(self.raw_labels))
        self.label_mapping = dict(zip(self.label_encoder.classes_, range(len(self.label_encoder.classes_))))

        print(f"Label mapping: {self.label_mapping}")
        print(f"Label distribution: {pd.Series(self.raw_labels).value_counts().to_dict()}")

        # Initialize feature placeholders (will be filled by feature extractors)
        self.audio_features = None
        self.text_features = None

    def _load_from_arrays(self, audio_features, text_features, labels):
        """Load dataset from pre-processed feature arrays"""
        print("Loading dataset from pre-processed arrays...")

        if audio_features is None or text_features is None or labels is None:
            raise ValueError("All feature arrays must be provided")

        # Convert to numpy arrays if needed
        audio_features = np.array(audio_features)
        text_features = np.array(text_features)
        labels = np.array(labels)

        # Ensure same number of samples
        assert len(audio_features) == len(text_features) == len(labels), \
            "All arrays must have the same number of samples"

        # Normalize features if requested
        if self.normalize_audio and self.audio_scaler:
            audio_features = self.audio_scaler.fit_transform(audio_features)

        if self.normalize_text and self.text_scaler:
            text_features = self.text_scaler.fit_transform(text_features)

        # Convert to tensors
        self.audio_features = torch.FloatTensor(audio_features)
        self.text_features = torch.FloatTensor(text_features)

        # Handle labels
        if labels.dtype == 'object' or isinstance(labels[0], str):
            # String labels - encode them
            self.raw_labels = labels.tolist()
            self.labels = torch.LongTensor(self.label_encoder.fit_transform(labels))
            self.label_mapping = dict(zip(self.label_encoder.classes_, range(len(self.label_encoder.classes_))))
        else:
            # Numeric labels
            self.labels = torch.LongTensor(labels)
            self.raw_labels = labels.tolist()
            self.label_mapping = {i: i for i in range(len(np.unique(labels)))}

        print(f"Dataset loaded: {len(self.labels)} samples")
        print(f"Audio features shape: {self.audio_features.shape}")
        print(f"Text features shape: {self.text_features.shape}")
        print(f"Label mapping: {self.label_mapping}")

    def set_audio_features(self, audio_features):
        """Set audio features after extraction"""
        audio_features = np.array(audio_features)

        if self.normalize_audio and self.audio_scaler:
            audio_features = self.audio_scaler.fit_transform(audio_features)

        self.audio_features = torch.FloatTensor(audio_features)
        print(f"Audio features set: {self.audio_features.shape}")

    def set_text_features(self, text_features):
        """Set text features after extraction"""
        text_features = np.array(text_features)

        if self.normalize_text and self.text_scaler:
            text_features = self.text_scaler.fit_transform(text_features)

        self.text_features = torch.FloatTensor(text_features)
        print(f"Text features set: {self.text_features.shape}")

    def extract_features_from_transcripts(self, audio_extractor, text_processor):
        """
        Extract features from audio files and transcripts

        Args:
            audio_extractor: AudioFeatureExtractor instance
            text_processor: BERTTextProcessor instance
        """
        if not hasattr(self, 'file_paths') or not hasattr(self, 'transcripts'):
            raise ValueError("Dataset must be loaded from CSV to extract features")

        print("Extracting audio and text features...")

        # Extract audio features
        print("Extracting audio features...")
        audio_features_list = []
        failed_audio = 0

        for i, audio_path in enumerate(self.file_paths):
            if i % 10 == 0:
                print(f"  Processing audio {i+1}/{len(self.file_paths)}")

            features, _ = audio_extractor.extract_all_features(audio_path)
            if features:
                # Convert to list of values in consistent order
                feature_vector = [features.get(key, 0) for key in sorted(features.keys())]
                audio_features_list.append(feature_vector)
            else:
                # Use zero vector for failed extractions
                failed_audio += 1
                if audio_features_list:
                    audio_features_list.append([0] * len(audio_features_list[0]))
                else:
                    audio_features_list.append([0] * 100)  # Default size

        if failed_audio > 0:
            print(f"  Warning: {failed_audio} audio files failed feature extraction")

        # Extract text features
        print("Extracting text features...")
        text_features_list = []

        for i, transcript in enumerate(self.transcripts):
            if i % 20 == 0:
                print(f"  Processing text {i+1}/{len(self.transcripts)}")

            # Extract BERT features
            bert_features = text_processor.extract_bert_features(transcript)

            # Extract linguistic features
            ling_features = text_processor.extract_linguistic_features(transcript)

            # Combine features
            combined_features = np.concatenate([
                bert_features,
                [ling_features[key] for key in sorted(ling_features.keys())]
            ])

            text_features_list.append(combined_features)

        # Set features
        self.set_audio_features(audio_features_list)
        self.set_text_features(text_features_list)

        print(f"Feature extraction complete!")
        print(f"  Audio features: {self.audio_features.shape}")
        print(f"  Text features: {self.text_features.shape}")

    def get_feature_info(self):
        """Get information about the features"""
        info = {
            'num_samples': len(self.labels),
            'num_classes': len(self.label_mapping),
            'label_mapping': self.label_mapping
        }

        if self.audio_features is not None:
            info['audio_feature_dim'] = self.audio_features.shape[1]

        if self.text_features is not None:
            info['text_feature_dim'] = self.text_features.shape[1]

        return info

    def get_class_weights(self):
        """Calculate class weights for imbalanced datasets"""
        if hasattr(self, 'raw_labels'):
            label_counts = pd.Series(self.raw_labels).value_counts()
            total_samples = len(self.raw_labels)

            # Calculate inverse frequency weights
            weights = {}
            for label, count in label_counts.items():
                weights[self.label_mapping[label]] = total_samples / (len(label_counts) * count)

            # Convert to tensor
            weight_tensor = torch.zeros(len(self.label_mapping))
            for class_idx, weight in weights.items():
                weight_tensor[class_idx] = weight

            return weight_tensor

        return None

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if self.audio_features is None or self.text_features is None:
            raise ValueError("Features not set. Call set_audio_features() and set_text_features() first, or extract_features_from_transcripts()")

        return self.audio_features[idx], self.text_features[idx], self.labels[idx]

    def get_sample_info(self, idx):
        """Get detailed information about a specific sample"""
        info = {
            'index': idx,
            'label': self.labels[idx].item(),
            'raw_label': self.raw_labels[idx] if hasattr(self, 'raw_labels') else None
        }

        if hasattr(self, 'file_ids'):
            info['file_id'] = self.file_ids[idx]

        if hasattr(self, 'transcripts'):
            info['transcript'] = self.transcripts[idx]

        if hasattr(self, 'file_paths'):
            info['file_path'] = self.file_paths[idx]

        return info

In [20]:
def create_dataset_from_transcripts(transcript_csv_path, audio_extractor, text_processor,
                                  dataset_type='diagnosis', test_size=0.2, val_size=0.1):
    """
    Create train/val/test datasets from transcript CSV

    Args:
        transcript_csv_path: Path to successful_transcripts.csv
        audio_extractor: AudioFeatureExtractor instance
        text_processor: BERTTextProcessor instance
        dataset_type: 'diagnosis' or 'progression'
        test_size: Proportion for test set
        val_size: Proportion for validation set

    Returns:
        train_dataset, val_dataset, test_dataset
    """
    from sklearn.model_selection import train_test_split

    # Load full dataset
    full_dataset = ADDataset(csv_path=transcript_csv_path, dataset_type=dataset_type)

    # Extract features
    full_dataset.extract_features_from_transcripts(audio_extractor, text_processor)

    # Get indices for splitting
    indices = list(range(len(full_dataset)))
    labels = [full_dataset.raw_labels[i] for i in indices]

    # First split: separate test set
    train_val_idx, test_idx = train_test_split(
        indices, test_size=test_size, stratify=labels, random_state=42
    )

    # Second split: separate train and validation
    train_labels = [labels[i] for i in train_val_idx]
    train_idx, val_idx = train_test_split(
        train_val_idx, test_size=val_size/(1-test_size), stratify=train_labels, random_state=42
    )

    # Create datasets
    def create_subset(indices):
        audio_subset = full_dataset.audio_features[indices]
        text_subset = full_dataset.text_features[indices]
        label_subset = full_dataset.labels[indices]

        return ADDataset(
            audio_features=audio_subset,
            text_features=text_subset,
            labels=label_subset,
            dataset_type=dataset_type,
            normalize_audio=False,  # Already normalized
            normalize_text=False
        )

    train_dataset = create_subset(train_idx)
    val_dataset = create_subset(val_idx)
    test_dataset = create_subset(test_idx)

    print(f"\nDataset split completed:")
    print(f"  Training: {len(train_dataset)} samples")
    print(f"  Validation: {len(val_dataset)} samples")
    print(f"  Test: {len(test_dataset)} samples")

    return train_dataset, val_dataset, test_dataset

In [23]:
# Integrated ADReSSo21 Audio Transcript Extractor and AD Classification Pipeline
# Combines transcript extraction with audio and text feature extraction and DARTS classification

import os
import tarfile
import pandas as pd
import numpy as np
from pathlib import Path
import librosa
import speech_recognition as sr
import soundfile as sf
from pydub import AudioSegment
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModel
import re
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("ADReSSo21 INTEGRATED AUDIO TRANSCRIPT AND AD CLASSIFICATION PIPELINE")
print("="*80)

# STEP 1: Mount Google Drive
print("\nMounting Google Drive...")
try:
    from google.colab import drive
    drive.mount('/content/drive')
    print("✓ Google Drive mounted successfully!")
except:
    print("⚠ Not running in Colab or Drive already mounted")

# STEP 2: Set Up Paths and Configuration
print("\nSetting up paths and configuration...")
BASE_PATH = "/content/drive/MyDrive/Voice/"
EXTRACT_PATH = "/content/drive/MyDrive/Voice/extracted/"
OUTPUT_PATH = "/content/drive/MyDrive/Voice/transcripts/"

os.makedirs(EXTRACT_PATH, exist_ok=True)
os.makedirs(OUTPUT_PATH, exist_ok=True)

datasets = {
    'progression_train': 'ADReSSo21-progression-train.tgz',
    'progression_test': 'ADReSSo21-progression-test.tgz',
    'diagnosis_train': 'ADReSSo21-diagnosis-train.tgz'
}

print(f"✓ Base path: {BASE_PATH}")
print(f"✓ Extract path: {EXTRACT_PATH}")
print(f"✓ Output path: {OUTPUT_PATH}")

# # STEP 3: Extract Dataset Files
# print("\nExtracting dataset files...")
# def extract_datasets():
#     for dataset_name, filename in datasets.items():
#         file_path = os.path.join(BASE_PATH, filename)
#         if os.path.exists(file_path):
#             print(f"  Extracting {filename}...")
#             try:
#                 with tarfile.open(file_path, 'r:gz') as tar:
#                     tar.extractall(path=EXTRACT_PATH)
#                 print(f"  ✓ {filename} extracted successfully")
#             except Exception as e:
#                 print(f"  ⚠ Error extracting {filename}: {e}")
#         else:
#             print(f"  ⚠ {filename} not found at {file_path}")

# extract_datasets()

# # STEP 4: Find All WAV Files
# print("\nFinding all WAV files...")
# def find_wav_files():
#     wav_files = {
#         'progression_train': {'decline': [], 'no_decline': []},
#         'progression_test': [],
#         'diagnosis_train': {'ad': [], 'cn': []}
#     }
#     diag_train_base = os.path.join(EXTRACT_PATH, "ADReSSo21/diagnosis/train/audio/")
#     ad_path = os.path.join(diag_train_base, "ad/")
#     if os.path.exists(ad_path):
#         ad_wavs = [f for f in os.listdir(ad_path) if f.endswith('.wav')]
#         wav_files['diagnosis_train']['ad'] = [os.path.join(ad_path, f) for f in ad_wavs]
#         print(f"  Found {len(ad_wavs)} AD WAV files")
#     cn_path = os.path.join(diag_train_base, "cn/")
#     if os.path.exists(cn_path):
#         cn_wavs = [f for f in os.listdir(cn_path) if f.endswith('.wav')]
#         wav_files['diagnosis_train']['cn'] = [os.path.join(cn_path, f) for f in cn_wavs]
#         print(f"  Found {len(cn_wavs)} CN WAV files")
#     return wav_files

# wav_files = find_wav_files()

# STEP 5: Audio Preprocessing Functions
print("\nSetting up audio preprocessing...")
def preprocess_audio(audio_path, target_sr=16000):
    try:
        audio, sr = librosa.load(audio_path, sr=target_sr)
        audio = librosa.util.normalize(audio)
        audio_trimmed, _ = librosa.effects.trim(audio, top_db=20)
        return audio_trimmed, target_sr
    except Exception as e:
        print(f"    Error preprocessing {audio_path}: {e}")
        return None, None

def convert_to_wav_if_needed(audio_path):
    try:
        if not audio_path.endswith('.wav'):
            audio = AudioSegment.from_file(audio_path)
            wav_path = audio_path.rsplit('.', 1)[0] + '_converted.wav'
            audio.export(wav_path, format="wav")
            return wav_path
        return audio_path
    except Exception as e:
        print(f"    Error converting {audio_path}: {e}")
        return audio_path

# STEP 6: Speech Recognition Function
print("\nSetting up speech recognition...")
def extract_transcript_from_audio(audio_path, method='google'):
    recognizer = sr.Recognizer()
    try:
        wav_path = convert_to_wav_if_needed(audio_path)
        audio_data, sr_rate = preprocess_audio(wav_path, target_sr=16000)
        if audio_data is None:
            return None, "Preprocessing failed"
        temp_wav = audio_path.replace('.wav', '_temp.wav')
        sf.write(temp_wav, audio_data, sr_rate)
        with sr.AudioFile(temp_wav) as source:
            recognizer.adjust_for_ambient_noise(source, duration=0.5)
            audio = recognizer.listen(source)
        transcript = None
        error_msg = ""
        if method == 'google':
            try:
                transcript = recognizer.recognize_google(audio)
            except sr.UnknownValueError:
                error_msg = "Google Speech Recognition could not understand audio"
            except sr.RequestError as e:
                error_msg = f"Google Speech Recognition error: {e}"
        if transcript is None:
            try:
                transcript = recognizer.recognize_sphinx(audio)
                method = 'sphinx'
            except sr.UnknownValueError:
                error_msg += "; Sphinx could not understand audio"
            except sr.RequestError as e:
                error_msg += f"; Sphinx error: {e}"
        if os.path.exists(temp_wav):
            os.remove(temp_wav)
        if transcript:
            return transcript.strip(), method
        else:
            return None, error_msg
    except Exception as e:
        return None, f"Error processing audio: {str(e)}"

# STEP 7: Process Audio Files and Extract Transcripts
print("\nProcessing audio files and extracting transcripts...")
def process_audio_files(wav_files):
    all_transcripts = []
    print("\n  Processing diagnosis training data...")
    for label in ['ad', 'cn']:
        files = wav_files['diagnosis_train'][label]
        print(f"    Processing {len(files)} {label} files...")
        for i, audio_path in enumerate(files):
            print(f"      Processing {i+1}/{len(files)}: {os.path.basename(audio_path)}")
            transcript, method_or_error = extract_transcript_from_audio(audio_path)
            all_transcripts.append({
                'file_id': os.path.splitext(os.path.basename(audio_path))[0],
                'file_path': audio_path,
                'dataset': 'diagnosis_train',
                'label': label,
                'transcript': transcript,
                'recognition_method': method_or_error if transcript else None,
                'error': None if transcript else method_or_error,
                'success': transcript is not None
            })
    return all_transcripts

transcripts = process_audio_files(wav_files)

# STEP 8: Save Transcription Results
print("\nSaving transcription results...")
df = pd.DataFrame(transcripts)
complete_output = os.path.join(OUTPUT_PATH, "all_transcripts.csv")
df.to_csv(complete_output, index=False)
print(f"✓ Saved complete results to: {complete_output}")

successful_df = df[df['success'] == True].copy()
success_output = os.path.join(OUTPUT_PATH, "successful_transcripts.csv")
successful_df.to_csv(success_output, index=False)
print(f"✓ Saved successful transcripts to: {success_output}")

# STEP 9: Audio Feature Extraction
print("\nSetting up audio feature extraction...")
class AudioFeatureExtractor:
    def __init__(self, sr=16000, n_mfcc=13, n_fft=2048, hop_length=512):
        self.sr = sr
        self.n_mfcc = n_mfcc
        self.n_fft = n_fft
        self.hop_length = hop_length

    def extract_mfcc_features(self, audio_path):
        try:
            y, sr = librosa.load(audio_path, sr=self.sr)
            mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=self.n_mfcc,
                                       n_fft=self.n_fft, hop_length=self.hop_length)
            delta_mfccs = librosa.feature.delta(mfccs)
            delta2_mfccs = librosa.feature.delta(mfccs, order=2)
            combined_mfccs = np.concatenate([mfccs, delta_mfccs, delta2_mfccs], axis=0)
            features = {
                'mfcc_mean': np.mean(combined_mfccs, axis=1),
                'mfcc_std': np.std(combined_mfccs, axis=1),
                'mfcc_min': np.min(combined_mfccs, axis=1),
                'mfcc_max': np.max(combined_mfccs, axis=1),
                'mfcc_median': np.median(combined_mfccs, axis=1),
                'mfcc_skew': self._calculate_skewness(combined_mfccs),
                'mfcc_kurtosis': self._calculate_kurtosis(combined_mfccs)
            }
            return features, combined_mfccs
        except Exception as e:
            print(f"Error extracting MFCC from {audio_path}: {e}")
            return None, None

    def extract_spectral_features(self, audio_path):
        try:
            y, sr = librosa.load(audio_path, sr=self.sr)
            features = {}
            spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)[0]
            features['spectral_centroid_mean'] = np.mean(spectral_centroid)
            features['spectral_centroid_std'] = np.std(spectral_centroid)
            spectral_bandwidth = librosa.feature.spectral_bandwidth(y=y, sr=sr)[0]
            features['spectral_bandwidth_mean'] = np.mean(spectral_bandwidth)
            features['spectral_bandwidth_std'] = np.std(spectral_bandwidth)
            spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)[0]
            features['spectral_rolloff_mean'] = np.mean(spectral_rolloff)
            features['spectral_rolloff_std'] = np.std(spectral_rolloff)
            zcr = librosa.feature.zero_crossing_rate(y)[0]
            features['zcr_mean'] = np.mean(zcr)
            features['zcr_std'] = np.std(zcr)
            chroma = librosa.feature.chroma_stft(y=y, sr=sr)
            features['chroma_mean'] = np.mean(chroma, axis=1)
            features['chroma_std'] = np.std(chroma, axis=1)
            tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
            features['tempo'] = tempo
            return features
        except Exception as e:
            print(f"Error extracting spectral features from {audio_path}: {e}")
            return {}

    def extract_prosodic_features(self, audio_path):
        try:
            y, sr = librosa.load(audio_path, sr=self.sr)
            features = {}
            f0, voiced_flag, voiced_probs = librosa.pyin(y, fmin=librosa.note_to_hz('C2'),
                                                       fmax=librosa.note_to_hz('C7'))
            f0_clean = f0[~np.isnan(f0)]
            if len(f0_clean) > 0:
                features['f0_mean'] = np.mean(f0_clean)
                features['f0_std'] = np.std(f0_clean)
                features['f0_min'] = np.min(f0_clean)
                features['f0_max'] = np.max(f0_clean)
                features['f0_range'] = np.max(f0_clean) - np.min(f0_clean)
            else:
                features.update({
                    'f0_mean': 0, 'f0_std': 0, 'f0_min': 0,
                    'f0_max': 0, 'f0_range': 0
                })
            rms = librosa.feature.rms(y=y)[0]
            features['rms_mean'] = np.mean(rms)
            features['rms_std'] = np.std(rms)
            contrast = librosa.feature.spectral_contrast(y=y, sr=sr)
            features['contrast_mean'] = np.mean(contrast, axis=1)
            features['contrast_std'] = np.std(contrast, axis=1)
            return features
        except Exception as e:
            print(f"Error extracting prosodic features from {audio_path}: {e}")
            return {}

    def _calculate_skewness(self, data):
        mean = np.mean(data, axis=1, keepdims=True)
        std = np.std(data, axis=1, keepdims=True)
        std[std == 0] = 1
        normalized = (data - mean) / std
        skewness = np.mean(normalized**3, axis=1)
        return skewness

    def _calculate_kurtosis(self, data):
        mean = np.mean(data, axis=1, keepdims=True)
        std = np.std(data, axis=1, keepdims=True)
        std[std == 0] = 1
        normalized = (data - mean) / std
        kurtosis = np.mean(normalized**4, axis=1) - 3
        return kurtosis

    def extract_all_features(self, audio_path):
        all_features = {}
        mfcc_features, mfcc_matrix = self.extract_mfcc_features(audio_path)
        if mfcc_features:
            all_features.update(mfcc_features)
        spectral_features = self.extract_spectral_features(audio_path)
        all_features.update(spectral_features)
        prosodic_features = self.extract_prosodic_features(audio_path)
        all_features.update(prosodic_features)
        flattened_features = {}
        for key, value in all_features.items():
            if isinstance(value, np.ndarray):
                if value.ndim == 1:
                    for i, v in enumerate(value):
                        flattened_features[f"{key}_{i}"] = v
                else:
                    flattened_features[key] = np.mean(value)
            else:
                flattened_features[key] = value
        return flattened_features, mfcc_matrix

# STEP 10: BERT Text Processing
print("\nSetting up BERT text processing...")
class BERTTextProcessor:
    def __init__(self, model_name='bert-base-uncased', max_length=512):
        self.model_name = model_name
        self.max_length = max_length
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.model.eval()

    def preprocess_text(self, text):
        if pd.isna(text) or text is None:
            return ""
        text = str(text).lower()
        text = re.sub(r'[^a-zA-Z0-9\s\.\,\!\?]', '', text)
        text = ' '.join(text.split())
        return text

    def extract_bert_features(self, text):
        try:
            clean_text = self.preprocess_text(text)
            if not clean_text:
                return np.zeros(768)
            inputs = self.tokenizer(
                clean_text,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            with torch.no_grad():
                outputs = self.model(**inputs)
                cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze()
                attention_mask = inputs['attention_mask']
                token_embeddings = outputs.last_hidden_state
                input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
                sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
                sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
                mean_embedding = sum_embeddings / sum_mask
                mean_embedding = mean_embedding.squeeze()
                combined_embedding = (cls_embedding + mean_embedding) / 2
                return combined_embedding.numpy()
        except Exception as e:
            print(f"Error extracting BERT features: {e}")
            return np.zeros(768)

    def extract_linguistic_features(self, text):
        try:
            clean_text = self.preprocess_text(text)
            if not clean_text:
                return {
                    'word_count': 0, 'char_count': 0, 'avg_word_length': 0,
                    'sentence_count': 0, 'question_count': 0, 'exclamation_count': 0
                }
            words = clean_text.split()
            sentences = clean_text.split('.')
            features = {
                'word_count': len(words),
                'char_count': len(clean_text),
                'avg_word_length': np.mean([len(word) for word in words]) if words else 0,
                'sentence_count': len([s for s in sentences if s.strip()]),
                'question_count': clean_text.count('?'),
                'exclamation_count': clean_text.count('!')
            }
            return features
        except Exception as e:
            print(f"Error extracting linguistic features: {e}")
            return {'word_count': 0, 'char_count': 0, 'avg_word_length': 0,
                   'sentence_count': 0, 'question_count': 0, 'exclamation_count': 0}

# STEP 11: DARTS Architecture
print("\nSetting up DARTS architecture...")
class ImprovedDARTSCell(nn.Module):
    def __init__(self, input_dim, output_dim, num_ops=8):
        super(ImprovedDARTSCell, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        if input_dim != output_dim:
            self.projection = nn.Linear(input_dim, output_dim)
        else:
            self.projection = nn.Identity()
        self.operations = nn.ModuleList([
            nn.Identity(),
            nn.ReLU(),
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU()),
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.Tanh()),
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU(), nn.Dropout(0.1)),
            nn.Sequential(nn.Linear(output_dim, output_dim // 2), nn.ReLU(), nn.Linear(output_dim // 2, output_dim)),
            nn.Sequential(nn.LayerNorm(output_dim), nn.ReLU()),
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU(), nn.Linear(output_dim, output_dim))
        ])
        self.alpha = nn.Parameter(torch.randn(len(self.operations)))
        self.temperature = nn.Parameter(torch.ones(1))

    def forward(self, x):
        x = self.projection(x)
        if self.training:
            gumbel_weights = F.gumbel_softmax(self.alpha, tau=self.temperature, hard=False)
        else:
            gumbel_weights = F.softmax(self.alpha / self.temperature, dim=0)
        outputs = []
        for op in self.operations:
            try:
                out = op(x)
                outputs.append(out)
            except Exception as e:
                outputs.append(x)
        result = sum(w * out for w, out in zip(gumbel_weights, outputs))
        return result

class MultimodalDARTSClassifier(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim=256, num_classes=2):
        super(MultimodalDARTSClassifier, self).__init__()
        self.audio_projection = nn.Sequential(
            nn.Linear(audio_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        self.audio_darts_cells = nn.ModuleList([
            ImprovedDARTSCell(hidden_dim, hidden_dim) for _ in range(3)
        ])
        self.text_projection = nn.Sequential(
            nn.Linear(text_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        self.text_darts_cells = nn.ModuleList([
            ImprovedDARTSCell(hidden_dim, hidden_dim) for _ in range(2)
        ])
        self.cross_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.4)
        )
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_classes)
        )
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, audio_features, text_features):
        audio_x = self.audio_projection(audio_features)
        for cell in self.audio_darts_cells:
            residual = audio_x
            audio_x = cell(audio_x)
            audio_x = audio_x + residual
        text_x = self.text_projection(text_features)
        for cell in self.text_darts_cells:
            residual = text_x
            text_x = cell(text_x)
            text_x = text_x + residual
        audio_attended, _ = self.cross_attention(
            audio_x.unsqueeze(1), text_x.unsqueeze(1), text_x.unsqueeze(1)
        )
        audio_attended = audio_attended.squeeze(1)
        fused = torch.cat([audio_attended, text_x], dim=1)
        fused = self.fusion(fused)
        output = self.classifier(fused)
        return output

# STEP 12: Dataset Class
print("\nSetting up dataset class...")
class ADDataset(Dataset):
    def __init__(self, csv_path=None, audio_features=None, text_features=None, labels=None,
                 dataset_type='diagnosis', normalize_audio=True, normalize_text=True):
        self.dataset_type = dataset_type
        self.normalize_audio = normalize_audio
        self.normalize_text = normalize_text
        self.audio_scaler = StandardScaler() if normalize_audio else None
        self.text_scaler = StandardScaler() if normalize_text else None
        self.label_encoder = LabelEncoder()
        if csv_path is not None:
            self._load_from_csv(csv_path)
        else:
            self._load_from_arrays(audio_features, text_features, labels)

    def _load_from_csv(self, csv_path):
        print(f"Loading dataset from {csv_path}...")
        df = pd.read_csv(csv_path)
        df = df[df['success'] == True].copy()
        print(f"Found {len(df)} successful transcripts")
        if self.dataset_type == 'diagnosis':
            df = df[df['dataset'] == 'diagnosis_train'].copy()
            valid_labels = ['ad', 'cn']
        elif self.dataset_type == 'progression':
            df = df[df['dataset'] == 'progression_train'].copy()
            valid_labels = ['decline', 'no_decline']
        else:
            valid_labels = df['label'].unique()
        df = df[df['label'].isin(valid_labels)].copy()
        print(f"After filtering: {len(df)} samples")
        self.file_ids = df['file_id'].tolist()
        self.file_paths = df['file_path'].tolist()
        self.transcripts = df['transcript'].tolist()
        self.raw_labels = df['label'].tolist()
        self.labels = torch.LongTensor(self.label_encoder.fit_transform(self.raw_labels))
        self.label_mapping = dict(zip(self.label_encoder.classes_, range(len(self.label_encoder.classes_))))
        print(f"Label mapping: {self.label_mapping}")
        self.audio_features = None
        self.text_features = None

    def _load_from_arrays(self, audio_features, text_features, labels):
        print("Loading dataset from pre-processed arrays...")
        if audio_features is None or text_features is None or labels is None:
            raise ValueError("All feature arrays must be provided")
        audio_features = np.array(audio_features)
        text_features = np.array(text_features)
        labels = np.array(labels)
        assert len(audio_features) == len(text_features) == len(labels)
        if self.normalize_audio and self.audio_scaler:
            audio_features = self.audio_scaler.fit_transform(audio_features)
        if self.normalize_text and self.text_scaler:
            text_features = self.text_scaler.fit_transform(text_features)
        self.audio_features = torch.FloatTensor(audio_features)
        self.text_features = torch.FloatTensor(text_features)
        if labels.dtype == 'object' or isinstance(labels[0], str):
            self.raw_labels = labels.tolist()
            self.labels = torch.LongTensor(self.label_encoder.fit_transform(labels))
            self.label_mapping = dict(zip(self.label_encoder.classes_, range(len(self.label_encoder.classes_))))
        else:
            self.labels = torch.LongTensor(labels)
            self.raw_labels = labels.tolist()
            self.label_mapping = {i: i for i in range(len(np.unique(labels)))}
        print(f"Dataset loaded: {len(self.labels)} samples")

    def set_audio_features(self, audio_features):
        audio_features = np.array(audio_features)
        if self.normalize_audio and self.audio_scaler:
            audio_features = self.audio_scaler.fit_transform(audio_features)
        self.audio_features = torch.FloatTensor(audio_features)

    def set_text_features(self, text_features):
        text_features = np.array(text_features)
        if self.normalize_text and self.text_scaler:
            text_features = self.text_scaler.fit_transform(text_features)
        self.text_features = torch.FloatTensor(text_features)

    def extract_features_from_transcripts(self, audio_extractor, text_processor):
        if not hasattr(self, 'file_paths') or not hasattr(self, 'transcripts'):
            raise ValueError("Dataset must be loaded from CSV to extract features")
        print("Extracting audio and text features...")
        audio_features_list = []
        failed_audio = 0
        for i, audio_path in enumerate(self.file_paths):
            if i % 10 == 0:
                print(f"  Processing audio {i+1}/{len(self.file_paths)}")
            features, _ = audio_extractor.extract_all_features(audio_path)
            if features:
                feature_vector = [features.get(key, 0) for key in sorted(features.keys())]
                audio_features_list.append(feature_vector)
            else:
                failed_audio += 1
                audio_features_list.append([0] * (len(audio_features_list[0]) if audio_features_list else 100))
        if failed_audio > 0:
            print(f"  Warning: {failed_audio} audio files failed feature extraction")
        text_features_list = []
        for i, transcript in enumerate(self.transcripts):
            if i % 20 == 0:
                print(f"  Processing text {i+1}/{len(self.transcripts)}")
            bert_features = text_processor.extract_bert_features(transcript)
            ling_features = text_processor.extract_linguistic_features(transcript)
            combined_features = np.concatenate([
                bert_features,
                [ling_features[key] for key in sorted(ling_features.keys())]
            ])
            text_features_list.append(combined_features)
        self.set_audio_features(audio_features_list)
        self.set_text_features(text_features_list)
        print(f"Feature extraction complete!")

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if self.audio_features is None or self.text_features is None:
            raise ValueError("Features not set.")
        return self.audio_features[idx], self.text_features[idx], self.labels[idx]

# STEP 13: Create Datasets
print("\nCreating train/val/test datasets...")
def create_dataset_from_transcripts(transcript_csv_path, audio_extractor, text_processor,
                                  dataset_type='diagnosis', test_size=0.2, val_size=0.1):
    full_dataset = ADDataset(csv_path=transcript_csv_path, dataset_type=dataset_type)
    full_dataset.extract_features_from_transcripts(audio_extractor, text_processor)
    indices = list(range(len(full_dataset)))
    labels = [full_dataset.raw_labels[i] for i in indices]
    train_val_idx, test_idx = train_test_split(
        indices, test_size=test_size, stratify=labels, random_state=42
    )
    train_labels = [labels[i] for i in train_val_idx]
    train_idx, val_idx = train_test_split(
        train_val_idx, test_size=val_size/(1-test_size), stratify=train_labels, random_state=42
    )
    def create_subset(indices):
        audio_subset = full_dataset.audio_features[indices]
        text_subset = full_dataset.text_features[indices]
        label_subset = full_dataset.labels[indices]
        return ADDataset(
            audio_features=audio_subset,
            text_features=text_subset,
            labels=label_subset,
            dataset_type=dataset_type,
            normalize_audio=False,
            normalize_text=False
        )
    train_dataset = create_subset(train_idx)
    val_dataset = create_subset(val_idx)
    test_dataset = create_subset(test_idx)
    print(f"\nDataset split completed:")
    print(f"  Training: {len(train_dataset)} samples")
    print(f"  Validation: {len(val_dataset)} samples")
    print(f"  Test: {len(test_dataset)} samples")
    return train_dataset, val_dataset, test_dataset

# STEP 14: Training Loop
print("\nSetting up training loop...")
def train_model(model, train_loader, val_loader, num_epochs=10, learning_rate=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for audio_features, text_features, labels in train_loader:
            audio_features, text_features, labels = audio_features.to(device), text_features.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(audio_features, text_features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_loader):.4f}")
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for audio_features, text_features, labels in val_loader:
                audio_features, text_features, labels = audio_features.to(device), text_features.to(device), labels.to(device)
                outputs = model(audio_features, text_features)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        print(f"Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {100 * correct/total:.2f}%")
    return model

# STEP 15: Main Execution
print("\nStarting main execution...")
audio_extractor = AudioFeatureExtractor()
text_processor = BERTTextProcessor()
transcript_csv_path = os.path.join(OUTPUT_PATH, "successful_transcripts.csv")
train_dataset, val_dataset, test_dataset = create_dataset_from_transcripts(
    transcript_csv_path, audio_extractor, text_processor, dataset_type='diagnosis'
)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)
test_loader = DataLoader(test_dataset, batch_size=16)

# Initialize model
audio_dim = train_dataset.audio_features.shape[1]
text_dim = train_dataset.text_features.shape[1]
model = MultimodalDARTSClassifier(audio_dim=audio_dim, text_dim=text_dim, num_classes=2)

# Train model
print("\nTraining model...")
model = train_model(model, train_loader, val_loader, num_epochs=10)

# STEP 16: Evaluate on Test Set
print("\nEvaluating on test set...")
def evaluate_model(model, test_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for audio_features, text_features, labels in test_loader:
            audio_features, text_features, labels = audio_features.to(device), text_features.to(device), labels.to(device)
            outputs = model(audio_features, text_features)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=['ad', 'cn']))

evaluate_model(model, test_loader)

print("\n" + "="*80)
print("PIPELINE EXECUTION COMPLETE!")
print("="*80)


ADReSSo21 INTEGRATED AUDIO TRANSCRIPT AND AD CLASSIFICATION PIPELINE

Mounting Google Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✓ Google Drive mounted successfully!

Setting up paths and configuration...
✓ Base path: /content/drive/MyDrive/Voice/
✓ Extract path: /content/drive/MyDrive/Voice/extracted/
✓ Output path: /content/drive/MyDrive/Voice/transcripts/

Setting up audio preprocessing...

Setting up speech recognition...

Processing audio files and extracting transcripts...

  Processing diagnosis training data...
    Processing 87 ad files...
      Processing 1/87: adrso024.wav
      Processing 2/87: adrso045.wav
      Processing 3/87: adrso043.wav
      Processing 4/87: adrso036.wav
      Processing 5/87: adrso060.wav
      Processing 6/87: adrso074.wav
      Processing 7/87: adrso070.wav
      Processing 8/87: adrso071.wav
      Processing 9/87: adrso072.wav
      Processing 10/87: ad

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Loading dataset from /content/drive/MyDrive/Voice/transcripts/successful_transcripts.csv...
Found 98 successful transcripts
After filtering: 98 samples
Label mapping: {np.str_('ad'): 0, np.str_('cn'): 1}
Extracting audio and text features...
  Processing audio 1/98
  Processing audio 11/98
  Processing audio 21/98
  Processing audio 31/98
  Processing audio 41/98
  Processing audio 51/98
  Processing audio 61/98
  Processing audio 71/98
  Processing audio 81/98
  Processing audio 91/98
  Processing text 1/98
  Processing text 21/98
  Processing text 41/98
  Processing text 61/98
  Processing text 81/98
Feature extraction complete!
Loading dataset from pre-processed arrays...
Dataset loaded: 68 samples
Loading dataset from pre-processed arrays...
Dataset loaded: 10 samples
Loading dataset from pre-processed arrays...
Dataset loaded: 20 samples

Dataset split completed:
  Training: 68 samples
  Validation: 10 samples
  Test: 20 samples

Training model...
Epoch 1/10, Train Loss: 0.7809
Va

In [27]:
# Complete ADReSSo21 Audio Transcript and AD Classification Pipeline
import os
import tarfile
import pandas as pd
import numpy as np
from pathlib import Path
import librosa
import speech_recognition as sr
import soundfile as sf
from pydub import AudioSegment
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from transformers import AutoTokenizer, AutoModel
import re
import warnings
from sklearn.metrics import classification_report
warnings.filterwarnings('ignore')

print("="*80)
print("COMPLETE ADReSSo21 AUDIO TRANSCRIPT AND AD CLASSIFICATION PIPELINE")
print("="*80)

# Set Up Paths (from your code)
print("\nSetting up paths...")
BASE_PATH = "/content/drive/MyDrive/Voice/"
EXTRACT_PATH = "/content/drive/MyDrive/Voice/extracted/"
OUTPUT_PATH = "/content/drive/MyDrive/Voice/transcripts/"
os.makedirs(EXTRACT_PATH, exist_ok=True)
os.makedirs(OUTPUT_PATH, exist_ok=True)
datasets = {'diagnosis_train': 'ADReSSo21-diagnosis-train.tgz'}
print(f"✓ Base path: {BASE_PATH}")
print(f"✓ Extract path: {EXTRACT_PATH}")
print(f"✓ Output path: {OUTPUT_PATH}")

# Extract Dataset Files (from your code)
print("\nExtracting dataset files...")
def extract_datasets():
    for dataset_name, filename in datasets.items():
        file_path = os.path.join(BASE_PATH, filename)
        if os.path.exists(file_path):
            print(f"  Extracting {filename}...")
            try:
                with tarfile.open(file_path, 'r:gz') as tar:
                    tar.extractall(path=EXTRACT_PATH)
                print(f"  ✓ {filename} extracted successfully")
            except Exception as e:
                print(f"  ⚠ Error extracting {filename}: {e}")
        else:
            print(f"  ⚠ {filename} not found at {file_path}")
extract_datasets()

# Find WAV Files (from your code)
print("\nFinding WAV files...")
def find_wav_files():
    wav_files = {'diagnosis_train': {'ad': [], 'cn': []}}
    diag_train_base = os.path.join(EXTRACT_PATH, "ADReSSo21/diagnosis/train/audio/")
    for label in ['ad', 'cn']:
        path = os.path.join(diag_train_base, label + "/")
        if os.path.exists(path):
            wavs = [f for f in os.listdir(path) if f.endswith('.wav')]
            wav_files['diagnosis_train'][label] = [os.path.join(path, f) for f in wavs]
            print(f"  Found {len(wavs)} {label.upper()} WAV files")
    return wav_files
wav_files = find_wav_files()

# Audio Preprocessing (from your code)
print("\nSetting up audio preprocessing...")
def preprocess_audio(audio_path, target_sr=16000):
    try:
        audio, sr = librosa.load(audio_path, sr=target_sr)
        audio = librosa.util.normalize(audio)
        audio, _ = librosa.effects.trim(audio, top_db=20)
        return audio, target_sr
    except Exception as e:
        print(f"    Error preprocessing {audio_path}: {e}")
        return None, None

# Additional Audio Conversion
def convert_to_wav_if_needed(audio_path):
    try:
        if not audio_path.endswith('.wav'):
            audio = AudioSegment.from_file(audio_path)
            wav_path = audio_path.rsplit('.', 1)[0] + '_converted.wav'
            audio.export(wav_path, format="wav")
            return wav_path
        return audio_path
    except Exception as e:
        print(f"    Error converting {audio_path}: {e}")
        return audio_path

# Speech Recognition with Google/Sphinx
print("\nSetting up speech recognition...")
def extract_transcript_from_audio(audio_path):
    recognizer = sr.Recognizer()
    try:
        wav_path = convert_to_wav_if_needed(audio_path)
        audio_data, sr_rate = preprocess_audio(wav_path)
        if audio_data is None:
            return None, None, "Preprocessing failed"
        temp_wav = audio_path.replace('.wav', '_temp.wav')
        sf.write(temp_wav, audio_data, sr_rate)
        with sr.AudioFile(temp_wav) as source:
            recognizer.adjust_for_ambient_noise(source, duration=0.5)
            audio = recognizer.listen(source)
        transcript = None
        error_msg = ""
        try:
            transcript = recognizer.recognize_google(audio)
            method = "google"
        except sr.UnknownValueError:
            error_msg = "Google Speech Recognition could not understand audio"
        except sr.RequestError as e:
            error_msg = f"Google Speech Recognition error: {e}"
        if transcript is None:
            try:
                transcript = recognizer.recognize_sphinx(audio)
                method = "sphinx"
            except sr.UnknownValueError:
                error_msg += "; Sphinx could not understand audio"
            except sr.RequestError as e:
                error_msg += f"; Sphinx error: {e}"
        if os.path.exists(temp_wav):
            os.remove(temp_wav)
        return transcript, method, None if transcript else error_msg
    except Exception as e:
        return None, None, f"Error processing audio: {str(e)}"

# Process Audio Files
print("\nProcessing audio files...")
def process_audio_files(wav_files):
    all_transcripts = []
    print("\n  Processing diagnosis training data...")
    for label in ['ad', 'cn']:
        files = wav_files['diagnosis_train'][label]
        print(f"    Processing {len(files)} {label} files...")
        for i, audio_path in enumerate(files):
            print(f"      Processing {i+1}/{len(files)}: {os.path.basename(audio_path)}")
            transcript, method, error = extract_transcript_from_audio(audio_path)
            all_transcripts.append({
                'file_id': os.path.splitext(os.path.basename(audio_path))[0],
                'file_path': audio_path,
                'dataset': 'diagnosis_train',
                'label': label,
                'transcript': transcript,
                'recognition_method': method,
                'error': error,
                'success': transcript is not None and len(transcript.split()) > 5  # Filter short transcripts
            })
    return all_transcripts
transcripts = process_audio_files(wav_files)

# Save Transcription Results
print("\nSaving transcription results...")
df = pd.DataFrame(transcripts)
complete_output = os.path.join(OUTPUT_PATH, "all_transcripts.csv")
df.to_csv(complete_output, index=False)
print(f"✓ Saved complete results to: {complete_output}")
successful_df = df[df['success'] == True].copy()
success_output = os.path.join(OUTPUT_PATH, "successful_transcripts.csv")
successful_df.to_csv(success_output, index=False)
print(f"✓ Saved successful transcripts to: {success_output}")

# Audio Feature Extraction
print("\nSetting up audio feature extraction...")
class AudioFeatureExtractor:
    def __init__(self, sr=16000, n_mfcc=13):
        self.sr = sr
        self.n_mfcc = n_mfcc
    def extract_mfcc_features(self, audio_path):
        try:
            y, sr = librosa.load(audio_path, sr=self.sr)
            mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=self.n_mfcc)
            return {'mfcc_mean': np.mean(mfccs, axis=1)}
        except:
            return None
    def extract_all_features(self, audio_path):
        features = self.extract_mfcc_features(audio_path)
        return {k: v for k, v in features.items() if isinstance(v, np.ndarray)} if features else {}

# BERT Text Processing
print("\nSetting up BERT text processing...")
class BERTTextProcessor:
    def __init__(self, model_name='bert-base-uncased', max_length=512):
        self.model_name = model_name
        self.max_length = max_length
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).eval()
    def preprocess_text(self, text):
        if pd.isna(text) or text is None:
            return ""
        text = str(text).lower()
        text = re.sub(r'[^a-zA-Z0-9\s\.\,\!\?]', '', text)
        return ' '.join(text.split())
    def extract_bert_features(self, text):
        try:
            clean_text = self.preprocess_text(text)
            if not clean_text:
                return np.zeros(768)
            inputs = self.tokenizer(clean_text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt')
            with torch.no_grad():
                outputs = self.model(**inputs)
                return outputs.last_hidden_state[:, 0, :].squeeze().numpy()
        except:
            return np.zeros(768)

# Simplified DARTS Architecture
print("\nSetting up DARTS architecture...")
class ImprovedDARTSCell(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.projection = nn.Linear(input_dim, output_dim) if input_dim != output_dim else nn.Identity()
        self.operations = nn.ModuleList([
            nn.Identity(),
            nn.Sequential(nn.Linear(output_dim, output_dim), nn.ReLU())
        ])
        self.alpha = nn.Parameter(torch.randn(len(self.operations)))
        self.temperature = nn.Parameter(torch.ones(1))
    def forward(self, x):
        x = self.projection(x)
        weights = F.gumbel_softmax(self.alpha, tau=self.temperature, hard=False) if self.training else F.softmax(self.alpha / self.temperature, dim=0)
        return sum(w * op(x) for w, op in zip(weights, self.operations))

class MultimodalDARTSClassifier(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim=64, num_classes=2):
        super().__init__()
        self.audio_projection = nn.Sequential(nn.Linear(audio_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.3))
        self.audio_darts_cell = ImprovedDARTSCell(hidden_dim, hidden_dim)
        self.text_projection = nn.Sequential(nn.Linear(text_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.3))
        self.text_darts_cell = ImprovedDARTSCell(hidden_dim, hidden_dim)
        self.fusion = nn.Sequential(nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(), nn.Dropout(0.3))
        self.classifier = nn.Linear(hidden_dim, num_classes)
        self.apply(lambda m: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) else None)
    def forward(self, audio_features, text_features):
        audio_x = self.audio_projection(audio_features)
        audio_x = self.audio_darts_cell(audio_x) + audio_x
        text_x = self.text_projection(text_features)
        text_x = self.text_darts_cell(text_x) + text_x
        fused = torch.cat([audio_x, text_x], dim=1)
        fused = self.fusion(fused)
        return self.classifier(fused)

# Dataset Class
print("\nSetting up dataset class...")
class ADDataset(Dataset):
    def __init__(self, csv_path=None, audio_features=None, text_features=None, labels=None, dataset_type='diagnosis'):
        self.dataset_type = dataset_type
        self.audio_scaler = StandardScaler()
        self.text_scaler = StandardScaler()
        self.label_encoder = LabelEncoder()
        if csv_path:
            self._load_from_csv(csv_path)
        else:
            self._load_from_arrays(audio_features, text_features, labels)
    def _load_from_csv(self, csv_path):
        df = pd.read_csv(csv_path)
        df = df[df['success'] == True].copy()
        if self.dataset_type == 'diagnosis':
            df = df[df['dataset'] == 'diagnosis_train'].copy()
        self.file_ids = df['file_id'].tolist()
        self.file_paths = df['file_path'].tolist()
        self.transcripts = df['transcript'].tolist()
        self.raw_labels = df['label'].tolist()
        self.labels = torch.LongTensor(self.label_encoder.fit_transform(self.raw_labels))
        self.label_mapping = dict(zip(self.label_encoder.classes_, range(len(self.label_encoder.classes_))))
        print(f"Loaded {len(df)} samples, Label mapping: {self.label_mapping}")
    def _load_from_arrays(self, audio_features, text_features, labels):
        audio_features = self.audio_scaler.fit_transform(np.array(audio_features))
        text_features = self.text_scaler.fit_transform(np.array(text_features))
        self.audio_features = torch.FloatTensor(audio_features)
        self.text_features = torch.FloatTensor(text_features)
        self.labels = torch.LongTensor(self.label_encoder.fit_transform(labels)) if isinstance(labels[0], str) else torch.LongTensor(labels)
        self.label_mapping = dict(zip(self.label_encoder.classes_, range(len(self.label_encoder.classes_))))
    def set_features(self, audio_features, text_features):
        audio_features = self.audio_scaler.fit_transform(np.array(audio_features))
        text_features = self.text_scaler.fit_transform(np.array(text_features))
        self.audio_features = torch.FloatTensor(audio_features)
        self.text_features = torch.FloatTensor(text_features)
    def extract_features(self, audio_extractor, text_processor):
        audio_features_list = []
        for i, audio_path in enumerate(self.file_paths):
            if i % 10 == 0:
                print(f"  Audio {i+1}/{len(self.file_paths)}")
            features = audio_extractor.extract_all_features(audio_path)
            audio_features_list.append([features.get(k, [0]*13)[i%13] for k in sorted(features.keys()) for i in range(13)] if features else [0] * 13)
        text_features_list = []
        for i, transcript in enumerate(self.transcripts):
            if i % 20 == 0:
                print(f"  Text {i+1}/{len(self.transcripts)}")
            text_features_list.append(text_processor.extract_bert_features(transcript))
        audio_features = PCA(n_components=10).fit_transform(audio_features_list)
        text_features = PCA(n_components=50).fit_transform(text_features_list)
        self.set_features(audio_features, text_features)
    def get_class_weights(self):
        counts = np.bincount(self.labels)
        weights = len(self.labels) / (len(counts) * counts)
        return torch.FloatTensor(weights)
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        return self.audio_features[idx], self.text_features[idx], self.labels[idx]

# Create Datasets
print("\nCreating datasets...")
def create_dataset_from_transcripts(transcript_csv_path, audio_extractor, text_processor, dataset_type='diagnosis'):
    full_dataset = ADDataset(csv_path=transcript_csv_path, dataset_type=dataset_type)
    full_dataset.extract_features(audio_extractor, text_processor)
    indices = list(range(len(full_dataset)))
    labels = [full_dataset.raw_labels[i] for i in indices]
    train_val_idx, test_idx = train_test_split(indices, test_size=0.2, stratify=labels, random_state=42)
    train_labels = [labels[i] for i in train_val_idx]
    train_idx, val_idx = train_test_split(train_val_idx, test_size=0.15/0.8, stratify=train_labels, random_state=42)
    def create_subset(indices):
        audio_subset = full_dataset.audio_features[indices]
        text_subset = full_dataset.text_features[indices]
        label_subset = full_dataset.labels[indices]
        return ADDataset(audio_features=audio_subset, text_features=text_subset, labels=label_subset, dataset_type=dataset_type)
    train_dataset = create_subset(train_idx)
    val_dataset = create_subset(val_idx)
    test_dataset = create_subset(test_idx)
    print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
    return train_dataset, val_dataset, test_dataset

# Training Loop
print("\nSetting up training loop...")
def train_model(model, train_loader, val_loader, num_epochs=20):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)
    criterion = nn.CrossEntropyLoss(weight=train_loader.dataset.get_class_weights().to(device))
    best_val_loss = float('inf')
    patience = 3
    counter = 0
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for audio_features, text_features, labels in train_loader:
            audio_features, text_features, labels = audio_features.to(device), text_features.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(audio_features, text_features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for audio_features, text_features, labels in val_loader:
                audio_features, text_features, labels = audio_features.to(device), text_features.to(device), labels.to(device)
                outputs = model(audio_features, text_features)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        val_loss /= len(val_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss:.4f}, Val Acc: {100 * correct/total:.2f}%")
        scheduler.step(val_loss)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), os.path.join(OUTPUT_PATH, "best_model.pt"))
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("Early stopping triggered")
                break
    model.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, "best_model.pt")))
    return model

# Evaluation
print("\nSetting up evaluation...")
def evaluate_model(model, test_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device).eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for audio_features, text_features, labels in test_loader:
            audio_features, text_features, labels = audio_features.to(device), text_features.to(device), labels.to(device)
            outputs = model(audio_features, text_features)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=['ad', 'cn']))

# Main Execution
print("\nStarting main execution...")
audio_extractor = AudioFeatureExtractor()
text_processor = BERTTextProcessor()
transcript_csv_path = os.path.join(OUTPUT_PATH, "successful_transcripts.csv")
train_dataset, val_dataset, test_dataset = create_dataset_from_transcripts(
    transcript_csv_path, audio_extractor, text_processor, dataset_type='diagnosis'
)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)
audio_dim = train_dataset.audio_features.shape[1]
text_dim = train_dataset.text_features.shape[1]
model = MultimodalDARTSClassifier(audio_dim=audio_dim, text_dim=text_dim, hidden_dim=64, num_classes=2)
model = train_model(model, train_loader, val_loader, num_epochs=20)
evaluate_model(model, test_loader)
print("\n" + "="*80)
print("PIPELINE EXECUTION COMPLETE!")
print("="*80)

SyntaxError: invalid syntax (<ipython-input-27-44d0242f4fac>, line 1)