<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/Graph_Attention_and_Feature_Fusion_for_Robust_Alzheimer%E2%80%99s_Disease_Diagnosis_from_the_ADReSSo21_Speech_Dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Create Final_output Folder, Install Dependencies, and Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [1]:
import subprocess
import os

# Create Final_output directory and subdirectories
try:
    OUTPUT_DIR = "/content/drive/MyDrive/Final_output"
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs(os.path.join(OUTPUT_DIR, "visualizations"), exist_ok=True)
    os.makedirs(os.path.join(OUTPUT_DIR, "models"), exist_ok=True)
    os.makedirs(os.path.join(OUTPUT_DIR, "logs"), exist_ok=True)
    os.makedirs(os.path.join(OUTPUT_DIR, "transcripts"), exist_ok=True)
    os.makedirs(os.path.join(OUTPUT_DIR, "features"), exist_ok=True)
    print(f"Created output directories at {OUTPUT_DIR}")
except Exception as e:
    print(f"Error creating directories: {str(e)}")

# Install required packages
try:
    packages = [
        "librosa", "soundfile", "opensmile", "speechbrain",
        "transformers", "torch", "openai-whisper",
        "pandas", "numpy", "matplotlib", "seaborn", "torch-geometric"
    ]
    for pkg in packages:
        subprocess.check_call(["pip", "install", pkg])
    print("All required packages installed successfully")
except Exception as e:
    print(f"Error installing packages: {str(e)}")


Created output directories at /content/drive/MyDrive/Final_output
All required packages installed successfully
Error mounting Google Drive: Mountpoint must not already contain files


# Import Libraries and Define Error Logging

In [2]:
import os
import pandas as pd
import numpy as np
import librosa
import soundfile as sf
from pathlib import Path
import pickle
import json
from typing import Dict, List, Tuple, Any
import warnings
import opensmile
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch_geometric
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
from torch_geometric.data import Data, Batch
from transformers import Wav2Vec2Processor, Wav2Vec2Model, BertTokenizer, BertModel
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
import networkx as nx
import whisper
from multiprocessing import Pool
import time

