<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/Detecting_dementia_from_speech_and_transcripts_using_transformers_May243.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import os
import tarfile
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import librosa
import librosa.display
from transformers import (BertTokenizer, BertModel, ViTModel, ViTFeatureExtractor,
                         Wav2Vec2ForCTC, Wav2Vec2Processor, WhisperProcessor,
                         WhisperForConditionalGeneration)
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
import json
import glob
import re
from pathlib import Path
import soundfile as sf
from collections import defaultdict
from scipy import ndimage
import pickle
warnings.filterwarnings('ignore')

class SpeechTranscriber:
    """Automatic Speech Recognition for generating transcripts from audio"""

    def __init__(self, model_name="openai/whisper-base", cache_dir="./asr_cache"):
        """
        Initialize ASR model
        Options:
        - openai/whisper-base: Good balance of speed/accuracy
        - facebook/wav2vec2-base-960h: Faster but less accurate
        - openai/whisper-small: More accurate but slower
        """
        self.model_name = model_name
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)

        print(f"Loading ASR model: {model_name}")

        if "whisper" in model_name.lower():
            self.processor = WhisperProcessor.from_pretrained(model_name)
            self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
            self.asr_type = "whisper"
        else:
            self.processor = Wav2Vec2Processor.from_pretrained(model_name)
            self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
            self.asr_type = "wav2vec2"

        # Move to GPU if available
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.model.to(self.device)
        self.model.eval()

        print(f"ASR model loaded on {self.device}")

    def transcribe_audio_file(self, audio_path, use_cache=True):
        """Transcribe a single audio file"""
        audio_path = Path(audio_path)

        # Check cache first
        cache_file = self.cache_dir / f"{audio_path.stem}_transcript.pkl"
        if use_cache and cache_file.exists():
            try:
                with open(cache_file, 'rb') as f:
                    cached_result = pickle.load(f)
                return cached_result['transcript']
            except:
                pass  # Cache corrupted, proceed with transcription

        try:
            # Load audio
            audio, sr = librosa.load(str(audio_path), sr=16000)  # Whisper expects 16kHz

            # Handle empty or very short audio
            if len(audio) < 1600:  # Less than 0.1 seconds
                transcript = ""
            else:
                transcript = self._transcribe_audio_array(audio)

            # Cache result
            if use_cache:
                try:
                    with open(cache_file, 'wb') as f:
                        pickle.dump({
                            'audio_path': str(audio_path),
                            'transcript': transcript,
                            'model': self.model_name
                        }, f)
                except:
                    pass  # Caching failed, but transcription succeeded

            return transcript

        except Exception as e:
            print(f"Error transcribing {audio_path}: {e}")
            return f"[Transcription failed for {audio_path.name}]"

    def _transcribe_audio_array(self, audio_array):
        """Transcribe audio array using the loaded model"""
        try:
            if self.asr_type == "whisper":
                return self._whisper_transcribe(audio_array)
            else:
                return self._wav2vec2_transcribe(audio_array)
        except Exception as e:
            print(f"Transcription error: {e}")
            return "[Transcription error]"

    def _whisper_transcribe(self, audio_array):
        """Transcribe using Whisper model"""
        # Process audio
        inputs = self.processor(
            audio_array,
            sampling_rate=16000,
            return_tensors="pt"
        ).input_features.to(self.device)

        # Generate transcription
        with torch.no_grad():
            predicted_ids = self.model.generate(inputs, max_length=448)

        # Decode
        transcript = self.processor.batch_decode(
            predicted_ids,
            skip_special_tokens=True
        )[0]

        return transcript.strip()

    def _wav2vec2_transcribe(self, audio_array):
        """Transcribe using Wav2Vec2 model"""
        # Process audio
        inputs = self.processor(
            audio_array,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        ).input_values.to(self.device)

        # Get logits
        with torch.no_grad():
            logits = self.model(inputs).logits

        # Decode
        predicted_ids = torch.argmax(logits, dim=-1)
        transcript = self.processor.batch_decode(predicted_ids)[0]

        return transcript.strip().lower()

    def batch_transcribe(self, audio_paths, batch_size=8):
        """Transcribe multiple audio files with progress bar"""
        transcripts = {}

        print(f"Transcribing {len(audio_paths)} audio files...")

        for audio_path in tqdm(audio_paths, desc="Transcribing"):
            transcript = self.transcribe_audio_file(audio_path)
            participant_id = self._extract_participant_id(Path(audio_path).name)
            transcripts[participant_id] = transcript

        return transcripts

    def _extract_participant_id(self, filename):
        """Extract participant ID from filename"""
        patterns = [
            r'adrso?(\d{3})',         # adrs0123 or adrso123
            r'adrsp?(\d{3})',         # adrsp123
            r'adrspt?(\d{1,3})',      # adrspt1, adrspt12
            r'(\d{3})',               # 3-digit numbers
            r'([A-Z]\d{2,3})',        # Letter followed by 2-3 digits
            r'(S\d{3})',              # S followed by 3 digits
        ]

        for pattern in patterns:
            match = re.search(pattern, filename)
            if match:
                return match.group(1) if pattern.startswith(r'(\d') else match.group(0)

        return Path(filename).stem

