<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/Detecting_dementia_from_speech_and_transcripts_using_transformers_May242.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# Upload your ADReSSo21-diagnosis-train.tgz file to Colab
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [6]:
# Dementia Detection from Speech and Transcripts using Transformers
# Complete Implementation with Multi-Visualizations for Google Colab

# ============================================================================
# STEP 1: INSTALL REQUIRED PACKAGES AND SETUP
# ============================================================================

!pip install transformers torch torchvision torchaudio librosa pandas scikit-learn matplotlib seaborn numpy plotly umap-learn

import os
import tarfile
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import librosa
import librosa.display
from transformers import BertTokenizer, BertModel, ViTModel, ViTFeatureExtractor
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
warnings.filterwarnings('ignore')

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Dataset paths
DATASET_PATHS = {
    'diagnosis_train': '/content/drive/MyDrive/Voice/ADReSSo21-diagnosis-train.tgz',
    'progression_train': '/content/drive/MyDrive/Voice/ADReSSo21-progression-train.tgz',
    'progression_test': '/content/drive/MyDrive/Voice/ADReSSo21-progression-test.tgz'
}

print("Available datasets:")
for name, path in DATASET_PATHS.items():
    exists = "✓" if os.path.exists(path) else "✗"
    print(f"{exists} {name}: {path}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
Available datasets:
✓ diagnosis_train: /content/drive/MyDrive/Voice/ADReSSo21-diagnosis-train.tgz
✓ progression_train: /content/drive/MyDrive/Voice/ADReSSo21-progression-train.tgz
✓ progression_test: /content/drive/MyDrive/Voice/ADReSSo21-progression-test.tgz


In [None]:
# ============================================================================
# STEP 2: DATA PROCESSING AND DATASET LOADING
# ============================================================================

import json
import glob
import re
from pathlib import Path
import soundfile as sf
from collections import defaultdict

class AudioProcessor:
    """Audio processing utilities for speech analysis"""

    def __init__(self, sample_rate=16000, n_mels=128, hop_length=512, n_fft=2048):
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.hop_length = hop_length
        self.n_fft = n_fft

    def load_audio(self, audio_path, max_length=None):
        """Load and preprocess audio file"""
        try:
            # Load audio file
            audio, sr = librosa.load(audio_path, sr=self.sample_rate)

            # Trim silence
            audio, _ = librosa.effects.trim(audio, top_db=20)

            # Normalize audio
            audio = librosa.util.normalize(audio)

            # Truncate or pad to max_length if specified
            if max_length is not None:
                if len(audio) > max_length:
                    audio = audio[:max_length]
                elif len(audio) < max_length:
                    audio = np.pad(audio, (0, max_length - len(audio)), mode='constant')

            return audio
        except Exception as e:
            print(f"Error loading audio {audio_path}: {e}")
            return None

    def extract_mel_spectrogram(self, audio):
        """Extract mel spectrogram features"""
        # Compute mel spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=audio,
            sr=self.sample_rate,
            n_mels=self.n_mels,
            hop_length=self.hop_length,
            n_fft=self.n_fft
        )

        # Convert to log scale
        log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

        return log_mel_spec

    def extract_audio_features(self, audio):
        """Extract comprehensive audio features"""
        features = {}

        # Mel spectrogram
        features['mel_spectrogram'] = self.extract_mel_spectrogram(audio)

        # MFCC features
        features['mfcc'] = librosa.feature.mfcc(
            y=audio, sr=self.sample_rate, n_mfcc=13
        )

        # Spectral features
        features['spectral_centroid'] = librosa.feature.spectral_centroid(
            y=audio, sr=self.sample_rate
        )
        features['spectral_rolloff'] = librosa.feature.spectral_rolloff(
            y=audio, sr=self.sample_rate
        )
        features['spectral_bandwidth'] = librosa.feature.spectral_bandwidth(
            y=audio, sr=self.sample_rate
        )

        # Rhythm features
        features['tempo'], features['beats'] = librosa.beat.beat_track(
            y=audio, sr=self.sample_rate
        )

        # Zero crossing rate
        features['zcr'] = librosa.feature.zero_crossing_rate(audio)

        # Root Mean Square Energy
        features['rms'] = librosa.feature.rms(y=audio)

        return features