# Error logging function
def log_error(message: str):
    log_file = os.path.join("/content/drive/MyDrive/Final_output", "logs", "pipeline_errors.log")
    with open(log_file, 'a') as f:
        f.write(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - {message}\n")
    print(f"Error logged: {message}")

KeyboardInterrupt: 

# ADReSSoAnalyzer Class

In [None]:
class ADReSSoAnalyzer:
    def __init__(self, base_path="/content/drive/MyDrive/Voice/extracted/ADReSSo21"):
        self.base_path = base_path
        self.output_path = "/content/drive/MyDrive/Final_output"
        self.features = {}
        self.transcripts = {}
        try:
            self.smile = opensmile.Smile(
                feature_set=opensmile.FeatureSet.eGeMAPSv02,
                feature_level=opensmile.FeatureLevel.Functionals,
            )
            self.whisper_model = whisper.load_model("base")
            self.wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
            self.wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h").to('cuda' if torch.cuda.is_available() else 'cpu')
            self.bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            self.bert_model = BertModel.from_pretrained('bert-base-uncased').to('cuda' if torch.cuda.is_available() else 'cpu')
        except Exception as e:
            log_error(f"Initialization error: {str(e)}")
            raise

    def get_audio_files(self) -> Dict[str, List[str]]:
        try:
            audio_files = {
                'diagnosis_ad': [], 'diagnosis_cn': [],
                'progression_decline': [], 'progression_no_decline': [],
                'progression_test': []
            }
            paths = {
                'diagnosis_ad': f"{self.base_path}/diagnosis/train/audio/ad",
                'diagnosis_cn': f"{self.base_path}/diagnosis/train/audio/cn",
                'progression_decline': f"{self.base_path}/progression/train/audio/decline",
                'progression_no_decline': f"{self.base_path}/progression/train/audio/no_decline",
                'progression_test': f"{self.base_path}/progression/test-dist/audio"
            }
            for category, path in paths.items():
                if os.path.exists(path):
                    audio_files[category] = [f"{path}/{f}" for f in os.listdir(path) if f.endswith('.wav')]
                else:
                    log_error(f"Path not found: {path}")
            return audio_files
        except Exception as e:
            log_error(f"Error in get_audio_files: {str(e)}")
            return {}

    def extract_acoustic_features_single(self, audio_path: str, sr=8000, extract_wav2vec=True) -> Tuple[str, Dict[str, Any]]:
        features = {}
        try:
            y, sr = librosa.load(audio_path, sr=sr)
            features['egemaps'] = self.smile.process_file(audio_path).values.flatten()
            mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
            features['mfccs'] = {
                'mean': np.mean(mfccs, axis=1),
                'std': np.std(mfccs, axis=1),
                'delta': np.mean(librosa.feature.delta(mfccs), axis=1),
                'delta2': np.mean(librosa.feature.delta(mfccs, order=2), axis=1)
            }
            mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=80)
            log_mel = librosa.power_to_db(mel_spec)
            features['log_mel'] = {
                'mean': np.mean(log_mel, axis=1),
                'std': np.std(log_mel, axis=1)
            }
            if extract_wav2vec:
                input_values = self.wav2vec_processor(y, sampling_rate=sr, return_tensors="pt").input_values.to(self.wav2vec_model.device)
                with torch.no_grad():
                    wav2vec_features = self.wav2vec_model(input_values).last_hidden_state
                features['wav2vec2'] = torch.mean(wav2vec_features, dim=1).squeeze().cpu().numpy()
            else:
                features['wav2vec2'] = np.zeros(768)
            f0 = librosa.yin(y, fmin=50, fmax=300, sr=sr)
            f0_clean = f0[f0 > 0]
            features['prosodic'] = {
                'f0_mean': np.mean(f0_clean) if len(f0_clean) > 0 else 0.0,
                'f0_std': np.std(f0_clean) if len(f0_clean) > 0 else 0.0,
                'energy_mean': np.mean(librosa.feature.rms(y=y)),
                'energy_std': np.std(librosa.feature.rms(y=y)),
                'zero_crossing_rate': np.mean(librosa.feature.zero_crossing_rate(y)),
                'spectral_centroid': np.mean(librosa.feature.spectral_centroid(y=y, sr=sr)),
                'spectral_rolloff': np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr)),
                'duration': len(y) / sr
            }
        except Exception as e:
            log_error(f"Error extracting features for {audio_path}: {str(e)}")
            features = {
                'egemaps': np.zeros(88),
                'mfccs': {'mean': np.zeros(13), 'std': np.zeros(13), 'delta': np.zeros(13), 'delta2': np.zeros(13)},
                'log_mel': {'mean': np.zeros(80), 'std': np.zeros(80)},
                'wav2vec2': np.zeros(768),
                'prosodic': {'f0_mean': 0.0, 'f0_std': 0.0, 'energy_mean': 0.0, 'energy_std': 0.0,
                             'zero_crossing_rate': 0.0, 'spectral_centroid': 0.0, 'spectral_rolloff': 0.0, 'duration': 0.0}
            }
        return audio_path, features

    def extract_acoustic_features(self, audio_files: Dict[str, List[str]], sample_fraction=0.5, extract_wav2vec=True):
        print("Extracting acoustic features...")
        feature_dict = {}
        checkpoint_file = os.path.join(self.output_path, "features", "acoustic_features_checkpoint.pkl")

        # Load existing checkpoint
        if os.path.exists(checkpoint_file):
            with open(checkpoint_file, 'rb') as f:
                feature_dict = pickle.load(f)
            print(f"Loaded {len(feature_dict)} features from checkpoint")

        # Prepare files for processing
        all_files = []
        for category, files in audio_files.items():
            sampled_files = files[:int(len(files) * sample_fraction)]
            all_files.extend([(f, category) for f in sampled_files if f"{category}_{os.path.basename(f)}" not in feature_dict])

        # Process files in parallel
        with Pool(processes=4) as pool:
            results = pool.starmap(self.extract_acoustic_features_single, [(f[0], 8000, extract_wav2vec) for f in all_files])

        for audio_path, features in results:
            category = next(c for c, files in audio_files.items() if audio_path in files)
            file_id = f"{category}_{os.path.basename(audio_path)}"
            feature_dict[file_id] = features

        # Save checkpoint
        with open(checkpoint_file, 'wb') as f:
            pickle.dump(feature_dict, f)
        print(f"Saved {len(feature_dict)} features to {checkpoint_file}")
        self.features = feature_dict
        return feature_dict

    def visualize_features(self, features: Dict, file_id: str):
        try:
            plt.figure(figsize=(10, 6))
            mfcc_data = np.vstack([features['mfccs']['mean'], features['mfccs']['std'], features['mfccs']['delta'], features['mfccs']['delta2']])
            sns.heatmap(mfcc_data, cmap='viridis')
            plt.title(f'MFCC Features - {file_id} (Early AD Detection)')
            plt.xlabel('Feature Index')
            plt.ylabel('MFCC Type (Mean, Std, Delta, Delta2)')
            plt.savefig(os.path.join(self.output_path, 'visualizations', f'{file_id}_mfcc.png'))
            plt.close()

            plt.figure(figsize=(10, 6))
            plt.plot(features['egemaps'], label='eGeMAPS')
            plt.title(f'eGeMAPS Features - {file_id} (Early AD Detection)')
            plt.xlabel('Feature Index')
            plt.ylabel('Value')
            plt.legend()
            plt.savefig(os.path.join(self.output_path, 'visualizations', f'{file_id}_egemaps.png'))
            plt.close()

            plt.figure(figsize=(10, 6))
            plt.scatter(range(len(features['wav2vec2'][:100])), features['wav2vec2'][:100])
            plt.title(f'Wav2Vec2 Features (First 100 dims) - {file_id}')
            plt.xlabel('Feature Index')
            plt.ylabel('Value')
            plt.savefig(os.path.join(self.output_path, 'visualizations', f'{file_id}_wav2vec2.png'))
            plt.close()

            plt.figure(figsize=(10, 6))
            prosodic_values = list(features['prosodic'].values())
            prosodic_keys = list(features['prosodic'].keys())
            plt.bar(prosodic_keys, prosodic_values)
            plt.title(f'Prosodic Features - {file_id} (Progression Tracking)')
            plt.xticks(rotation=45)
            plt.tight_layout()
            plt.savefig(os.path.join(self.output_path, 'visualizations', f'{file_id}_prosodic.png'))
            plt.close()
        except Exception as e:
            log_error(f"Error visualizing features for {file_id}: {str(e)}")

    def perform_eda(self, features_dict: Dict, transcripts: Dict):
        try:
            print("\nPerforming Exploratory Data Analysis...")
            eda_data = []
            for file_id, features in features_dict.items():
                category = file_id.split('_')[0]
                transcript = transcripts.get(file_id, {})
                eda_data.append({
                    'File_ID': file_id,
                    'Category': category,
                    'MFCC_Mean_0': features['mfccs']['mean'][0],
                    'F0_Mean': features['prosodic']['f0_mean'],
                    'F0_Std': features['prosodic']['f0_std'],
                    'Lexical_Diversity': transcript.get('lexical_diversity', 0),
                    'Word_Count': transcript.get('word_count', 0)
                })
            eda_df = pd.DataFrame(eda_data)
            eda_df.to_csv(os.path.join(self.output_path, "eda_summary.csv"), index=False)

            plt.figure(figsize=(10, 6))
            sns.boxplot(x='Category', y='F0_Mean', data=eda_df)
            plt.title('F0 Mean Distribution by Category (Early AD Detection)')
            plt.savefig(os.path.join(self.output_path, 'visualizations', 'f0_mean_boxplot.png'))
            plt.close()

            plt.figure(figsize=(10, 6))
            sns.boxplot(x='Category', y='Lexical_Diversity', data=eda_df)
            plt.title('Lexical Diversity by Category (Early AD Detection)')
            plt.savefig(os.path.join(self.output_path, 'visualizations', 'lexical_diversity_boxplot.png'))
            plt.close()

            plt.figure(figsize=(10, 6))
            sns.scatterplot(x='Word_Count', y='F0_Std', hue='Category', data=eda_df)
            plt.title('Word Count vs F0 Std (Progression Tracking)')
            plt.savefig(os.path.join(self.output_path, 'visualizations', 'word_count_f0_std_scatter.png'))
            plt.close()

            print("EDA visualizations saved to", os.path.join(self.output_path, 'visualizations'))
        except Exception as e:
            log_error(f"Error in EDA: {str(e)}")

    def extract_transcripts(self, audio_files: Dict[str, List[str]]) -> Dict[str, str]:
        transcripts = {}
        print("Extracting transcripts...")
        checkpoint_file = os.path.join(self.output_path, "transcripts", "transcripts_checkpoint.pkl")
        if os.path.exists(checkpoint_file):
            with open(checkpoint_file, 'rb') as f:
                transcripts = pickle.load(f)
            print(f"Loaded {len(transcripts)} transcripts from checkpoint")

        for category, files in audio_files.items():
            files = files[:int(len(files) * 0.5)]  # Sample 50% of files
            for file_path in files:
                file_id = f"{category}_{os.path.basename(file_path)}"
                if file_id in transcripts:
                    continue
                try:
                    result = self.whisper_model.transcribe(file_path)
                    transcripts[file_id] = {
                        'file_path': file_path,
                        'category': category,
                        'filename': os.path.basename(file_path),
                        'transcript': result["text"].strip(),
                        'language': result.get('language', 'en'),
                        'segments': len(result.get('segments', []))
                    }
                except Exception as e:
                    log_error(f"Error transcribing {file_id}: {str(e)}")
                    transcripts[file_id] = {
                        'file_path': file_path,
                        'category': category,
                        'filename': os.path.basename(file_path),
                        'transcript': "",
                        'error': str(e)
                    }
        with open(checkpoint_file, 'wb') as f:
            pickle.dump(transcripts, f)
        return transcripts

    def save_transcripts(self, transcripts: Dict[str, str]):
        for key, data in transcripts.items():
            filename = f"{key}_transcript.txt"
            filepath = os.path.join(self.output_path, "transcripts", filename)
            with open(filepath, 'w', encoding='utf-8') as f:
                f.write(data['transcript'])
        with open(os.path.join(self.output_path, "transcripts", "all_transcripts.json"), 'w', encoding='utf-8') as f:
            json.dump(transcripts, f, indent=2, ensure_ascii=False)
        with open(os.path.join(self.output_path, "transcripts", "transcripts.pkl"), 'wb') as f:
            pickle.dump(transcripts, f)
        print(f"Transcripts saved to {os.path.join(self.output_path, 'transcripts')}")

    def create_transcript_table(self, transcripts: Dict[str, str]) -> pd.DataFrame:
        data = []
        for key, info in transcripts.items():
            data.append({
                'File_ID': key,
                'Category': info['category'],
                'Filename': info['filename'],
                'Transcript_Length': len(info['transcript']),
                'Word_Count': len(info['transcript'].split()) if info['transcript'] else 0,
                'Language': info.get('language', 'N/A'),
                'Segments': info.get('segments', 'N/A'),
                'Has_Error': 'error' in info,
                'Transcript_Preview': info['transcript'][:100] + "..." if len(info['transcript']) > 100 else info['transcript']
            })
        df = pd.DataFrame(data)
        df.to_csv(os.path.join(self.output_path, "transcript_summary.csv"), index=False)
        return df

    def extract_linguistic_features(self, transcripts: Dict[str, str]) -> Dict[str, Any]:
        linguistic_features = {}
        print("Extracting linguistic features...")
        for key, data in transcripts.items():
            try:
                transcript = data['transcript']
                if not transcript:
                    linguistic_features[key] = {
                        'raw_text': '', 'word_count': 0, 'sentence_count': 0, 'avg_word_length': 0,
                        'unique_words': 0, 'lexical_diversity': 0, 'bert_tokens': [],
                        'bert_input_ids': [], 'bert_attention_mask': []
                    }
                    continue
                words = transcript.split()
                sentences = transcript.split('.')
                bert_encoding = self.bert_tokenizer(
                    transcript, truncation=True, padding='max_length', max_length=512, return_tensors='pt'
                ).to(self.bert_model.device)
                with torch.no_grad():
                    bert_outputs = self.bert_model(**bert_encoding)
                linguistic_features[key] = {
                    'raw_text': transcript,
                    'word_count': len(words),
                    'sentence_count': len([s for s in sentences if s.strip()]),
                    'avg_word_length': np.mean([len(word) for word in words]) if words else 0,
                    'unique_words': len(set(words)),
                    'lexical_diversity': len(set(words)) / len(words) if words else 0,
                    'bert_tokens': self.bert_tokenizer.tokenize(transcript),
                    'bert_input_ids': bert_encoding['input_ids'].squeeze().tolist(),
                    'bert_attention_mask': bert_encoding['attention_mask'].squeeze().tolist(),
                    'bert_encoding': bert_outputs.last_hidden_state.cpu()
                }
            except Exception as e:
                log_error(f"Error extracting linguistic features for {key}: {str(e)}")
                linguistic_features[key] = {
                    'raw_text': '', 'word_count': 0, 'sentence_count': 0, 'avg_word_length': 0,
                    'unique_words': 0, 'lexical_diversity': 0, 'bert_tokens': [],
                    'bert_input_ids': [], 'bert_attention_mask': [], 'bert_encoding': None
                }
        with open(os.path.join(self.output_path, "linguistic_features.pkl"), 'wb') as f:
            pickle.dump(linguistic_features, f)
        return linguistic_features

