<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/Jul16_Speech.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install required packages
!pip install librosa soundfile opensmile speechbrain transformers torch openai-whisper
!pip install torch-geometric

Collecting opensmile
  Downloading opensmile-2.5.1-py3-none-manylinux_2_17_x86_64.whl.metadata (15 kB)
Collecting speechbrain
  Downloading speechbrain-1.0.3-py3-none-any.whl.metadata (24 kB)
Collecting openai-whisper
  Downloading openai_whisper-20250625.tar.gz (803 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m803.2/803.2 kB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting audobject>=0.6.1 (from opensmile)
  Downloading audobject-0.7.12-py3-none-any.whl.metadata (2.7 kB)
Collecting audinterface>=0.7.0 (from opensmile)
  Downloading audinterface-1.3.1-py3-none-any.whl.metadata (4.3 kB)
Collecting hyperpyyaml (from speechbrain)
  Downloading HyperPyYAML-1.2.2-py3-none-any.whl.metadata (7.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-

In [5]:
import os
import json
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Any, Tuple
import warnings
warnings.filterwarnings('ignore')

# Audio processing
import librosa
import opensmile
import whisper

# Deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Transformers
from transformers import (
    Wav2Vec2Processor, Wav2Vec2Model,
    BertTokenizer, BertModel,
    ViTModel, ViTFeatureExtractor
)

# Graph networks
import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GATConv, global_mean_pool

# ML utilities
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split

# Visualization
import networkx as nx
from tqdm import tqdm

class ADReSSoAnalyzer:
    """Complete ADReSSo analysis pipeline with error handling and checkpoints"""

    def __init__(self, base_path="/content/drive/MyDrive/Voice/extracted/ADReSSo21"):
        self.base_path = base_path
        self.output_path = "/content/drive/MyDrive/ADReSSo_Results"
        self.checkpoint_path = f"{self.output_path}/checkpoints"

        # Create output directories
        os.makedirs(self.output_path, exist_ok=True)
        os.makedirs(self.checkpoint_path, exist_ok=True)
        os.makedirs(f"{self.output_path}/visualizations", exist_ok=True)

        # Initialize containers
        self.audio_files = {}
        self.features = {}
        self.transcripts = {}
        self.linguistic_features = {}

        # Initialize models
        self.initialize_models()

    def initialize_models(self):
        """Initialize all required models"""
        print("Initializing models...")

        try:
            # Initialize openSMILE
            self.smile = opensmile.Smile(
                feature_set=opensmile.FeatureSet.eGeMAPSv02,
                feature_level=opensmile.FeatureLevel.Functionals,
            )
            print("✓ OpenSMILE initialized")
        except Exception as e:
            print(f"⚠ OpenSMILE initialization failed: {e}")
            self.smile = None

        try:
            # Initialize Whisper
            self.whisper_model = whisper.load_model("base")
            print("✓ Whisper model loaded")
        except Exception as e:
            print(f"⚠ Whisper initialization failed: {e}")
            self.whisper_model = None

        try:
            # Initialize Wav2Vec2
            self.wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
            self.wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
            print("✓ Wav2Vec2 models loaded")
        except Exception as e:
            print(f"⚠ Wav2Vec2 initialization failed: {e}")
            self.wav2vec_processor = None
            self.wav2vec_model = None

        try:
            # Initialize BERT
            self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            self.bert_model = BertModel.from_pretrained('bert-base-uncased')
            print("✓ BERT models loaded")
        except Exception as e:
            print(f"⚠ BERT initialization failed: {e}")
            self.bert_tokenizer = None
            self.bert_model = None

    def save_checkpoint(self, data: Any, filename: str, step: str):
        """Save checkpoint data"""
        filepath = f"{self.checkpoint_path}/{filename}"

        try:
            if filename.endswith('.pkl'):
                with open(filepath, 'wb') as f:
                    pickle.dump(data, f)
            elif filename.endswith('.json'):
                with open(filepath, 'w') as f:
                    json.dump(data, f, indent=2, ensure_ascii=False)
            elif filename.endswith('.csv'):
                if isinstance(data, pd.DataFrame):
                    data.to_csv(filepath, index=False)
                else:
                    pd.DataFrame(data).to_csv(filepath, index=False)

            print(f"✓ Checkpoint saved: {filename}")
            return True
        except Exception as e:
            print(f"⚠ Failed to save checkpoint {filename}: {e}")
            return False

    def load_checkpoint(self, filename: str):
        """Load checkpoint data"""
        filepath = f"{self.checkpoint_path}/{filename}"

        if not os.path.exists(filepath):
            return None

        try:
            if filename.endswith('.pkl'):
                with open(filepath, 'rb') as f:
                    return pickle.load(f)
            elif filename.endswith('.json'):
                with open(filepath, 'r') as f:
                    return json.load(f)
            elif filename.endswith('.csv'):
                return pd.read_csv(filepath)
        except Exception as e:
            print(f"⚠ Failed to load checkpoint {filename}: {e}")
            return None

    def step_1_get_audio_files(self) -> Dict[str, List[str]]:
        """Step 1: Get all audio files from the dataset"""
        print("\n" + "="*60)
        print("STEP 1: GETTING AUDIO FILES")
        print("="*60)

        # Check if checkpoint exists
        checkpoint_file = "step1_audio_files.json"
        audio_files = self.load_checkpoint(checkpoint_file)

        if audio_files is not None:
            print("✓ Loaded audio files from checkpoint")
            self.audio_files = audio_files
            return audio_files

        audio_files = {
            'diagnosis_ad': [],
            'diagnosis_cn': [],
            'progression_decline': [],
            'progression_no_decline': [],
            'progression_test': []
        }

        # Define paths
        paths = {
            'diagnosis_ad': f"{self.base_path}/diagnosis/train/audio/ad",
            'diagnosis_cn': f"{self.base_path}/diagnosis/train/audio/cn",
            'progression_decline': f"{self.base_path}/progression/train/audio/decline",
            'progression_no_decline': f"{self.base_path}/progression/train/audio/no_decline",
            'progression_test': f"{self.base_path}/progression/test-dist/audio"
        }

        # Collect files
        for category, path in paths.items():
            if os.path.exists(path):
                files = [f"{path}/{f}" for f in os.listdir(path) if f.endswith('.wav')]
                audio_files[category] = files
                print(f"✓ Found {len(files)} files in {category}")
            else:
                print(f"⚠ Path not found: {path}")

        total_files = sum(len(files) for files in audio_files.values())
        print(f"\nTotal audio files found: {total_files}")

        # Save checkpoint
        self.save_checkpoint(audio_files, checkpoint_file, "step1")
        self.audio_files = audio_files

        # Visualize file distribution
        self.visualize_file_distribution(audio_files)

        return audio_files

    def visualize_file_distribution(self, audio_files: Dict[str, List[str]]):
        """Visualize audio file distribution"""
        categories = list(audio_files.keys())
        counts = [len(files) for files in audio_files.values()]

        plt.figure(figsize=(12, 6))
        bars = plt.bar(categories, counts, color=['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7'])
        plt.title('Audio File Distribution by Category', fontsize=16, fontweight='bold')
        plt.xlabel('Category', fontsize=12)
        plt.ylabel('Number of Files', fontsize=12)
        plt.xticks(rotation=45, ha='right')

        # Add value labels on bars
        for bar, count in zip(bars, counts):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                    str(count), ha='center', va='bottom', fontweight='bold')

        plt.tight_layout()
        plt.savefig(f"{self.output_path}/visualizations/file_distribution.png", dpi=300, bbox_inches='tight')
        plt.show()

    def step_2_extract_acoustic_features(self, limit_per_category: int = None):
        """Step 2: Extract acoustic features from audio files"""
        print("\n" + "="*60)
        print("STEP 2: EXTRACTING ACOUSTIC FEATURES")
        print("="*60)

        # Check if checkpoint exists
        checkpoint_file = "step2_acoustic_features.pkl"
        features = self.load_checkpoint(checkpoint_file)

        if features is not None:
            print("✓ Loaded acoustic features from checkpoint")
            self.features = features
            return features

        features = {}

        for category, files in self.audio_files.items():
            if not files:
                continue

            print(f"\nProcessing {category}...")

            # Limit files if specified
            if limit_per_category:
                files = files[:limit_per_category]

            for file_path in tqdm(files, desc=f"Extracting features for {category}"):
                try:
                    filename = os.path.basename(file_path)
                    file_key = f"{category}_{filename}"

                    # Extract features
                    file_features = self.extract_acoustic_features_from_file(file_path)

                    if file_features is not None:
                        features[file_key] = {
                            'file_path': file_path,
                            'category': category,
                            'filename': filename,
                            **file_features
                        }

                except Exception as e:
                    print(f"⚠ Error processing {filename}: {e}")
                    continue

        print(f"\n✓ Extracted features from {len(features)} files")

        # Save checkpoint
        self.save_checkpoint(features, checkpoint_file, "step2")
        self.features = features

        # Visualize features
        self.visualize_acoustic_features(features)

        return features

    def extract_acoustic_features_from_file(self, audio_path: str) -> Dict[str, Any]:
        """Extract acoustic features from a single audio file"""
        features = {}

        try:
            # Load audio
            y, sr = librosa.load(audio_path, sr=16000)

            if len(y) == 0:
                return None

            # 1. eGeMAPS features
            if self.smile is not None:
                try:
                    egemaps = self.smile.process_file(audio_path).values.flatten()
                    features['egemaps'] = egemaps
                except Exception as e:
                    features['egemaps'] = np.zeros(88)
            else:
                features['egemaps'] = np.zeros(88)

            # 2. MFCC features
            try:
                mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
                features['mfccs'] = {
                    'mean': np.mean(mfccs, axis=1),
                    'std': np.std(mfccs, axis=1),
                    'delta': np.mean(librosa.feature.delta(mfccs), axis=1),
                    'delta2': np.mean(librosa.feature.delta(mfccs, order=2), axis=1)
                }
            except Exception as e:
                features['mfccs'] = {
                    'mean': np.zeros(13), 'std': np.zeros(13),
                    'delta': np.zeros(13), 'delta2': np.zeros(13)
                }

            # 3. Mel spectrogram
            try:
                mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=80)
                log_mel = librosa.power_to_db(mel_spec)
                features['log_mel'] = {
                    'mean': np.mean(log_mel, axis=1),
                    'std': np.std(log_mel, axis=1)
                }
            except Exception as e:
                features['log_mel'] = {
                    'mean': np.zeros(80), 'std': np.zeros(80)
                }

            # 4. Wav2Vec2 features
            if self.wav2vec_processor is not None and self.wav2vec_model is not None:
                try:
                    input_values = self.wav2vec_processor(
                        y, sampling_rate=16000, return_tensors="pt"
                    ).input_values

                    with torch.no_grad():
                        wav2vec_features = self.wav2vec_model(input_values).last_hidden_state
                    features['wav2vec2'] = torch.mean(wav2vec_features, dim=1).squeeze().numpy()
                except Exception as e:
                    features['wav2vec2'] = np.zeros(768)
            else:
                features['wav2vec2'] = np.zeros(768)

            # 5. Prosodic features
            try:
                f0 = librosa.yin(y, fmin=50, fmax=300, sr=sr)
                f0_clean = f0[f0 > 0]

                features['prosodic'] = {
                    'f0_mean': np.mean(f0_clean) if len(f0_clean) > 0 else 0.0,
                    'f0_std': np.std(f0_clean) if len(f0_clean) > 0 else 0.0,
                    'energy_mean': np.mean(librosa.feature.rms(y=y)),
                    'energy_std': np.std(librosa.feature.rms(y=y)),
                    'zero_crossing_rate': np.mean(librosa.feature.zero_crossing_rate(y)),
                    'spectral_centroid': np.mean(librosa.feature.spectral_centroid(y=y, sr=sr)),
                    'spectral_rolloff': np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr)),
                    'duration': len(y) / sr
                }
            except Exception as e:
                features['prosodic'] = {
                    'f0_mean': 0.0, 'f0_std': 0.0, 'energy_mean': 0.0, 'energy_std': 0.0,
                    'zero_crossing_rate': 0.0, 'spectral_centroid': 0.0, 'spectral_rolloff': 0.0,
                    'duration': 0.0
                }

        except Exception as e:
            print(f"Error processing audio file: {e}")
            return None

        return features

    def visualize_acoustic_features(self, features: Dict[str, Any]):
        """Visualize acoustic features"""
        if not features:
            return

        # Sample file for visualization
        sample_key = list(features.keys())[0]
        sample_features = features[sample_key]

        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle(f'Acoustic Features Visualization - {sample_key}', fontsize=16, fontweight='bold')

        # eGeMAPS
        axes[0, 0].plot(sample_features['egemaps'][:20])
        axes[0, 0].set_title('eGeMAPS Features (first 20)')
        axes[0, 0].set_xlabel('Feature Index')
        axes[0, 0].set_ylabel('Value')

        # MFCC
        mfcc_mean = sample_features['mfccs']['mean']
        axes[0, 1].plot(mfcc_mean, marker='o')
        axes[0, 1].set_title('MFCC Mean')
        axes[0, 1].set_xlabel('MFCC Coefficient')
        axes[0, 1].set_ylabel('Value')

        # Mel spectrogram
        mel_mean = sample_features['log_mel']['mean']
        axes[0, 2].plot(mel_mean)
        axes[0, 2].set_title('Log-Mel Spectrogram Mean')
        axes[0, 2].set_xlabel('Mel Bin')
        axes[0, 2].set_ylabel('Value')

        # Wav2Vec2
        axes[1, 0].plot(sample_features['wav2vec2'][:50])
        axes[1, 0].set_title('Wav2Vec2 Features (first 50)')
        axes[1, 0].set_xlabel('Feature Index')
        axes[1, 0].set_ylabel('Value')

        # Prosodic features
        prosodic = sample_features['prosodic']
        prosodic_names = list(prosodic.keys())
        prosodic_values = list(prosodic.values())

        axes[1, 1].bar(prosodic_names, prosodic_values)
        axes[1, 1].set_title('Prosodic Features')
        axes[1, 1].set_ylabel('Value')
        axes[1, 1].tick_params(axis='x', rotation=45)

        # Feature distribution by category
        categories = {}
        for key, feature_data in features.items():
            category = feature_data['category']
            if category not in categories:
                categories[category] = []
            categories[category].append(feature_data['prosodic']['duration'])

        for category, durations in categories.items():
            axes[1, 2].hist(durations, alpha=0.7, label=category, bins=20)

        axes[1, 2].set_title('Duration Distribution by Category')
        axes[1, 2].set_xlabel('Duration (seconds)')
        axes[1, 2].set_ylabel('Frequency')
        axes[1, 2].legend()

        plt.tight_layout()
        plt.savefig(f"{self.output_path}/visualizations/acoustic_features.png", dpi=300, bbox_inches='tight')
        plt.show()

    def step_3_extract_transcripts(self, limit_per_category: int = None):
        """Step 3: Extract transcripts using Whisper"""
        print("\n" + "="*60)
        print("STEP 3: EXTRACTING TRANSCRIPTS")
        print("="*60)

        # Check if checkpoint exists
        checkpoint_file = "step3_transcripts.json"
        transcripts = self.load_checkpoint(checkpoint_file)

        if transcripts is not None:
            print("✓ Loaded transcripts from checkpoint")
            self.transcripts = transcripts
            return transcripts

        if self.whisper_model is None:
            print("⚠ Whisper model not available, skipping transcript extraction")
            return {}

        transcripts = {}

        for category, files in self.audio_files.items():
            if not files:
                continue

            print(f"\nProcessing {category}...")

            # Limit files if specified
            if limit_per_category:
                files = files[:limit_per_category]

            for file_path in tqdm(files, desc=f"Transcribing {category}"):
                try:
                    filename = os.path.basename(file_path)
                    file_key = f"{category}_{filename}"

                    # Transcribe
                    result = self.whisper_model.transcribe(file_path)

                    transcripts[file_key] = {
                        'file_path': file_path,
                        'category': category,
                        'filename': filename,
                        'transcript': result["text"].strip(),
                        'language': result.get('language', 'en'),
                        'segments': len(result.get('segments', []))
                    }

                except Exception as e:
                    print(f"⚠ Error transcribing {filename}: {e}")
                    transcripts[file_key] = {
                        'file_path': file_path,
                        'category': category,
                        'filename': filename,
                        'transcript': "",
                        'error': str(e)
                    }

        print(f"\n✓ Extracted transcripts from {len(transcripts)} files")

        # Save checkpoint
        self.save_checkpoint(transcripts, checkpoint_file, "step3")
        self.transcripts = transcripts

        # Visualize transcripts
        self.visualize_transcripts(transcripts)

        return transcripts

    def visualize_transcripts(self, transcripts: Dict[str, Any]):
        """Visualize transcript statistics"""
        if not transcripts:
            return

        # Prepare data
        data = []
        for key, info in transcripts.items():
            transcript = info.get('transcript', '')
            data.append({
                'category': info['category'],
                'word_count': len(transcript.split()) if transcript else 0,
                'char_count': len(transcript),
                'has_error': 'error' in info
            })

        df = pd.DataFrame(data)

        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle('Transcript Analysis', fontsize=16, fontweight='bold')

        # Word count distribution
        df.boxplot(column='word_count', by='category', ax=axes[0, 0])
        axes[0, 0].set_title('Word Count Distribution by Category')
        axes[0, 0].set_ylabel('Word Count')

        # Character count distribution
        df.boxplot(column='char_count', by='category', ax=axes[0, 1])
        axes[0, 1].set_title('Character Count Distribution by Category')
        axes[0, 1].set_ylabel('Character Count')

        # Error rate by category
        error_rate = df.groupby('category')['has_error'].mean()
        axes[1, 0].bar(error_rate.index, error_rate.values)
        axes[1, 0].set_title('Error Rate by Category')
        axes[1, 0].set_ylabel('Error Rate')
        axes[1, 0].tick_params(axis='x', rotation=45)

        # Average metrics by category
        avg_metrics = df.groupby('category')[['word_count', 'char_count']].mean()
        avg_metrics.plot(kind='bar', ax=axes[1, 1])
        axes[1, 1].set_title('Average Metrics by Category')
        axes[1, 1].set_ylabel('Count')
        axes[1, 1].tick_params(axis='x', rotation=45)
        axes[1, 1].legend()

        plt.tight_layout()
        plt.savefig(f"{self.output_path}/visualizations/transcript_analysis.png", dpi=300, bbox_inches='tight')
        plt.show()

    def step_4_extract_linguistic_features(self):
        """Step 4: Extract linguistic features for BERT"""
        print("\n" + "="*60)
        print("STEP 4: EXTRACTING LINGUISTIC FEATURES")
        print("="*60)

        # Check if checkpoint exists
        checkpoint_file = "step4_linguistic_features.pkl"
        linguistic_features = self.load_checkpoint(checkpoint_file)

        if linguistic_features is not None:
            print("✓ Loaded linguistic features from checkpoint")
            self.linguistic_features = linguistic_features
            return linguistic_features

        if self.bert_tokenizer is None:
            print("⚠ BERT tokenizer not available, skipping linguistic feature extraction")
            return {}

        linguistic_features = {}

        print("Processing transcripts for linguistic features...")

        for key, data in tqdm(self.transcripts.items(), desc="Extracting linguistic features"):
            transcript = data.get('transcript', '')

            if not transcript:
                linguistic_features[key] = self.create_empty_linguistic_features()
                continue

            try:
                # Basic linguistic features
                words = transcript.split()
                sentences = [s.strip() for s in transcript.split('.') if s.strip()]

                # BERT tokenization
                bert_encoding = self.bert_tokenizer(
                    transcript,
                    truncation=True,
                    padding='max_length',
                    max_length=512,
                    return_tensors='pt'
                )

                linguistic_features[key] = {
                    'raw_text': transcript,
                    'word_count': len(words),
                    'sentence_count': len(sentences),
                    'avg_word_length': np.mean([len(word) for word in words]) if words else 0,
                    'unique_words': len(set(words)),
                    'lexical_diversity': len(set(words)) / len(words) if words else 0,
                    'bert_input_ids': bert_encoding['input_ids'].squeeze().tolist(),
                    'bert_attention_mask': bert_encoding['attention_mask'].squeeze().tolist(),
                    'category': data['category']
                }

            except Exception as e:
                print(f"⚠ Error processing {key}: {e}")
                linguistic_features[key] = self.create_empty_linguistic_features()

        print(f"\n✓ Extracted linguistic features from {len(linguistic_features)} files")

        # Save checkpoint
        self.save_checkpoint(linguistic_features, checkpoint_file, "step4")
        self.linguistic_features = linguistic_features

        # Visualize linguistic features
        self.visualize_linguistic_features(linguistic_features)

        return linguistic_features

    def create_empty_linguistic_features(self):
        """Create empty linguistic features structure"""
        return {
            'raw_text': '',
            'word_count': 0,
            'sentence_count': 0,
            'avg_word_length': 0,
            'unique_words': 0,
            'lexical_diversity': 0,
            'bert_input_ids': [0] * 512,
            'bert_attention_mask': [0] * 512,
            'category': 'unknown'
        }

    def visualize_linguistic_features(self, linguistic_features: Dict[str, Any]):
        """Visualize linguistic features"""
        if not linguistic_features:
            return

        # Prepare data
        data = []
        for key, features in linguistic_features.items():
            data.append({
                'key': key,
                'category': features['category'],
                'word_count': features['word_count'],
                'sentence_count': features['sentence_count'],
                'avg_word_length': features['avg_word_length'],
                'unique_words': features['unique_words'],
                'lexical_diversity': features['lexical_diversity']
            })

        df = pd.DataFrame(data)

        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle('Linguistic Features Analysis', fontsize=16, fontweight='bold')

        # Metrics by category
        metrics = ['word_count', 'sentence_count', 'avg_word_length', 'unique_words', 'lexical_diversity']

        for i, metric in enumerate(metrics):
            row = i // 3
            col = i % 3

            if row < 2 and col < 3:
                df.boxplot(column=metric, by='category', ax=axes[row, col])
                axes[row, col].set_title(f'{metric.replace("_", " ").title()} by Category')
                axes[row, col].set_ylabel(metric.replace("_", " ").title())

        # Correlation heatmap
        numeric_df = df.select_dtypes(include=[np.number])
        correlation_matrix = numeric_df.corr()

        sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0, ax=axes[1, 2])
        axes[1, 2].set_title('Feature Correlation Matrix')

        plt.tight_layout()
        plt.savefig(f"{self.output_path}/visualizations/linguistic_features.png", dpi=300, bbox_inches='tight')
        plt.show()

    def run_complete_pipeline(self, limit_per_category: int = None):
        """Run the complete analysis pipeline"""
        print("="*80)
        print("ADRESSO21 COMPLETE ANALYSIS PIPELINE")
        print("="*80)

        results = {}

        # Step 1: Get audio files
        results['audio_files'] = self.step_1_get_audio_files()

        # Step 2: Extract acoustic features
        results['acoustic_features'] = self.step_2_extract_acoustic_features(limit_per_category)

        # Step 3: Extract transcripts
        results['transcripts'] = self.step_3_extract_transcripts(limit_per_category)

        # Step 4: Extract linguistic features
        results['linguistic_features'] = self.step_4_extract_linguistic_features()

        # Generate final summary
        self.generate_final_summary(results)

        print("\n" + "="*80)
        print("PIPELINE COMPLETED SUCCESSFULLY!")
        print("="*80)
        print(f"Results saved to: {self.output_path}")
        print(f"Checkpoints saved to: {self.checkpoint_path}")
        print(f"Visualizations saved to: {self.output_path}/visualizations")

        return results


