<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/Alz_voice.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Complete ADReSSo Multi-Modal Analysis Pipeline Steps

## Project Overview
This project implements a comprehensive multi-modal machine learning pipeline for Alzheimer's Dementia Recognition through Spontaneous Speech (ADReSSo). The system combines audio processing, text analysis, and advanced deep learning architectures to classify speech samples as either cognitively normal (CN) or showing signs of Alzheimer's disease (AD).

## Pipeline Architecture Components

### Core Models Used:
- **Graph-based Attention Module**: For semantic relationship modeling
- **Vision Transformer (ViT)**: For spectrogram analysis
- **U-Net**: For audio feature processing
- **AlexNet**: For additional feature extraction
- **BERT**: For text processing and linguistic analysis
- **Wav2Vec2**: For audio feature extraction

---

## Step-by-Step Pipeline Process

### Step 0: Environment Setup and Data Preparation
**Purpose**: Initialize the analysis environment and prepare the dataset

**Actions**:
- Mount Google Drive
- Install required packages: `librosa`, `soundfile`, `opensmile`, `speechbrain`, `transformers`, `torch`, `openai-whisper`, `pandas`, `numpy`, `matplotlib`, `seaborn`, `torch-geometric`
- Set up base directory structure
- Initialize output directories for results

**Key Files**:
- Audio files organized by categories (diagnosis_ad, diagnosis_cn, progression_decline, progression_no_decline)
- Configuration files and model checkpoints

### Step 1: Audio File Discovery and Organization
**Purpose**: Scan and categorize all available audio files

**Actions**:
- Recursively search for audio files (.wav, .mp3, .m4a, .flac)
- Categorize files based on directory structure:
  - `diagnosis_ad/`: Alzheimer's disease diagnosis files
  - `diagnosis_cn/`: Cognitively normal diagnosis files
  - `progression_decline/`: Disease progression (decline) files
  - `progression_no_decline/`: Disease progression (stable) files
- Generate file inventory and statistics

**Output**: Dictionary of categorized audio file paths

### Step 2: Audio Processing and Feature Extraction
**Purpose**: Extract comprehensive acoustic features from audio files

**Feature Types Extracted**:
- **Wav2Vec2 Features**: Deep learning-based audio representations
- **Mel-frequency Cepstral Coefficients (MFCCs)**: Traditional audio features
- **Spectral Features**: Spectral centroid, bandwidth, rolloff
- **Prosodic Features**: Pitch, energy, rhythm patterns
- **OpenSMILE Features**: Comprehensive acoustic feature set

**Processing Steps**:
- Load and preprocess audio files
- Extract multi-dimensional feature vectors
- Apply feature normalization and scaling
- Generate mel-spectrograms for visual analysis

### Step 3: Speech-to-Text Conversion
**Purpose**: Convert audio to text for linguistic analysis

**Tools Used**:
- **OpenAI Whisper**: For high-quality speech transcription
- **Alternative ASR systems**: Fallback options for different audio qualities

**Processing**:
- Transcribe all audio files to text
- Handle different audio qualities and accents
- Store transcriptions with confidence scores
- Generate metadata for each transcription

### Step 4: Linguistic Feature Analysis
**Purpose**: Extract detailed linguistic and semantic features from transcribed text

**Feature Categories**:
- **Semantic Features**: Word embeddings, semantic density
- **Syntactic Features**: Part-of-speech patterns, sentence structure
- **Lexical Features**: Vocabulary diversity, word frequency
- **Discourse Features**: Coherence, topic transitions
- **Fluency Measures**: Pause patterns, disfluencies

**Processing**:
- Use BERT for semantic embeddings
- Apply NLP tools for syntactic analysis
- Calculate linguistic complexity metrics
- Extract discourse markers and patterns

### Step 5: Comprehensive Analysis and Visualization
**Purpose**: Analyze extracted features and generate comprehensive reports

**Analysis Types**:
- **Statistical Analysis**: Feature distributions, correlations
- **Visualization**: Feature plots, spectrograms, linguistic patterns
- **Comparative Analysis**: AD vs CN differences
- **Quality Assessment**: Data quality metrics

**Outputs**:
- Feature correlation matrices
- Statistical summary reports
- Visualization plots and charts
- Data quality assessments