In [12]:
class EnhancedADReSSDataProcessor:
    """Enhanced ADReSS data processor with automatic speech recognition"""

    def __init__(self, output_dir='./extracted_data', asr_model="openai/whisper-base"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        # Initialize ASR
        self.transcriber = SpeechTranscriber(model_name=asr_model)

    def extract_adress_dataset(self, tar_path, dataset_name):
        """Extract ADReSS dataset and organize files properly"""
        extract_path = self.output_dir / dataset_name
        extract_path.mkdir(exist_ok=True)

        print(f"Extracting {tar_path} to {extract_path}")

        try:
            with tarfile.open(tar_path, 'r:gz') as tar:
                tar.extractall(path=extract_path)
            print(f"Successfully extracted {dataset_name}")

            # Find the actual dataset directory structure
            self._explore_directory_structure(extract_path)
            return extract_path
        except Exception as e:
            print(f"Error extracting {tar_path}: {e}")
            return None

    def _explore_directory_structure(self, base_path):
        """Explore and print directory structure"""
        print(f"\nDirectory structure for {base_path.name}:")
        for root, dirs, files in os.walk(base_path):
            level = root.replace(str(base_path), '').count(os.sep)
            indent = ' ' * 2 * level
            print(f"{indent}{os.path.basename(root)}/")
            subindent = ' ' * 2 * (level + 1)
            for file in files[:5]:
                print(f"{subindent}{file}")
            if len(files) > 5:
                print(f"{subindent}... and {len(files) - 5} more files")

    def process_adress_dataset_with_asr(self, extract_path):
        """Process ADReSS dataset with automatic speech recognition"""
        dataset_info = {
            'audio_files': [],
            'transcript_files': [],
            'metadata_files': [],
            'labels': {},
            'paired_data': [],
            'generated_transcripts': {}
        }

        # Look for ADReSS structure
        adress_dirs = list(extract_path.rglob("*ADReSS*"))
        if adress_dirs:
            main_dir = adress_dirs[0]
        else:
            main_dir = extract_path

        print(f"Processing from directory: {main_dir}")

        # Find audio files
        audio_patterns = ['**/*.wav', '**/*.mp3', '**/*.flac']
        for pattern in audio_patterns:
            dataset_info['audio_files'].extend(list(main_dir.glob(pattern)))

        print(f"Found {len(dataset_info['audio_files'])} audio files")

        # Generate transcripts using ASR
        print("Generating transcripts using ASR...")
        audio_paths = [str(path) for path in dataset_info['audio_files']]
        generated_transcripts = self.transcriber.batch_transcribe(audio_paths)
        dataset_info['generated_transcripts'] = generated_transcripts

        print(f"Generated {len(generated_transcripts)} transcripts")

        # Process labels from directory structure
        labels = self._extract_labels_from_structure(dataset_info['audio_files'])
        dataset_info['labels'] = labels

        # Create paired dataset with generated transcripts
        paired_data = self._create_paired_data_with_asr(dataset_info)
        dataset_info['paired_data'] = paired_data

        return dataset_info

    def _extract_labels_from_structure(self, audio_files):
        """Extract labels from file paths or directory structure"""
        labels = {}

        for audio_file in audio_files:
            # Extract participant ID
            participant_id = self._extract_participant_id(audio_file.name)

            # Determine label from path
            path_str = str(audio_file).lower()
            if '/ad/' in path_str or 'dementia' in path_str or 'alzheimer' in path_str:
                label = 1  # AD/Dementia
                class_name = 'AD'
            elif '/cn/' in path_str or 'control' in path_str or 'normal' in path_str:
                label = 0  # Control/Normal
                class_name = 'CN'
            elif 'decline' in path_str:
                label = 1  # Decline/progression
                class_name = 'AD'
            elif 'no_decline' in path_str or 'no-decline' in path_str:
                label = 0  # No decline
                class_name = 'CN'
            else:
                # Default classification based on filename patterns
                if any(marker in audio_file.name.lower() for marker in ['ad', 'dem', 'alz']):
                    label = 1
                    class_name = 'AD'
                else:
                    label = 0  # Default to control
                    class_name = 'CN'

            labels[participant_id] = {
                'label': label,
                'class_name': class_name,
                'audio_path': audio_file
            }

        return labels

    def _extract_participant_id(self, filename):
        """Extract participant ID from filename"""
        patterns = [
            r'adrso?(\d{3})',
            r'adrsp?(\d{3})',
            r'adrspt?(\d{1,3})',
            r'(\d{3})',
            r'([A-Z]\d{2,3})',
            r'(S\d{3})',
        ]

        for pattern in patterns:
            match = re.search(pattern, filename)
            if match:
                return match.group(1) if pattern.startswith(r'(\d') else match.group(0)

        return Path(filename).stem

    def _create_paired_data_with_asr(self, dataset_info):
        """Create paired audio-transcript dataset using ASR-generated transcripts"""
        paired_data = []

        # Create paired dataset using generated transcripts
        for participant_id, label_info in dataset_info['labels'].items():
            # Get generated transcript
            transcript = dataset_info['generated_transcripts'].get(participant_id, "")

            # Clean and validate transcript
            transcript = self._clean_and_validate_transcript(transcript)

            # If transcript is still empty or invalid, create a meaningful placeholder
            if not transcript or len(transcript.strip()) < 10:
                transcript = f"Audio sample from participant {participant_id}. Speech content unclear or silent."

            paired_data.append({
                'participant_id': participant_id,
                'audio_path': str(label_info['audio_path']),
                'transcript': transcript,
                'label': label_info['label'],
                'class_name': label_info['class_name'],
                'transcript_source': 'ASR_generated'
            })

        print(f"Created {len(paired_data)} paired samples with ASR transcripts")

        # Print class distribution
        labels = [item['label'] for item in paired_data]
        unique, counts = np.unique(labels, return_counts=True)
        for cls, count in zip(unique, counts):
            class_name = 'CN' if cls == 0 else 'AD'
            print(f"  {class_name}: {count} samples ({count/len(labels)*100:.1f}%)")

        # Print sample transcripts for verification
        print("\nSample generated transcripts:")
        print("-" * 50)
        for i, sample in enumerate(paired_data[:3]):
            print(f"Participant {sample['participant_id']} ({sample['class_name']}):")
            print(f"Transcript: {sample['transcript'][:100]}...")
            print()

        return paired_data

    def _clean_and_validate_transcript(self, transcript):
        """Clean and validate ASR-generated transcript"""
        if not transcript:
            return ""

        # Remove common ASR artifacts
        transcript = transcript.strip()
        transcript = re.sub(r'\[.*?\]', '', transcript)  # Remove [NOISE], [MUSIC], etc.
        transcript = re.sub(r'<.*?>', '', transcript)    # Remove <unk>, <pad>, etc.
        transcript = re.sub(r'\s+', ' ', transcript)     # Normalize whitespace

        # Remove very short or repetitive transcripts
        if len(transcript) < 5:
            return ""

        # Check for repetitive patterns (common ASR error)
        words = transcript.split()
        if len(words) > 1:
            # If more than 70% of words are the same, likely an error
            unique_words = set(words)
            if len(unique_words) / len(words) < 0.3:
                return ""

        return transcript

In [13]:
class EnhancedMultiModalDataset(Dataset):
    """Enhanced dataset with ASR-generated transcripts and linguistic features"""

    def __init__(self, data_samples, audio_processor, tokenizer,
                 max_text_length=512, audio_max_length=16*16000,
                 image_size=(224, 224)):
        self.data_samples = data_samples
        self.audio_processor = audio_processor
        self.tokenizer = tokenizer
        self.max_text_length = max_text_length
        self.audio_max_length = audio_max_length
        self.image_size = image_size

        # Precompute linguistic features for efficiency
        self._precompute_linguistic_features()

    def _precompute_linguistic_features(self):
        """Precompute linguistic features that might be important for AD detection"""
        print("Precomputing linguistic features...")

        for sample in tqdm(self.data_samples, desc="Computing linguistic features"):
            transcript = sample['transcript']

            # Basic linguistic metrics
            words = transcript.split()
            sentences = re.split(r'[.!?]+', transcript)

            linguistic_features = {
                'word_count': len(words),
                'sentence_count': len([s for s in sentences if s.strip()]),
                'avg_word_length': np.mean([len(word) for word in words]) if words else 0,
                'unique_words': len(set(words)),
                'lexical_diversity': len(set(words)) / len(words) if words else 0,
                'pause_markers': transcript.count('[pause]') + transcript.count('...'),
                'filler_words': sum(1 for word in words if word.lower() in ['um', 'uh', 'er', 'ah']),
                'transcript_length': len(transcript)
            }

            sample['linguistic_features'] = linguistic_features

    def __len__(self):
        return len(self.data_samples)

    def __getitem__(self, idx):
        sample = self.data_samples[idx]

        # Process text
        text = sample['transcript']
        if not text or text.strip() == "":
            text = "No speech content detected in audio sample"

        try:
            encoding = self.tokenizer(
                text,
                truncation=True,
                padding='max_length',
                max_length=self.max_text_length,
                return_tensors='pt'
            )
        except Exception as e:
            print(f"Error tokenizing text for {sample['participant_id']}: {e}")
            # Create dummy encoding
            encoding = {
                'input_ids': torch.zeros(self.max_text_length, dtype=torch.long),
                'attention_mask': torch.zeros(self.max_text_length, dtype=torch.long),
                'token_type_ids': torch.zeros(self.max_text_length, dtype=torch.long)
            }

        # Process audio
        try:
            audio = self.audio_processor.load_audio(
                sample['audio_path'],
                max_length=self.audio_max_length
            )

            # Extract spectrogram features
            spectrogram = self.audio_processor.extract_mel_spectrogram(audio)

            # Resize for ViT input
            audio_features = self.audio_processor.resize_spectrogram_to_image(
                spectrogram, self.image_size
            )

        except Exception as e:
            print(f"Error processing audio for {sample['participant_id']}: {e}")
            # Create dummy audio features
            audio_features = np.random.rand(3, self.image_size[0], self.image_size[1])

        # Linguistic features
        ling_features = sample.get('linguistic_features', {})
        linguistic_vector = np.array([
            ling_features.get('word_count', 0),
            ling_features.get('sentence_count', 0),
            ling_features.get('avg_word_length', 0),
            ling_features.get('unique_words', 0),
            ling_features.get('lexical_diversity', 0),
            ling_features.get('pause_markers', 0),
            ling_features.get('filler_words', 0),
            ling_features.get('transcript_length', 0)
        ], dtype=np.float32)

        return {
            'input_ids': encoding['input_ids'].squeeze() if hasattr(encoding['input_ids'], 'squeeze') else encoding['input_ids'],
            'attention_mask': encoding['attention_mask'].squeeze() if hasattr(encoding['attention_mask'], 'squeeze') else encoding['attention_mask'],
            'audio_features': torch.FloatTensor(audio_features),
            'linguistic_features': torch.FloatTensor(linguistic_vector),
            'label': torch.LongTensor([sample['label']]).squeeze(),
            'participant_id': sample['participant_id'],
            'class_name': sample['class_name'],
            'transcript_preview': text[:100] + "..." if len(text) > 100 else text
        }

In [14]:
class SimpleAudioProcessor:
    def __init__(self, sample_rate=16000, n_mels=128):
        self.sample_rate = sample_rate
        self.n_mels = n_mels

    def load_audio(self, audio_path, max_length=None):
        try:
            audio, sr = librosa.load(audio_path, sr=self.sample_rate)
            if audio.ndim > 1:
                audio = np.mean(audio, axis=1)
            audio, _ = librosa.effects.trim(audio, top_db=20)
            if np.max(np.abs(audio)) > 0:
                audio = librosa.util.normalize(audio)

            if max_length is not None:
                if len(audio) > max_length:
                    start = (len(audio) - max_length) // 2
                    audio = audio[start:start + max_length]
                elif len(audio) < max_length:
                    pad_length = max_length - len(audio)
                    audio = np.pad(audio, (0, pad_length), mode='constant')

            return audio
        except:
            length = max_length if max_length else self.sample_rate * 10
            return np.random.randn(length) * 0.01

    def extract_mel_spectrogram(self, audio):
        try:
            mel_spec = librosa.feature.melspectrogram(
                y=audio, sr=self.sample_rate, n_mels=self.n_mels
            )
            log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
            delta = librosa.feature.delta(log_mel_spec)
            delta2 = librosa.feature.delta(log_mel_spec, order=2)
            return np.stack([log_mel_spec, delta, delta2], axis=0)
        except:
            return np.random.randn(3, self.n_mels, 100)

    def resize_spectrogram_to_image(self, spectrogram, target_size=(224, 224)):
        try:
            if spectrogram.ndim == 3:
                resized_channels = []
                for i in range(spectrogram.shape[0]):
                    channel = spectrogram[i]
                    zoom_factors = [target_size[j] / channel.shape[j] for j in range(2)]
                    resized_channel = ndimage.zoom(channel, zoom_factors, order=1)
                    resized_channels.append(resized_channel)
                resized = np.stack(resized_channels, axis=0)
            else:
                zoom_factors = [target_size[i] / spectrogram.shape[i] for i in range(2)]
                resized = ndimage.zoom(spectrogram, zoom_factors, order=1)
                resized = np.stack([resized] * 3, axis=0)

            resized = (resized - resized.min()) / (resized.max() - resized.min() + 1e-8)
            return resized
        except:
            return np.random.rand(3, target_size[0], target_size[1])

In [15]:
class EnhancedMultiModalADClassifier(nn.Module):
    """Enhanced multimodal classifier for Alzheimer's detection"""

    def __init__(self, text_hidden_size=768, audio_hidden_size=768,
                 linguistic_feature_size=8, fusion_hidden_size=512,
                 num_classes=2, dropout=0.3):
        super().__init__()

        # Text encoder (BERT)
        self.text_encoder = BertModel.from_pretrained('bert-base-uncased')

        # Audio encoder (ViT for spectrograms)
        self.audio_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224')

        # Linguistic features processor
        self.linguistic_processor = nn.Sequential(
            nn.Linear(linguistic_feature_size, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 32),
            nn.ReLU()
        )

        # Attention mechanism for modality fusion
        self.attention = nn.MultiheadAttention(
            embed_dim=fusion_hidden_size,
            num_heads=8,
            dropout=dropout
        )

        # Feature projections to common dimensionality
        self.text_projection = nn.Linear(text_hidden_size, fusion_hidden_size)
        self.audio_projection = nn.Linear(audio_hidden_size, fusion_hidden_size)
        self.linguistic_projection = nn.Linear(32, fusion_hidden_size)

        # Fusion layers
        self.fusion_layers = nn.Sequential(
            nn.Linear(fusion_hidden_size * 3, fusion_hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(fusion_hidden_size, fusion_hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(fusion_hidden_size // 2, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes)
        )

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Initialize model weights"""
        for module in [self.linguistic_processor, self.fusion_layers, self.classifier]:
            for layer in module:
                if isinstance(layer, nn.Linear):
                    nn.init.xavier_uniform_(layer.weight)
                    nn.init.zeros_(layer.bias)

    def forward(self, input_ids, attention_mask, audio_features, linguistic_features):
        batch_size = input_ids.size(0)

        # Text encoding
        text_output = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        text_features = text_output.pooler_output  # [batch_size, 768]

        # Audio encoding
        audio_output = self.audio_encoder(pixel_values=audio_features)
        audio_features = audio_output.pooler_output  # [batch_size, 768]

        # Linguistic features processing
        linguistic_processed = self.linguistic_processor(linguistic_features)  # [batch_size, 32]

        # Project to common dimensionality
        text_projected = self.text_projection(text_features)        # [batch_size, 512]
        audio_projected = self.audio_projection(audio_features)     # [batch_size, 512]
        linguistic_projected = self.linguistic_projection(linguistic_processed)  # [batch_size, 512]

        # Prepare for attention mechanism
        # Convert to [seq_len, batch_size, embed_dim] for attention
        modality_features = torch.stack([
            text_projected,
            audio_projected,
            linguistic_projected
        ], dim=0)  # [3, batch_size, 512]

        # Apply self-attention across modalities
        attended_features, attention_weights = self.attention(
            modality_features, modality_features, modality_features
        )  # [3, batch_size, 512]

        # Convert back to [batch_size, features]
        attended_features = attended_features.transpose(0, 1)  # [batch_size, 3, 512]

        # Flatten for fusion
        fused_input = attended_features.reshape(batch_size, -1)  # [batch_size, 1536]

        # Fusion
        fused_features = self.fusion_layers(fused_input)  # [batch_size, 256]

        # Classification
        logits = self.classifier(fused_features)  # [batch_size, 2]

        return {
            'logits': logits,
            'text_features': text_features,
            'audio_features': audio_features,
            'linguistic_features': linguistic_processed,
            'attention_weights': attention_weights,
            'fused_features': fused_features
        }

In [16]:
class ModelTrainer:
    """Enhanced model trainer with comprehensive evaluation"""

    def __init__(self, model, device, learning_rate=2e-5, weight_decay=0.01):
        self.model = model.to(device)
        self.device = device
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay
        )
        self.criterion = nn.CrossEntropyLoss()
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', patience=3, factor=0.5, verbose=True
        )

        # Training history
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.best_val_accuracy = 0.0
        self.best_model_state = None

    def train_epoch(self, train_loader):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0.0
        correct_predictions = 0
        total_predictions = 0

        progress_bar = tqdm(train_loader, desc="Training", leave=False)

        for batch in progress_bar:
            # Move batch to device
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            audio_features = batch['audio_features'].to(self.device)
            linguistic_features = batch['linguistic_features'].to(self.device)
            labels = batch['label'].to(self.device)

            # Zero gradients
            self.optimizer.zero_grad()

            # Forward pass
            try:
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    audio_features=audio_features,
                    linguistic_features=linguistic_features
                )

                logits = outputs['logits']
                loss = self.criterion(logits, labels)

                # Backward pass
                loss.backward()

                # Gradient clipping to prevent exploding gradients
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

                self.optimizer.step()

                # Statistics
                total_loss += loss.item()
                predictions = torch.argmax(logits, dim=1)
                correct_predictions += (predictions == labels).sum().item()
                total_predictions += labels.size(0)

                # Update progress bar
                current_accuracy = correct_predictions / total_predictions
                progress_bar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'Acc': f'{current_accuracy:.4f}'
                })

            except Exception as e:
                print(f"Error in training batch: {e}")
                continue

        avg_loss = total_loss / len(train_loader)
        accuracy = correct_predictions / total_predictions

        return avg_loss, accuracy

    def validate_epoch(self, val_loader):
        """Validate for one epoch"""
        self.model.eval()
        total_loss = 0.0
        correct_predictions = 0
        total_predictions = 0
        all_predictions = []
        all_labels = []
        all_participant_ids = []

        with torch.no_grad():
            progress_bar = tqdm(val_loader, desc="Validation", leave=False)

            for batch in progress_bar:
                # Move batch to device
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                audio_features = batch['audio_features'].to(self.device)
                linguistic_features = batch['linguistic_features'].to(self.device)
                labels = batch['label'].to(self.device)

                try:
                    # Forward pass
                    outputs = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        audio_features=audio_features,
                        linguistic_features=linguistic_features
                    )

                    logits = outputs['logits']
                    loss = self.criterion(logits, labels)

                    # Statistics
                    total_loss += loss.item()
                    predictions = torch.argmax(logits, dim=1)
                    correct_predictions += (predictions == labels).sum().item()
                    total_predictions += labels.size(0)

                    # Store for detailed analysis
                    all_predictions.extend(predictions.cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())
                    all_participant_ids.extend(batch['participant_id'])

                    # Update progress bar
                    current_accuracy = correct_predictions / total_predictions
                    progress_bar.set_postfix({
                        'Loss': f'{loss.item():.4f}',
                        'Acc': f'{current_accuracy:.4f}'
                    })

                except Exception as e:
                    print(f"Error in validation batch: {e}")
                    continue

        avg_loss = total_loss / len(val_loader)
        accuracy = correct_predictions / total_predictions

        return avg_loss, accuracy, all_predictions, all_labels, all_participant_ids

    def train(self, train_loader, val_loader, num_epochs=10, save_path='best_model.pt'):
        """Full training loop with validation"""
        print(f"Starting training for {num_epochs} epochs...")
        print(f"Training on {len(train_loader.dataset)} samples")
        print(f"Validating on {len(val_loader.dataset)} samples")
        print(f"Device: {self.device}")

        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            print("-" * 50)

            # Training
            train_loss, train_accuracy = self.train_epoch(train_loader)
            self.train_losses.append(train_loss)
            self.train_accuracies.append(train_accuracy)

            # Validation
            val_loss, val_accuracy, val_predictions, val_labels, val_ids = self.validate_epoch(val_loader)
            self.val_losses.append(val_loss)
            self.val_accuracies.append(val_accuracy)

            # Learning rate scheduling
            self.scheduler.step(val_accuracy)

            # Save best model
            if val_accuracy > self.best_val_accuracy:
                self.best_val_accuracy = val_accuracy
                self.best_model_state = self.model.state_dict().copy()
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.best_model_state,
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'best_val_accuracy': self.best_val_accuracy,
                    'train_losses': self.train_losses,
                    'val_losses': self.val_losses,
                    'train_accuracies': self.train_accuracies,
                    'val_accuracies': self.val_accuracies
                }, save_path)
                print(f"✓ New best model saved with validation accuracy: {val_accuracy:.4f}")

            # Print epoch results
            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")

            # Detailed validation metrics for best epochs
            if val_accuracy == self.best_val_accuracy:
                self._print_detailed_metrics(val_labels, val_predictions)

        print(f"\nTraining completed!")
        print(f"Best validation accuracy: {self.best_val_accuracy:.4f}")

        # Load best model
        if self.best_model_state is not None:
            self.model.load_state_dict(self.best_model_state)
            print("✓ Best model loaded for final evaluation")

    def _print_detailed_metrics(self, true_labels, predictions):
        """Print detailed classification metrics"""
        accuracy = accuracy_score(true_labels, predictions)
        precision = precision_score(true_labels, predictions, average='weighted', zero_division=0)
        recall = recall_score(true_labels, predictions, average='weighted', zero_division=0)
        f1 = f1_score(true_labels, predictions, average='weighted', zero_division=0)

        print(f"\nDetailed Metrics:")
        print(f"  Accuracy: {accuracy:.4f}")
        print(f"  Precision: {precision:.4f}")
        print(f"  Recall: {recall:.4f}")
        print(f"  F1-Score: {f1:.4f}")

    def evaluate_model(self, test_loader, class_names=['CN', 'AD']):
        """Comprehensive model evaluation"""
        print("\n" + "="*60)
        print("COMPREHENSIVE MODEL EVALUATION")
        print("="*60)

        self.model.eval()
        all_predictions = []
        all_labels = []
        all_probabilities = []
        all_participant_ids = []
        detailed_results = []

        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Evaluating"):
                # Move batch to device
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                audio_features = batch['audio_features'].to(self.device)
                linguistic_features = batch['linguistic_features'].to(self.device)
                labels = batch['label'].to(self.device)

                try:
                    # Forward pass
                    outputs = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        audio_features=audio_features,
                        linguistic_features=linguistic_features
                    )

                    logits = outputs['logits']
                    probabilities = torch.softmax(logits, dim=1)
                    predictions = torch.argmax(logits, dim=1)

                    # Store results
                    all_predictions.extend(predictions.cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())
                    all_probabilities.extend(probabilities.cpu().numpy())
                    all_participant_ids.extend(batch['participant_id'])

                    # Store detailed results for analysis
                    for i in range(len(batch['participant_id'])):
                        detailed_results.append({
                            'participant_id': batch['participant_id'][i],
                            'true_label': labels[i].cpu().item(),
                            'predicted_label': predictions[i].cpu().item(),
                            'cn_probability': probabilities[i][0].cpu().item(),
                            'ad_probability': probabilities[i][1].cpu().item(),
                            'correct': labels[i].cpu().item() == predictions[i].cpu().item(),
                            'class_name': batch['class_name'][i],
                            'transcript_preview': batch['transcript_preview'][i]
                        })

                except Exception as e:
                    print(f"Error in evaluation batch: {e}")
                    continue

        # Calculate metrics
        accuracy = accuracy_score(all_labels, all_predictions)
        precision = precision_score(all_labels, all_predictions, average=None, zero_division=0)
        recall = recall_score(all_labels, all_predictions, average=None, zero_division=0)
        f1 = f1_score(all_labels, all_predictions, average=None, zero_division=0)

        # Print comprehensive results
        print(f"\nOVERALL PERFORMANCE:")
        print(f"Total Samples: {len(all_labels)}")
        print(f"Overall Accuracy: {accuracy:.4f} ({accuracy*100:.1f}%)")
        print(f"Overall Precision: {np.mean(precision):.4f}")
        print(f"Overall Recall: {np.mean(recall):.4f}")
        print(f"Overall F1-Score: {np.mean(f1):.4f}")

        print(f"\nPER-CLASS PERFORMANCE:")
        for i, class_name in enumerate(class_names):
            class_count = sum(1 for label in all_labels if label == i)
            print(f"{class_name}:")
            print(f"  Count: {class_count}")
            print(f"  Precision: {precision[i]:.4f}")
            print(f"  Recall: {recall[i]:.4f}")
            print(f"  F1-Score: {f1[i]:.4f}")

        # Confusion Matrix
        cm = confusion_matrix(all_labels, all_predictions)
        print(f"\nCONFUSION MATRIX:")
        print(f"        Predicted")
        print(f"        CN    AD")
        print(f"Actual CN {cm[0,0]:4d}  {cm[0,1]:4d}")
        print(f"       AD {cm[1,0]:4d}  {cm[1,1]:4d}")

        # Error Analysis
        print(f"\nERROR ANALYSIS:")
        errors = [result for result in detailed_results if not result['correct']]
        print(f"Total Errors: {len(errors)}")

        if errors:
            print(f"\nSample Errors:")
            for i, error in enumerate(errors[:5]):  # Show first 5 errors
                true_class = class_names[error['true_label']]
                pred_class = class_names[error['predicted_label']]
                confidence = max(error['cn_probability'], error['ad_probability'])
                print(f"  {i+1}. Participant {error['participant_id']}: {true_class} → {pred_class} (conf: {confidence:.3f})")
                print(f"     Transcript: {error['transcript_preview']}")

        # Save detailed results
        results_df = pd.DataFrame(detailed_results)
        results_df.to_csv('detailed_evaluation_results.csv', index=False)
        print(f"\nDetailed results saved to 'detailed_evaluation_results.csv'")

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'confusion_matrix': cm,
            'detailed_results': detailed_results,
            'predictions': all_predictions,
            'labels': all_labels,
            'probabilities': all_probabilities
        }

    def plot_training_history(self, save_path='training_history.png'):
        """Plot training history"""
        if not self.train_losses:
            print("No training history to plot")
            return

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

        # Plot losses
        epochs = range(1, len(self.train_losses) + 1)
        ax1.plot(epochs, self.train_losses, 'b-', label='Training Loss', linewidth=2)
        ax1.plot(epochs, self.val_losses, 'r-', label='Validation Loss', linewidth=2)
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training and Validation Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # Plot accuracies
        ax2.plot(epochs, self.train_accuracies, 'b-', label='Training Accuracy', linewidth=2)
        ax2.plot(epochs, self.val_accuracies, 'r-', label='Validation Accuracy', linewidth=2)
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy')
        ax2.set_title('Training and Validation Accuracy')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        ax2.set_ylim(0, 1)

        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
        print(f"Training history plot saved to {save_path}")

In [17]:
def analyze_model_features(model, test_loader, device, max_batch_size=8):
    """Analyze model features and attention patterns"""
    print("Analyzing model features and attention patterns...")

    model.eval()
    attention_weights_list = []

    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            if i >= 10:  # Analyze first 10 batches
                break

            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            audio_features = batch['audio_features'].to(device)
            linguistic_features = batch['linguistic_features'].to(device)

            try:
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    audio_features=audio_features,
                    linguistic_features=linguistic_features
                )

                # Get attention weights and pad if necessary
                attention_weights = outputs['attention_weights'].cpu().numpy()  # Shape: [num_heads, seq_len, seq_len]
                batch_size = input_ids.size(0)

                # Pad attention weights to max_batch_size
                if batch_size < max_batch_size:
                    pad_shape = (max_batch_size - batch_size, attention_weights.shape[1], attention_weights.shape[2])
                    padded_weights = np.zeros((max_batch_size, attention_weights.shape[1], attention_weights.shape[2]))
                    padded_weights[:batch_size] = attention_weights
                    attention_weights = padded_weights

                attention_weights_list.append(attention_weights)

            except Exception as e:
                print(f"Error in feature analysis: {e}")
                continue

    if attention_weights_list:
        # Average attention weights across batches
        # All arrays now have shape [max_batch_size, num_heads, seq_len, seq_len]
        avg_attention = np.mean(attention_weights_list, axis=0)  # Shape: [max_batch_size, num_heads, seq_len, seq_len]

        # Average over batch dimension to get [num_heads, seq_len, seq_len]
        avg_attention = np.mean(avg_attention, axis=0)  # Shape: [num_heads, seq_len, seq_len]

        # Plot attention patterns
        plt.figure(figsize=(10, 6))

        # Average attention weights for each modality
        modalities = ['Text', 'Audio', 'Linguistic']
        attention_by_modality = np.mean(avg_attention, axis=(0, 1))  # Average across heads and source tokens

        plt.bar(modalities, attention_by_modality)
        plt.title('Average Attention Weights by Modality')
        plt.ylabel('Attention Weight')
        plt.xlabel('Modality')
        plt.grid(True, alpha=0.3)

        for i, v in enumerate(attention_by_modality):
            plt.text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom')

        plt.tight_layout()
        plt.savefig('modality_attention_analysis.png', dpi=300, bbox_inches='tight')
        plt.show()

        print("✅ Feature analysis completed. Attention analysis plot saved.")
    else:
        print("⚠️  No attention weights collected for analysis.")

In [20]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [18]:
def main():
    """Main execution pipeline for Enhanced Alzheimer's Detection"""

    print("="*80)
    print("ENHANCED ALZHEIMER'S DETECTION WITH AUTOMATIC SPEECH RECOGNITION")
    print("="*80)

    # Configuration
    config = {
        'data_dir': './data',
        'output_dir': './extracted_data',
        'model_save_path': './best_ad_model.pt',
        'batch_size': 8,
        'num_epochs': 15,
        'learning_rate': 2e-5,
        'max_text_length': 512,
        'audio_max_length': 16*16000,  # 16 seconds
        'test_size': 0.2,
        'val_size': 0.15,
        'random_state': 42,
        'asr_model': 'openai/whisper-base'  # Can change to 'facebook/wav2vec2-base-960h'
    }

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    try:
        # Step 1: Initialize data processor with ASR
        print("\n1. Initializing Enhanced Data Processor with ASR...")
        processor = EnhancedADReSSDataProcessor(
            output_dir=config['output_dir'],
            asr_model=config['asr_model']
        )

        # Step 2: Look for dataset files
        print("\n2. Looking for ADReSS dataset files...")
        data_dir = Path(config['data_dir'])

        # Look for compressed dataset files
        dataset_files = []
        for pattern in ['*.tar.gz', '*.tgz', '*.zip']:
            dataset_files.extend(list(data_dir.glob(pattern)))

        if not dataset_files:
            print("⚠️  No dataset files found. Creating synthetic data for demonstration...")
            # Create synthetic dataset for demonstration
            dataset_info = create_synthetic_dataset_with_asr(processor)
        else:
            print(f"Found {len(dataset_files)} dataset files")

            # Process first dataset file
            dataset_file = dataset_files[0]
            print(f"Processing: {dataset_file}")

            # Extract dataset
            extract_path = processor.extract_adress_dataset(
                dataset_file,
                f"dataset_{dataset_file.stem}"
            )

            if extract_path is None:
                print("❌ Failed to extract dataset. Creating synthetic data...")
                dataset_info = create_synthetic_dataset_with_asr(processor)
            else:
                # Process with ASR
                dataset_info = processor.process_adress_dataset_with_asr(extract_path)

        if not dataset_info['paired_data']:
            print("❌ No valid data found. Exiting...")
            return

        print(f"✅ Successfully processed {len(dataset_info['paired_data'])} samples")

        # Step 3: Prepare datasets
        print("\n3. Preparing datasets...")

        # Split data
        train_data, temp_data = train_test_split(
            dataset_info['paired_data'],
            test_size=config['test_size'] + config['val_size'],
            random_state=config['random_state'],
            stratify=[item['label'] for item in dataset_info['paired_data']]
        )

        val_data, test_data = train_test_split(
            temp_data,
            test_size=config['test_size'] / (config['test_size'] + config['val_size']),
            random_state=config['random_state'],
            stratify=[item['label'] for item in temp_data]
        )

        print(f"Train samples: {len(train_data)}")
        print(f"Validation samples: {len(val_data)}")
        print(f"Test samples: {len(test_data)}")

        # Initialize components
        audio_processor = SimpleAudioProcessor()
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        # Create datasets
        train_dataset = EnhancedMultiModalDataset(
            train_data, audio_processor, tokenizer,
            max_text_length=config['max_text_length'],
            audio_max_length=config['audio_max_length']
        )

        val_dataset = EnhancedMultiModalDataset(
            val_data, audio_processor, tokenizer,
            max_text_length=config['max_text_length'],
            audio_max_length=config['audio_max_length']
        )

        test_dataset = EnhancedMultiModalDataset(
            test_data, audio_processor, tokenizer,
            max_text_length=config['max_text_length'],
            audio_max_length=config['audio_max_length']
        )

        # Create data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=0  # Set to 0 to avoid multiprocessing issues
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=0
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=0
        )

        # Step 4: Initialize model
        print("\n4. Initializing Enhanced Multimodal Model...")
        model = EnhancedMultiModalADClassifier()

        # Step 5: Initialize trainer
        print("\n5. Initializing Model Trainer...")
        trainer = ModelTrainer(
            model=model,
            device=device,
            learning_rate=config['learning_rate']
        )

        # Step 6: Train model
        print("\n6. Starting Model Training...")
        trainer.train(
            train_loader=train_loader,
            val_loader=val_loader,
            num_epochs=config['num_epochs'],
            save_path=config['model_save_path']
        )

        # Step 7: Evaluate model
        print("\n7. Evaluating Model...")
        evaluation_results = trainer.evaluate_model(test_loader)

        # Step 8: Plot training history
        print("\n8. Generating Training History Plot...")
        trainer.plot_training_history()

        # Step 9: Feature Analysis
        print("\n9. Performing Feature Analysis...")
        analyze_model_features(trainer.model, test_loader, device)

        print("\n" + "="*80)
        print("✅ ENHANCED ALZHEIMER'S DETECTION PIPELINE COMPLETED SUCCESSFULLY!")
        print("="*80)
        print(f"Final Test Accuracy: {evaluation_results['accuracy']:.4f}")
        print(f"Model saved to: {config['model_save_path']}")
        print("Check the generated plots and CSV files for detailed analysis.")

    except Exception as e:
        print(f"❌ Error in main pipeline: {e}")
        import traceback
        traceback.print_exc()

def create_synthetic_dataset_with_asr(processor):
    """Create synthetic dataset for demonstration purposes"""
    print("Creating synthetic dataset with ASR for demonstration...")

    # Create synthetic data structure
    synthetic_data = []

    # Generate synthetic samples
    for i in range(100):  # 100 synthetic samples
        participant_id = f"SYNTH_{i:03d}"

        # Alternate between AD and CN
        label = i % 2
        class_name = 'AD' if label == 1 else 'CN'

        # Create synthetic transcript based on class
        if label == 1:  # AD
            transcript = generate_ad_like_transcript()
        else:  # CN
            transcript = generate_cn_like_transcript()

        synthetic_data.append({
            'participant_id': participant_id,
            'audio_path': f'synthetic_audio_{participant_id}.wav',
            'transcript': transcript,
            'label': label,
            'class_name': class_name,
            'transcript_source': 'synthetic'
        })

    return {
        'audio_files': [],
        'transcript_files': [],
        'metadata_files': [],
        'labels': {},
        'paired_data': synthetic_data,
        'generated_transcripts': {}
    }

def generate_ad_like_transcript():
    """Generate AD-like transcript with typical characteristics"""
    ad_patterns = [
        "Um, let me see... the boy is... um... he's climbing on the... the thing there...",
        "There's a woman in the kitchen and she's... what is she doing... oh yes, washing dishes I think...",
        "The... the thing with water is overflowing and there's... there's problems happening...",
        "I see children playing and... um... something about cookies or... or food...",
        "The lady is trying to... to do something with the... with the sink and water is...",
        "There are people in the picture and they're... um... doing things but I can't... I can't remember..."
    ]
    return np.random.choice(ad_patterns)

def generate_cn_like_transcript():
    """Generate Control-like transcript with typical characteristics"""
    cn_patterns = [
        "In this picture, I can see a kitchen scene where a woman is washing dishes at the sink. The sink appears to be overflowing with water onto the floor.",
        "There's a boy who has climbed up on a stool to reach the cookie jar on the counter. His sister is asking him to give her a cookie.",
        "The scene shows a typical kitchen with a woman doing dishes while children are nearby. The boy is reaching for cookies while standing on a chair.",
        "I can observe a domestic scene with a mother washing dishes. There are two children in the kitchen, and one of them is trying to get cookies from a jar.",
        "The picture depicts a kitchen where a woman is at the sink with running water. There are children present, and one child is reaching up to get something from the counter.",
        "This shows a busy kitchen scene with a woman washing dishes while water overflows. Meanwhile, children are nearby, with one trying to access a cookie jar."
    ]
    return np.random.choice(cn_patterns)