<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/Extractor_Jul10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
!pip install librosa soundfile opensmile speechbrain transformers torch openai-whisper
!pip install pandas numpy matplotlib seaborn

import os
import json
import pickle
import numpy as np
import pandas as pd
import librosa
import torch
import whisper
import opensmile
from transformers import Wav2Vec2Processor, Wav2Vec2Model, BertTokenizer, BertModel
from typing import Dict, List, Any

class ADReSSoAnalyzer:
    def __init__(self, base_path="/content/drive/MyDrive/Speech"):
        self.base_path = base_path
        self.output_path = "/content"
        self.features = {}
        self.transcripts = {}

        # Initialize feature extractors
        self.smile = opensmile.Smile(
            feature_set=opensmile.FeatureSet.eGeMAPSv02,
            feature_level=opensmile.FeatureLevel.Functionals,
        )

        # Initialize Whisper for transcription
        self.whisper_model = whisper.load_model("base")

        # Initialize Wav2Vec2
        self.wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
        self.wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

        # Initialize BERT
        self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')

    def get_audio_files(self) -> Dict[str, List[str]]:
        """Get all audio files from the dataset"""
        audio_files = {
            'diagnosis_ad': [],
            'diagnosis_cn': [],
            'progression_decline': [],
            'progression_no_decline': [],
            'progression_test': []
        }

        # Diagnosis files - extracted_diagnosis_train
        diag_ad_path = f"{self.base_path}/extracted_diagnosis_train/ADReSSo21/diagnosis/train/audio/ad"
        diag_cn_path = f"{self.base_path}/extracted_diagnosis_train/ADReSSo21/diagnosis/train/audio/cn"

        if os.path.exists(diag_ad_path):
            audio_files['diagnosis_ad'] = [f"{diag_ad_path}/{f}" for f in os.listdir(diag_ad_path) if f.endswith('.wav')]
        if os.path.exists(diag_cn_path):
            audio_files['diagnosis_cn'] = [f"{diag_cn_path}/{f}" for f in os.listdir(diag_cn_path) if f.endswith('.wav')]

        # Progression files - extracted_progression_train
        prog_decline_path = f"{self.base_path}/extracted_progression_train/ADReSSo21/progression/train/audio/decline"
        prog_no_decline_path = f"{self.base_path}/extracted_progression_train/ADReSSo21/progression/train/audio/no_decline"

        if os.path.exists(prog_decline_path):
            audio_files['progression_decline'] = [f"{prog_decline_path}/{f}" for f in os.listdir(prog_decline_path) if f.endswith('.wav')]
        if os.path.exists(prog_no_decline_path):
            audio_files['progression_no_decline'] = [f"{prog_no_decline_path}/{f}" for f in os.listdir(prog_no_decline_path) if f.endswith('.wav')]

        # Progression test files - extracted_progression_test
        prog_test_path = f"{self.base_path}/extracted_progression_test/ADReSSo21/progression/test-dist/audio"

        if os.path.exists(prog_test_path):
            audio_files['progression_test'] = [f"{prog_test_path}/{f}" for f in os.listdir(prog_test_path) if f.endswith('.wav')]

        return audio_files

    def extract_acoustic_features(self, audio_path: str) -> Dict[str, Any]:
        """Extract all acoustic features from audio file"""
        features = {}

        try:
            # Load audio - resample to 16kHz for Wav2Vec2 compatibility
            y, sr = librosa.load(audio_path, sr=16000)  # Force 16kHz sampling rate

            # 1. eGeMAPS features using openSMILE
            try:
                features['egemaps'] = self.smile.process_file(audio_path).values.flatten()
            except Exception as e:
                print(f"  Warning: eGeMAPS extraction failed for {os.path.basename(audio_path)}: {str(e)}")
                features['egemaps'] = np.zeros(88)  # Default eGeMAPS feature size

            # 2. MFCC features
            try:
                mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
                features['mfccs'] = {
                    'mean': np.mean(mfccs, axis=1),
                    'std': np.std(mfccs, axis=1),
                    'delta': np.mean(librosa.feature.delta(mfccs), axis=1),
                    'delta2': np.mean(librosa.feature.delta(mfccs, order=2), axis=1)
                }
            except Exception as e:
                print(f"  Warning: MFCC extraction failed for {os.path.basename(audio_path)}: {str(e)}")
                features['mfccs'] = {
                    'mean': np.zeros(13),
                    'std': np.zeros(13),
                    'delta': np.zeros(13),
                    'delta2': np.zeros(13)
                }

            # 3. Log-mel spectrogram
            try:
                mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=80)
                log_mel = librosa.power_to_db(mel_spec)
                features['log_mel'] = {
                    'mean': np.mean(log_mel, axis=1),
                    'std': np.std(log_mel, axis=1)
                }
            except Exception as e:
                print(f"  Warning: Log-mel extraction failed for {os.path.basename(audio_path)}: {str(e)}")
                features['log_mel'] = {
                    'mean': np.zeros(80),
                    'std': np.zeros(80)
                }

            # 4. Wav2Vec2 features - with proper sampling rate handling
            try:
                # Ensure sampling rate is exactly 16000 Hz for Wav2Vec2
                if len(y) == 0:
                    raise ValueError("Empty audio signal")

                input_values = self.wav2vec_processor(
                    y,
                    sampling_rate=16000,  # Explicitly set to 16000
                    return_tensors="pt"
                ).input_values

                with torch.no_grad():
                    wav2vec_features = self.wav2vec_model(input_values).last_hidden_state
                features['wav2vec2'] = torch.mean(wav2vec_features, dim=1).squeeze().numpy()

            except Exception as e:
                print(f"  Warning: Wav2Vec2 extraction failed for {os.path.basename(audio_path)}: {str(e)}")
                features['wav2vec2'] = np.zeros(768)  # Default Wav2Vec2 feature size

            # 5. Additional prosodic features
            try:
                # Handle potential issues with F0 extraction
                f0 = librosa.yin(y, fmin=50, fmax=300, sr=sr)
                f0_clean = f0[f0 > 0]  # Remove unvoiced frames

                features['prosodic'] = {
                    'f0_mean': np.mean(f0_clean) if len(f0_clean) > 0 else 0.0,
                    'f0_std': np.std(f0_clean) if len(f0_clean) > 0 else 0.0,
                    'energy_mean': np.mean(librosa.feature.rms(y=y)),
                    'energy_std': np.std(librosa.feature.rms(y=y)),
                    'zero_crossing_rate': np.mean(librosa.feature.zero_crossing_rate(y)),
                    'spectral_centroid': np.mean(librosa.feature.spectral_centroid(y=y, sr=sr)),
                    'spectral_rolloff': np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr)),
                    'duration': len(y) / sr
                }
            except Exception as e:
                print(f"  Warning: Prosodic feature extraction failed for {os.path.basename(audio_path)}: {str(e)}")
                features['prosodic'] = {
                    'f0_mean': 0.0, 'f0_std': 0.0, 'energy_mean': 0.0, 'energy_std': 0.0,
                    'zero_crossing_rate': 0.0, 'spectral_centroid': 0.0, 'spectral_rolloff': 0.0,
                    'duration': 0.0
                }

        except Exception as e:
            print(f"Error processing {audio_path}: {str(e)}")
            features = None

        return features

    def show_acoustic_features(self, sample_file: str):
        """Display acoustic features for a sample file"""
        features = self.extract_acoustic_features(sample_file)

        if features is None:
            print(f"Could not extract features from {sample_file}")
            return

        print(f"=== Acoustic Features for {os.path.basename(sample_file)} ===\n")

        # eGeMAPS
        print(f"1. eGeMAPS Features: {len(features['egemaps'])} features")
        print(f"   Shape: {features['egemaps'].shape}")
        print(f"   Sample values: {features['egemaps'][:5]}")
        print()

        # MFCCs
        print("2. MFCC Features:")
        print(f"   Mean: {features['mfccs']['mean'].shape} - {features['mfccs']['mean'][:5]}")
        print(f"   Std: {features['mfccs']['std'].shape} - {features['mfccs']['std'][:5]}")
        print(f"   Delta: {features['mfccs']['delta'].shape} - {features['mfccs']['delta'][:5]}")
        print(f"   Delta-Delta: {features['mfccs']['delta2'].shape} - {features['mfccs']['delta2'][:5]}")
        print()

        # Log-mel
        print("3. Log-Mel Spectrogram Features:")
        print(f"   Mean: {features['log_mel']['mean'].shape} - {features['log_mel']['mean'][:5]}")
        print(f"   Std: {features['log_mel']['std'].shape} - {features['log_mel']['std'][:5]}")
        print()

        # Wav2Vec2
        print(f"4. Wav2Vec2 Features: {features['wav2vec2'].shape}")
        print(f"   Sample values: {features['wav2vec2'][:5]}")
        print()

        # Prosodic
        print("5. Prosodic Features:")
        for key, value in features['prosodic'].items():
            print(f"   {key}: {value:.4f}")
        print()

    def extract_transcripts(self, audio_files: Dict[str, List[str]]) -> Dict[str, str]:
        """Extract transcripts using Whisper"""
        transcripts = {}

        print("Extracting transcripts...")

        for category, files in audio_files.items():
            print(f"\nProcessing {category}...")
            for file_path in files:
                try:
                    filename = os.path.basename(file_path)
                    print(f"  Transcribing {filename}...")

                    result = self.whisper_model.transcribe(file_path)
                    transcript_text = result["text"].strip()

                    transcripts[f"{category}_{filename}"] = {
                        'file_path': file_path,
                        'category': category,
                        'filename': filename,
                        'transcript': transcript_text,
                        'language': result.get('language', 'en'),
                        'segments': len(result.get('segments', []))
                    }

                except Exception as e:
                    print(f"    Error transcribing {filename}: {str(e)}")
                    transcripts[f"{category}_{filename}"] = {
                        'file_path': file_path,
                        'category': category,
                        'filename': filename,
                        'transcript': "",
                        'error': str(e)
                    }

        return transcripts

    def save_transcripts(self, transcripts: Dict[str, str]):
        """Save transcripts to files"""
        os.makedirs(f"{self.output_path}/transcripts", exist_ok=True)

        # Save individual transcript files
        for key, data in transcripts.items():
            filename = f"{key}_transcript.txt"
            filepath = f"{self.output_path}/transcripts/{filename}"

            with open(filepath, 'w', encoding='utf-8') as f:
                f.write(data['transcript'])

        # Save consolidated JSON
        with open(f"{self.output_path}/transcripts/all_transcripts.json", 'w', encoding='utf-8') as f:
            json.dump(transcripts, f, indent=2, ensure_ascii=False)

        # Save as pickle for easy loading
        with open(f"{self.output_path}/transcripts/transcripts.pkl", 'wb') as f:
            pickle.dump(transcripts, f)

        print(f"Transcripts saved to {self.output_path}/transcripts/")

    def create_transcript_table(self, transcripts: Dict[str, str]) -> pd.DataFrame:
        """Create a DataFrame with transcript information"""
        data = []

        for key, info in transcripts.items():
            data.append({
                'File_ID': key,
                'Category': info['category'],
                'Filename': info['filename'],
                'Transcript_Length': len(info['transcript']),
                'Word_Count': len(info['transcript'].split()) if info['transcript'] else 0,
                'Language': info.get('language', 'N/A'),
                'Segments': info.get('segments', 'N/A'),
                'Has_Error': 'error' in info,
                'Transcript_Preview': info['transcript'][:100] + "..." if len(info['transcript']) > 100 else info['transcript']
            })

        df = pd.DataFrame(data)

        # Save the table
        df.to_csv(f"{self.output_path}/transcript_summary.csv", index=False)

        return df

    def extract_linguistic_features(self, transcripts: Dict[str, str]) -> Dict[str, Any]:
        """Extract linguistic features for BERT preparation"""
        linguistic_features = {}

        print("Extracting linguistic features...")

        for key, data in transcripts.items():
            transcript = data['transcript']

            if not transcript:
                linguistic_features[key] = {
                    'raw_text': '',
                    'word_count': 0,
                    'sentence_count': 0,
                    'avg_word_length': 0,
                    'bert_tokens': [],
                    'bert_input_ids': [],
                    'bert_attention_mask': []
                }
                continue

            # Basic linguistic features
            words = transcript.split()
            sentences = transcript.split('.')

            # BERT tokenization
            bert_encoding = self.bert_tokenizer(
                transcript,
                truncation=True,
                padding='max_length',
                max_length=512,
                return_tensors='pt'
            )

            linguistic_features[key] = {
                'raw_text': transcript,
                'word_count': len(words),
                'sentence_count': len([s for s in sentences if s.strip()]),
                'avg_word_length': np.mean([len(word) for word in words]) if words else 0,
                'unique_words': len(set(words)),
                'lexical_diversity': len(set(words)) / len(words) if words else 0,
                'bert_tokens': self.bert_tokenizer.tokenize(transcript),
                'bert_input_ids': bert_encoding['input_ids'].squeeze().tolist(),
                'bert_attention_mask': bert_encoding['attention_mask'].squeeze().tolist(),
                'bert_encoding': bert_encoding
            }

        # Save linguistic features
        with open(f"{self.output_path}/linguistic_features.pkl", 'wb') as f:
            pickle.dump(linguistic_features, f)

    def get_segmentation_files(self) -> Dict[str, List[str]]:
        """Get all segmentation CSV files from the dataset"""
        segmentation_files = {
            'diagnosis_ad': [],
            'diagnosis_cn': [],
            'progression_test': []
        }

        # Diagnosis segmentation files
        diag_ad_seg_path = f"{self.base_path}/extracted_diagnosis_train/ADReSSo21/diagnosis/train/segmentation/ad"
        diag_cn_seg_path = f"{self.base_path}/extracted_diagnosis_train/ADReSSo21/diagnosis/train/segmentation/cn"

        if os.path.exists(diag_ad_seg_path):
            segmentation_files['diagnosis_ad'] = [f"{diag_ad_seg_path}/{f}" for f in os.listdir(diag_ad_seg_path) if f.endswith('.csv')]
        if os.path.exists(diag_cn_seg_path):
            segmentation_files['diagnosis_cn'] = [f"{diag_cn_seg_path}/{f}" for f in os.listdir(diag_cn_seg_path) if f.endswith('.csv')]

        # Progression test segmentation files
        prog_test_seg_path = f"{self.base_path}/extracted_progression_test/ADReSSo21/progression/test-dist/segmentation"

        if os.path.exists(prog_test_seg_path):
            segmentation_files['progression_test'] = [f"{prog_test_seg_path}/{f}" for f in os.listdir(prog_test_seg_path) if f.endswith('.csv')]

        return segmentation_files

    def verify_dataset_structure(self):
        """Verify that the dataset structure matches expectations"""
        print("=== Dataset Structure Verification ===\n")

        expected_paths = [
            f"{self.base_path}/extracted_diagnosis_train/ADReSSo21/diagnosis/train/audio/ad",
            f"{self.base_path}/extracted_diagnosis_train/ADReSSo21/diagnosis/train/audio/cn",
            f"{self.base_path}/extracted_progression_train/ADReSSo21/progression/train/audio/decline",
            f"{self.base_path}/extracted_progression_train/ADReSSo21/progression/train/audio/no_decline",
            f"{self.base_path}/extracted_progression_test/ADReSSo21/progression/test-dist/audio"
        ]

        # Also check for segmentation directories
        segmentation_paths = [
            f"{self.base_path}/extracted_diagnosis_train/ADReSSo21/diagnosis/train/segmentation/ad",
            f"{self.base_path}/extracted_diagnosis_train/ADReSSo21/diagnosis/train/segmentation/cn",
            f"{self.base_path}/extracted_progression_test/ADReSSo21/progression/test-dist/segmentation"
        ]

        print("Audio directories:")
        for path in expected_paths:
            if os.path.exists(path):
                wav_files = [f for f in os.listdir(path) if f.endswith('.wav')]
                print(f"✓ {path}: {len(wav_files)} .wav files found")
            else:
                print(f"✗ {path}: Directory not found")

        print("\nSegmentation directories:")
        for path in segmentation_paths:
            if os.path.exists(path):
                csv_files = [f for f in os.listdir(path) if f.endswith('.csv')]
                print(f"✓ {path}: {len(csv_files)} .csv files found")
            else:
                print(f"✗ {path}: Directory not found")

        print()

    def run_complete_pipeline(self):
        """Run the complete analysis pipeline"""
        print("=== ADReSSo21 Speech Analysis Pipeline ===\n")

        # Step 0: Verify dataset structure
        self.verify_dataset_structure()

        # Step 1: Get audio files
        print("Step 1: Getting audio files...")
        audio_files = self.get_audio_files()

        # Also get segmentation files
        print("Step 1b: Getting segmentation files...")
        segmentation_files = self.get_segmentation_files()

        total_files = sum(len(files) for files in audio_files.values())
        total_seg_files = sum(len(files) for files in segmentation_files.values())

        print(f"Found {total_files} audio files across all categories")
        print(f"Found {total_seg_files} segmentation files across all categories")

        for category, files in audio_files.items():
            seg_count = len(segmentation_files.get(category, []))
            print(f"  {category}: {len(files)} audio files, {seg_count} segmentation files")

        if total_files == 0:
            print("No audio files found. Please check the dataset path.")
            return

        # Step 2: Show acoustic features for sample files
        print("\n" + "="*50)
        print("Step 2: Demonstrating acoustic features...")

        # Show features for one file from each category that has files
        for category, files in audio_files.items():
            if files:
                print(f"\nShowing features for {category}:")
                self.show_acoustic_features(files[0])
                break  # Just show one example to avoid too much output

        # Step 3: Extract transcripts
        print("\n" + "="*50)
        print("Step 3: Extracting transcripts...")
        transcripts = self.extract_transcripts(audio_files)

        # Step 4: Save transcripts
        print("\n" + "="*50)
        print("Step 4: Saving transcripts...")
        self.save_transcripts(transcripts)

        # Step 5: Create transcript table
        print("\n" + "="*50)
        print("Step 5: Creating transcript table...")
        transcript_df = self.create_transcript_table(transcripts)

        print("Transcript Summary Table:")
        print(transcript_df.to_string(index=False))

        # Step 6: Extract linguistic features for BERT
        print("\n" + "="*50)
        print("Step 6: Extracting linguistic features for BERT...")
        linguistic_features = self.extract_linguistic_features(transcripts)

        print("\nPipeline completed successfully!")
        print(f"Results saved to: {self.output_path}")
        print("\nOutput files:")
        print(f"  - Transcripts: {self.output_path}/transcripts/")
        print(f"  - Transcript summary: {self.output_path}/transcript_summary.csv")
        print(f"  - Linguistic features: {self.output_path}/linguistic_features.pkl")

        return {
            'audio_files': audio_files,
            'segmentation_files': segmentation_files,
            'transcripts': transcripts,
            'transcript_df': transcript_df,
            'linguistic_features': linguistic_features
        }

