<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/Detecting_dementia_from_speech_and_transcripts_using_transformers_May243.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import os
import tarfile
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import librosa
import librosa.display
from transformers import (BertTokenizer, BertModel, ViTModel, ViTFeatureExtractor,
                         Wav2Vec2ForCTC, Wav2Vec2Processor, WhisperProcessor,
                         WhisperForConditionalGeneration)
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
import json
import glob
import re
from pathlib import Path
import soundfile as sf
from collections import defaultdict
from scipy import ndimage
import pickle
warnings.filterwarnings('ignore')

class SpeechTranscriber:
    """Automatic Speech Recognition for generating transcripts from audio"""

    def __init__(self, model_name="openai/whisper-base", cache_dir="./asr_cache"):
        """
        Initialize ASR model
        Options:
        - openai/whisper-base: Good balance of speed/accuracy
        - facebook/wav2vec2-base-960h: Faster but less accurate
        - openai/whisper-small: More accurate but slower
        """
        self.model_name = model_name
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)

        print(f"Loading ASR model: {model_name}")

        if "whisper" in model_name.lower():
            self.processor = WhisperProcessor.from_pretrained(model_name)
            self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
            self.asr_type = "whisper"
        else:
            self.processor = Wav2Vec2Processor.from_pretrained(model_name)
            self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
            self.asr_type = "wav2vec2"

        # Move to GPU if available
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.model.to(self.device)
        self.model.eval()

        print(f"ASR model loaded on {self.device}")

    def transcribe_audio_file(self, audio_path, use_cache=True):
        """Transcribe a single audio file"""
        audio_path = Path(audio_path)

        # Check cache first
        cache_file = self.cache_dir / f"{audio_path.stem}_transcript.pkl"
        if use_cache and cache_file.exists():
            try:
                with open(cache_file, 'rb') as f:
                    cached_result = pickle.load(f)
                return cached_result['transcript']
            except:
                pass  # Cache corrupted, proceed with transcription

        try:
            # Load audio
            audio, sr = librosa.load(str(audio_path), sr=16000)  # Whisper expects 16kHz

            # Handle empty or very short audio
            if len(audio) < 1600:  # Less than 0.1 seconds
                transcript = ""
            else:
                transcript = self._transcribe_audio_array(audio)

            # Cache result
            if use_cache:
                try:
                    with open(cache_file, 'wb') as f:
                        pickle.dump({
                            'audio_path': str(audio_path),
                            'transcript': transcript,
                            'model': self.model_name
                        }, f)
                except:
                    pass  # Caching failed, but transcription succeeded

            return transcript

        except Exception as e:
            print(f"Error transcribing {audio_path}: {e}")
            return f"[Transcription failed for {audio_path.name}]"

    def _transcribe_audio_array(self, audio_array):
        """Transcribe audio array using the loaded model"""
        try:
            if self.asr_type == "whisper":
                return self._whisper_transcribe(audio_array)
            else:
                return self._wav2vec2_transcribe(audio_array)
        except Exception as e:
            print(f"Transcription error: {e}")
            return "[Transcription error]"

    def _whisper_transcribe(self, audio_array):
        """Transcribe using Whisper model"""
        # Process audio
        inputs = self.processor(
            audio_array,
            sampling_rate=16000,
            return_tensors="pt"
        ).input_features.to(self.device)

        # Generate transcription
        with torch.no_grad():
            predicted_ids = self.model.generate(inputs, max_length=448)

        # Decode
        transcript = self.processor.batch_decode(
            predicted_ids,
            skip_special_tokens=True
        )[0]

        return transcript.strip()

    def _wav2vec2_transcribe(self, audio_array):
        """Transcribe using Wav2Vec2 model"""
        # Process audio
        inputs = self.processor(
            audio_array,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        ).input_values.to(self.device)

        # Get logits
        with torch.no_grad():
            logits = self.model(inputs).logits

        # Decode
        predicted_ids = torch.argmax(logits, dim=-1)
        transcript = self.processor.batch_decode(predicted_ids)[0]

        return transcript.strip().lower()

    def batch_transcribe(self, audio_paths, batch_size=8):
        """Transcribe multiple audio files with progress bar"""
        transcripts = {}

        print(f"Transcribing {len(audio_paths)} audio files...")

        for audio_path in tqdm(audio_paths, desc="Transcribing"):
            transcript = self.transcribe_audio_file(audio_path)
            participant_id = self._extract_participant_id(Path(audio_path).name)
            transcripts[participant_id] = transcript

        return transcripts

    def _extract_participant_id(self, filename):
        """Extract participant ID from filename"""
        patterns = [
            r'adrso?(\d{3})',         # adrs0123 or adrso123
            r'adrsp?(\d{3})',         # adrsp123
            r'adrspt?(\d{1,3})',      # adrspt1, adrspt12
            r'(\d{3})',               # 3-digit numbers
            r'([A-Z]\d{2,3})',        # Letter followed by 2-3 digits
            r'(S\d{3})',              # S followed by 3 digits
        ]

        for pattern in patterns:
            match = re.search(pattern, filename)
            if match:
                return match.group(1) if pattern.startswith(r'(\d') else match.group(0)

        return Path(filename).stem