# Model Definitions

In [None]:
class GraphAttentionModule(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=256, num_heads=8, num_layers=3):
        super().__init__()
        self.gat_layers = nn.ModuleList([
            GATConv(input_dim if i == 0 else hidden_dim, hidden_dim, heads=num_heads, dropout=0.2)
            for i in range(num_layers)
        ])
        self.projection = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x, edge_index, batch=None):
        for gat_layer in self.gat_layers:
            x = gat_layer(x, edge_index)
            x = F.relu(x)
            x = self.dropout(x)
        if batch is not None:
            x = global_mean_pool(x, batch)
        else:
            x = torch.mean(x, dim=0, keepdim=True)
        return self.projection(x)

class VisionTransformerModule(nn.Module):
    def __init__(self, input_dim=80, patch_size=8, embed_dim=768, num_heads=12, num_layers=6):
        super().__init__()
        self.patch_embed = nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, 1000, embed_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*4, dropout=0.1, activation='gelu')
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.classifier = nn.Linear(embed_dim, 256)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.patch_embed(x).flatten(2).transpose(1, 2)
        num_patches = x.shape[1]
        x = x + self.pos_embed[:, :num_patches, :]
        x = x.transpose(0, 1)
        x = self.transformer(x)
        x = x.transpose(0, 1)
        x = torch.mean(x, dim=1)
        return self.classifier(x)