# Usage example:
if __name__ == "__main__":
    # Initialize the analyzer with your dataset path
    analyzer = ADReSSoAnalyzer(base_path="/content/drive/MyDrive/Speech")

    # Run the complete pipeline
    results = analyzer.run_complete_pipeline()

    # You can also run individual steps if needed:
    # analyzer.verify_dataset_structure()
    # audio_files = analyzer.get_audio_files()
    # segmentation_files = analyzer.get_segmentation_files()
    # transcripts = analyzer.extract_transcripts(audio_files)



Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


=== ADReSSo21 Speech Analysis Pipeline ===

=== Dataset Structure Verification ===

Audio directories:
✓ /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/audio/ad: 87 .wav files found
✓ /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/audio/cn: 79 .wav files found
✓ /content/drive/MyDrive/Speech/extracted_progression_train/ADReSSo21/progression/train/audio/decline: 15 .wav files found
✓ /content/drive/MyDrive/Speech/extracted_progression_train/ADReSSo21/progression/train/audio/no_decline: 58 .wav files found
✓ /content/drive/MyDrive/Speech/extracted_progression_test/ADReSSo21/progression/test-dist/audio: 32 .wav files found

Segmentation directories:
✓ /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/segmentation/ad: 87 .csv files found
✓ /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/segmentation/cn: 79 .csv files found
✓ /content/drive/MyDrive/Speech/e

