<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/Extractor_Jul10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install librosa soundfile opensmile speechbrain transformers torch openai-whisper
!pip install pandas numpy matplotlib seaborn

import os
import json
import pickle
import numpy as np
import pandas as pd
import librosa
import torch
import whisper
import opensmile
from transformers import Wav2Vec2Processor, Wav2Vec2Model, BertTokenizer, BertModel
from typing import Dict, List, Any

class ADReSSoAnalyzer:
    def __init__(self, base_path="/content/drive/MyDrive/Speech"):
        self.base_path = base_path
        self.output_path = "/content"
        self.features = {}
        self.transcripts = {}

        # Initialize feature extractors
        self.smile = opensmile.Smile(
            feature_set=opensmile.FeatureSet.eGeMAPSv02,
            feature_level=opensmile.FeatureLevel.Functionals,
        )

        # Initialize Whisper for transcription
        self.whisper_model = whisper.load_model("base")

        # Initialize Wav2Vec2
        self.wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
        self.wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

        # Initialize BERT
        self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')

    def get_audio_files(self) -> Dict[str, List[str]]:
        """Get all audio files from the dataset"""
        audio_files = {
            'diagnosis_ad': [],
            'diagnosis_cn': [],
            'progression_decline': [],
            'progression_no_decline': [],
            'progression_test': []
        }

        # Diagnosis files - extracted-diagnosis-train
        diag_ad_path = f"{self.base_path}/extracted-diagnosis-train/audio/ad"
        diag_cn_path = f"{self.base_path}/extracted-diagnosis-train/audio/cn"

        if os.path.exists(diag_ad_path):
            audio_files['diagnosis_ad'] = [f"{diag_ad_path}/{f}" for f in os.listdir(diag_ad_path) if f.endswith('.wav')]
        if os.path.exists(diag_cn_path):
            audio_files['diagnosis_cn'] = [f"{diag_cn_path}/{f}" for f in os.listdir(diag_cn_path) if f.endswith('.wav')]

        # Progression files - extracted-progression-train
        prog_decline_path = f"{self.base_path}/extracted-progression-train/audio/decline"
        prog_no_decline_path = f"{self.base_path}/extracted-progression-train/audio/no-decline"

        if os.path.exists(prog_decline_path):
            audio_files['progression_decline'] = [f"{prog_decline_path}/{f}" for f in os.listdir(prog_decline_path) if f.endswith('.wav')]
        if os.path.exists(prog_no_decline_path):
            audio_files['progression_no_decline'] = [f"{prog_no_decline_path}/{f}" for f in os.listdir(prog_no_decline_path) if f.endswith('.wav')]

        # Progression test files - extracted-progression-test
        prog_test_path = f"{self.base_path}/extracted-progression-test/audio"

        if os.path.exists(prog_test_path):
            audio_files['progression_test'] = [f"{prog_test_path}/{f}" for f in os.listdir(prog_test_path) if f.endswith('.wav')]

        return audio_files

    def extract_acoustic_features(self, audio_path: str) -> Dict[str, Any]:
        """Extract all acoustic features from audio file"""
        features = {}

        try:
            # Load audio - resample to 16kHz for Wav2Vec2 compatibility
            y, sr = librosa.load(audio_path, sr=16000)  # Force 16kHz sampling rate

            # 1. eGeMAPS features using openSMILE
            try:
                features['egemaps'] = self.smile.process_file(audio_path).values.flatten()
            except Exception as e:
                print(f"  Warning: eGeMAPS extraction failed for {os.path.basename(audio_path)}: {str(e)}")
                features['egemaps'] = np.zeros(88)  # Default eGeMAPS feature size

            # 2. MFCC features
            try:
                mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
                features['mfccs'] = {
                    'mean': np.mean(mfccs, axis=1),
                    'std': np.std(mfccs, axis=1),
                    'delta': np.mean(librosa.feature.delta(mfccs), axis=1),
                    'delta2': np.mean(librosa.feature.delta(mfccs, order=2), axis=1)
                }
            except Exception as e:
                print(f"  Warning: MFCC extraction failed for {os.path.basename(audio_path)}: {str(e)}")
                features['mfccs'] = {
                    'mean': np.zeros(13),
                    'std': np.zeros(13),
                    'delta': np.zeros(13),
                    'delta2': np.zeros(13)
                }

            # 3. Log-mel spectrogram
            try:
                mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=80)
                log_mel = librosa.power_to_db(mel_spec)
                features['log_mel'] = {
                    'mean': np.mean(log_mel, axis=1),
                    'std': np.std(log_mel, axis=1)
                }
            except Exception as e:
                print(f"  Warning: Log-mel extraction failed for {os.path.basename(audio_path)}: {str(e)}")
                features['log_mel'] = {
                    'mean': np.zeros(80),
                    'std': np.zeros(80)
                }

            # 4. Wav2Vec2 features - with proper sampling rate handling
            try:
                # Ensure sampling rate is exactly 16000 Hz for Wav2Vec2
                if len(y) == 0:
                    raise ValueError("Empty audio signal")

                input_values = self.wav2vec_processor(
                    y,
                    sampling_rate=16000,  # Explicitly set to 16000
                    return_tensors="pt"
                ).input_values

                with torch.no_grad():
                    wav2vec_features = self.wav2vec_model(input_values).last_hidden_state
                features['wav2vec2'] = torch.mean(wav2vec_features, dim=1).squeeze().numpy()

            except Exception as e:
                print(f"  Warning: Wav2Vec2 extraction failed for {os.path.basename(audio_path)}: {str(e)}")
                features['wav2vec2'] = np.zeros(768)  # Default Wav2Vec2 feature size

            # 5. Additional prosodic features
            try:
                # Handle potential issues with F0 extraction
                f0 = librosa.yin(y, fmin=50, fmax=300, sr=sr)
                f0_clean = f0[f0 > 0]  # Remove unvoiced frames

                features['prosodic'] = {
                    'f0_mean': np.mean(f0_clean) if len(f0_clean) > 0 else 0.0,
                    'f0_std': np.std(f0_clean) if len(f0_clean) > 0 else 0.0,
                    'energy_mean': np.mean(librosa.feature.rms(y=y)),
                    'energy_std': np.std(librosa.feature.rms(y=y)),
                    'zero_crossing_rate': np.mean(librosa.feature.zero_crossing_rate(y)),
                    'spectral_centroid': np.mean(librosa.feature.spectral_centroid(y=y, sr=sr)),
                    'spectral_rolloff': np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr)),
                    'duration': len(y) / sr
                }
            except Exception as e:
                print(f"  Warning: Prosodic feature extraction failed for {os.path.basename(audio_path)}: {str(e)}")
                features['prosodic'] = {
                    'f0_mean': 0.0, 'f0_std': 0.0, 'energy_mean': 0.0, 'energy_std': 0.0,
                    'zero_crossing_rate': 0.0, 'spectral_centroid': 0.0, 'spectral_rolloff': 0.0,
                    'duration': 0.0
                }

        except Exception as e:
            print(f"Error processing {audio_path}: {str(e)}")
            features = None

        return features

    def show_acoustic_features(self, sample_file: str):
        """Display acoustic features for a sample file"""
        features = self.extract_acoustic_features(sample_file)

        if features is None:
            print(f"Could not extract features from {sample_file}")
            return

        print(f"=== Acoustic Features for {os.path.basename(sample_file)} ===\n")

        # eGeMAPS
        print(f"1. eGeMAPS Features: {len(features['egemaps'])} features")
        print(f"   Shape: {features['egemaps'].shape}")
        print(f"   Sample values: {features['egemaps'][:5]}")
        print()

        # MFCCs
        print("2. MFCC Features:")
        print(f"   Mean: {features['mfccs']['mean'].shape} - {features['mfccs']['mean'][:5]}")
        print(f"   Std: {features['mfccs']['std'].shape} - {features['mfccs']['std'][:5]}")
        print(f"   Delta: {features['mfccs']['delta'].shape} - {features['mfccs']['delta'][:5]}")
        print(f"   Delta-Delta: {features['mfccs']['delta2'].shape} - {features['mfccs']['delta2'][:5]}")
        print()

        # Log-mel
        print("3. Log-Mel Spectrogram Features:")
        print(f"   Mean: {features['log_mel']['mean'].shape} - {features['log_mel']['mean'][:5]}")
        print(f"   Std: {features['log_mel']['std'].shape} - {features['log_mel']['std'][:5]}")
        print()

        # Wav2Vec2
        print(f"4. Wav2Vec2 Features: {features['wav2vec2'].shape}")
        print(f"   Sample values: {features['wav2vec2'][:5]}")
        print()

        # Prosodic
        print("5. Prosodic Features:")
        for key, value in features['prosodic'].items():
            print(f"   {key}: {value:.4f}")
        print()

    def extract_transcripts(self, audio_files: Dict[str, List[str]]) -> Dict[str, str]:
        """Extract transcripts using Whisper"""
        transcripts = {}

        print("Extracting transcripts...")

        for category, files in audio_files.items():
            print(f"\nProcessing {category}...")
            for file_path in files:
                try:
                    filename = os.path.basename(file_path)
                    print(f"  Transcribing {filename}...")

                    result = self.whisper_model.transcribe(file_path)
                    transcript_text = result["text"].strip()

                    transcripts[f"{category}_{filename}"] = {
                        'file_path': file_path,
                        'category': category,
                        'filename': filename,
                        'transcript': transcript_text,
                        'language': result.get('language', 'en'),
                        'segments': len(result.get('segments', []))
                    }

                except Exception as e:
                    print(f"    Error transcribing {filename}: {str(e)}")
                    transcripts[f"{category}_{filename}"] = {
                        'file_path': file_path,
                        'category': category,
                        'filename': filename,
                        'transcript': "",
                        'error': str(e)
                    }

        return transcripts

    def save_transcripts(self, transcripts: Dict[str, str]):
        """Save transcripts to files"""
        os.makedirs(f"{self.output_path}/transcripts", exist_ok=True)

        # Save individual transcript files
        for key, data in transcripts.items():
            filename = f"{key}_transcript.txt"
            filepath = f"{self.output_path}/transcripts/{filename}"

            with open(filepath, 'w', encoding='utf-8') as f:
                f.write(data['transcript'])

        # Save consolidated JSON
        with open(f"{self.output_path}/transcripts/all_transcripts.json", 'w', encoding='utf-8') as f:
            json.dump(transcripts, f, indent=2, ensure_ascii=False)

        # Save as pickle for easy loading
        with open(f"{self.output_path}/transcripts/transcripts.pkl", 'wb') as f:
            pickle.dump(transcripts, f)

        print(f"Transcripts saved to {self.output_path}/transcripts/")

    def create_transcript_table(self, transcripts: Dict[str, str]) -> pd.DataFrame:
        """Create a DataFrame with transcript information"""
        data = []

        for key, info in transcripts.items():
            data.append({
                'File_ID': key,
                'Category': info['category'],
                'Filename': info['filename'],
                'Transcript_Length': len(info['transcript']),
                'Word_Count': len(info['transcript'].split()) if info['transcript'] else 0,
                'Language': info.get('language', 'N/A'),
                'Segments': info.get('segments', 'N/A'),
                'Has_Error': 'error' in info,
                'Transcript_Preview': info['transcript'][:100] + "..." if len(info['transcript']) > 100 else info['transcript']
            })

        df = pd.DataFrame(data)

        # Save the table
        df.to_csv(f"{self.output_path}/transcript_summary.csv", index=False)

        return df

    def extract_linguistic_features(self, transcripts: Dict[str, str]) -> Dict[str, Any]:
        """Extract linguistic features for BERT preparation"""
        linguistic_features = {}

        print("Extracting linguistic features...")

        for key, data in transcripts.items():
            transcript = data['transcript']

            if not transcript:
                linguistic_features[key] = {
                    'raw_text': '',
                    'word_count': 0,
                    'sentence_count': 0,
                    'avg_word_length': 0,
                    'bert_tokens': [],
                    'bert_input_ids': [],
                    'bert_attention_mask': []
                }
                continue

            # Basic linguistic features
            words = transcript.split()
            sentences = transcript.split('.')

            # BERT tokenization
            bert_encoding = self.bert_tokenizer(
                transcript,
                truncation=True,
                padding='max_length',
                max_length=512,
                return_tensors='pt'
            )

            linguistic_features[key] = {
                'raw_text': transcript,
                'word_count': len(words),
                'sentence_count': len([s for s in sentences if s.strip()]),
                'avg_word_length': np.mean([len(word) for word in words]) if words else 0,
                'unique_words': len(set(words)),
                'lexical_diversity': len(set(words)) / len(words) if words else 0,
                'bert_tokens': self.bert_tokenizer.tokenize(transcript),
                'bert_input_ids': bert_encoding['input_ids'].squeeze().tolist(),
                'bert_attention_mask': bert_encoding['attention_mask'].squeeze().tolist(),
                'bert_encoding': bert_encoding
            }

        # Save linguistic features
        with open(f"{self.output_path}/linguistic_features.pkl", 'wb') as f:
            pickle.dump(linguistic_features, f)

        return linguistic_features

    def verify_dataset_structure(self):
        """Verify that the dataset structure matches expectations"""
        print("=== Dataset Structure Verification ===\n")

        expected_paths = [
            f"{self.base_path}/extracted-diagnosis-train/audio/ad",
            f"{self.base_path}/extracted-diagnosis-train/audio/cn",
            f"{self.base_path}/extracted-progression-train/audio/decline",
            f"{self.base_path}/extracted-progression-train/audio/no-decline",
            f"{self.base_path}/extracted-progression-test/audio"
        ]

        for path in expected_paths:
            if os.path.exists(path):
                wav_files = [f for f in os.listdir(path) if f.endswith('.wav')]
                print(f"✓ {path}: {len(wav_files)} .wav files found")
            else:
                print(f"✗ {path}: Directory not found")

        print()

    def run_complete_pipeline(self):
        """Run the complete analysis pipeline"""
        print("=== ADReSSo21 Speech Analysis Pipeline ===\n")

        # Step 0: Verify dataset structure
        self.verify_dataset_structure()

        # Step 1: Get audio files
        print("Step 1: Getting audio files...")
        audio_files = self.get_audio_files()

        total_files = sum(len(files) for files in audio_files.values())
        print(f"Found {total_files} audio files across all categories")

        for category, files in audio_files.items():
            print(f"  {category}: {len(files)} files")

        if total_files == 0:
            print("No audio files found. Please check the dataset path.")
            return

        # Step 2: Show acoustic features for sample files
        print("\n" + "="*50)
        print("Step 2: Demonstrating acoustic features...")

        # Show features for one file from each category that has files
        for category, files in audio_files.items():
            if files:
                print(f"\nShowing features for {category}:")
                self.show_acoustic_features(files[0])
                break  # Just show one example to avoid too much output

        # Step 3: Extract transcripts
        print("\n" + "="*50)
        print("Step 3: Extracting transcripts...")
        transcripts = self.extract_transcripts(audio_files)

        # Step 4: Save transcripts
        print("\n" + "="*50)
        print("Step 4: Saving transcripts...")
        self.save_transcripts(transcripts)

        # Step 5: Create transcript table
        print("\n" + "="*50)
        print("Step 5: Creating transcript table...")
        transcript_df = self.create_transcript_table(transcripts)

        print("Transcript Summary Table:")
        print(transcript_df.to_string(index=False))

        # Step 6: Extract linguistic features for BERT
        print("\n" + "="*50)
        print("Step 6: Extracting linguistic features for BERT...")
        linguistic_features = self.extract_linguistic_features(transcripts)

        print("\nPipeline completed successfully!")
        print(f"Results saved to: {self.output_path}")
        print("\nOutput files:")
        print(f"  - Transcripts: {self.output_path}/transcripts/")
        print(f"  - Transcript summary: {self.output_path}/transcript_summary.csv")
        print(f"  - Linguistic features: {self.output_path}/linguistic_features.pkl")

        return {
            'audio_files': audio_files,
            'transcripts': transcripts,
            'transcript_df': transcript_df,
            'linguistic_features': linguistic_features
        }

# Usage example:
if __name__ == "__main__":
    # Initialize the analyzer with your dataset path
    analyzer = ADReSSoAnalyzer(base_path="/content/drive/MyDrive/Speech")

    # Run the complete pipeline
    results = analyzer.run_complete_pipeline()

    # You can also run individual steps if needed:
    # analyzer.verify_dataset_structure()
    # audio_files = analyzer.get_audio_files()
    # transcripts = analyzer.extract_transcripts(audio_files)

Collecting opensmile
  Downloading opensmile-2.5.1-py3-none-manylinux_2_17_x86_64.whl.metadata (15 kB)
Collecting speechbrain
  Downloading speechbrain-1.0.3-py3-none-any.whl.metadata (24 kB)
Collecting openai-whisper
  Downloading openai_whisper-20250625.tar.gz (803 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m803.2/803.2 kB[0m [31m30.4 MB/s[0m eta [36m0:00:00[0m
[?25h