In [12]:
class EnhancedADReSSDataProcessor:
    """Enhanced ADReSS data processor with automatic speech recognition"""

    def __init__(self, output_dir='./extracted_data', asr_model="openai/whisper-base"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        # Initialize ASR
        self.transcriber = SpeechTranscriber(model_name=asr_model)

    def extract_adress_dataset(self, tar_path, dataset_name):
        """Extract ADReSS dataset and organize files properly"""
        extract_path = self.output_dir / dataset_name
        extract_path.mkdir(exist_ok=True)

        print(f"Extracting {tar_path} to {extract_path}")

        try:
            with tarfile.open(tar_path, 'r:gz') as tar:
                tar.extractall(path=extract_path)
            print(f"Successfully extracted {dataset_name}")

            # Find the actual dataset directory structure
            self._explore_directory_structure(extract_path)
            return extract_path
        except Exception as e:
            print(f"Error extracting {tar_path}: {e}")
            return None

    def _explore_directory_structure(self, base_path):
        """Explore and print directory structure"""
        print(f"\nDirectory structure for {base_path.name}:")
        for root, dirs, files in os.walk(base_path):
            level = root.replace(str(base_path), '').count(os.sep)
            indent = ' ' * 2 * level
            print(f"{indent}{os.path.basename(root)}/")
            subindent = ' ' * 2 * (level + 1)
            for file in files[:5]:
                print(f"{subindent}{file}")
            if len(files) > 5:
                print(f"{subindent}... and {len(files) - 5} more files")

    def process_adress_dataset_with_asr(self, extract_path):
        """Process ADReSS dataset with automatic speech recognition"""
        dataset_info = {
            'audio_files': [],
            'transcript_files': [],
            'metadata_files': [],
            'labels': {},
            'paired_data': [],
            'generated_transcripts': {}
        }

        # Look for ADReSS structure
        adress_dirs = list(extract_path.rglob("*ADReSS*"))
        if adress_dirs:
            main_dir = adress_dirs[0]
        else:
            main_dir = extract_path

        print(f"Processing from directory: {main_dir}")

        # Find audio files
        audio_patterns = ['**/*.wav', '**/*.mp3', '**/*.flac']
        for pattern in audio_patterns:
            dataset_info['audio_files'].extend(list(main_dir.glob(pattern)))

        print(f"Found {len(dataset_info['audio_files'])} audio files")

        # Generate transcripts using ASR
        print("Generating transcripts using ASR...")
        audio_paths = [str(path) for path in dataset_info['audio_files']]
        generated_transcripts = self.transcriber.batch_transcribe(audio_paths)
        dataset_info['generated_transcripts'] = generated_transcripts

        print(f"Generated {len(generated_transcripts)} transcripts")

        # Process labels from directory structure
        labels = self._extract_labels_from_structure(dataset_info['audio_files'])
        dataset_info['labels'] = labels

        # Create paired dataset with generated transcripts
        paired_data = self._create_paired_data_with_asr(dataset_info)
        dataset_info['paired_data'] = paired_data

        return dataset_info

    def _extract_labels_from_structure(self, audio_files):
        """Extract labels from file paths or directory structure"""
        labels = {}

        for audio_file in audio_files:
            # Extract participant ID
            participant_id = self._extract_participant_id(audio_file.name)

            # Determine label from path
            path_str = str(audio_file).lower()
            if '/ad/' in path_str or 'dementia' in path_str or 'alzheimer' in path_str:
                label = 1  # AD/Dementia
                class_name = 'AD'
            elif '/cn/' in path_str or 'control' in path_str or 'normal' in path_str:
                label = 0  # Control/Normal
                class_name = 'CN'
            elif 'decline' in path_str:
                label = 1  # Decline/progression
                class_name = 'AD'
            elif 'no_decline' in path_str or 'no-decline' in path_str:
                label = 0  # No decline
                class_name = 'CN'
            else:
                # Default classification based on filename patterns
                if any(marker in audio_file.name.lower() for marker in ['ad', 'dem', 'alz']):
                    label = 1
                    class_name = 'AD'
                else:
                    label = 0  # Default to control
                    class_name = 'CN'

            labels[participant_id] = {
                'label': label,
                'class_name': class_name,
                'audio_path': audio_file
            }

        return labels

    def _extract_participant_id(self, filename):
        """Extract participant ID from filename"""
        patterns = [
            r'adrso?(\d{3})',
            r'adrsp?(\d{3})',
            r'adrspt?(\d{1,3})',
            r'(\d{3})',
            r'([A-Z]\d{2,3})',
            r'(S\d{3})',
        ]

        for pattern in patterns:
            match = re.search(pattern, filename)
            if match:
                return match.group(1) if pattern.startswith(r'(\d') else match.group(0)

        return Path(filename).stem

    def _create_paired_data_with_asr(self, dataset_info):
        """Create paired audio-transcript dataset using ASR-generated transcripts"""
        paired_data = []

        # Create paired dataset using generated transcripts
        for participant_id, label_info in dataset_info['labels'].items():
            # Get generated transcript
            transcript = dataset_info['generated_transcripts'].get(participant_id, "")

            # Clean and validate transcript
            transcript = self._clean_and_validate_transcript(transcript)

            # If transcript is still empty or invalid, create a meaningful placeholder
            if not transcript or len(transcript.strip()) < 10:
                transcript = f"Audio sample from participant {participant_id}. Speech content unclear or silent."

            paired_data.append({
                'participant_id': participant_id,
                'audio_path': str(label_info['audio_path']),
                'transcript': transcript,
                'label': label_info['label'],
                'class_name': label_info['class_name'],
                'transcript_source': 'ASR_generated'
            })

        print(f"Created {len(paired_data)} paired samples with ASR transcripts")

        # Print class distribution
        labels = [item['label'] for item in paired_data]
        unique, counts = np.unique(labels, return_counts=True)
        for cls, count in zip(unique, counts):
            class_name = 'CN' if cls == 0 else 'AD'
            print(f"  {class_name}: {count} samples ({count/len(labels)*100:.1f}%)")

        # Print sample transcripts for verification
        print("\nSample generated transcripts:")
        print("-" * 50)
        for i, sample in enumerate(paired_data[:3]):
            print(f"Participant {sample['participant_id']} ({sample['class_name']}):")
            print(f"Transcript: {sample['transcript'][:100]}...")
            print()

        return paired_data

    def _clean_and_validate_transcript(self, transcript):
        """Clean and validate ASR-generated transcript"""
        if not transcript:
            return ""

        # Remove common ASR artifacts
        transcript = transcript.strip()
        transcript = re.sub(r'\[.*?\]', '', transcript)  # Remove [NOISE], [MUSIC], etc.
        transcript = re.sub(r'<.*?>', '', transcript)    # Remove <unk>, <pad>, etc.
        transcript = re.sub(r'\s+', ' ', transcript)     # Normalize whitespace

        # Remove very short or repetitive transcripts
        if len(transcript) < 5:
            return ""

        # Check for repetitive patterns (common ASR error)
        words = transcript.split()
        if len(words) > 1:
            # If more than 70% of words are the same, likely an error
            unique_words = set(words)
            if len(unique_words) / len(words) < 0.3:
                return ""

        return transcript