In [7]:
def generate_final_summary(self, results: Dict[str, Any]):
        """Generate final summary report"""
        print("\n" + "="*60)
        print("GENERATING FINAL SUMMARY")
        print("="*60)

        summary = {
            'total_audio_files': sum(len(files) for files in results['audio_files'].values()),
            'processed_features': len(results['acoustic_features']),
            'transcripts_extracted': len(results['transcripts']),
            'linguistic_features_extracted': len(results['linguistic_features']),
            'categories': list(results['audio_files'].keys()),
            'feature_dimensions': self.get_feature_dimensions(),
            'processing_errors': self.count_processing_errors(results)
        }

        # Save summary
        self.save_checkpoint(summary, "final_summary.json", "summary")

        # Print summary
        print("\n📊 FINAL SUMMARY REPORT:")
        print("-" * 40)
        for key, value in summary.items():
            print(f"{key.replace('_', ' ').title()}: {value}")

        return summary

    def get_feature_dimensions(self):
        """Get dimensions of extracted features"""
        if not self.features:
            return {}

        sample_key = list(self.features.keys())[0]
        sample_features = self.features[sample_key]

        dimensions = {
            'egemaps': len(sample_features.get('egemaps', [])),
            'mfcc_mean': len(sample_features.get('mfccs', {}).get('mean', [])),
            'log_mel_mean': len(sample_features.get('log_mel', {}).get('mean', [])),
            'wav2vec2': len(sample_features.get('wav2vec2', [])),
            'prosodic': len(sample_features.get('prosodic', {}))
        }

        return dimensions

    def count_processing_errors(self, results: Dict[str, Any]):
        """Count processing errors across all steps"""
        errors = {
            'acoustic_errors': 0,
            'transcript_errors': 0,
            'linguistic_errors': 0
        }

        # Count transcript errors
        for key, data in results['transcripts'].items():
            if 'error' in data:
                errors['transcript_errors'] += 1

        # Count linguistic errors (empty features)
        for key, data in results['linguistic_features'].items():
            if data.get('word_count', 0) == 0 and data.get('raw_text', ''):
                errors['linguistic_errors'] += 1

        return errors

    def step_5_create_multimodal_dataset(self):
        """Step 5: Create unified multimodal dataset"""
        print("\n" + "="*60)
        print("STEP 5: CREATING MULTIMODAL DATASET")
        print("="*60)

        # Check if checkpoint exists
        checkpoint_file = "step5_multimodal_dataset.pkl"
        dataset = self.load_checkpoint(checkpoint_file)

        if dataset is not None:
            print("✓ Loaded multimodal dataset from checkpoint")
            return dataset

        dataset = []

        print("Creating unified multimodal dataset...")

        # Get common keys across all feature types
        common_keys = set(self.features.keys()) & set(self.transcripts.keys()) & set(self.linguistic_features.keys())

        for key in tqdm(common_keys, desc="Creating multimodal samples"):
            try:
                # Get features
                acoustic_features = self.features[key]
                transcript_data = self.transcripts[key]
                linguistic_features = self.linguistic_features[key]

                # Create unified sample
                sample = {
                    'id': key,
                    'category': acoustic_features['category'],
                    'file_path': acoustic_features['file_path'],
                    'filename': acoustic_features['filename'],

                    # Acoustic features
                    'acoustic': self.flatten_acoustic_features(acoustic_features),

                    # Text features
                    'transcript': transcript_data.get('transcript', ''),
                    'linguistic': linguistic_features,

                    # Labels
                    'diagnosis_label': 1 if 'ad' in acoustic_features['category'] else 0,
                    'progression_label': 1 if 'decline' in acoustic_features['category'] else 0,
                }

                dataset.append(sample)

            except Exception as e:
                print(f"⚠ Error creating sample for {key}: {e}")
                continue

        print(f"\n✓ Created multimodal dataset with {len(dataset)} samples")

        # Save checkpoint
        self.save_checkpoint(dataset, checkpoint_file, "step5")

        # Visualize dataset
        self.visualize_multimodal_dataset(dataset)

        return dataset

    def flatten_acoustic_features(self, features: Dict[str, Any]):
        """Flatten acoustic features into a single vector"""
        flattened = []

        # eGeMAPS
        if 'egemaps' in features:
            flattened.extend(features['egemaps'])

        # MFCC
        if 'mfccs' in features:
            mfcc_data = features['mfccs']
            flattened.extend(mfcc_data.get('mean', []))
            flattened.extend(mfcc_data.get('std', []))
            flattened.extend(mfcc_data.get('delta', []))
            flattened.extend(mfcc_data.get('delta2', []))

        # Log-Mel
        if 'log_mel' in features:
            mel_data = features['log_mel']
            flattened.extend(mel_data.get('mean', []))
            flattened.extend(mel_data.get('std', []))

        # Wav2Vec2
        if 'wav2vec2' in features:
            flattened.extend(features['wav2vec2'])

        # Prosodic
        if 'prosodic' in features:
            prosodic_data = features['prosodic']
            flattened.extend(list(prosodic_data.values()))

        return np.array(flattened, dtype=np.float32)

    def visualize_multimodal_dataset(self, dataset: List[Dict[str, Any]]):
        """Visualize multimodal dataset statistics"""
        if not dataset:
            return

        # Prepare data
        categories = [sample['category'] for sample in dataset]
        diagnosis_labels = [sample['diagnosis_label'] for sample in dataset]
        progression_labels = [sample['progression_label'] for sample in dataset]
        acoustic_dims = [len(sample['acoustic']) for sample in dataset]

        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle('Multimodal Dataset Overview', fontsize=16, fontweight='bold')

        # Category distribution
        category_counts = pd.Series(categories).value_counts()
        axes[0, 0].pie(category_counts.values, labels=category_counts.index, autopct='%1.1f%%')
        axes[0, 0].set_title('Category Distribution')

        # Label distributions
        axes[0, 1].bar(['Control', 'AD'], [sum(1-np.array(diagnosis_labels)), sum(diagnosis_labels)])
        axes[0, 1].set_title('Diagnosis Label Distribution')
        axes[0, 1].set_ylabel('Count')

        # Acoustic feature dimensions
        axes[1, 0].hist(acoustic_dims, bins=20, edgecolor='black')
        axes[1, 0].set_title('Acoustic Feature Dimensions')
        axes[1, 0].set_xlabel('Dimension')
        axes[1, 0].set_ylabel('Count')

        # Sample acoustic features correlation
        if len(dataset) > 0:
            sample_acoustic = dataset[0]['acoustic'][:50]  # First 50 features
            axes[1, 1].plot(sample_acoustic, marker='o', markersize=3)
            axes[1, 1].set_title('Sample Acoustic Features (first 50)')
            axes[1, 1].set_xlabel('Feature Index')
            axes[1, 1].set_ylabel('Value')

        plt.tight_layout()
        plt.savefig(f"{self.output_path}/visualizations/multimodal_dataset.png", dpi=300, bbox_inches='tight')
        plt.show()

    def step_6_create_graph_structures(self, dataset: List[Dict[str, Any]]):
        """Step 6: Create graph structures for graph neural networks"""
        print("\n" + "="*60)
        print("STEP 6: CREATING GRAPH STRUCTURES")
        print("="*60)

        # Check if checkpoint exists
        checkpoint_file = "step6_graph_structures.pkl"
        graph_data = self.load_checkpoint(checkpoint_file)

        if graph_data is not None:
            print("✓ Loaded graph structures from checkpoint")
            return graph_data

        graph_data = []

        print("Creating graph structures...")

        for sample in tqdm(dataset, desc="Creating graphs"):
            try:
                # Create feature-based graph
                acoustic_features = sample['acoustic']

                # Create nodes (features as nodes)
                num_nodes = min(len(acoustic_features), 100)  # Limit for computational efficiency
                node_features = acoustic_features[:num_nodes].reshape(-1, 1)

                # Create edges based on feature similarity
                edges = self.create_feature_similarity_edges(acoustic_features[:num_nodes])

                # Create PyTorch Geometric data object
                graph = Data(
                    x=torch.tensor(node_features, dtype=torch.float),
                    edge_index=torch.tensor(edges, dtype=torch.long).t().contiguous(),
                    y=torch.tensor([sample['diagnosis_label']], dtype=torch.long),
                    sample_id=sample['id']
                )

                graph_data.append(graph)

            except Exception as e:
                print(f"⚠ Error creating graph for {sample['id']}: {e}")
                continue

        print(f"\n✓ Created {len(graph_data)} graph structures")

        # Save checkpoint
        self.save_checkpoint(graph_data, checkpoint_file, "step6")

        # Visualize graphs
        self.visualize_graph_structures(graph_data[:5])  # Visualize first 5 graphs

        return graph_data

    def create_feature_similarity_edges(self, features: np.ndarray, threshold: float = 0.7):
        """Create edges based on feature similarity"""
        edges = []

        # Create similarity matrix
        similarity_matrix = np.corrcoef(features.reshape(1, -1), features.reshape(1, -1))

        # Create edges for similar features
        for i in range(len(features)):
            for j in range(i+1, len(features)):
                if abs(features[i] - features[j]) < threshold:
                    edges.append([i, j])
                    edges.append([j, i])  # Undirected graph

        # If no edges, create a simple chain
        if not edges:
            for i in range(len(features)-1):
                edges.append([i, i+1])
                edges.append([i+1, i])

        return edges

    def visualize_graph_structures(self, graph_data: List[Data]):
        """Visualize graph structures"""
        if not graph_data:
            return

        fig, axes = plt.subplots(1, min(3, len(graph_data)), figsize=(15, 5))
        if len(graph_data) == 1:
            axes = [axes]

        fig.suptitle('Graph Structure Visualization', fontsize=16, fontweight='bold')

        for i, graph in enumerate(graph_data[:3]):
            ax = axes[i] if len(graph_data) > 1 else axes[0]

            # Convert to NetworkX for visualization
            G = nx.Graph()
            edge_index = graph.edge_index.numpy()

            # Add nodes
            for node in range(graph.x.shape[0]):
                G.add_node(node)

            # Add edges
            for edge in edge_index.T:
                G.add_edge(edge[0], edge[1])

            # Draw graph
            pos = nx.spring_layout(G, k=0.5, iterations=50)
            nx.draw(G, pos, ax=ax, node_size=30, node_color='lightblue',
                   edge_color='gray', alpha=0.7, with_labels=False)

            ax.set_title(f'Graph {i+1}\nNodes: {graph.x.shape[0]}, Edges: {graph.edge_index.shape[1]}')
            ax.axis('off')

        plt.tight_layout()
        plt.savefig(f"{self.output_path}/visualizations/graph_structures.png", dpi=300, bbox_inches='tight')
        plt.show()

    def step_7_prepare_training_data(self, dataset: List[Dict[str, Any]]):
        """Step 7: Prepare data for training"""
        print("\n" + "="*60)
        print("STEP 7: PREPARING TRAINING DATA")
        print("="*60)

        # Check if checkpoint exists
        checkpoint_file = "step7_training_data.pkl"
        training_data = self.load_checkpoint(checkpoint_file)

        if training_data is not None:
            print("✓ Loaded training data from checkpoint")
            return training_data

        print("Preparing training data...")

        # Separate by task
        diagnosis_data = []
        progression_data = []

        for sample in dataset:
            category = sample['category']

            if 'diagnosis' in category:
                diagnosis_data.append(sample)
            elif 'progression' in category:
                progression_data.append(sample)

        # Prepare diagnosis task data
        diagnosis_train_data = self.prepare_task_data(diagnosis_data, 'diagnosis')

        # Prepare progression task data
        progression_train_data = self.prepare_task_data(progression_data, 'progression')

        training_data = {
            'diagnosis': diagnosis_train_data,
            'progression': progression_train_data,
            'combined': dataset
        }

        # Save checkpoint
        self.save_checkpoint(training_data, checkpoint_file, "step7")

        # Visualize training data
        self.visualize_training_data(training_data)

        return training_data

    def prepare_task_data(self, data: List[Dict[str, Any]], task: str):
        """Prepare data for specific task"""
        if not data:
            return None

        # Extract features and labels
        X_acoustic = np.array([sample['acoustic'] for sample in data])
        X_linguistic = np.array([sample['linguistic']['bert_input_ids'] for sample in data])

        if task == 'diagnosis':
            y = np.array([sample['diagnosis_label'] for sample in data])
        else:
            y = np.array([sample['progression_label'] for sample in data])

        # Split data
        X_acoustic_train, X_acoustic_test, X_ling_train, X_ling_test, y_train, y_test = train_test_split(
            X_acoustic, X_linguistic, y, test_size=0.2, random_state=42, stratify=y
        )

        # Scale acoustic features
        scaler = StandardScaler()
        X_acoustic_train_scaled = scaler.fit_transform(X_acoustic_train)
        X_acoustic_test_scaled = scaler.transform(X_acoustic_test)

        return {
            'X_acoustic_train': X_acoustic_train_scaled,
            'X_acoustic_test': X_acoustic_test_scaled,
            'X_linguistic_train': X_ling_train,
            'X_linguistic_test': X_ling_test,
            'y_train': y_train,
            'y_test': y_test,
            'scaler': scaler,
            'sample_ids': [sample['id'] for sample in data]
        }

    def visualize_training_data(self, training_data: Dict[str, Any]):
        """Visualize training data splits"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle('Training Data Visualization', fontsize=16, fontweight='bold')

        for i, (task, data) in enumerate(training_data.items()):
            if task == 'combined' or data is None:
                continue

            row = i // 2
            col = i % 2

            # Class distribution
            train_dist = pd.Series(data['y_train']).value_counts()
            test_dist = pd.Series(data['y_test']).value_counts()

            x = np.arange(len(train_dist))
            width = 0.35

            axes[row, col].bar(x - width/2, train_dist.values, width, label='Train', alpha=0.8)
            axes[row, col].bar(x + width/2, test_dist.values, width, label='Test', alpha=0.8)

            axes[row, col].set_title(f'{task.title()} Task - Class Distribution')
            axes[row, col].set_xlabel('Class')
            axes[row, col].set_ylabel('Count')
            axes[row, col].set_xticks(x)
            axes[row, col].set_xticklabels(train_dist.index)
            axes[row, col].legend()

        # Feature distribution
        if training_data['diagnosis'] is not None:
            diagnosis_features = training_data['diagnosis']['X_acoustic_train']
            axes[1, 0].hist(diagnosis_features[:, 0], bins=30, alpha=0.7, label='Feature 1')
            axes[1, 0].hist(diagnosis_features[:, 1], bins=30, alpha=0.7, label='Feature 2')
            axes[1, 0].set_title('Diagnosis: Feature Distribution')
            axes[1, 0].set_xlabel('Feature Value')
            axes[1, 0].set_ylabel('Frequency')
            axes[1, 0].legend()

        # Dataset size comparison
        sizes = []
        labels = []
        for task, data in training_data.items():
            if task != 'combined' and data is not None:
                sizes.append(len(data['y_train']) + len(data['y_test']))
                labels.append(task.title())

        if sizes:
            axes[1, 1].bar(labels, sizes, color=['skyblue', 'lightgreen'])
            axes[1, 1].set_title('Dataset Sizes by Task')
            axes[1, 1].set_ylabel('Number of Samples')

        plt.tight_layout()
        plt.savefig(f"{self.output_path}/visualizations/training_data.png", dpi=300, bbox_inches='tight')
        plt.show()

    def step_8_train_models(self, training_data: Dict[str, Any]):
        """Step 8: Train machine learning models"""
        print("\n" + "="*60)
        print("STEP 8: TRAINING MODELS")
        print("="*60)

        # Check if checkpoint exists
        checkpoint_file = "step8_trained_models.pkl"
        trained_models = self.load_checkpoint(checkpoint_file)

        if trained_models is not None:
            print("✓ Loaded trained models from checkpoint")
            return trained_models

        trained_models = {}

        # Train models for each task
        for task_name, task_data in training_data.items():
            if task_name == 'combined' or task_data is None:
                continue

            print(f"\nTraining models for {task_name} task...")

            # Train different model types
            task_models = {}

            # 1. Simple Neural Network
            task_models['neural_network'] = self.train_neural_network(task_data, task_name)

            # 2. Multimodal Fusion Model
            task_models['multimodal_fusion'] = self.train_multimodal_fusion(task_data, task_name)

            trained_models[task_name] = task_models

        # Save checkpoint
        self.save_checkpoint(trained_models, checkpoint_file, "step8")

        # Visualize model performance
        self.visualize_model_performance(trained_models)

        return trained_models

    def train_neural_network(self, data: Dict[str, Any], task_name: str):
        """Train a simple neural network"""
        print(f"  Training Neural Network for {task_name}...")

        # Create model
        input_dim = data['X_acoustic_train'].shape[1]
        model = self.create_neural_network(input_dim)

        # Prepare data
        X_train = torch.tensor(data['X_acoustic_train'], dtype=torch.float32)
        y_train = torch.tensor(data['y_train'], dtype=torch.long)
        X_test = torch.tensor(data['X_acoustic_test'], dtype=torch.float32)
        y_test = torch.tensor(data['y_test'], dtype=torch.long)

        # Training setup
        criterion = nn.CrossEntropyLoss()
        optimizer = Adam(model.parameters(), lr=0.001)

        # Training loop
        model.train()
        for epoch in range(100):
            optimizer.zero_grad()
            outputs = model(X_train)
            loss = criterion(outputs, y_train)
            loss.backward()
            optimizer.step()

        # Evaluate
        model.eval()
        with torch.no_grad():
            test_outputs = model(X_test)
            test_predictions = torch.argmax(test_outputs, dim=1)
            test_accuracy = accuracy_score(y_test.numpy(), test_predictions.numpy())

        return {
            'model': model,
            'accuracy': test_accuracy,
            'predictions': test_predictions.numpy(),
            'true_labels': y_test.numpy()
        }

    def create_neural_network(self, input_dim: int):
        """Create a simple neural network"""
        return nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 2)  # Binary classification
        )

    def train_multimodal_fusion(self, data: Dict[str, Any], task_name: str):
        """Train multimodal fusion model"""
        print(f"  Training Multimodal Fusion for {task_name}...")

        # Create model
        acoustic_dim = data['X_acoustic_train'].shape[1]
        linguistic_dim = data['X_linguistic_train'].shape[1]
        model = self.create_multimodal_fusion_model(acoustic_dim, linguistic_dim)

        # Prepare data
        X_acoustic_train = torch.tensor(data['X_acoustic_train'], dtype=torch.float32)
        X_linguistic_train = torch.tensor(data['X_linguistic_train'], dtype=torch.float32)
        y_train = torch.tensor(data['y_train'], dtype=torch.long)

        X_acoustic_test = torch.tensor(data['X_acoustic_test'], dtype=torch.float32)
        X_linguistic_test = torch.tensor(data['X_linguistic_test'], dtype=torch.float32)
        y_test = torch.tensor(data['y_test'], dtype=torch.long)

        # Training setup
        criterion = nn.CrossEntropyLoss()
        optimizer = Adam(model.parameters(), lr=0.001)

        # Training loop
        model.train()
        for epoch in range(100):
            optimizer.zero_grad()
            outputs = model(X_acoustic_train, X_linguistic_train)
            loss = criterion(outputs, y_train)
            loss.backward()
            optimizer.step()

        # Evaluate
        model.eval()
        with torch.no_grad():
            test_outputs = model(X_acoustic_test, X_linguistic_test)
            test_predictions = torch.argmax(test_outputs, dim=1)
            test_accuracy = accuracy_score(y_test.numpy(), test_predictions.numpy())

        return {
            'model': model,
            'accuracy': test_accuracy,
            'predictions': test_predictions.numpy(),
            'true_labels': y_test.numpy()
        }

    def create_multimodal_fusion_model(self, acoustic_dim: int, linguistic_dim: int):
        """Create multimodal fusion model"""

        class MultimodalFusionModel(nn.Module):
            def __init__(self, acoustic_dim, linguistic_dim):
                super().__init__()

                # Acoustic branch
                self.acoustic_branch = nn.Sequential(
                    nn.Linear(acoustic_dim, 256),
                    nn.ReLU(),
                    nn.Dropout(0.3),
                    nn.Linear(256, 128)
                )

                # Linguistic branch
                self.linguistic_branch = nn.Sequential(
                    nn.Linear(linguistic_dim, 256),
                    nn.ReLU(),
                    nn.Dropout(0.3),
                    nn.Linear(256, 128)
                )

                # Fusion layer
                self.fusion = nn.Sequential(
                    nn.Linear(256, 128),
                    nn.ReLU(),
                    nn.Dropout(0.3),
                    nn.Linear(128, 64),
                    nn.ReLU(),
                    nn.Linear(64, 2)
                )

            def forward(self, acoustic, linguistic):
                acoustic_features = self.acoustic_branch(acoustic)
                linguistic_features = self.linguistic_branch(linguistic)

                # Concatenate features
                fused_features = torch.cat([acoustic_features, linguistic_features], dim=1)

                return self.fusion(fused_features)

        return MultimodalFusionModel(acoustic_dim, linguistic_dim)

    def visualize_model_performance(self, trained_models: Dict[str, Any]):
        """Visualize model performance"""
        if not trained_models:
            return

        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle('Model Performance Visualization', fontsize=16, fontweight='bold')

        # Performance comparison
        tasks = []
        models = []
        accuracies = []

        for task_name, task_models in trained_models.items():
            for model_name, model_info in task_models.items():
                tasks.append(task_name)
                models.append(model_name)
                accuracies.append(model_info['accuracy'])

        # Accuracy comparison
        performance_df = pd.DataFrame({
            'Task': tasks,
            'Model': models,
            'Accuracy': accuracies
        })

        # Group by task and model
        pivot_df = performance_df.pivot(index='Task', columns='Model', values='Accuracy')
        pivot_df.plot(kind='bar', ax=axes[0, 0])
        axes[0, 0].set_title('Model Accuracy Comparison')
        axes[0, 0].set_ylabel('Accuracy')
        axes[0, 0].set_xlabel('Task')
        axes[0, 0].legend(title='Model')
        axes[0, 0].tick_params(axis='x', rotation=45)

        # Confusion matrix for best model
        best_task = max(trained_models.keys(), key=lambda x: max(model['accuracy'] for model in trained_models[x].values()))
        best_model_name = max(trained_models[best_task].keys(), key=lambda x: trained_models[best_task][x]['accuracy'])
        best_model = trained_models[best_task][best_model_name]

        cm = confusion_matrix(best_model['true_labels'], best_model['predictions'])
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0, 1])
        axes[0, 1].set_title(f'Confusion Matrix - {best_task} ({best_model_name})')
        axes[0, 1].set_ylabel('True Label')
        axes[0, 1].set_xlabel('Predicted Label')

        # Model comparison
        model_names = list(set(models))
        avg_accuracies = [performance_df[performance_df['Model'] == model]['Accuracy'].mean() for model in model_names]

        axes[1, 0].bar(model_names, avg_accuracies, color=['skyblue', 'lightgreen'])
        axes[1, 0].set_title('Average Model Performance')
        axes[1, 0].set_ylabel('Average Accuracy')
        axes[1, 0].set_xlabel('Model Type')
        axes[1, 0].tick_params(axis='x', rotation=45)

        # Task difficulty comparison
        task_names = list(set(tasks))
        avg_task_accuracies = [performance_df[performance_df['Task'] == task]['Accuracy'].mean() for task in task_names]

        axes[1, 1].bar(task_names, avg_task_accuracies, color=['coral', 'lightblue'])
        axes[1, 1].set_title('Task Difficulty Comparison')
        axes[1, 1].set_ylabel('Average Accuracy')
        axes[1, 1].set_xlabel('Task')
        axes[1, 1].tick_params(axis='x', rotation=45)

        plt.tight_layout()
        plt.savefig(f"{self.output_path}/visualizations/model_performance.png", dpi=300, bbox_inches='tight')
        plt.show()

    def run_extended_pipeline(self, limit_per_category: int = None):
        """Run the complete extended analysis pipeline"""
        print("="*80)
        print("ADRESSO21 EXTENDED ANALYSIS PIPELINE")
        print("="*80)

        # Run basic pipeline
        basic_results = self.run_complete_pipeline(limit_per_category)

        # Extended steps
        print("\n" + "="*60)
        print("RUNNING EXTENDED STEPS")
        print("="*60)



IndentationError: unindent does not match any outer indentation level (<tokenize>, line 28)