In [2]:
import os
import json
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.stats import pearsonr, spearmanr
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
import warnings
warnings.filterwarnings('ignore')

class FeatureCorrelationAnalyzer:
    def __init__(self, base_path="/content", analyzer=None):
        self.base_path = base_path
        self.analyzer = analyzer
        self.combined_features = {}
        self.correlation_results = {}

    def load_existing_features(self):
        """Load previously extracted features"""
        try:
            # Load linguistic features
            with open(f"{self.base_path}/linguistic_features.pkl", 'rb') as f:
                self.linguistic_features = pickle.load(f)

            # Load transcripts for additional processing
            with open(f"{self.base_path}/transcripts/transcripts.pkl", 'rb') as f:
                self.transcripts = pickle.load(f)

            print("✓ Existing features loaded successfully")
            return True

        except FileNotFoundError:
            print("✗ Feature files not found. Please run the main pipeline first.")
            return False

    def extract_combined_features(self, audio_files):
        """Extract both acoustic and linguistic features for correlation analysis"""
        if not self.load_existing_features():
            return None

        print("Extracting combined acoustic and linguistic features...")

        combined_data = {}

        for category, files in audio_files.items():
            print(f"\nProcessing {category}...")

            for file_path in files:
                filename = os.path.basename(file_path)
                key = f"{category}_{filename}"

                print(f"  Processing {filename}...")

                # Extract acoustic features
                acoustic_features = self.analyzer.extract_acoustic_features(file_path)

                if acoustic_features is None:
                    print(f"    Skipping {filename} due to acoustic feature extraction failure")
                    continue

                # Get linguistic features
                linguistic_data = self.linguistic_features.get(key, {})

                if not linguistic_data.get('raw_text'):
                    print(f"    Skipping {filename} due to missing transcript")
                    continue

                # Combine features
                combined_data[key] = {
                    'filename': filename,
                    'category': category,
                    'file_path': file_path,
                    'acoustic': acoustic_features,
                    'linguistic': linguistic_data
                }

        self.combined_features = combined_data
        print(f"\nCombined features extracted for {len(combined_data)} files")
        return combined_data

    def create_feature_matrix(self):
        """Create matrices for correlation analysis"""
        if not self.combined_features:
            print("No combined features available. Run extract_combined_features first.")
            return None, None

        # Lists to store flattened features
        acoustic_matrix = []
        linguistic_matrix = []
        file_info = []

        for key, data in self.combined_features.items():
            acoustic_feat = data['acoustic']
            linguistic_feat = data['linguistic']

            # Flatten acoustic features
            acoustic_vector = []

            # eGeMAPS features
            if 'egemaps' in acoustic_feat:
                acoustic_vector.extend(acoustic_feat['egemaps'].flatten())

            # MFCC features
            if 'mfccs' in acoustic_feat:
                mfccs = acoustic_feat['mfccs']
                acoustic_vector.extend(mfccs['mean'])
                acoustic_vector.extend(mfccs['std'])
                acoustic_vector.extend(mfccs['delta'])
                acoustic_vector.extend(mfccs['delta2'])

            # Log-mel features
            if 'log_mel' in acoustic_feat:
                log_mel = acoustic_feat['log_mel']
                acoustic_vector.extend(log_mel['mean'])
                acoustic_vector.extend(log_mel['std'])

            # Prosodic features
            if 'prosodic' in acoustic_feat:
                prosodic = acoustic_feat['prosodic']
                acoustic_vector.extend([
                    prosodic['f0_mean'], prosodic['f0_std'],
                    prosodic['energy_mean'], prosodic['energy_std'],
                    prosodic['zero_crossing_rate'],
                    prosodic['spectral_centroid'], prosodic['spectral_rolloff'],
                    prosodic['duration']
                ])

            # Wav2Vec2 features (sample first 50 dimensions to avoid too many features)
            if 'wav2vec2' in acoustic_feat:
                wav2vec_feat = acoustic_feat['wav2vec2']
                if len(wav2vec_feat.shape) > 0:
                    acoustic_vector.extend(wav2vec_feat[:50])

            # Linguistic features
            linguistic_vector = [
                linguistic_feat['word_count'],
                linguistic_feat['sentence_count'],
                linguistic_feat['avg_word_length'],
                linguistic_feat['unique_words'],
                linguistic_feat['lexical_diversity']
            ]

            acoustic_matrix.append(acoustic_vector)
            linguistic_matrix.append(linguistic_vector)
            file_info.append({
                'key': key,
                'filename': data['filename'],
                'category': data['category']
            })

        # Convert to numpy arrays
        acoustic_matrix = np.array(acoustic_matrix)
        linguistic_matrix = np.array(linguistic_matrix)

        # Handle any NaN or infinite values
        acoustic_matrix = np.nan_to_num(acoustic_matrix, nan=0.0, posinf=0.0, neginf=0.0)
        linguistic_matrix = np.nan_to_num(linguistic_matrix, nan=0.0, posinf=0.0, neginf=0.0)

        print(f"Feature matrices created:")
        print(f"  Acoustic: {acoustic_matrix.shape}")
        print(f"  Linguistic: {linguistic_matrix.shape}")

        return acoustic_matrix, linguistic_matrix, file_info

    def calculate_correlations(self, acoustic_matrix, linguistic_matrix):
        """Calculate correlations between acoustic and linguistic features"""

        linguistic_feature_names = [
            'Word Count', 'Sentence Count', 'Avg Word Length',
            'Unique Words', 'Lexical Diversity'
        ]

        # Create comprehensive acoustic feature names
        acoustic_feature_names = []

        # eGeMAPS (88 features)
        acoustic_feature_names.extend([f'eGeMAPS_{i}' for i in range(88)])

        # MFCC (13 x 4 = 52 features)
        for stat in ['mean', 'std', 'delta', 'delta2']:
            acoustic_feature_names.extend([f'MFCC_{stat}_{i}' for i in range(13)])

        # Log-mel (80 x 2 = 160 features)
        for stat in ['mean', 'std']:
            acoustic_feature_names.extend([f'LogMel_{stat}_{i}' for i in range(80)])

        # Prosodic (8 features)
        prosodic_names = ['F0_Mean', 'F0_Std', 'Energy_Mean', 'Energy_Std',
                         'ZCR', 'SpectralCentroid', 'SpectralRolloff', 'Duration']
        acoustic_feature_names.extend(prosodic_names)

        # Wav2Vec2 (50 features)
        acoustic_feature_names.extend([f'Wav2Vec2_{i}' for i in range(50)])

        # Adjust feature names to match actual matrix size
        acoustic_feature_names = acoustic_feature_names[:acoustic_matrix.shape[1]]

        # Calculate correlation matrix
        print("Calculating feature correlations...")

        correlations = {}

        for i, ling_name in enumerate(linguistic_feature_names):
            correlations[ling_name] = {}

            for j, acoustic_name in enumerate(acoustic_feature_names):
                # Calculate Pearson correlation
                try:
                    pearson_r, pearson_p = pearsonr(linguistic_matrix[:, i], acoustic_matrix[:, j])
                    spearman_r, spearman_p = spearmanr(linguistic_matrix[:, i], acoustic_matrix[:, j])

                    correlations[ling_name][acoustic_name] = {
                        'pearson_r': pearson_r,
                        'pearson_p': pearson_p,
                        'spearman_r': spearman_r,
                        'spearman_p': spearman_p
                    }
                except:
                    correlations[ling_name][acoustic_name] = {
                        'pearson_r': 0.0,
                        'pearson_p': 1.0,
                        'spearman_r': 0.0,
                        'spearman_p': 1.0
                    }

        return correlations, acoustic_feature_names, linguistic_feature_names

    def plot_correlation_heatmap(self, correlations, acoustic_names, linguistic_names,
                                correlation_type='pearson_r', top_n=50):
        """Create heatmap of top correlations"""

        # Create correlation matrix for heatmap
        correlation_matrix = []
        selected_acoustic_names = []

        # Find top correlations for each linguistic feature
        for ling_name in linguistic_names:
            ling_correlations = correlations[ling_name]

            # Sort by absolute correlation value
            sorted_correlations = sorted(ling_correlations.items(),
                                       key=lambda x: abs(x[1][correlation_type]),
                                       reverse=True)

            # Take top correlations
            top_correlations = sorted_correlations[:top_n//len(linguistic_names)]

            for acoustic_name, corr_data in top_correlations:
                if acoustic_name not in selected_acoustic_names:
                    selected_acoustic_names.append(acoustic_name)

        # Create matrix with selected features
        matrix_data = []
        for ling_name in linguistic_names:
            row = []
            for acoustic_name in selected_acoustic_names:
                if acoustic_name in correlations[ling_name]:
                    row.append(correlations[ling_name][acoustic_name][correlation_type])
                else:
                    row.append(0.0)
            matrix_data.append(row)

        correlation_matrix = np.array(matrix_data)

        # Create heatmap
        plt.figure(figsize=(20, 8))
        sns.heatmap(correlation_matrix,
                   xticklabels=selected_acoustic_names,
                   yticklabels=linguistic_names,
                   annot=False,
                   cmap='RdBu_r',
                   center=0,
                   vmin=-1, vmax=1)

        plt.title(f'Linguistic-Acoustic Feature Correlations ({correlation_type.replace("_", " ").title()})')
        plt.xlabel('Acoustic Features')
        plt.ylabel('Linguistic Features')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.savefig(f'{self.base_path}/correlation_heatmap_{correlation_type}.png', dpi=300, bbox_inches='tight')
        plt.show()

    def find_strongest_correlations(self, correlations, threshold=0.3):
        """Find and display strongest correlations"""

        strong_correlations = []

        for ling_name, acoustic_correlations in correlations.items():
            for acoustic_name, corr_data in acoustic_correlations.items():
                pearson_r = corr_data['pearson_r']
                pearson_p = corr_data['pearson_p']

                if abs(pearson_r) >= threshold and pearson_p < 0.05:
                    strong_correlations.append({
                        'linguistic_feature': ling_name,
                        'acoustic_feature': acoustic_name,
                        'correlation': pearson_r,
                        'p_value': pearson_p,
                        'abs_correlation': abs(pearson_r)
                    })

        # Sort by absolute correlation
        strong_correlations.sort(key=lambda x: x['abs_correlation'], reverse=True)

        print(f"\n=== Strongest Correlations (|r| >= {threshold}, p < 0.05) ===")
        print(f"Found {len(strong_correlations)} significant correlations\n")

        for i, corr in enumerate(strong_correlations[:20]):  # Show top 20
            print(f"{i+1:2d}. {corr['linguistic_feature']} ↔ {corr['acoustic_feature']}")
            print(f"    Correlation: {corr['correlation']:6.3f}, p-value: {corr['p_value']:.4f}")
            print()

        return strong_correlations

    def plot_feature_relationships(self, acoustic_matrix, linguistic_matrix,
                                 acoustic_names, linguistic_names, file_info):
        """Plot relationships between specific features"""

        # Create scatter plots for interesting relationships
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.flatten()

        # Define some interesting feature pairs to plot
        feature_pairs = [
            ('Word Count', 'Duration'),
            ('Lexical Diversity', 'F0_Std'),
            ('Avg Word Length', 'SpectralCentroid'),
            ('Sentence Count', 'Energy_Mean'),
            ('Unique Words', 'SpectralRolloff'),
            ('Word Count', 'F0_Mean')
        ]

        categories = [info['category'] for info in file_info]
        unique_categories = list(set(categories))
        colors = plt.cm.Set3(np.linspace(0, 1, len(unique_categories)))

        for i, (ling_feat, acoustic_feat) in enumerate(feature_pairs):
            if i >= 6:
                break

            ax = axes[i]

            # Find feature indices
            ling_idx = linguistic_names.index(ling_feat)

            # Find acoustic feature index
            acoustic_idx = None
            for j, name in enumerate(acoustic_names):
                if acoustic_feat in name:
                    acoustic_idx = j
                    break

            if acoustic_idx is None:
                continue

            # Plot scatter for each category
            for cat_idx, category in enumerate(unique_categories):
                cat_mask = np.array(categories) == category
                ax.scatter(linguistic_matrix[cat_mask, ling_idx],
                          acoustic_matrix[cat_mask, acoustic_idx],
                          c=[colors[cat_idx]], label=category.replace('_', ' ').title(),
                          alpha=0.7)

            # Calculate and plot correlation line
            x_vals = linguistic_matrix[:, ling_idx]
            y_vals = acoustic_matrix[:, acoustic_idx]

            # Remove any invalid values
            valid_mask = ~(np.isnan(x_vals) | np.isnan(y_vals) | np.isinf(x_vals) | np.isinf(y_vals))
            x_vals = x_vals[valid_mask]
            y_vals = y_vals[valid_mask]

            if len(x_vals) > 1:
                z = np.polyfit(x_vals, y_vals, 1)
                p = np.poly1d(z)
                ax.plot(x_vals, p(x_vals), "r--", alpha=0.8)

                # Calculate correlation
                corr, p_val = pearsonr(x_vals, y_vals)
                ax.text(0.05, 0.95, f'r = {corr:.3f}\np = {p_val:.3f}',
                       transform=ax.transAxes, verticalalignment='top',
                       bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

            ax.set_xlabel(ling_feat)
            ax.set_ylabel(acoustic_feat)
            ax.set_title(f'{ling_feat} vs {acoustic_feat}')
            ax.legend()
            ax.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(f'{self.base_path}/feature_relationships.png', dpi=300, bbox_inches='tight')
        plt.show()

    def perform_pca_analysis(self, acoustic_matrix, linguistic_matrix, file_info):
        """Perform PCA analysis on combined features"""

        # Combine acoustic and linguistic features
        combined_matrix = np.hstack([acoustic_matrix, linguistic_matrix])

        # Standardize features
        scaler = StandardScaler()
        combined_scaled = scaler.fit_transform(combined_matrix)

        # Perform PCA
        pca = PCA(n_components=10)
        pca_result = pca.fit_transform(combined_scaled)

        print(f"PCA Analysis:")
        print(f"Explained variance ratio: {pca.explained_variance_ratio_}")
        print(f"Cumulative variance: {np.cumsum(pca.explained_variance_ratio_)}")

        # Plot PCA results
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

        # PCA scatter plot
        categories = [info['category'] for info in file_info]
        unique_categories = list(set(categories))
        colors = plt.cm.Set3(np.linspace(0, 1, len(unique_categories)))

        for i, category in enumerate(unique_categories):
            mask = np.array(categories) == category
            ax1.scatter(pca_result[mask, 0], pca_result[mask, 1],
                       c=[colors[i]], label=category.replace('_', ' ').title(),
                       alpha=0.7)

        ax1.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
        ax1.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
        ax1.set_title('PCA: Combined Acoustic-Linguistic Features')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # Explained variance plot
        ax2.bar(range(1, 11), pca.explained_variance_ratio_)
        ax2.set_xlabel('Principal Component')
        ax2.set_ylabel('Explained Variance Ratio')
        ax2.set_title('PCA Explained Variance')
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(f'{self.base_path}/pca_analysis.png', dpi=300, bbox_inches='tight')
        plt.show()

        return pca_result, pca

    def run_correlation_analysis(self, audio_files):
        """Run complete correlation analysis"""

        print("=== Linguistic-Acoustic Feature Correlation Analysis ===\n")

        # Extract combined features
        combined_features = self.extract_combined_features(audio_files)

        if not combined_features:
            print("Failed to extract combined features")
            return None

        # Create feature matrices
        acoustic_matrix, linguistic_matrix, file_info = self.create_feature_matrix()

        if acoustic_matrix is None:
            print("Failed to create feature matrices")
            return None

        # Calculate correlations
        correlations, acoustic_names, linguistic_names = self.calculate_correlations(
            acoustic_matrix, linguistic_matrix)

        # Find strongest correlations
        strong_correlations = self.find_strongest_correlations(correlations)

        # Create visualizations
        print("\nCreating correlation heatmap...")
        self.plot_correlation_heatmap(correlations, acoustic_names, linguistic_names)

        print("\nCreating feature relationship plots...")
        self.plot_feature_relationships(acoustic_matrix, linguistic_matrix,
                                      acoustic_names, linguistic_names, file_info)

        print("\nPerforming PCA analysis...")
        pca_result, pca = self.perform_pca_analysis(acoustic_matrix, linguistic_matrix, file_info)

        # Save results
        results = {
            'correlations': correlations,
            'strong_correlations': strong_correlations,
            'acoustic_names': acoustic_names,
            'linguistic_names': linguistic_names,
            'acoustic_matrix': acoustic_matrix,
            'linguistic_matrix': linguistic_matrix,
            'file_info': file_info,
            'pca_result': pca_result
        }

        with open(f'{self.base_path}/correlation_analysis_results.pkl', 'wb') as f:
            pickle.dump(results, f)

        print(f"\nAnalysis complete! Results saved to {self.base_path}/")
        print("Generated files:")
        print("  - correlation_heatmap_pearson_r.png")
        print("  - feature_relationships.png")
        print("  - pca_analysis.png")
        print("  - correlation_analysis_results.pkl")

        return results

# Usage example for integration with your existing code:
def run_correlation_analysis_pipeline():
    """Run the complete pipeline including correlation analysis"""

    # Initialize your existing analyzer
    analyzer = ADReSSoAnalyzer(base_path="/content/drive/MyDrive/Speech")

    # Run the main pipeline (if not already done)
    print("Running main analysis pipeline...")
    # main_results = analyzer.run_complete_pipeline()

    # Initialize correlation analyzer
    corr_analyzer = FeatureCorrelationAnalyzer(base_path="/content", analyzer=analyzer)

    # Run correlation analysis
    print("\n" + "="*60)
    print("CORRELATION ANALYSIS")
    print("="*60)

    correlation_results = corr_analyzer.run_correlation_analysis(main_results['audio_files'])

    return main_results, correlation_results

# If running as standalone script
if __name__ == "__main__":
    # Run complete pipeline
    main_results, correlation_results = run_correlation_analysis_pipeline()

    print("\n" + "="*60)
    print("ANALYSIS SUMMARY")
    print("="*60)

    print(f"Total files processed: {len(correlation_results['file_info'])}")
    print(f"Acoustic features: {len(correlation_results['acoustic_names'])}")
    print(f"Linguistic features: {len(correlation_results['linguistic_names'])}")
    print(f"Strong correlations found: {len(correlation_results['strong_correlations'])}")

NameError: name 'ADReSSoAnalyzer' is not defined