class DatasetExtractor:
    """Extract and organize dataset files"""

    def __init__(self, output_dir='/content/extracted_data'):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)

    def extract_tar_file(self, tar_path, extract_name):
        """Extract tar.gz file to organized directory"""
        extract_path = self.output_dir / extract_name
        extract_path.mkdir(exist_ok=True)

        print(f"Extracting {tar_path} to {extract_path}")

        try:
            with tarfile.open(tar_path, 'r:gz') as tar:
                tar.extractall(path=extract_path)
            print(f"Successfully extracted {extract_name}")
            return extract_path
        except Exception as e:
            print(f"Error extracting {tar_path}: {e}")
            return None

    def find_files(self, base_path, file_extension):
        """Recursively find files with given extension"""
        return list(Path(base_path).rglob(f"*.{file_extension}"))

    def organize_dataset_files(self, extract_path, dataset_type):
        """Organize extracted files by type and create file mapping"""
        file_mapping = {
            'audio_files': [],
            'transcript_files': [],
            'metadata_files': []
        }

        # Find all files
        audio_files = (
            self.find_files(extract_path, 'wav') +
            self.find_files(extract_path, 'mp3') +
            self.find_files(extract_path, 'flac')
        )

        transcript_files = (
            self.find_files(extract_path, 'txt') +
            self.find_files(extract_path, 'cha')  # CHAT format
        )

        metadata_files = (
            self.find_files(extract_path, 'csv') +
            self.find_files(extract_path, 'tsv') +
            self.find_files(extract_path, 'json')
        )

        file_mapping['audio_files'] = audio_files
        file_mapping['transcript_files'] = transcript_files
        file_mapping['metadata_files'] = metadata_files

        print(f"\n{dataset_type} Dataset Organization:")
        print(f"Audio files: {len(audio_files)}")
        print(f"Transcript files: {len(transcript_files)}")
        print(f"Metadata files: {len(metadata_files)}")

        return file_mapping

class DataProcessor:
    """Process and prepare data for training"""

    def __init__(self, audio_processor=None):
        self.audio_processor = audio_processor or AudioProcessor()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def load_metadata(self, metadata_files):
        """Load and parse metadata files"""
        metadata = {}

        for file_path in metadata_files:
            try:
                if file_path.suffix.lower() == '.csv':
                    df = pd.read_csv(file_path)
                    metadata[file_path.stem] = df
                elif file_path.suffix.lower() == '.tsv':
                    df = pd.read_csv(file_path, sep='\t')
                    metadata[file_path.stem] = df
                elif file_path.suffix.lower() == '.json':
                    with open(file_path, 'r') as f:
                        data = json.load(f)
                    metadata[file_path.stem] = data

                print(f"Loaded metadata from {file_path.name}")
            except Exception as e:
                print(f"Error loading metadata from {file_path}: {e}")

        return metadata

    def load_transcripts(self, transcript_files):
        """Load and process transcript files"""
        transcripts = {}

        for file_path in transcript_files:
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    content = f.read()

                # Clean transcript content
                content = self.clean_transcript(content)

                # Extract participant ID from filename
                participant_id = self.extract_participant_id(file_path.name)
                transcripts[participant_id] = content

            except Exception as e:
                print(f"Error loading transcript {file_path}: {e}")

        return transcripts

    def clean_transcript(self, text):
        """Clean and preprocess transcript text"""
        # Remove CHAT format annotations
        text = re.sub(r'\*[A-Z]{3}:', '', text)  # Remove speaker markers
        text = re.sub(r'&[a-z]+', '', text)      # Remove hesitation markers
        text = re.sub(r'\[.*?\]', '', text)      # Remove action descriptions
        text = re.sub(r'<.*?>', '', text)        # Remove timing information
        text = re.sub(r'\+\+', '', text)         # Remove incomplete word markers
        text = re.sub(r'xxx', '', text)          # Remove unintelligible speech

        # Clean general text
        text = re.sub(r'\s+', ' ', text)         # Normalize whitespace
        text = text.strip()

        return text

    def extract_participant_id(self, filename):
        """Extract participant ID from filename"""
        # Common patterns for participant IDs
        patterns = [
            r'(\d{3})',           # 3-digit numbers
            r'([A-Z]\d{2})',      # Letter followed by 2 digits
            r'(S\d{3})',          # S followed by 3 digits
            r'([A-Z]{2}\d{2})'    # 2 letters followed by 2 digits
        ]

        for pattern in patterns:
            match = re.search(pattern, filename)
            if match:
                return match.group(1)

        # If no pattern matches, use filename without extension
        return Path(filename).stem

    def create_paired_dataset(self, file_mapping, metadata, task_type='diagnosis'):
        """Create paired audio-transcript dataset with labels"""
        paired_data = []

        # Create mapping of participant IDs to files
        audio_map = {}
        transcript_map = {}

        # Map audio files
        for audio_file in file_mapping['audio_files']:
            participant_id = self.extract_participant_id(audio_file.name)
            audio_map[participant_id] = audio_file

        # Map transcript files
        for transcript_file in file_mapping['transcript_files']:
            participant_id = self.extract_participant_id(transcript_file.name)
            transcript_map[participant_id] = transcript_file

        # Load transcripts
        transcripts = self.load_transcripts(file_mapping['transcript_files'])

        # Get labels from metadata
        labels_df = None
        for key, df in metadata.items():
            if isinstance(df, pd.DataFrame) and any(col in df.columns for col in ['diagnosis', 'label', 'class']):
                labels_df = df
                break

        if labels_df is None:
            print("Warning: No labels found in metadata. Creating dummy labels.")
            # Create dummy labels for demonstration
            all_participants = set(audio_map.keys()) | set(transcript_map.keys())
            labels_df = pd.DataFrame({
                'participant_id': list(all_participants),
                'label': np.random.choice([0, 1], len(all_participants))
            })

        # Create paired dataset
        for _, row in labels_df.iterrows():
            if 'participant_id' in row:
                pid = str(row['participant_id'])
            else:
                pid = str(row.iloc[0])  # Use first column as ID

            # Get label
            if 'diagnosis' in row:
                label = 1 if row['diagnosis'] in ['AD', 'Dementia', 'Decline'] else 0
            elif 'label' in row:
                label = int(row['label'])
            elif 'class' in row:
                label = int(row['class'])
            else:
                label = np.random.choice([0, 1])  # Random label if none found

            # Find matching audio and transcript
            audio_file = audio_map.get(pid)
            transcript_text = transcripts.get(pid, "")

            if audio_file and transcript_text:
                paired_data.append({
                    'participant_id': pid,
                    'audio_path': str(audio_file),
                    'transcript': transcript_text,
                    'label': label,
                    'class_name': 'AD' if label == 1 else 'CN' if task_type == 'diagnosis' else ('Decline' if label == 1 else 'Stable')
                })

        print(f"Created paired dataset with {len(paired_data)} samples")
        return paired_data