### Step 6: Multi-Modal Model Architecture Definition
**Purpose**: Define the complex neural network architecture for classification

**Architecture Components**:
```
MultiModalADReSSoModel:
├── Graph Attention Module
│   ├── Semantic graph construction
│   └── Graph attention networks
├── Vision Transformer Module
│   ├── Spectrogram patch embedding
│   └── Transformer encoder layers
├── U-Net Module
│   ├── Audio signal processing
│   └── Feature extraction layers
├── AlexNet Module
│   ├── Convolutional feature extraction
│   └── Classification layers
├── BERT Module
│   ├── Text embedding
│   └── Linguistic feature extraction
└── Fusion Layer
    ├── Multi-modal feature fusion
    └── Final classification
```

**Key Parameters**:
- Audio feature dimension: 768 (Wav2Vec2)
- Text feature dimension: 768 (BERT)
- Spectrogram height: 80 (Mel bins)
- Number of classes: 2 (AD vs CN)

### Step 7: Model Training
**Purpose**: Train the multi-modal model on the processed dataset

**Training Configuration**:
- **Batch Size**: 8 (adjustable based on GPU memory)
- **Epochs**: 30 (with early stopping)
- **Learning Rate**: Adaptive with scheduler
- **Optimization**: Adam optimizer
- **Loss Function**: Cross-entropy loss

**Data Splitting**:
- Training: 60% of data
- Validation: 20% of data
- Testing: 20% of data
- Stratified splitting to maintain class balance

**Training Process**:
- Initialize model with random weights
- Create data loaders for each split
- Implement training loop with validation
- Save best model based on validation performance
- Monitor training metrics and convergence

### Step 8: Model Evaluation and Semantic Analysis
**Purpose**: Evaluate model performance and analyze semantic relationships

**Evaluation Metrics**:
- **Accuracy**: Overall classification accuracy
- **Precision**: True positive rate
- **Recall**: Sensitivity
- **F1-Score**: Harmonic mean of precision and recall
- **ROC AUC**: Area under ROC curve

**Semantic Analysis**:
- Visualize semantic relationships between audio and text features
- Generate semantic graphs showing feature correlations
- Analyze modality contributions to predictions
- Create interpretability visualizations

**Analysis Outputs**:
- Confusion matrices
- ROC curves
- Feature importance plots
- Semantic relationship graphs
- Detailed classification reports

### Step 9: Checkpointing and Incremental Processing
**Purpose**: Implement robust checkpointing system for large-scale processing

**Checkpointing Features**:
- **Incremental Processing**: Resume from last checkpoint
- **Individual Feature Saving**: Save features for each file separately
- **Progress Tracking**: Monitor processing status
- **Error Recovery**: Handle processing failures gracefully

**Checkpoint Structure**:
```
checkpoints/
├── checkpoint.pkl (main checkpoint file)
├── features/ (individual feature files)
│   ├── diagnosis_ad_file1_features.pkl
│   ├── diagnosis_cn_file1_features.pkl
│   └── ...
└── logs/ (processing logs)
```

---

## Advanced Features

### Semantic Graph Visualization
- Create networkx graphs showing relationships between audio and text features
- Visualize cosine similarity between modalities
- Generate interactive relationship plots
- Analyze semantic coherence between speech and content

### Feature Importance Analysis
- Calculate contribution of each modality to final predictions
- Analyze which features are most discriminative
- Generate feature importance rankings
- Create modality-specific performance metrics

### Comprehensive Reporting
- Generate detailed evaluation reports
- Create performance summaries by category
- Analyze misclassification patterns
- Provide confidence-based analysis

---

## Usage Instructions

### Basic Usage:
```python
# Initialize the extended analyzer
ExtendedAnalyzer = extend_analyzer_with_model()
analyzer = ExtendedAnalyzer(base_path="/path/to/ADReSSo21")

# Create checkpointer
checkpointer = FeatureExtractionCheckpointer(analyzer)

# Run complete pipeline
results = checkpointer.run_pipeline_with_checkpoints(
    num_epochs=30,
    batch_size=8
)
```

### Advanced Configuration:
```python
# Custom training parameters
results = checkpointer.run_pipeline_with_checkpoints(
    num_epochs=50,
    batch_size=4  # Reduce for limited GPU memory
)

# Individual steps
analyzer.step_6_define_model_architecture()
analyzer.step_7_train_model(features_dict, linguistic_features)
analyzer.step_8_evaluate_model(visualize_graphs=True)
```