class UNetModule(nn.Module):
    def __init__(self, in_channels=1, out_channels=128):
        super().__init__()
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        self.bottleneck = self.conv_block(512, 1024)
        self.dec4 = self.upconv_block(1024, 512)
        self.dec3 = self.upconv_block(512, 256)
        self.dec2 = self.upconv_block(256, 128)
        self.dec1 = self.upconv_block(128, 64)
        self.final = nn.Conv1d(64, out_channels, kernel_size=1)
        self.pool = nn.AdaptiveAvgPool1d(1)

    def conv_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv1d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv1d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_ch),
            nn.ReLU(inplace=True)
        )

    def upconv_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.ConvTranspose1d(in_ch, out_ch, kernel_size=2, stride=2),
            nn.BatchNorm1d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(F.max_pool1d(e1, 2))
        e3 = self.enc3(F.max_pool1d(e2, 2))
        e4 = self.enc4(F.max_pool1d(e3, 2))
        b = self.bottleneck(F.max_pool1d(e4, 2))
        d4 = self.dec4(b)
        d3 = self.dec3(d4)
        d2 = self.dec2(d3)
        d1 = self.dec1(d2)
        out = self.final(d1)
        out = self.pool(out).squeeze(-1)
        return out

class AlexNetModule(nn.Module):
    def __init__(self, input_dim=768, num_classes=256):
        super().__init__()
        self.features = nn.Sequential(
            nn.Linear(input_dim, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
        )
        self.classifier = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

class MultiModalADReSSoModel(nn.Module):
    def __init__(self, audio_feature_dim=768, text_feature_dim=768, spectrogram_height=80, num_classes=2):
        super().__init__()
        self.graph_attention = GraphAttentionModule(input_dim=text_feature_dim)
        self.vision_transformer = VisionTransformerModule(input_dim=spectrogram_height)
        self.unet = UNetModule()
        self.alexnet = AlexNetModule(input_dim=audio_feature_dim)
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fusion_layer = nn.Sequential(
            nn.Linear(256 + 256 + 128 + 256, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        self.classifier = nn.Linear(256, num_classes)

    def create_semantic_graph(self, text_features, audio_features):
        batch_size = text_features.shape[0]
        graphs = []
        device = text_features.device
        for i in range(batch_size):
            text_feat = text_features[i].unsqueeze(0)
            audio_feat = audio_features[i].unsqueeze(0)
            node_features = torch.cat([text_feat, audio_feat], dim=0)
            similarity = F.cosine_similarity(text_feat, audio_feat, dim=1)
            edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long, device=device).t() if similarity.item() > 0.1 else torch.tensor([[0, 1], [0, 1]], dtype=torch.long, device=device).t()
            graph = Data(x=node_features, edge_index=edge_index)
            graphs.append(graph)
        return Batch.from_data_list(graphs)

    def forward(self, audio_features, text_input_ids, text_attention_mask, spectrograms):
        bert_outputs = self.bert(input_ids=text_input_ids, attention_mask=text_attention_mask)
        text_features = bert_outputs.last_hidden_state.mean(dim=1)
        graph_batch = self.create_semantic_graph(text_features, audio_features)
        graph_out = self.graph_attention(graph_batch.x, graph_batch.edge_index, graph_batch.batch)
        vit_out = self.vision_transformer(spectrograms)
        audio_1d = audio_features.unsqueeze(1)
        unet_out = self.unet(audio_1d)
        alexnet_out = self.alexnet(audio_features)
        fused_features = torch.cat([graph_out, vit_out, unet_out, alexnet_out], dim=1)
        fused_features = self.fusion_layer(fused_features)
        output = self.classifier(fused_features)
        return output

# Extended Analyzer with Model Training

In [None]:
class ADReSSoAnalyzerExtended(ADReSSoAnalyzer):
    def __init__(self, base_path="/content/drive/MyDrive/Voice/extracted/ADReSSo21"):
        super().__init__(base_path)
        self.model = None
        self.trainer = None
        self.scaler = StandardScaler()

    def train_individual_model(self, model, train_loader, val_loader, model_name, num_epochs=5):
        trainer = ADReSSoTrainer(model)
        trainer.train(train_loader, val_loader, num_epochs=num_epochs)
        torch.save(model.state_dict(), os.path.join(self.output_path, 'models', f'{model_name}.pth'))
        print(f"Saved {model_name} to {os.path.join(self.output_path, 'models', f'{model_name}.pth')}")
        return trainer

    def step_6_define_model_architecture(self):
        print("\n" + "="*60)
        print("STEP 6: DEFINING MODEL ARCHITECTURE")
        print("="*60)
        self.model = MultiModalADReSSoModel(audio_feature_dim=768, text_feature_dim=768, spectrogram_height=80, num_classes=2)
        self.trainer = ADReSSoTrainer(self.model)
        print(f"Model initialized with {sum(p.numel() for p in self.model.parameters()):,} parameters")
        return self.model

    def step_7_train_model(self, features_dict, linguistic_features, batch_size=4, num_epochs=5):
        print("\n" + "="*60)
        print("STEP 7: TRAINING MODEL")
        print("="*60)
        if self.model is None:
            self.step_6_define_model_architecture()
        labels = {fid: 1 if 'diagnosis_ad' in fid or 'progression_decline' in fid else 0 for fid in features_dict.keys()}
        file_ids = list(features_dict.keys())
        train_ids, test_ids = train_test_split(file_ids, test_size=0.2, stratify=[labels[f] for f in file_ids], random_state=42)
        train_ids, val_ids = train_test_split(train_ids, test_size=0.2, stratify=[labels[f] for f in train_ids], random_state=42)
        train_features = {fid: features_dict[fid] for fid in train_ids}
        val_features = {fid: features_dict[fid] for fid in val_ids}
        test_features = {fid: features_dict[fid] for fid in test_ids}
        train_linguistic = {fid: linguistic_features[fid] for fid in train_ids}
        val_linguistic = {fid: linguistic_features[fid] for fid in val_ids}
        test_linguistic = {fid: linguistic_features[fid] for fid in test_ids}
        train_labels = {fid: labels[fid] for fid in train_ids}
        val_labels = {fid: labels[fid] for fid in val_ids}
        test_labels = {fid: labels[fid] for fid in test_ids}
        train_dataset = ADReSSoDataset(train_features, train_linguistic, train_labels)
        val_dataset = ADReSSoDataset(val_features, val_linguistic, val_labels)
        test_dataset = ADReSSoDataset(test_features, test_linguistic, test_labels)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

        print("Training individual models...")
        self.train_individual_model(self.model.graph_attention, train_loader, val_loader, 'graph_attention')
        self.train_individual_model(self.model.vision_transformer, train_loader, val_loader, 'vision_transformer')
        self.train_individual_model(self.model.unet, train_loader, val_loader, 'unet')
        self.train_individual_model(self.model.alexnet, train_loader, val_loader, 'alexnet')

        print("\nTraining fused model...")
        self.trainer.train(train_loader, val_loader, num_epochs=num_epochs)
        self.test_loader = test_loader
        return self.trainer

    def step_8_evaluate_model(self, visualize_graphs=True, num_graph_samples=5):
        print("\n" + "="*60)
        print("STEP 8: MODEL EVALUATION AND SEMANTIC ANALYSIS")
        print("="*60)
        try:
            self.model.load_state_dict(torch.load(os.path.join(self.output_path, 'models', 'best_adresso_model.pth')))
            evaluation_results = self.trainer.evaluate_detailed(self.test_loader, class_names=['CN', 'AD'])
            if visualize_graphs:
                print(f"\nVisualizing semantic relationships for {num_graph_samples} samples...")
                self.trainer.visualize_semantic_relationships(self.test_loader, num_samples=num_graph_samples)
            self.trainer.analyze_feature_importance()
            self.generate_evaluation_report(evaluation_results)
            return evaluation_results
        except Exception as e:
            log_error(f"Error in model evaluation: {str(e)}")
            return None

    def generate_evaluation_report(self, evaluation_results):
        report = f"""
=== EVALUATION REPORT ===
Accuracy: {evaluation_results['accuracy']:.4f}
Precision: {evaluation_results['precision']:.4f}
Recall: {evaluation_results['recall']:.4f}
F1-Score: {evaluation_results['f1']:.4f}
ROC AUC: {evaluation_results['auc']:.4f if evaluation_results['auc'] is not None else 'N/A'}
Confusion Matrix:
{evaluation_results['confusion_matrix']}
"""
        with open(os.path.join(self.output_path, 'evaluation_report.txt'), 'w') as f:
            f.write(report)
        print(report)

# Run the Pipeline

In [None]:
analyzer = ADReSSoAnalyzerExtended()
print("=== ADReSSo21 Speech Analysis Pipeline ===\n")
audio_files = analyzer.get_audio_files()
total_files = sum(len(files) for files in audio_files.values())
print(f"Found {total_files} audio files across all categories")
for category, files in audio_files.items():
    print(f"  {category}: {len(files)} files")

if total_files == 0:
    log_error("No audio files found. Please check the dataset path.")
else:
    print("\nStep 1: Extracting acoustic features...")
    features_dict = analyzer.extract_acoustic_features(audio_files, sample_fraction=0.5, extract_wav2vec=False)

    print("\nStep 2: Visualizing acoustic features for a sample...")
    for category, files in audio_files.items():
        if files:
            file_id = f"{category}_{os.path.basename(files[0])}"
            analyzer.visualize_features(features_dict[file_id], file_id)
            break

    print("\nStep 3: Extracting transcripts...")
    transcripts = analyzer.extract_transcripts(audio_files)

    print("\nStep 4: Saving transcripts...")
    analyzer.save_transcripts(transcripts)

    print("\nStep 5: Creating transcript table...")
    transcript_df = analyzer.create_transcript_table(transcripts)
    print("Transcript Summary Table:")
    print(transcript_df.to_string(index=False))

    print("\nStep 6: Extracting linguistic features...")
    linguistic_features = analyzer.extract_linguistic_features(transcripts)

    print("\nStep 7: Performing EDA...")
    analyzer.perform_eda(features_dict, linguistic_features)

    print("\nStep 8: Defining and training model...")
    trainer = analyzer.step_7_train_model(features_dict, linguistic_features, batch_size=4, num_epochs=5)

    print("\nStep 9: Evaluating model...")
    evaluation_results = analyzer.step_8_evaluate_model(visualize_graphs=True, num_graph_samples=5)

    print("\nPipeline completed successfully!")