<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/Detecting_dementia_from_speech_and_transcripts_using_transformers_May243.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import librosa
from transformers import BertTokenizer, BertModel, ViTModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from tqdm import tqdm
import matplotlib.pyplot as plt
import re
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [2]:
def create_synthetic_dataset(num_ad=87, num_cn=79):
    """Generate synthetic dataset mimicking ADReSS."""
    data = []

    # AD-like transcripts: hesitant, repetitive, vague
    ad_patterns = [
        "Um... I see a kitchen, and uh... someone is there... washing something, maybe dishes...",
        "The boy is... uh... climbing to get... um... cookies or something... I think...",
        "There's water... uh... spilling and... uh... people are doing things..."
    ]

    # CN-like transcripts: coherent, detailed
    cn_patterns = [
        "In the kitchen, a woman is washing dishes while a boy reaches for a cookie jar.",
        "The scene shows a sink overflowing and a child on a stool grabbing cookies.",
        "A mother is cleaning dishes, and two children are nearby, one reaching for snacks."
    ]

    # Add variability to avoid perfect separation
    for i in range(num_ad):
        transcript = np.random.choice(ad_patterns) + " " + np.random.choice(ad_patterns, size=1)[0][:20]
        data.append({
            'participant_id': f'AD_{i:03d}',
            'audio_path': f'synthetic_audio_AD_{i:03d}.wav',
            'transcript': transcript,
            'label': 1,
            'class_name': 'AD'
        })

    for i in range(num_cn):
        transcript = np.random.choice(cn_patterns) + " " + np.random.choice(cn_patterns, size=1)[0][:20]
        data.append({
            'participant_id': f'CN_{i:03d}',
            'audio_path': f'synthetic_audio_CN_{i:03d}.wav',
            'transcript': transcript,
            'label': 0,
            'class_name': 'CN'
        })

    print(f"Created synthetic dataset: {num_ad} AD, {num_cn} CN samples")
    return data

# Generate dataset
dataset = create_synthetic_dataset()

Created synthetic dataset: 87 AD, 79 CN samples


In [3]:
class AudioProcessor:
    def __init__(self, sample_rate=16000, n_mels=224, win_length=2048, hop_length=1024):
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.win_length = win_length
        self.hop_length = hop_length

    def load_audio(self, audio_path, max_length=16000*16):
        """Load audio or return synthetic signal if file missing."""
        try:
            audio, sr = librosa.load(audio_path, sr=self.sample_rate)
        except:
            # Simulate audio for synthetic dataset
            audio = np.random.randn(max_length) * 0.01
            sr = self.sample_rate

        if len(audio) > max_length:
            start = (len(audio) - max_length) // 2
            audio = audio[start:start + max_length]
        elif len(audio) < max_length:
            audio = np.pad(audio, (0, max_length - len(audio)), mode='constant')
        return audio

    def extract_mel_spectrogram(self, audio):
        """Extract 3-channel Log-Mel spectrogram."""
        try:
            mel_spec = librosa.feature.melspectrogram(
                y=audio, sr=self.sample_rate, n_mels=self.n_mels,
                n_fft=self.win_length, hop_length=self.hop_length
            )
            log_mel = librosa.power_to_db(mel_spec, ref=np.max)
            delta = librosa.feature.delta(log_mel)
            delta2 = librosa.feature.delta(log_mel, order=2)
            return np.stack([log_mel, delta, delta2], axis=0)  # Shape: (3, n_mels, time)
        except:
            return np.random.randn(3, self.n_mels, 100)

    def resize_spectrogram(self, spectrogram, target_size=(224, 224)):
        """Resize spectrogram to ViT input size."""
        from scipy.ndimage import zoom
        try:
            resized_channels = []
            for channel in spectrogram:
                zoom_factors = [target_size[i] / channel.shape[i] for i in range(2)]
                resized = zoom(channel, zoom_factors, order=1)
                resized_channels.append(resized)
            resized = np.stack(resized_channels, axis=0)
            resized = (resized - resized.min()) / (resized.max() - resized.min() + 1e-8)
            return resized
        except:
            return np.random.rand(3, target_size[0], target_size[1])