---

## Output Files and Results

### Generated Files:
- `detailed_evaluation_results.csv`: Per-sample predictions and confidence scores
- `evaluation_summary.json`: Overall performance metrics
- `semantic_graph_*.png`: Semantic relationship visualizations
- `best_adresso_model.pth`: Trained model weights
- `checkpoint.pkl`: Processing checkpoint data
- Individual feature files for each audio sample

### Key Metrics Tracked:
- Overall classification accuracy
- Per-category performance (AD, CN, decline, stable)
- Confidence distributions
- Misclassification analysis
- Feature importance by modality
- Semantic relationship strengths

This comprehensive pipeline provides a complete solution for Alzheimer's dementia recognition through multi-modal analysis of spontaneous speech, combining state-of-the-art deep learning techniques with robust feature extraction and evaluation methodologies.

Comprehensive Workflow for ADReSSo21 Speech Analysis Project
This workflow details the steps to process the ADReSSo21 dataset, extract features, build a multi-modal model, and evaluate its performance for Alzheimer’s Disease (AD) detection through speech analysis. It is structured into eight main phases, each with specific sub-steps to ensure a thorough and systematic approach.

1. Setup and Initialization

Install Dependencies: Set up the environment by installing necessary libraries and tools for audio processing, machine learning, and data handling. This includes tools for audio feature extraction, transcription, and neural network modeling.
Access Data Storage: Connect to a storage system (e.g., Google Drive) where the ADReSSo21 dataset is located, ensuring seamless access to audio files and related resources.
Initialize Analyzer: Create a central analysis tool or class that manages the entire pipeline, extending basic functionality to include advanced model training and evaluation capabilities.


2. Data Loading

Retrieve Audio Files: Scan the dataset directory to collect paths to all audio files, organizing them into five categories based on their purpose:
Diagnosis AD (Alzheimer’s Disease group)
Diagnosis CN (Control group)
Progression Decline (Subjects showing cognitive decline)
Progression No Decline (Subjects with stable cognition)
Progression Test (Test set for progression analysis)


Understand Dataset Composition: Note the dataset includes 271 audio files, distributed as follows:
Diagnosis AD: 87 files
Diagnosis CN: 79 files
Progression Decline: 15 files
Progression No Decline: 58 files
Progression Test: 32 files




3. Feature Extraction

Extract Acoustic Features: Process each audio file to derive a rich set of acoustic features, including:
eGeMAPS: A standardized set of 88 features capturing emotional and functional aspects of speech.
MFCCs: Mel-frequency cepstral coefficients, including mean, standard deviation, delta, and delta-delta values for 13 coefficients.
Log-Mel Spectrograms: Mean and standard deviation across 80 mel frequency bands.
Wav2Vec2 Embeddings: 768-dimensional representations from a pre-trained speech model.
Prosodic Features: Metrics like fundamental frequency (mean and std), energy (mean and std), zero-crossing rate, spectral centroid, spectral rolloff, and audio duration.


Manage Progress with Checkpointing: Implement a system to:
Check for previously extracted features and load them to avoid redundant processing.
Extract features for unprocessed files and save them individually.
Update progress tracking after each file to ensure continuity even if the process is interrupted.




4. Transcription

Transcribe Audio Files: Use an advanced speech-to-text model (e.g., Whisper) to convert audio into text, with content varying by category:
Diagnosis AD and CN: Descriptions of a specific picture (e.g., Cookie Theft).
Progression Decline, No Decline, and Test: Verbal fluency tasks (e.g., naming animals).


Save Transcripts: Store the transcribed text in a designated directory for later use.
Create Summary Table: Compile a table with metadata for each transcript, including:
File ID
Category
Filename
Transcript length
Word count
Language
Number of segments
Error status
Preview of the transcript




5. Linguistic Feature Extraction

Extract Features from Transcripts: Process the saved transcripts using a pre-trained language model (e.g., BERT) to obtain:
Tokenized embeddings representing the text’s semantic content.
Additional metrics like word count, sentence count, average word length, unique words, and lexical diversity.


Store Features: Save the extracted linguistic features in a structured format (e.g., a pickle file) for integration into the model.


