<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/Detecting_dementia_from_speech_and_transcripts_using_transformers_May243.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import librosa
from transformers import BertTokenizer, BertModel, ViTModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from tqdm import tqdm
import matplotlib.pyplot as plt
import re
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [2]:
def create_synthetic_dataset(num_ad=87, num_cn=79):
    """Generate synthetic dataset mimicking ADReSS."""
    data = []

    # AD-like transcripts: hesitant, repetitive, vague
    ad_patterns = [
        "Um... I see a kitchen, and uh... someone is there... washing something, maybe dishes...",
        "The boy is... uh... climbing to get... um... cookies or something... I think...",
        "There's water... uh... spilling and... uh... people are doing things..."
    ]

    # CN-like transcripts: coherent, detailed
    cn_patterns = [
        "In the kitchen, a woman is washing dishes while a boy reaches for a cookie jar.",
        "The scene shows a sink overflowing and a child on a stool grabbing cookies.",
        "A mother is cleaning dishes, and two children are nearby, one reaching for snacks."
    ]

    # Add variability to avoid perfect separation
    for i in range(num_ad):
        transcript = np.random.choice(ad_patterns) + " " + np.random.choice(ad_patterns, size=1)[0][:20]
        data.append({
            'participant_id': f'AD_{i:03d}',
            'audio_path': f'synthetic_audio_AD_{i:03d}.wav',
            'transcript': transcript,
            'label': 1,
            'class_name': 'AD'
        })

    for i in range(num_cn):
        transcript = np.random.choice(cn_patterns) + " " + np.random.choice(cn_patterns, size=1)[0][:20]
        data.append({
            'participant_id': f'CN_{i:03d}',
            'audio_path': f'synthetic_audio_CN_{i:03d}.wav',
            'transcript': transcript,
            'label': 0,
            'class_name': 'CN'
        })

    print(f"Created synthetic dataset: {num_ad} AD, {num_cn} CN samples")
    return data

# Generate dataset
dataset = create_synthetic_dataset()

Created synthetic dataset: 87 AD, 79 CN samples


In [3]:
class AudioProcessor:
    def __init__(self, sample_rate=16000, n_mels=224, win_length=2048, hop_length=1024):
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.win_length = win_length
        self.hop_length = hop_length

    def load_audio(self, audio_path, max_length=16000*16):
        """Load audio or return synthetic signal if file missing."""
        try:
            audio, sr = librosa.load(audio_path, sr=self.sample_rate)
        except:
            # Simulate audio for synthetic dataset
            audio = np.random.randn(max_length) * 0.01
            sr = self.sample_rate

        if len(audio) > max_length:
            start = (len(audio) - max_length) // 2
            audio = audio[start:start + max_length]
        elif len(audio) < max_length:
            audio = np.pad(audio, (0, max_length - len(audio)), mode='constant')
        return audio

    def extract_mel_spectrogram(self, audio):
        """Extract 3-channel Log-Mel spectrogram."""
        try:
            mel_spec = librosa.feature.melspectrogram(
                y=audio, sr=self.sample_rate, n_mels=self.n_mels,
                n_fft=self.win_length, hop_length=self.hop_length
            )
            log_mel = librosa.power_to_db(mel_spec, ref=np.max)
            delta = librosa.feature.delta(log_mel)
            delta2 = librosa.feature.delta(log_mel, order=2)
            return np.stack([log_mel, delta, delta2], axis=0)  # Shape: (3, n_mels, time)
        except:
            return np.random.randn(3, self.n_mels, 100)

    def resize_spectrogram(self, spectrogram, target_size=(224, 224)):
        """Resize spectrogram to ViT input size."""
        from scipy.ndimage import zoom
        try:
            resized_channels = []
            for channel in spectrogram:
                zoom_factors = [target_size[i] / channel.shape[i] for i in range(2)]
                resized = zoom(channel, zoom_factors, order=1)
                resized_channels.append(resized)
            resized = np.stack(resized_channels, axis=0)
            resized = (resized - resized.min()) / (resized.max() - resized.min() + 1e-8)
            return resized
        except:
            return np.random.rand(3, target_size[0], target_size[1])