<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/AddResso_Predict.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install librosa soundfile opensmile speechbrain transformers torch openai-whisper


Collecting opensmile
  Downloading opensmile-2.5.1-py3-none-manylinux_2_17_x86_64.whl.metadata (15 kB)
Collecting speechbrain
  Downloading speechbrain-1.0.3-py3-none-any.whl.metadata (24 kB)
Collecting openai-whisper
  Downloading openai_whisper-20250625.tar.gz (803 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m803.2/803.2 kB[0m [31m43.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting audobject>=0.6.1 (from opensmile)
  Downloading audobject-0.7.12-py3-none-any.whl.metadata (2.7 kB)
Collecting audinterface>=0.7.0 (from opensmile)
  Downloading audinterface-1.3.1-py3-none-any.whl.metadata (4.3 kB)
Collecting hyperpyyaml (from speechbrain)
  Downloading HyperPyYAML-1.2.2-py3-none-any.whl.metadata (7.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-

In [3]:
import os
import json
import pickle
import numpy as np
import pandas as pd
from typing import Dict, List, Any, Optional
import librosa
import torch
import whisper
import opensmile
from transformers import Wav2Vec2Processor, Wav2Vec2Model, BertTokenizer, BertModel
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

class ADReSSoAnalyzer:
    def __init__(self, base_path="/content/drive/MyDrive/Voice/extracted/ADReSSo21"):
        self.base_path = base_path
        self.output_path = "/content"
        self.checkpoint_path = "/content/checkpoints"

        # Initialize models
        self._initialize_models()

        # Create output directories
        os.makedirs(f"{self.output_path}/transcripts", exist_ok=True)
        os.makedirs(f"{self.output_path}/features", exist_ok=True)
        os.makedirs(self.checkpoint_path, exist_ok=True)

    def _initialize_models(self):
        """Initialize all required models"""
        print("Initializing models...")

        # OpenSMILE for eGeMAPS features
        self.smile = opensmile.Smile(
            feature_set=opensmile.FeatureSet.eGeMAPSv02,
            feature_level=opensmile.FeatureLevel.Functionals,
        )

        # Whisper for transcription
        self.whisper_model = whisper.load_model("base")

        # Wav2Vec2 for deep acoustic features
        self.wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
        self.wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

        # BERT for linguistic features
        self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')

        print("Models initialized successfully!")

    def get_audio_files(self) -> Dict[str, List[str]]:
        """Get all audio files from the dataset"""
        audio_files = {
            'diagnosis_ad': [],
            'diagnosis_cn': [],
            'progression_decline': [],
            'progression_no_decline': [],
            'progression_test': []
        }

        # Define paths
        paths = {
            'diagnosis_ad': f"{self.base_path}/diagnosis/train/audio/ad",
            'diagnosis_cn': f"{self.base_path}/diagnosis/train/audio/cn",
            'progression_decline': f"{self.base_path}/progression/train/audio/decline",
            'progression_no_decline': f"{self.base_path}/progression/train/audio/no_decline",
            'progression_test': f"{self.base_path}/progression/test-dist/audio"
        }

        # Collect audio files
        for category, path in paths.items():
            if os.path.exists(path):
                audio_files[category] = [
                    f"{path}/{f}" for f in os.listdir(path)
                    if f.endswith('.wav')
                ]

        return audio_files

    def extract_acoustic_features(self, audio_path: str) -> Optional[Dict[str, Any]]:
        """Extract comprehensive acoustic features from audio file"""
        try:
            # Load audio at 16kHz
            y, sr = librosa.load(audio_path, sr=16000)

            if len(y) == 0:
                print(f"Warning: Empty audio file {os.path.basename(audio_path)}")
                return None

            features = {}

            # 1. eGeMAPS features
            try:
                features['egemaps'] = self.smile.process_file(audio_path).values.flatten()
            except Exception as e:
                print(f"Warning: eGeMAPS failed for {os.path.basename(audio_path)}: {str(e)}")
                features['egemaps'] = np.zeros(88)

            # 2. MFCC features
            try:
                mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
                features['mfccs'] = {
                    'mean': np.mean(mfccs, axis=1),
                    'std': np.std(mfccs, axis=1),
                    'delta': np.mean(librosa.feature.delta(mfccs), axis=1),
                    'delta2': np.mean(librosa.feature.delta(mfccs, order=2), axis=1)
                }
            except Exception as e:
                print(f"Warning: MFCC failed for {os.path.basename(audio_path)}: {str(e)}")
                features['mfccs'] = {
                    'mean': np.zeros(13), 'std': np.zeros(13),
                    'delta': np.zeros(13), 'delta2': np.zeros(13)
                }

            # 3. Log-mel spectrogram features
            try:
                mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=80)
                log_mel = librosa.power_to_db(mel_spec)
                features['log_mel'] = {
                    'mean': np.mean(log_mel, axis=1),
                    'std': np.std(log_mel, axis=1)
                }
            except Exception as e:
                print(f"Warning: Log-mel failed for {os.path.basename(audio_path)}: {str(e)}")
                features['log_mel'] = {'mean': np.zeros(80), 'std': np.zeros(80)}

            # 4. Wav2Vec2 features
            try:
                input_values = self.wav2vec_processor(
                    y, sampling_rate=16000, return_tensors="pt"
                ).input_values

                with torch.no_grad():
                    wav2vec_features = self.wav2vec_model(input_values).last_hidden_state
                features['wav2vec2'] = torch.mean(wav2vec_features, dim=1).squeeze().numpy()

            except Exception as e:
                print(f"Warning: Wav2Vec2 failed for {os.path.basename(audio_path)}: {str(e)}")
                features['wav2vec2'] = np.zeros(768)

            # 5. Prosodic features
            try:
                f0 = librosa.yin(y, fmin=50, fmax=300, sr=sr)
                f0_clean = f0[f0 > 0]

                features['prosodic'] = {
                    'f0_mean': np.mean(f0_clean) if len(f0_clean) > 0 else 0.0,
                    'f0_std': np.std(f0_clean) if len(f0_clean) > 0 else 0.0,
                    'energy_mean': np.mean(librosa.feature.rms(y=y)),
                    'energy_std': np.std(librosa.feature.rms(y=y)),
                    'zero_crossing_rate': np.mean(librosa.feature.zero_crossing_rate(y)),
                    'spectral_centroid': np.mean(librosa.feature.spectral_centroid(y=y, sr=sr)),
                    'spectral_rolloff': np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr)),
                    'duration': len(y) / sr
                }
            except Exception as e:
                print(f"Warning: Prosodic features failed for {os.path.basename(audio_path)}: {str(e)}")
                features['prosodic'] = {
                    'f0_mean': 0.0, 'f0_std': 0.0, 'energy_mean': 0.0, 'energy_std': 0.0,
                    'zero_crossing_rate': 0.0, 'spectral_centroid': 0.0, 'spectral_rolloff': 0.0,
                    'duration': 0.0
                }

            return features

        except Exception as e:
            print(f"Error processing {audio_path}: {str(e)}")
            return None

    def extract_acoustic_features_with_checkpoint(self, audio_files: Dict[str, List[str]]) -> Dict[str, Dict[str, Any]]:
        """Extract acoustic features with checkpoint support"""
        checkpoint_file = f"{self.checkpoint_path}/acoustic_features.pkl"

        # Try to load existing checkpoint
        if os.path.exists(checkpoint_file):
            print("Loading existing acoustic features checkpoint...")
            with open(checkpoint_file, 'rb') as f:
                acoustic_features = pickle.load(f)
            print(f"Loaded {len(acoustic_features)} existing features")
        else:
            acoustic_features = {}

        # Count total files and processed files
        total_files = sum(len(files) for files in audio_files.values())
        processed_files = len(acoustic_features)

        print(f"Extracting acoustic features: {processed_files}/{total_files} completed")

        # Process remaining files
        with tqdm(total=total_files, initial=processed_files, desc="Extracting acoustic features") as pbar:
            for category, files in audio_files.items():
                for file_path in files:
                    filename = os.path.basename(file_path)
                    key = f"{category}_{filename}"

                    # Skip if already processed
                    if key in acoustic_features:
                        pbar.update(1)
                        continue

                    # Extract features
                    features = self.extract_acoustic_features(file_path)

                    if features is not None:
                        acoustic_features[key] = {
                            'file_path': file_path,
                            'category': category,
                            'filename': filename,
                            'features': features
                        }
                    else:
                        print(f"Failed to extract features for {filename}")

                    pbar.update(1)

                    # Save checkpoint every 10 files
                    if len(acoustic_features) % 10 == 0:
                        with open(checkpoint_file, 'wb') as f:
                            pickle.dump(acoustic_features, f)

        # Final save
        with open(checkpoint_file, 'wb') as f:
            pickle.dump(acoustic_features, f)

        print(f"Acoustic features extraction completed: {len(acoustic_features)} files processed")
        return acoustic_features

    def extract_transcripts_with_checkpoint(self, audio_files: Dict[str, List[str]]) -> Dict[str, Dict[str, Any]]:
        """Extract transcripts with checkpoint support"""
        checkpoint_file = f"{self.checkpoint_path}/transcripts.pkl"

        # Try to load existing checkpoint
        if os.path.exists(checkpoint_file):
            print("Loading existing transcripts checkpoint...")
            with open(checkpoint_file, 'rb') as f:
                transcripts = pickle.load(f)
            print(f"Loaded {len(transcripts)} existing transcripts")
        else:
            transcripts = {}

        # Count total files and processed files
        total_files = sum(len(files) for files in audio_files.values())
        processed_files = len(transcripts)

        print(f"Extracting transcripts: {processed_files}/{total_files} completed")

        # Process remaining files
        with tqdm(total=total_files, initial=processed_files, desc="Extracting transcripts") as pbar:
            for category, files in audio_files.items():
                for file_path in files:
                    filename = os.path.basename(file_path)
                    key = f"{category}_{filename}"

                    # Skip if already processed
                    if key in transcripts:
                        pbar.update(1)
                        continue

                    try:
                        result = self.whisper_model.transcribe(file_path)
                        transcripts[key] = {
                            'file_path': file_path,
                            'category': category,
                            'filename': filename,
                            'transcript': result["text"].strip(),
                            'language': result.get('language', 'en'),
                            'segments': len(result.get('segments', []))
                        }
                    except Exception as e:
                        print(f"Error transcribing {filename}: {str(e)}")
                        transcripts[key] = {
                            'file_path': file_path,
                            'category': category,
                            'filename': filename,
                            'transcript': "",
                            'error': str(e)
                        }

                    pbar.update(1)

                    # Save checkpoint every 5 files (transcription is slower)
                    if len(transcripts) % 5 == 0:
                        with open(checkpoint_file, 'wb') as f:
                            pickle.dump(transcripts, f)

        # Final save
        with open(checkpoint_file, 'wb') as f:
            pickle.dump(transcripts, f)

        print(f"Transcript extraction completed: {len(transcripts)} files processed")
        return transcripts

    def extract_linguistic_features(self, transcripts: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
        """Extract linguistic features from transcripts"""
        linguistic_features = {}

        print("Extracting linguistic features...")

        for key, data in tqdm(transcripts.items(), desc="Processing transcripts"):
            transcript = data['transcript']

            if not transcript:
                linguistic_features[key] = {
                    'raw_text': '',
                    'word_count': 0,
                    'sentence_count': 0,
                    'avg_word_length': 0,
                    'unique_words': 0,
                    'lexical_diversity': 0,
                    'bert_input_ids': [],
                    'bert_attention_mask': []
                }
                continue

            # Basic linguistic features
            words = transcript.split()
            sentences = [s.strip() for s in transcript.split('.') if s.strip()]

            # BERT tokenization
            bert_encoding = self.bert_tokenizer(
                transcript,
                truncation=True,
                padding='max_length',
                max_length=512,
                return_tensors='pt'
            )

            linguistic_features[key] = {
                'raw_text': transcript,
                'word_count': len(words),
                'sentence_count': len(sentences),
                'avg_word_length': np.mean([len(word) for word in words]) if words else 0,
                'unique_words': len(set(words)),
                'lexical_diversity': len(set(words)) / len(words) if words else 0,
                'bert_input_ids': bert_encoding['input_ids'].squeeze().tolist(),
                'bert_attention_mask': bert_encoding['attention_mask'].squeeze().tolist()
            }

        # Save linguistic features
        with open(f"{self.output_path}/features/linguistic_features.pkl", 'wb') as f:
            pickle.dump(linguistic_features, f)

        return linguistic_features

    def save_results(self, acoustic_features: Dict, transcripts: Dict, linguistic_features: Dict):
        """Save all results to files"""
        print("Saving results...")

        # Save acoustic features
        with open(f"{self.output_path}/features/acoustic_features.pkl", 'wb') as f:
            pickle.dump(acoustic_features, f)

        # Save transcripts
        with open(f"{self.output_path}/transcripts/transcripts.pkl", 'wb') as f:
            pickle.dump(transcripts, f)

        # Save transcripts as JSON
        with open(f"{self.output_path}/transcripts/transcripts.json", 'w', encoding='utf-8') as f:
            json.dump(transcripts, f, indent=2, ensure_ascii=False)

        # Create summary DataFrame
        summary_data = []
        for key in acoustic_features.keys():
            transcript_data = transcripts.get(key, {})
            linguistic_data = linguistic_features.get(key, {})

            summary_data.append({
                'File_ID': key,
                'Category': acoustic_features[key]['category'],
                'Filename': acoustic_features[key]['filename'],
                'Has_Acoustic_Features': 'features' in acoustic_features[key],
                'Has_Transcript': bool(transcript_data.get('transcript', '')),
                'Word_Count': linguistic_data.get('word_count', 0),
                'Transcript_Length': len(transcript_data.get('transcript', '')),
                'Language': transcript_data.get('language', 'N/A'),
                'Has_Error': 'error' in transcript_data
            })

        summary_df = pd.DataFrame(summary_data)
        summary_df.to_csv(f"{self.output_path}/processing_summary.csv", index=False)

        print(f"Results saved to {self.output_path}")

    def run_complete_pipeline(self):
        """Run the complete analysis pipeline with checkpointing"""
        print("=== ADReSSo21 Speech Analysis Pipeline ===\n")

        # Step 1: Get audio files
        print("Step 1: Getting audio files...")
        audio_files = self.get_audio_files()

        total_files = sum(len(files) for files in audio_files.values())
        print(f"Found {total_files} audio files")

        for category, files in audio_files.items():
            print(f"  {category}: {len(files)} files")

        if total_files == 0:
            print("No audio files found. Please check the dataset path.")
            return

        # Step 2: Extract acoustic features with checkpoint
        print("\nStep 2: Extracting acoustic features...")
        acoustic_features = self.extract_acoustic_features_with_checkpoint(audio_files)

        # Step 3: Extract transcripts with checkpoint
        print("\nStep 3: Extracting transcripts...")
        transcripts = self.extract_transcripts_with_checkpoint(audio_files)

        # Step 4: Extract linguistic features
        print("\nStep 4: Extracting linguistic features...")
        linguistic_features = self.extract_linguistic_features(transcripts)

        # Step 5: Save all results
        print("\nStep 5: Saving results...")
        self.save_results(acoustic_features, transcripts, linguistic_features)

        print("\n=== Pipeline completed successfully! ===")
        print(f"Processed {len(acoustic_features)} files")
        print(f"Results saved to: {self.output_path}")
        print(f"Checkpoints saved to: {self.checkpoint_path}")

        return {
            'audio_files': audio_files,
            'acoustic_features': acoustic_features,
            'transcripts': transcripts,
            'linguistic_features': linguistic_features
        }

# Usage example:
if __name__ == "__main__":
    # Initialize analyzer
    analyzer = ADReSSoAnalyzer()

    # Run complete pipeline
    results = analyzer.run_complete_pipeline()

Initializing models...


100%|████████████████████████████████████████| 139M/139M [00:00<00:00, 157MiB/s]


preprocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/163 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/378M [00:00<?, ?B/s]

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Models initialized successfully!
=== ADReSSo21 Speech Analysis Pipeline ===

Step 1: Getting audio files...
Found 0 audio files
  diagnosis_ad: 0 files
  diagnosis_cn: 0 files
  progression_decline: 0 files
  progression_no_decline: 0 files
  progression_test: 0 files
No audio files found. Please check the dataset path.