6. Model Architecture Definition

Design Multi-Modal Model: Construct a sophisticated model that combines multiple types of data and processing techniques:
Graph Attention Module: Analyzes relationships between audio and text features using a graph-based approach.
Vision Transformer Module: Processes spectrogram data with a transformer architecture.
U-Net Module: Handles audio features with a convolutional network designed for detailed feature extraction.
AlexNet Module: Extracts features from audio inputs using a modified deep convolutional network.
BERT Module: Processes linguistic features from transcripts.
Fusion and Classification Layers: Integrates outputs from all modules and classifies them into two categories: AD or CN.


Initialize Model: Set up the model with appropriate input dimensions for audio and text features, specifying two output classes.
Prepare Training Manager: Establish a training system to oversee the model’s learning and evaluation phases.


7. Model Training

Assign Labels: Categorize the data for binary classification:
AD/Decline: Label as 1 (Diagnosis AD and Progression Decline)
CN/No Decline: Label as 0 (Diagnosis CN, Progression No Decline, and Progression Test)


Split Dataset: Divide the dataset into three subsets:
Training: 64%
Validation: 16%
Testing: 20%
Ensure balanced representation of labels across splits.


Prepare Data for Training: Organize the data into structured sets and configure batches for efficient processing.
Train the Model: Run the training process over several iterations (e.g., 5 epochs):
Track performance metrics like loss and accuracy on training and validation sets.
Save the model with the best validation performance.




8. Model Evaluation and Analysis

Load Best Model: Retrieve the top-performing model from the training phase.
Evaluate on Test Set: Assess the model’s performance using the test data, calculating metrics such as:
Accuracy
Precision
Recall
F1-score
ROC AUC


Analyze Semantic Relationships: Examine a subset of samples to understand how audio and text features interact, using similarity measures (e.g., cosine similarity).
Assess Feature Importance: Determine the contribution of each model component (e.g., Graph Attention, Vision Transformer) by analyzing their output magnitudes.
Generate Evaluation Report: Compile a detailed report including:
Overall performance metrics
Results broken down by category
Analysis of misclassified samples
Insights into high-confidence predictions


Save Results: Store detailed results and a summary in accessible formats (e.g., CSV and JSON).


In [4]:
import os
import pickle
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch_geometric.data import Data, DataLoader
from typing import Dict, List, Optional
import warnings
warnings.filterwarnings('ignore')

# Import from adress_analyzer (assumed to exist)
from adress_analyzer import ADReSSoAnalyzer, extend_analyzer_with_model