In [4]:
class MultiModalDataset(Dataset):
    def __init__(self, data, audio_processor, tokenizer, max_text_length=512, audio_max_length=16000*16):
        self.data = data
        self.audio_processor = audio_processor
        self.tokenizer = tokenizer
        self.max_text_length = max_text_length
        self.audio_max_length = audio_max_length
        self._precompute_linguistic_features()

    def _precompute_linguistic_features(self):
        """Compute linguistic features for each transcript."""
        for sample in tqdm(self.data, desc="Computing linguistic features"):
            transcript = sample['transcript']
            words = transcript.split()
            sentences = re.split(r'[.!?]+', transcript)
            sample['linguistic_features'] = {
                'word_count': len(words),
                'unique_words': len(set(words)),
                'lexical_diversity': len(set(words)) / len(words) if words else 0,
                'filler_words': sum(1 for w in words if w.lower() in ['um', 'uh', 'er'])
            }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        # Process text
        text = sample['transcript'] or "No speech content detected"
        encoding = self.tokenizer(
            text, truncation=True, padding='max_length', max_length=self.max_text_length,
            return_tensors='pt'
        )

        # Process audio
        audio = self.audio_processor.load_audio(sample['audio_path'], self.audio_max_length)
        spectrogram = self.audio_processor.extract_mel_spectrogram(audio)
        audio_features = self.audio_processor.resize_spectrogram(spectrogram)

        # Linguistic features
        ling_features = sample['linguistic_features']
        ling_vector = np.array([
            ling_features['word_count'],
            ling_features['unique_words'],
            ling_features['lexical_diversity'],
            ling_features['filler_words']
        ], dtype=np.float32)

        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'audio_features': torch.FloatTensor(audio_features),
            'linguistic_features': torch.FloatTensor(ling_vector),
            'label': torch.LongTensor([sample['label']]).squeeze(),
            'participant_id': sample['participant_id'],
            'class_name': sample['class_name']
        }