class MultiModalDataset(Dataset):
    """PyTorch Dataset for multimodal audio-text data"""

    def __init__(self, data_samples, audio_processor, tokenizer, max_length=512,
                 audio_max_length=16*16000, transform_audio_to_image=True):
        self.data_samples = data_samples
        self.audio_processor = audio_processor
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.audio_max_length = audio_max_length
        self.transform_audio_to_image = transform_audio_to_image

    def __len__(self):
        return len(self.data_samples)

    def __getitem__(self, idx):
        sample = self.data_samples[idx]

        # Process text
        text = sample['transcript']
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Process audio
        audio = self.audio_processor.load_audio(
            sample['audio_path'],
            max_length=self.audio_max_length
        )

        if audio is None:
            # Create dummy audio if loading fails
            audio = np.random.randn(self.audio_max_length)

        # Extract audio features
        if self.transform_audio_to_image:
            # Convert to spectrogram and format as image for ViT
            mel_spec = self.audio_processor.extract_mel_spectrogram(audio)
            # Resize to 224x224 for ViT and convert to 3-channel
            mel_spec_resized = self._resize_spectrogram(mel_spec, (224, 224))
            audio_features = np.stack([mel_spec_resized] * 3, axis=0)  # Create RGB channels
        else:
            # Use raw spectrogram
            audio_features = self.audio_processor.extract_mel_spectrogram(audio)

        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'token_type_ids': encoding['token_type_ids'].squeeze(),
            'audio_features': torch.FloatTensor(audio_features),
            'label': torch.LongTensor([sample['label']]).squeeze(),
            'participant_id': sample['participant_id']
        }

    def _resize_spectrogram(self, spectrogram, target_size):
        """Resize spectrogram to target size"""
        from scipy import ndimage
        current_size = spectrogram.shape
        zoom_factors = [target_size[i] / current_size[i] for i in range(2)]
        resized = ndimage.zoom(spectrogram, zoom_factors, order=1)

        # Normalize to 0-1 range
        resized = (resized - resized.min()) / (resized.max() - resized.min() + 1e-8)
        return resized