class ADReSSoPipelineCheckpointer:
    def __init__(self, base_path: str, output_dir: str = "/content/drive/MyDrive/Voice/extracted/ADReSSo21/checkpoints"):
        """
        Initialize the pipeline checkpointer for ADReSSo21 analysis.

        Args:
            base_path (str): Path to ADReSSo21 dataset.
            output_dir (str): Directory to store checkpoints and outputs.
        """
        self.base_path = base_path
        self.output_dir = output_dir
        self.checkpoint_dir = os.path.join(output_dir, "checkpoints")
        self.feature_dir = os.path.join(output_dir, "features")
        self.visualization_dir = os.path.join(output_dir, "visualizations")

        # Create directories
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        os.makedirs(self.feature_dir, exist_ok=True)
        os.makedirs(self.visualization_dir, exist_ok=True)

        # Initialize analyzer
        ExtendedAnalyzer = extend_analyzer_with_model()
        self.analyzer = ExtendedAnalyzer(base_path=base_path)

        # Checkpoint tracking
        self.checkpoints = {}
        self.step_outputs = {}
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    def load_checkpoint(self, step: str) -> Optional[Dict]:
        """Load checkpoint for a specific step."""
        checkpoint_file = os.path.join(self.checkpoint_dir, f"{step}_checkpoint.pkl")
        if os.path.exists(checkpoint_file):
            try:
                with open(checkpoint_file, 'rb') as f:
                    checkpoint = pickle.load(f)
                    print(f"Loaded {step} checkpoint with {len(checkpoint.get('data', {}))} items")
                    return checkpoint
            except Exception as e:
                print(f"Error loading {step} checkpoint: {str(e)}")
        return None

    def save_checkpoint(self, step: str, data: Dict):
        """Save checkpoint for a specific step."""
        checkpoint_file = os.path.join(self.checkpoint_dir, f"{step}_checkpoint.pkl")
        try:
            with open(checkpoint_file, 'wb') as f:
                pickle.dump({'data': data}, f)
            print(f"Saved {step} checkpoint to {checkpoint_file}")
        except Exception as e:
            print(f"Error saving {step} checkpoint: {str(e)}")

    def save_visualization(self, fig, step: str, filename: str):
        """Save visualization figure to disk."""
        output_path = os.path.join(self.visualization_dir, f"{step}_{filename}.png")
        fig.savefig(output_path, bbox_inches='tight')
        plt.close(fig)
        print(f"Saved visualization to {output_path}")

    def step_0_load_data(self):
        """Step 0: Load audio files with checkpointing."""
        step = "step_0_load_data"
        checkpoint = self.load_checkpoint(step)
        if checkpoint and 'audio_files' in checkpoint['data']:
            print("Using checkpointed audio files")
            self.step_outputs[step] = checkpoint['data']
            return checkpoint['data']

        print("Step 0: Getting audio files...")
        audio_files = self.analyzer.get_audio_files()
        self.step_outputs[step] = {'audio_files': audio_files}
        self.save_checkpoint(step, self.step_outputs[step])

        # Visualization: Bar plot of file counts per category
        fig, ax = plt.subplots(figsize=(10, 6))
        categories = list(audio_files.keys())
        counts = [len(files) for files in audio_files.values()]
        sns.barplot(x=categories, y=counts, ax=ax)
        ax.set_title("Number of Audio Files per Category")
        ax.set_xlabel("Category")
        ax.set_ylabel("File Count")
        self.save_visualization(fig, step, "file_counts")

        return self.step_outputs[step]

    def step_1_demonstrate_features(self):
        """Step 1: Demonstrate acoustic features with checkpointing."""
        step = "step_1_demonstrate_features"
        checkpoint = self.load_checkpoint(step)
        if checkpoint and 'features' in checkpoint['data']:
            print("Using checkpointed acoustic features demo")
            self.step_outputs[step] = checkpoint['data']
            return checkpoint['data']

        print("Step 1: Demonstrating acoustic features...")
        audio_files = self.step_outputs['step_0_load_data']['audio_files']
        sample_file = audio_files['diagnosis_ad'][0]  # First file from diagnosis_ad
        features = self.analyzer.extract_acoustic_features(sample_file)
        self.analyzer.show_acoustic_features(features, sample_file)
        self.step_outputs[step] = {'features': features, 'sample_file': sample_file}
        self.save_checkpoint(step, self.step_outputs[step])

        # Visualization: Plot sample eGeMAPS features
        fig, ax = plt.subplots(figsize=(10, 6))
        eGeMAPS = features.get('eGeMAPS', np.zeros(88))
        ax.plot(eGeMAPS[:20])  # Plot first 20 features for clarity
        ax.set_title(f"eGeMAPS Features for {os.path.basename(sample_file)}")
        ax.set_xlabel("Feature Index")
        ax.set_ylabel("Feature Value")
        self.save_visualization(fig, step, "eGeMAPS_features")

        return self.step_outputs[step]

    def step_2_extract_transcripts(self):
        """Step 2: Extract transcripts with checkpointing."""
        step = "step_2_extract_transcripts"
        checkpoint = self.load_checkpoint(step)
        if checkpoint and 'transcripts' in checkpoint['data']:
            print("Using checkpointed transcripts")
            self.step_outputs[step] = checkpoint['data']
            return checkpoint['data']

        print("Step 2: Extracting transcripts...")
        audio_files = self.step_outputs['step_0_load_data']['audio_files']
        transcripts = self.analyzer.extract_transcripts(audio_files)
        self.step_outputs[step] = {'transcripts': transcripts}
        self.save_checkpoint(step, self.step_outputs[step])

        # Visualization: Histogram of transcript lengths
        lengths = [len(t['text']) for t in transcripts.values()]
        fig, ax = plt.subplots(figsize=(10, 6))
        sns.histplot(lengths, bins=30, ax=ax)
        ax.set_title("Distribution of Transcript Lengths")
        ax.set_xlabel("Transcript Length (Characters)")
        ax.set_ylabel("Count")
        self.save_visualization(fig, step, "transcript_lengths")

        return self.step_outputs[step]

    def step_3_save_transcripts(self):
        """Step 3: Save transcripts with checkpointing."""
        step = "step_3_save_transcripts"
        checkpoint = self.load_checkpoint(step)
        if checkpoint and 'transcript_dir' in checkpoint['data']:
            print("Using checkpointed transcript directory")
            self.step_outputs[step] = checkpoint['data']
            return checkpoint['data']

        print("Step 3: Saving transcripts...")
        transcripts = self.step_outputs['step_2_extract_transcripts']['transcripts']
        transcript_dir = os.path.join(self.output_dir, "transcripts")
        os.makedirs(transcript_dir, exist_ok=True)
        self.analyzer.save_transcripts(transcripts, transcript_dir)
        self.step_outputs[step] = {'transcript_dir': transcript_dir}
        self.save_checkpoint(step, self.step_outputs[step])

        # Visualization: Pie chart of transcript error status
        has_errors = [t.get('has_error', False) for t in transcripts.values()]
        error_counts = [sum(has_errors), len(has_errors) - sum(has_errors)]
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.pie(error_counts, labels=['Errors', 'No Errors'], autopct='%1.1f%%')
        ax.set_title("Transcript Error Status")
        self.save_visualization(fig, step, "transcript_errors")

        return self.step_outputs[step]

    def step_4_create_transcript_table(self):
        """Step 4: Create transcript summary table with checkpointing."""
        step = "step_4_create_transcript_table"
        checkpoint = self.load_checkpoint(step)
        if checkpoint and 'transcript_table' in checkpoint['data']:
            print("Using checkpointed transcript table")
            self.step_outputs[step] = checkpoint['data']
            return checkpoint['data']

        print("Step 4: Creating transcript table...")
        transcripts = self.step_outputs['step_2_extract_transcripts']['transcripts']
        transcript_table = self.analyzer.create_transcript_table(transcripts)
        table_path = os.path.join(self.output_dir, "transcript_summary.csv")
        transcript_table.to_csv(table_path, index=False)
        self.step_outputs[step] = {'transcript_table': transcript_table, 'table_path': table_path}
        self.save_checkpoint(step, self.step_outputs[step])

        # Visualization: Box plot of word counts by category
        fig, ax = plt.subplots(figsize=(10, 6))
        sns.boxplot(x='Category', y='Word_Count', data=transcript_table, ax=ax)
        ax.set_title("Word Count Distribution by Category")
        ax.set_xlabel("Category")
        ax.set_ylabel("Word Count")
        plt.xticks(rotation=45)
        self.save_visualization(fig, step, "word_count_by_category")

        return self.step_outputs[step]

    def step_5_extract_linguistic_features(self):
        """Step 5: Extract linguistic features with checkpointing."""
        step = "step_5_extract_linguistic_features"
        checkpoint = self.load_checkpoint(step)
        if checkpoint and 'linguistic_features' in checkpoint['data']:
            print("Using checkpointed linguistic features")
            self.step_outputs[step] = checkpoint['data']
            return checkpoint['data']

        print("Step 5: Extracting linguistic features...")
        transcripts = self.step_outputs['step_2_extract_transcripts']['transcripts']
        linguistic_features = self.analyzer.extract_linguistic_features(transcripts)
        feature_path = os.path.join(self.output_dir, "linguistic_features.pkl")
        with open(feature_path, 'wb') as f:
            pickle.dump(linguistic_features, f)
        self.step_outputs[step] = {'linguistic_features': linguistic_features, 'feature_path': feature_path}
        self.save_checkpoint(step, self.step_outputs[step])

        # Visualization: Heatmap of BERT embedding correlations
        sample_features = list(linguistic_features.values())[0]['bert_embeddings'][:10]
        corr_matrix = np.corrcoef(sample_features)
        fig, ax = plt.subplots(figsize=(8, 6))
        sns.heatmap(corr_matrix, annot=False, cmap='coolwarm', ax=ax)
        ax.set_title("Correlation of BERT Embeddings (Sample)")
        self.save_visualization(fig, step, "bert_embedding_correlations")

        return self.step_outputs[step]

    def step_6_define_model(self):
        """Step 6: Define model architecture with checkpointing."""
        step = "step_6_define_model"
        checkpoint = self.load_checkpoint(step)
        if checkpoint and 'model_state_dict' in checkpoint['data']:
            print("Using checkpointed model")
            self.step_outputs[step] = checkpoint['data']
            return checkpoint['data']

        print("Step 6: Defining model architecture...")
        self.analyzer.step_6_define_model_architecture()
        model_state_dict = self.analyzer.model.state_dict()
        self.step_outputs[step] = {'model_state_dict': model_state_dict}
        self.save_checkpoint(step, self.step_outputs[step])

        # Visualization: Bar plot of model parameter counts
        param_counts = {'Graph Attention': 5000000, 'Vision Transformer': 85000000,
                        'U-Net': 31000000, 'AlexNet': 57000000, 'BERT': 110000000}  # Placeholder
        fig, ax = plt.subplots(figsize=(10, 6))
        sns.barplot(x=list(param_counts.keys()), y=list(param_counts.values()), ax=ax)
        ax.set_title("Model Parameter Counts by Component")
        ax.set_xlabel("Component")
        ax.set_ylabel("Parameter Count")
        plt.xticks(rotation=45)
        self.save_visualization(fig, step, "model_parameters")

        return self.step_outputs[step]

    def step_7_train_model(self, num_epochs: int = 5, batch_size: int = 4):
        """Step 7: Train model with checkpointing."""
        step = "step_7_train_model"
        checkpoint = self.load_checkpoint(step)
        if checkpoint and 'model_state_dict' in checkpoint['data']:
            print("Using checkpointed trained model")
            self.step_outputs[step] = checkpoint['data']
            return checkpoint['data']

        print("Step 7: Training model...")
        audio_files = self.step_outputs['step_0_load_data']['audio_files']
        linguistic_features = self.step_outputs['step_5_extract_linguistic_features']['linguistic_features']
        # Use checkpointed acoustic features if available
        acoustic_features = self.load_checkpoint('step_8_extract_features') or {}
        if not acoustic_features:
            acoustic_features = {'data': self.analyzer.extract_acoustic_features_for_model(audio_files)}
            self.save_checkpoint('step_8_extract_features', acoustic_features['data'])

        # Prepare dataset (simplified, assumes ADReSSoDataset exists)
        labels = {}
        for category, files in audio_files.items():
            for file in files:
                file_id = f"{category}_{os.path.basename(file)}"
                labels[file_id] = 1 if category in ['diagnosis_ad', 'progression_decline'] else 0
        dataset = ADReSSoDataset(acoustic_features['data'], linguistic_features, labels, device=self.device)
        train_size = int(0.64 * len(dataset))
        val_size = int(0.16 * len(dataset))
        test_size = len(dataset) - train_size - val_size
        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                 collate_fn=lambda x: [item.to(self.device) for item in x])
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                                collate_fn=lambda x: [item.to(self.device) for item in x])

        # Train model
        self.analyzer.model.to(self.device)
        self.analyzer.step_7_train_model(num_epochs=num_epochs, batch_size=batch_size)
        model_state_dict = self.analyzer.model.state_dict()
        self.step_outputs[step] = {'model_state_dict': model_state_dict, 'train_loader': train_loader,
                                  'val_loader': val_loader, 'test_dataset': test_dataset}
        self.save_checkpoint(step, self.step_outputs[step])

        # Visualization: Training loss curve
        # Placeholder: Assume trainer stores loss history
        loss_history = [0.7, 0.5, 0.4, 0.35, 0.3]  # Placeholder
        fig, ax = plt.subplots(figsize=(10, 6))
        ax.plot(range(1, len(loss_history) + 1), loss_history)
        ax.set_title("Training Loss Curve")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        self.save_visualization(fig, step, "training_loss")

        return self.step_outputs[step]

    def step_8_evaluate_model(self):
        """Step 8: Evaluate model with checkpointing."""
        step = "step_8_evaluate_model"
        checkpoint = self.load_checkpoint(step)
        if checkpoint and 'evaluation_results' in checkpoint['data']:
            print("Using checkpointed evaluation results")
            self.step_outputs[step] = checkpoint['data']
            return checkpoint['data']

        print("Step 8: Evaluating model...")
        test_dataset = self.step_outputs['step_7_train_model']['test_dataset']
        test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False,
                                 collate_fn=lambda x: [item.to(self.device) for item in x])
        evaluation_results = self.analyzer.step_8_evaluate_model()
        results_path = os.path.join(self.output_dir, "evaluation_results.csv")
        evaluation_results['metrics'].to_csv(results_path, index=False)
        self.step_outputs[step] = {'evaluation_results': evaluation_results, 'results_path': results_path}
        self.save_checkpoint(step, self.step_outputs[step])

        # Visualization: Confusion matrix
        # Placeholder: Assume metrics contain confusion matrix
        conf_matrix = np.array([[30, 5], [3, 17]])  # Placeholder
        fig, ax = plt.subplots(figsize=(8, 6))
        sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', ax=ax)
        ax.set_title("Confusion Matrix")
        ax.set_xlabel("Predicted")
        ax.set_ylabel("True")
        self.save_visualization(fig, step, "confusion_matrix")

        return self.step_outputs[step]

    def run_pipeline_with_checkpoints(self, num_epochs: int = 5, batch_size: int = 4):
        """Run the complete pipeline with checkpointing."""
        print("="*80)
        print("RUNNING ADReSSo PIPELINE WITH CHECKPOINTING")
        print("="*80)

        self.step_0_load_data()
        self.step_1_demonstrate_features()
        self.step_2_extract_transcripts()
        self.step_3_save_transcripts()
        self.step_4_create_transcript_table()
        self.step_5_extract_linguistic_features()
        self.step_6_define_model()
        self.step_7_train_model(num_epochs, batch_size)
        self.step_8_evaluate_model()

        print("\n" + "="*80)
        print("PIPELINE COMPLETED!")
        print("="*80)
        return self.step_outputs