In [5]:
class MultiModalADClassifier(nn.Module):
    def __init__(self, text_hidden_size=768, audio_hidden_size=768, ling_hidden_size=4,
                 fusion_hidden_size=512, num_classes=2, dropout=0.3):
        super().__init__()
        self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
        self.audio_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.ling_processor = nn.Sequential(
            nn.Linear(ling_hidden_size, 32), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(32, 32), nn.ReLU()
        )
        self.attention = nn.MultiheadAttention(embed_dim=fusion_hidden_size, num_heads=8, dropout=dropout)
        self.text_projection = nn.Linear(text_hidden_size, fusion_hidden_size)
        self.audio_projection = nn.Linear(audio_hidden_size, fusion_hidden_size)
        self.ling_projection = nn.Linear(32, fusion_hidden_size)
        self.fusion = nn.Sequential(
            nn.Linear(fusion_hidden_size * 3, fusion_hidden_size), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(fusion_hidden_size, fusion_hidden_size // 2), nn.ReLU(), nn.Dropout(dropout)
        )
        self.classifier = nn.Linear(fusion_hidden_size // 2, num_classes)

        # Initialize weights
        for module in [self.ling_processor, self.fusion, self.classifier]:
            for layer in module:
                if isinstance(layer, nn.Linear):
                    nn.init.xavier_uniform_(layer.weight)
                    nn.init.zeros_(layer.bias)

    def forward(self, input_ids, attention_mask, audio_features, linguistic_features):
        # Encode text
        text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_features = text_output.pooler_output  # [batch_size, 768]

        # Encode audio
        audio_output = self.audio_encoder(pixel_values=audio_features)
        audio_features = audio_output.pooler_output  # [batch_size, 768]

        # Process linguistic features
        ling_features = self.ling_processor(linguistic_features)  # [batch_size, 32]

        # Project to common dimension
        text_proj = self.text_projection(text_features)  # [batch_size, 512]
        audio_proj = self.audio_projection(audio_features)  # [batch_size, 512]
        ling_proj = self.ling_projection(ling_features)  # [batch_size, 512]

        # Crossmodal attention
        modality_features = torch.stack([text_proj, audio_proj, ling_proj], dim=0)  # [3, batch_size, 512]
        attended_features, attention_weights = self.attention(modality_features, modality_features, modality_features)
        attended_features = attended_features.transpose(0, 1).reshape(-1, 3 * 512)  # [batch_size, 1536]

        # Fusion and classification
        fused = self.fusion(attended_features)
        logits = self.classifier(fused)

        return {'logits': logits, 'attention_weights': attention_weights}

In [6]:
class ModelTrainer:
    def __init__(self, model, device, learning_rate=2e-5):
        self.model = model.to(device)
        self.device = device
        self.optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
        self.criterion = nn.CrossEntropyLoss()
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.5, patience=3)
        self.train_losses = []
        self.val_losses = []
        self.val_accuracies = []
        self.best_val_loss = float('inf')
        self.best_model_state = None

    def train_epoch(self, train_loader):
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        for batch in tqdm(train_loader, desc="Training"):
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            audio_features = batch['audio_features'].to(self.device)
            ling_features = batch['linguistic_features'].to(self.device)
            labels = batch['label'].to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(input_ids, attention_mask, audio_features, ling_features)
            loss = self.criterion(outputs['logits'], labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()

            total_loss += loss.item()
            preds = torch.argmax(outputs['logits'], dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        return total_loss / len(train_loader), correct / total

    def validate_epoch(self, val_loader):
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        preds, labels = [], []
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating"):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                audio_features = batch['audio_features'].to(self.device)
                ling_features = batch['linguistic_features'].to(self.device)
                batch_labels = batch['label'].to(self.device)

                outputs = self.model(input_ids, attention_mask, audio_features, ling_features)
                loss = self.criterion(outputs['logits'], batch_labels)

                total_loss += loss.item()
                batch_preds = torch.argmax(outputs['logits'], dim=1)
                correct += (batch_preds == batch_labels).sum().item()
                total += batch_labels.size(0)
                preds.extend(batch_preds.cpu().numpy())
                labels.extend(batch_labels.cpu().numpy())

        return total_loss / len(val_loader), correct / total, preds, labels

    def train(self, train_loader, val_loader, num_epochs=10, patience=6):
        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            train_loss, train_acc = self.train_epoch(train_loader)
            val_loss, val_acc, _, _ = self.validate_epoch(val_loader)

            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.val_accuracies.append(val_acc)

            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.best_model_state = self.model.state_dict().copy()
                print("✓ Saved best model")
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            self.scheduler.step(val_loss)
            if epochs_no_improve >= patience:
                print("Early stopping triggered")
                break

        if self.best_model_state:
            self.model.load_state_dict(self.best_model_state)
            print("✓ Loaded best model")

    def evaluate(self, test_loader):
        self.model.eval()
        preds, labels = [], []
        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Evaluating"):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                audio_features = batch['audio_features'].to(self.device)
                ling_features = batch['linguistic_features'].to(self.device)
                batch_labels = batch['label'].to(self.device)

                outputs = self.model(input_ids, attention_mask, audio_features, ling_features)
                batch_preds = torch.argmax(outputs['logits'], dim=1)
                preds.extend(batch_preds.cpu().numpy())
                labels.extend(batch_labels.cpu().numpy())

        accuracy = accuracy_score(labels, preds)
        precision = precision_score(labels, preds, average='weighted', zero_division=0)
        recall = recall_score(labels, preds, average='weighted', zero_division=0)
        f1 = f1_score(labels, preds, average='weighted', zero_division=0)
        cm = confusion_matrix(labels, preds)

        print("\nEvaluation Results:")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1-Score: {f1:.4f}")
        print("Confusion Matrix:")
        print(f"        CN    AD")
        print(f"CN     {cm[0,0]:4d}  {cm[0,1]:4d}")
        print(f"AD     {cm[1,0]:4d}  {cm[1,1]:4d}")

        return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'cm': cm}