def setup_datasets(dataset_paths, task_type='diagnosis'):
    """Main function to setup and process all datasets"""
    print("=" * 60)
    print("SETTING UP DATASETS")
    print("=" * 60)

    # Initialize processors
    extractor = DatasetExtractor()
    audio_processor = AudioProcessor()
    data_processor = DataProcessor(audio_processor)

    all_datasets = {}

    # Process each dataset
    for dataset_name, tar_path in dataset_paths.items():
        if not os.path.exists(tar_path):
            print(f"Skipping {dataset_name}: File not found at {tar_path}")
            continue

        print(f"\nProcessing {dataset_name}...")

        # Extract dataset
        extract_path = extractor.extract_tar_file(tar_path, dataset_name)
        if extract_path is None:
            continue

        # Organize files
        file_mapping = extractor.organize_dataset_files(extract_path, dataset_name)

        # Load metadata
        metadata = data_processor.load_metadata(file_mapping['metadata_files'])

        # Create paired dataset
        paired_data = data_processor.create_paired_dataset(
            file_mapping, metadata, task_type
        )

        if paired_data:
            all_datasets[dataset_name] = {
                'data': paired_data,
                'file_mapping': file_mapping,
                'metadata': metadata
            }

            # Visualize data distribution
            vis_utils = VisualizationUtils()
            vis_utils.plot_data_distribution(paired_data, task_type)

    return all_datasets, audio_processor, data_processor

def create_train_val_split(dataset_dict, test_size=0.2, random_state=42):
    """Create train/validation split from dataset"""
    print("\nCreating train/validation split...")

    # Combine all training data
    all_data = []
    for dataset_name, dataset_info in dataset_dict.items():
        if 'train' in dataset_name.lower():
            all_data.extend(dataset_info['data'])

    if not all_data:
        print("No training data found!")
        return None, None

    # Split data
    train_data, val_data = train_test_split(
        all_data,
        test_size=test_size,
        random_state=random_state,
        stratify=[sample['label'] for sample in all_data]
    )

    print(f"Train samples: {len(train_data)}")
    print(f"Validation samples: {len(val_data)}")

    # Print class distribution
    train_labels = [sample['label'] for sample in train_data]
    val_labels = [sample['label'] for sample in val_data]

    print(f"\nTrain class distribution:")
    unique, counts = np.unique(train_labels, return_counts=True)
    for cls, count in zip(unique, counts):
        print(f"  Class {cls}: {count} samples ({count/len(train_labels)*100:.1f}%)")

    print(f"\nValidation class distribution:")
    unique, counts = np.unique(val_labels, return_counts=True)
    for cls, count in zip(unique, counts):
        print(f"  Class {cls}: {count} samples ({count/len(val_labels)*100:.1f}%)")

    return train_data, val_data

# Example usage and testing
def test_data_processing():
    """Test data processing pipeline"""
    print("Testing data processing pipeline...")

    # Setup datasets
    datasets, audio_processor, data_processor = setup_datasets(DATASET_PATHS, 'diagnosis')

    if not datasets:
        print("No datasets found. Creating dummy data for testing...")

        # Create dummy data for testing
        dummy_data = []
        for i in range(20):
            dummy_data.append({
                'participant_id': f'P{i:03d}',
                'audio_path': f'/dummy/path/audio_{i}.wav',
                'transcript': f'This is a test transcript for participant {i}. ' * np.random.randint(5, 20),
                'label': np.random.choice([0, 1]),
                'class_name': np.random.choice(['CN', 'AD'])
            })

        datasets = {'dummy_train': {'data': dummy_data}}

    # Create train/val split
    train_data, val_data = create_train_val_split(datasets)

    if train_data and val_data:
        # Create datasets
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        train_dataset = MultiModalDataset(
            train_data[:5], audio_processor, tokenizer  # Use first 5 samples for testing
        )

        # Test dataset loading
        print(f"\nTesting dataset loading...")
        sample = train_dataset[0]

        print(f"Sample keys: {list(sample.keys())}")
        print(f"Input shape: {sample['input_ids'].shape}")
        print(f"Audio features shape: {sample['audio_features'].shape}")
        print(f"Label: {sample['label']}")
        print(f"Participant ID: {sample['participant_id']}")

        print("Data processing test completed successfully!")
        return train_dataset, val_data

    return None, None

# Run the test
print("Running data processing setup...")
train_dataset, val_data = test_data_processing()

Running data processing setup...
Testing data processing pipeline...
SETTING UP DATASETS


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]


Processing diagnosis_train...
Extracting /content/drive/MyDrive/Voice/ADReSSo21-diagnosis-train.tgz to /content/extracted_data/diagnosis_train