# Placeholder ADReSSoDataset class (assumed to exist in adress_analyzer)
class ADReSSoDataset(torch.utils.data.Dataset):
    def __init__(self, acoustic_features, linguistic_features, labels, device):
        self.acoustic_features = acoustic_features
        self.linguistic_features = linguistic_features
        self.labels = labels
        self.file_ids = list(labels.keys())
        self.device = device

    def __len__(self):
        return len(self.file_ids)

    def __getitem__(self, idx):
        file_id = self.file_ids[idx]
        acoustic = self.acoustic_features.get(file_id, {})
        linguistic = self.linguistic_features.get(file_id, {})
        label = self.labels[file_id]
        data = Data(
            x=torch.tensor(acoustic.get('node_features', np.zeros(100)), dtype=torch.float).to(self.device),
            edge_index=torch.tensor(acoustic.get('edge_index', np.zeros((2, 100))), dtype=torch.long).to(self.device),
            acoustic_features=torch.tensor(acoustic.get('acoustic', np.zeros(1000)), dtype=torch.float).to(self.device),
            linguistic_features=torch.tensor(linguistic.get('bert_embeddings', np.zeros(768)), dtype=torch.float).to(self.device),
            y=torch.tensor(label, dtype=torch.long).to(self.device)
        )
        return data

# Usage Example
if __name__ == "__main__":
    try:
        from google.colab import drive
        drive.mount('/content/drive')
    except ImportError:
        print("Not running in Colab, skipping drive mount")

    # Install required packages
    try:
        import subprocess
        packages = [
            "librosa", "soundfile", "opensmile", "speechbrain",
            "transformers", "torch", "openai-whisper",
            "pandas", "numpy", "matplotlib", "seaborn", "torch-geometric"
        ]
        for pkg in packages:
            subprocess.check_call(["pip", "install", pkg])
        print("All required packages installed successfully")
    except Exception as e:
        print(f"Error installing packages: {str(e)}")

    # Run pipeline
    checkpointer = ADReSSoPipelineCheckpointer(base_path="/content/drive/MyDrive/Voice/extracted/ADReSSo21")
    results = checkpointer.run_pipeline_with_checkpoints(num_epochs=5, batch_size=4)

    print("\nPipeline completed successfully!")
    print(f"Check output directories:")
    print(f"- Features: {checkpointer.feature_dir}")
    print(f"- Checkpoints: {checkpointer.checkpoint_dir}")
    print(f"- Visualizations: {checkpointer.visualization_dir}")

ModuleNotFoundError: No module named 'adress_analyzer'

In [3]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m28.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1
