<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/DetectAlzFromSpeechGBA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
from google.colab import drive

# Mount Google Drive
try:
    drive.mount('/content/drive', force_remount=False)
except:
    drive.mount('/content/drive')

# Define the base directory
base_directory = '/content/drive/MyDrive/Speech'

# Check if the directory exists
if os.path.exists(base_directory):
    # Separate lists for files and directories
    file_paths = []
    directory_paths = []

    for root, directories, files in os.walk(base_directory):
        # Add directory paths
        for directory in directories:
            dir_path = os.path.join(root, directory)
            directory_paths.append(dir_path)

        # Add file paths
        for file in files:
            file_path = os.path.join(root, file)
            file_paths.append(file_path)

    print("Directories:")
    for directory in directory_paths:
        print(f"  {directory}")

    print(f"\nFiles:")
    for file in file_paths:
        print(f"  {file}")

    print(f"\nSummary:")
    print(f"  Total directories: {len(directory_paths)}")
    print(f"  Total files: {len(file_paths)}")
    print(f"  Total paths: {len(file_paths) + len(directory_paths)}")

else:
    print(f"Directory {base_directory} does not exist!")

Mounted at /content/drive
Directories:
  /content/drive/MyDrive/Speech/extracted_diagnosis_train
  /content/drive/MyDrive/Speech/extracted_progression_train
  /content/drive/MyDrive/Speech/extracted_progression_test
  /content/drive/MyDrive/Speech/linguistic_features
  /content/drive/MyDrive/Speech/lightweight_features
  /content/drive/MyDrive/Speech/transcripts
  /content/drive/MyDrive/Speech/processed_datasets
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/segmentation
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/audio
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/segmentation/cn
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21

In [5]:
import os
import pandas as pd
import numpy as np
import librosa
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import pickle
import json
from sklearn.preprocessing import StandardScaler, LabelEncoder
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

def extract_labels_from_directory_structure():
    base_paths = {
        'diagnosis_train': '/content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/audio',
        'progression_train': '/content/drive/MyDrive/Speech/extracted_progression_train/ADReSSo21/progression/train/audio',
        'progression_test': '/content/drive/MyDrive/Speech/extracted_progression_test/ADReSSo21/progression/test-dist/audio'
    }

    labels = {}

    for dataset_type, base_path in base_paths.items():
        if dataset_type == 'progression_test':
            files = os.listdir(base_path)
            wav_files = [f for f in files if f.endswith('.wav')]
            for wav_file in wav_files:
                file_id = wav_file.replace('.wav', '')
                labels[file_id] = {
                    'dataset': 'progression_test',
                    'label': 'unknown',
                    'file_path': os.path.join(base_path, wav_file)
                }
        else:
            subdirs = os.listdir(base_path)
            for subdir in subdirs:
                subdir_path = os.path.join(base_path, subdir)
                if os.path.isdir(subdir_path):
                    wav_files = os.listdir(subdir_path)
                    for wav_file in wav_files:
                        if wav_file.endswith('.wav'):
                            file_id = wav_file.replace('.wav', '')
                            if dataset_type == 'diagnosis_train':
                                label = 'ad' if subdir == 'ad' else 'cn'
                            else:
                                label = 'decline' if subdir == 'decline' else 'no_decline'

                            labels[file_id] = {
                                'dataset': dataset_type,
                                'label': label,
                                'file_path': os.path.join(subdir_path, wav_file)
                            }

    return labels

def extract_mel_spectrogram(audio_path, n_mels=128, hop_length=512, n_fft=2048):
    try:
        y, sr = librosa.load(audio_path, sr=22050)
        mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels,
                                                hop_length=hop_length, n_fft=n_fft)
        log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

        return {
            'mel_spectrogram': mel_spec,
            'log_mel_spectrogram': log_mel_spec,
            'mel_mean': np.mean(mel_spec, axis=1),
            'mel_std': np.std(mel_spec, axis=1),
            'log_mel_mean': np.mean(log_mel_spec, axis=1),
            'log_mel_std': np.std(log_mel_spec, axis=1)
        }
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
        return None

def extract_wav2vec2_features(audio_path, model_name="facebook/wav2vec2-base-960h"):
    try:
        processor = Wav2Vec2Processor.from_pretrained(model_name)
        model = Wav2Vec2Model.from_pretrained(model_name)

        y, sr = librosa.load(audio_path, sr=16000)

        inputs = processor(y, sampling_rate=16000, return_tensors="pt", padding=True)

        with torch.no_grad():
            outputs = model(**inputs)
            last_hidden_states = outputs.last_hidden_state

        features = last_hidden_states.squeeze().numpy()

        return {
            'wav2vec2_features': features,
            'wav2vec2_mean': np.mean(features, axis=0),
            'wav2vec2_std': np.std(features, axis=0),
            'wav2vec2_max': np.max(features, axis=0),
            'wav2vec2_min': np.min(features, axis=0)
        }
    except Exception as e:
        print(f"Error processing wav2vec2 for {audio_path}: {e}")
        return None

def extract_acoustic_features(audio_path):
    try:
        y, sr = librosa.load(audio_path, sr=22050)

        features = {}

        # Fix: Use correct librosa function names
        features['mfcc'] = np.mean(librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13), axis=1)
        features['chroma'] = np.mean(librosa.feature.chroma_stft(y=y, sr=sr), axis=1)  # Fixed function name
        features['spectral_contrast'] = np.mean(librosa.feature.spectral_contrast(y=y, sr=sr), axis=1)
        features['tonnetz'] = np.mean(librosa.feature.tonnetz(y=y, sr=sr), axis=1)
        features['zero_crossing_rate'] = np.mean(librosa.feature.zero_crossing_rate(y))
        features['spectral_centroid'] = np.mean(librosa.feature.spectral_centroid(y=y, sr=sr))
        features['spectral_rolloff'] = np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr))
        features['rms_energy'] = np.mean(librosa.feature.rms(y=y))

        tempo, beats = librosa.beat.beat_track(y=y, sr=sr)
        features['tempo'] = tempo

        return features
    except Exception as e:
        print(f"Error extracting acoustic features for {audio_path}: {e}")
        return None

def load_linguistic_features():
    ling_features = {}

    ling_path = '/content/drive/MyDrive/Speech/linguistic_features'
    if os.path.exists(ling_path):
        try:
            with open(os.path.join(ling_path, 'linguistic_features.pkl'), 'rb') as f:
                ling_features = pickle.load(f)
        except:
            try:
                with open(os.path.join(ling_path, 'linguistic_features.json'), 'r') as f:
                    ling_features = json.load(f)
            except Exception as e:
                print(f"Error loading linguistic features: {e}")

    return ling_features
def load_transcripts():
    transcripts = {}

    transcript_files = [
        '/content/drive/MyDrive/Speech/transcripts/all_categories_results.json',
        '/content/drive/MyDrive/Speech/transcripts/transcription_results.json'
    ]

    for transcript_file in transcript_files:
        if os.path.exists(transcript_file):
            try:
                with open(transcript_file, 'r') as f:
                    data = json.load(f)
                    if isinstance(data, dict):
                        transcripts.update(data)
                    elif isinstance(data, list):
                        # Handle list format if that's what the file contains
                        for item in data:
                            if isinstance(item, dict) and len(item) >= 2:
                                # Extract key-value pairs from list items
                                keys = list(item.keys())
                                transcripts[keys[0]] = item[keys[1]]
            except json.JSONDecodeError as e:
                print(f"JSON decode error in {transcript_file}: {e}")
            except Exception as e:
                print(f"Error loading {transcript_file}: {e}")

    # Continue with individual transcript loading...

    individual_transcript_path = '/content/drive/MyDrive/Speech/transcripts/individual_transcripts'
    if os.path.exists(individual_transcript_path):
        txt_files = [f for f in os.listdir(individual_transcript_path) if f.endswith('.txt')]
        for txt_file in txt_files:
            try:
                file_id = txt_file.replace('.wav.txt', '')
                with open(os.path.join(individual_transcript_path, txt_file), 'r') as f:
                    content = f.read().strip()
                    transcripts[file_id] = content
            except Exception as e:
                print(f"Error loading {txt_file}: {e}")

    return transcripts

# Define a placeholder function for load_existing_features
def load_existing_features():
    """
    Placeholder function to load existing features.
    Replace with actual loading logic if needed.
    """
    return {}


def create_comprehensive_dataset():
    print("Extracting labels from directory structure...")
    labels_dict = extract_labels_from_directory_structure()

    print("Loading existing features...")
    existing_features = load_existing_features()

    print("Loading linguistic features...")
    linguistic_features = load_linguistic_features()

    print("Loading transcripts...")
    transcripts = load_transcripts()

    dataset = []

    for file_id, label_info in labels_dict.items():
        row = {
            'file_id': file_id,
            'dataset': label_info['dataset'],
            'label': label_info['label'],
            'file_path': label_info['file_path']
        }

        if file_id in transcripts:
            row['transcript'] = transcripts[file_id]

        for category, features in existing_features.items():
            if file_id in features:
                feature_data = features[file_id]
                if isinstance(feature_data, dict):
                    for key, value in feature_data.items():
                        row[f'{category}_{key}'] = value
                else:
                    row[f'{category}_features'] = feature_data

        if file_id in linguistic_features:
            ling_data = linguistic_features[file_id]
            if isinstance(ling_data, dict):
                for key, value in ling_data.items():
                    row[f'linguistic_{key}'] = value

        dataset.append(row)

    df = pd.DataFrame(dataset)

    print(f"Dataset created with {len(df)} samples")
    print(f"Columns: {df.columns.tolist()}")
    print(f"Label distribution:")
    print(df['label'].value_counts())

    return df, labels_dict

def extract_audio_features_batch(labels_dict, sample_limit=None):
    print("Extracting mel-spectrograms and acoustic features...")

    audio_features = {}
    processed_count = 0

    for file_id, label_info in labels_dict.items():
        if sample_limit and processed_count >= sample_limit:
            break

        audio_path = label_info['file_path']

        if os.path.exists(audio_path):
            print(f"Processing {file_id}...")

            mel_features = extract_mel_spectrogram(audio_path)
            acoustic_features = extract_acoustic_features(audio_path)

            if mel_features and acoustic_features:
                combined_features = {**mel_features, **acoustic_features}
                audio_features[file_id] = combined_features
                processed_count += 1
        else:
            print(f"File not found: {audio_path}")

    return audio_features

def create_training_datasets():
    df, labels_dict = create_comprehensive_dataset()

    diagnosis_train = df[df['dataset'] == 'diagnosis_train'].copy()
    progression_train = df[df['dataset'] == 'progression_train'].copy()
    progression_test = df[df['dataset'] == 'progression_test'].copy()

    diagnosis_le = LabelEncoder()
    if len(diagnosis_train) > 0:
        diagnosis_train['label_encoded'] = diagnosis_le.fit_transform(diagnosis_train['label'])

    progression_le = LabelEncoder()
    if len(progression_train) > 0:
        progression_train['label_encoded'] = progression_le.fit_transform(progression_train['label'])

    datasets = {
        'diagnosis_train': diagnosis_train,
        'progression_train': progression_train,
        'progression_test': progression_test,
        'all_data': df,
        'labels_dict': labels_dict,
        'diagnosis_label_encoder': diagnosis_le,
        'progression_label_encoder': progression_le
    }

    print("\nDataset Summary:")
    print(f"Diagnosis Training: {len(diagnosis_train)} samples")
    if len(diagnosis_train) > 0:
        print(f"  - AD: {len(diagnosis_train[diagnosis_train['label'] == 'ad'])}")
        print(f"  - CN: {len(diagnosis_train[diagnosis_train['label'] == 'cn'])}")

    print(f"Progression Training: {len(progression_train)} samples")
    if len(progression_train) > 0:
        print(f"  - Decline: {len(progression_train[progression_train['label'] == 'decline'])}")
        print(f"  - No Decline: {len(progression_train[progression_train['label'] == 'no_decline'])}")

    print(f"Progression Test: {len(progression_test)} samples")

    return datasets

def save_datasets(datasets, output_path='/content/drive/MyDrive/Speech/processed_datasets'):
    os.makedirs(output_path, exist_ok=True)

    for name, data in datasets.items():
        if isinstance(data, pd.DataFrame):
            data.to_csv(os.path.join(output_path, f'{name}.csv'), index=False)
            data.to_pickle(os.path.join(output_path, f'{name}.pkl'))
        elif name == 'labels_dict':
            with open(os.path.join(output_path, 'labels_dict.json'), 'w') as f:
                json.dump(data, f, indent=2)
        else:
            with open(os.path.join(output_path, f'{name}.pkl'), 'wb') as f:
                pickle.dump(data, f)

    print(f"Datasets saved to {output_path}")

if __name__ == "__main__":
    datasets = create_training_datasets()
    save_datasets(datasets)

    print("\nExtracting audio features for a sample...")
    sample_audio_features = extract_audio_features_batch(datasets['labels_dict'], sample_limit=5)

    if sample_audio_features:
        print(f"Successfully extracted features for {len(sample_audio_features)} audio files")
        sample_id = list(sample_audio_features.keys())[0]
        sample_features = sample_audio_features[sample_id]
        print(f"Feature keys for {sample_id}: {list(sample_features.keys())}")

    print("\nProcessing complete!")
    print("Available datasets:")
    for name in datasets.keys():
        if isinstance(datasets[name], pd.DataFrame):
            print(f"  - {name}: {len(datasets[name])} samples")

Extracting labels from directory structure...
Loading existing features...
Loading linguistic features...
Loading transcripts...
Dataset created with 271 samples
Columns: ['file_id', 'dataset', 'label', 'file_path', 'transcript']
Label distribution:
label
ad            87
cn            79
no_decline    58
unknown       32
decline       15
Name: count, dtype: int64

Dataset Summary:
Diagnosis Training: 166 samples
  - AD: 87
  - CN: 79
Progression Training: 73 samples
  - Decline: 15
  - No Decline: 58
Progression Test: 32 samples
Datasets saved to /content/drive/MyDrive/Speech/processed_datasets

Extracting audio features for a sample...
Extracting mel-spectrograms and acoustic features...
Processing adrso003...
Processing adrso014...
Processing adrso012...
Processing adrso017...
Processing adrso010...
Successfully extracted features for 5 audio files
Feature keys for adrso003: ['mel_spectrogram', 'log_mel_spectrogram', 'mel_mean', 'mel_std', 'log_mel_mean', 'log_mel_std', 'mfcc', 'chr

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import numpy as np
from transformers import BertModel, BertTokenizer
import math

class XceptionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(XceptionBlock, self).__init__()
        self.separable_conv1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels),
            nn.Conv2d(in_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.separable_conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, padding=1, groups=out_channels),
            nn.Conv2d(out_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels)
        )
        self.skip_connection = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, stride=stride),
            nn.BatchNorm2d(out_channels)
        ) if in_channels != out_channels or stride != 1 else nn.Identity()

        self.stride = stride

    def forward(self, x):
        residual = self.skip_connection(x)
        x = self.separable_conv1(x)
        x = self.separable_conv2(x)
        if self.stride > 1:
            x = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
        return F.relu(x + residual)

class XceptionNet(nn.Module):
    def __init__(self, input_channels=1, num_classes=512):
        super(XceptionNet, self).__init__()
        self.entry_flow = nn.Sequential(
            nn.Conv2d(input_channels, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.middle_flow = nn.Sequential(
            XceptionBlock(64, 128, stride=2),
            XceptionBlock(128, 256, stride=2),
            XceptionBlock(256, 728, stride=2),
            XceptionBlock(728, 728),
            XceptionBlock(728, 728),
            XceptionBlock(728, 728),
        )

        self.exit_flow = nn.Sequential(
            XceptionBlock(728, 1024, stride=2),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.entry_flow(x)
        x = self.middle_flow(x)
        x = self.exit_flow(x)
        return x

class ViTBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super(ViTBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

class ViTEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, num_classes=512):
        super(ViTEncoder, self).__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.embed_dim = embed_dim

        self.patch_embed = nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.dropout = nn.Dropout(0.1)

        self.blocks = nn.ModuleList([
            ViTBlock(embed_dim, num_heads) for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.patch_embed(x).flatten(2).transpose(1, 2)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed[:, :x.size(1)]
        x = self.dropout(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)
        return self.head(x[:, 0])

class SpectrogramEncoder(nn.Module):
    def __init__(self, use_xception=True, feature_dim=512):
        super(SpectrogramEncoder, self).__init__()
        self.use_xception = use_xception
        if use_xception:
            self.xception = XceptionNet(input_channels=1, num_classes=feature_dim)
        else:
            self.vit = ViTEncoder(img_size=224, patch_size=16, embed_dim=768, depth=6, num_heads=8, num_classes=feature_dim)

    def forward(self, x):
        if self.use_xception:
            return self.xception(x)
        else:
            return self.vit(x)

class AcousticFeatureEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, output_dim=512):
        super(AcousticFeatureEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.encoder(x)

class LinguisticEncoder(nn.Module):
    def __init__(self, bert_model_name='bert-base-uncased', output_dim=512, max_length=512):
        super(LinguisticEncoder, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.max_length = max_length
        self.projection = nn.Linear(self.bert.config.hidden_size, output_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        return self.projection(pooled_output)

class GraphAttention(nn.Module):
    def __init__(self, feature_dim=512, num_heads=8, dropout=0.1):
        super(GraphAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = feature_dim // num_heads
        self.scale = math.sqrt(self.head_dim)

        self.query = nn.Linear(feature_dim, feature_dim)
        self.key = nn.Linear(feature_dim, feature_dim)
        self.value = nn.Linear(feature_dim, feature_dim)
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(feature_dim, feature_dim)

    def forward(self, nodes):
        B, N, D = nodes.shape

        q = self.query(nodes).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(nodes).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(nodes).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) / self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)

        out = (attn @ v).transpose(1, 2).contiguous().view(B, N, D)
        return self.proj(out)

class MultiModalFusionModule(nn.Module):
    def __init__(self, feature_dim=512, num_layers=3):
        super(MultiModalFusionModule, self).__init__()
        self.layers = nn.ModuleList([
            GraphAttention(feature_dim) for _ in range(num_layers)
        ])
        self.norm_layers = nn.ModuleList([
            nn.LayerNorm(feature_dim) for _ in range(num_layers)
        ])

    def forward(self, features):
        x = torch.stack(features, dim=1)

        for layer, norm in zip(self.layers, self.norm_layers):
            residual = x
            x = layer(x)
            x = norm(x + residual)

        return x.mean(dim=1)

class MultiModalSpeechModel(nn.Module):
    def __init__(self,
                 acoustic_input_dim=50,
                 bert_model_name='bert-base-uncased',
                 feature_dim=512,
                 num_classes_diagnosis=2,
                 num_classes_progression=2,
                 use_xception_for_spec=True):
        super(MultiModalSpeechModel, self).__init__()

        self.spectrogram_encoder = SpectrogramEncoder(use_xception=use_xception_for_spec, feature_dim=feature_dim)
        self.acoustic_encoder = AcousticFeatureEncoder(acoustic_input_dim, output_dim=feature_dim)
        self.linguistic_encoder = LinguisticEncoder(bert_model_name, output_dim=feature_dim)

        self.fusion_module = MultiModalFusionModule(feature_dim)

        self.diagnosis_classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes_diagnosis)
        )

        self.progression_classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes_progression)
        )

    def forward(self, spectrogram=None, acoustic_features=None, input_ids=None, attention_mask=None, task='diagnosis'):
        features = []

        if spectrogram is not None:
            spec_features = self.spectrogram_encoder(spectrogram)
            features.append(spec_features)

        if acoustic_features is not None:
            acoustic_feat = self.acoustic_encoder(acoustic_features)
            features.append(acoustic_feat)

        if input_ids is not None and attention_mask is not None:
            linguistic_feat = self.linguistic_encoder(input_ids, attention_mask)
            features.append(linguistic_feat)

        if len(features) == 0:
            raise ValueError("At least one modality must be provided")

        if len(features) == 1:
            fused_features = features[0]
        else:
            fused_features = self.fusion_module(features)

        if task == 'diagnosis':
            return self.diagnosis_classifier(fused_features)
        elif task == 'progression':
            return self.progression_classifier(fused_features)
        else:
            return {
                'diagnosis': self.diagnosis_classifier(fused_features),
                'progression': self.progression_classifier(fused_features)
            }

class MultiModalDataset(torch.utils.data.Dataset):
    def __init__(self, spectrograms, acoustic_features, texts, labels, tokenizer, max_length=512):
        self.spectrograms = spectrograms
        self.acoustic_features = acoustic_features
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {}

        if self.spectrograms is not None and idx < len(self.spectrograms):
            spec = self.spectrograms[idx]
            if isinstance(spec, np.ndarray):
                spec = torch.FloatTensor(spec)
            if len(spec.shape) == 2:
                spec = spec.unsqueeze(0)
            item['spectrogram'] = spec

        if self.acoustic_features is not None and idx < len(self.acoustic_features):
            acoustic = self.acoustic_features[idx]
            if isinstance(acoustic, np.ndarray):
                acoustic = torch.FloatTensor(acoustic)
            item['acoustic_features'] = acoustic

        if self.texts is not None and idx < len(self.texts):
            text = str(self.texts[idx])
            encoding = self.tokenizer(
                text,
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
            item['input_ids'] = encoding['input_ids'].squeeze()
            item['attention_mask'] = encoding['attention_mask'].squeeze()

        item['label'] = torch.LongTensor([self.labels[idx]])
        return item

def create_model_and_dataloaders(train_data, val_data=None, batch_size=16):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    train_dataset = MultiModalDataset(
        spectrograms=train_data.get('spectrograms'),
        acoustic_features=train_data.get('acoustic_features'),
        texts=train_data.get('texts'),
        labels=train_data['labels'],
        tokenizer=tokenizer
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )

    val_loader = None
    if val_data is not None:
        val_dataset = MultiModalDataset(
            spectrograms=val_data.get('spectrograms'),
            acoustic_features=val_data.get('acoustic_features'),
            texts=val_data.get('texts'),
            labels=val_data['labels'],
            tokenizer=tokenizer
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=collate_fn
        )

    # Fix: Handle the array checking properly
    acoustic_features = train_data.get('acoustic_features')
    if acoustic_features is not None and len(acoustic_features) > 0:
        acoustic_dim = len(acoustic_features[0])
    else:
        acoustic_dim = 50

    model = MultiModalSpeechModel(acoustic_input_dim=acoustic_dim)

    return model, train_loader, val_loader, tokenizer

def train_model(model, train_loader, val_loader=None, epochs=10, learning_rate=1e-4, device='cuda'):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        for batch in train_loader:
            optimizer.zero_grad()

            inputs = {}
            for key in ['spectrogram', 'acoustic_features', 'input_ids', 'attention_mask']:
                if key in batch:
                    inputs[key] = batch[key].to(device)

            labels = batch['label'].squeeze().to(device)
            outputs = model(**inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_accuracy = 100 * train_correct / train_total
        avg_train_loss = train_loss / len(train_loader)

        print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_train_loss:.4f}, Accuracy: {train_accuracy:.2f}%')

        if val_loader:
            model.eval()
            val_loss = 0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for batch in val_loader:
                    inputs = {}
                    for key in ['spectrogram', 'acoustic_features', 'input_ids', 'attention_mask']:
                        if key in batch:
                            inputs[key] = batch[key].to(device)

                    labels = batch['label'].squeeze().to(device)
                    outputs = model(**inputs)

                    loss = criterion(outputs, labels)
                    val_loss += loss.item()

                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()

            val_accuracy = 100 * val_correct / val_total
            avg_val_loss = val_loss / len(val_loader)
            scheduler.step(avg_val_loss)

            print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')

    return model

In [7]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import pickle
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from transformers import BertTokenizer
import os

def collate_fn(batch):
    collated = {}
    for key in batch[0].keys():
        if key in ['spectrogram', 'acoustic_features', 'input_ids', 'attention_mask']:
            try:
                collated[key] = torch.stack([item[key] for item in batch if key in item])
            except:
                tensors = []
                for item in batch:
                    if key in item:
                        tensor = item[key]
                        if len(tensor.shape) == 1:
                            tensor = tensor.unsqueeze(0)
                        tensors.append(tensor)
                if tensors:
                    collated[key] = torch.cat(tensors, dim=0)
        elif key == 'label':
            collated[key] = torch.cat([item[key] for item in batch])
    return collated

def load_processed_data():
    data_path = '/content/drive/MyDrive/Speech/processed_datasets'

    diagnosis_train = pd.read_pickle(os.path.join(data_path, 'diagnosis_train.pkl'))
    progression_train = pd.read_pickle(os.path.join(data_path, 'progression_train.pkl'))
    progression_test = pd.read_pickle(os.path.join(data_path, 'progression_test.pkl'))

    with open(os.path.join(data_path, 'diagnosis_label_encoder.pkl'), 'rb') as f:
        diagnosis_le = pickle.load(f)

    with open(os.path.join(data_path, 'progression_label_encoder.pkl'), 'rb') as f:
        progression_le = pickle.load(f)

    return diagnosis_train, progression_train, progression_test, diagnosis_le, progression_le

def load_lightweight_features():
    features_path = '/content/drive/MyDrive/Speech/lightweight_features'

    features = {}
    feature_files = [
        'diagnosis_ad_features.pkl',
        'diagnosis_cn_features.pkl',
        'progression_decline_features.pkl',
        'progression_no_decline_features.pkl',
        'progression_test_features.pkl'
    ]

    for file in feature_files:
        file_path = os.path.join(features_path, file)
        if os.path.exists(file_path):
            with open(file_path, 'rb') as f:
                data = pickle.load(f)
                features.update(data)

    return features

def prepare_data_for_training(df, features_dict, tokenizer, task='diagnosis'):
    valid_indices = []
    spectrograms = []
    acoustic_features_list = []
    texts = []
    labels = []

    for idx, row in df.iterrows():
        file_id = row['file_id']

        if file_id in features_dict:
            feature_data = features_dict[file_id]

            if 'mel_spectrogram' in feature_data:
                mel_spec = feature_data['mel_spectrogram']
                if mel_spec.shape[1] < 224:
                    pad_width = 224 - mel_spec.shape[1]
                    mel_spec = np.pad(mel_spec, ((0, 0), (0, pad_width)), mode='constant')
                else:
                    mel_spec = mel_spec[:, :224]

                if mel_spec.shape[0] < 224:
                    pad_height = 224 - mel_spec.shape[0]
                    mel_spec = np.pad(mel_spec, ((0, pad_height), (0, 0)), mode='constant')
                else:
                    mel_spec = mel_spec[:224, :]

                spectrograms.append(mel_spec)
            else:
                spectrograms.append(np.zeros((224, 224)))

            acoustic_feat = []
            for key in ['mfcc', 'chroma', 'spectral_contrast', 'tonnetz']:
                if key in feature_data:
                    acoustic_feat.extend(feature_data[key])

            scalar_features = []
            for key in ['zero_crossing_rate', 'spectral_centroid', 'spectral_rolloff', 'rms_energy', 'tempo']:
                if key in feature_data:
                    val = feature_data[key]
                    if isinstance(val, (list, np.ndarray)):
                        scalar_features.extend(val if len(val) > 0 else [0])
                    else:
                        scalar_features.append(val)

            acoustic_feat.extend(scalar_features)

            while len(acoustic_feat) < 50:
                acoustic_feat.append(0)
            acoustic_feat = acoustic_feat[:50]

            acoustic_features_list.append(acoustic_feat)
            texts.append(row.get('transcript', ''))
            labels.append(row['label_encoded'])
            valid_indices.append(idx)

    return {
        'spectrograms': np.array(spectrograms),
        'acoustic_features': np.array(acoustic_features_list),
        'texts': texts,
        'labels': np.array(labels)
    }

def evaluate_model(model, test_loader, device='cuda', task='diagnosis'):
    model.eval()
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for batch in test_loader:
            inputs = {}
            for key in ['spectrogram', 'acoustic_features', 'input_ids', 'attention_mask']:
                if key in batch:
                    inputs[key] = batch[key].to(device)

            labels = batch['label'].squeeze().to(device)
            outputs = model(**inputs, task=task)

            _, predicted = torch.max(outputs.data, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_predictions)
    return accuracy, all_predictions, all_labels

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    diagnosis_train, progression_train, progression_test, diagnosis_le, progression_le = load_processed_data()

    features_dict = load_lightweight_features()

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    diagnosis_data = prepare_data_for_training(diagnosis_train, features_dict, tokenizer, 'diagnosis')
    progression_data = prepare_data_for_training(progression_train, features_dict, tokenizer, 'progression')

    X_diag_train, X_diag_val, y_diag_train, y_diag_val = train_test_split(
        range(len(diagnosis_data['labels'])), diagnosis_data['labels'],
        test_size=0.2, random_state=42, stratify=diagnosis_data['labels']
    )

    diag_train_data = {
        'spectrograms': diagnosis_data['spectrograms'][X_diag_train],
        'acoustic_features': diagnosis_data['acoustic_features'][X_diag_train],
        'texts': [diagnosis_data['texts'][i] for i in X_diag_train],
        'labels': y_diag_train
    }

    diag_val_data = {
        'spectrograms': diagnosis_data['spectrograms'][X_diag_val],
        'acoustic_features': diagnosis_data['acoustic_features'][X_diag_val],
        'texts': [diagnosis_data['texts'][i] for i in X_diag_val],
        'labels': y_diag_val
    }

    model_diag, train_loader_diag, val_loader_diag, _ = create_model_and_dataloaders(diag_train_data, diag_val_data)

    print("Training diagnosis model...")
    model_diag = train_model(model_diag, train_loader_diag, val_loader_diag, epochs=10, device=device)

    print("Evaluating diagnosis model...")
    diag_accuracy, diag_preds, diag_labels = evaluate_model(model_diag, val_loader_diag, device, 'diagnosis')
    print(f"Diagnosis Accuracy: {diag_accuracy:.4f}")
    print("Diagnosis Classification Report:")
    print(classification_report(diag_labels, diag_preds, target_names=diagnosis_le.classes_))

    X_prog_train, X_prog_val, y_prog_train, y_prog_val = train_test_split(
        range(len(progression_data['labels'])), progression_data['labels'],
        test_size=0.2, random_state=42, stratify=progression_data['labels']
    )

    prog_train_data = {
        'spectrograms': progression_data['spectrograms'][X_prog_train],
        'acoustic_features': progression_data['acoustic_features'][X_prog_train],
        'texts': [progression_data['texts'][i] for i in X_prog_train],
        'labels': y_prog_train
    }

    prog_val_data = {
        'spectrograms': progression_data['spectrograms'][X_prog_val],
        'acoustic_features': progression_data['acoustic_features'][X_prog_val],
        'texts': [progression_data['texts'][i] for i in X_prog_val],
        'labels': y_prog_val
    }

    model_prog, train_loader_prog, val_loader_prog, _ = create_model_and_dataloaders(prog_train_data, prog_val_data)

    print("Training progression model...")
    model_prog = train_model(model_prog, train_loader_prog, val_loader_prog, epochs=10, device=device)

    print("Evaluating progression model...")
    prog_accuracy, prog_preds, prog_labels = evaluate_model(model_prog, val_loader_prog, device, 'progression')
    print(f"Progression Accuracy: {prog_accuracy:.4f}")
    print("Progression Classification Report:")
    print(classification_report(prog_labels, prog_preds, target_names=progression_le.classes_))

    torch.save(model_diag.state_dict(), '/content/drive/MyDrive/Speech/diagnosis_model.pth')
    torch.save(model_prog.state_dict(), '/content/drive/MyDrive/Speech/progression_model.pth')

    return model_diag, model_prog

if __name__ == "__main__":
    main()

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Training diagnosis model...
Epoch [1/10], Loss: 0.7065, Accuracy: 49.24%
Validation Loss: 0.6364, Validation Accuracy: 52.94%
Epoch [2/10], Loss: 0.7136, Accuracy: 53.03%
Validation Loss: 0.6730, Validation Accuracy: 61.76%
Epoch [3/10], Loss: 0.7252, Accuracy: 50.00%
Validation Loss: 0.9617, Validation Accuracy: 52.94%
Epoch [4/10], Loss: 0.7539, Accuracy: 62.12%
Validation Loss: 0.6361, Validation Accuracy: 52.94%
Epoch [5/10], Loss: 0.7288, Accuracy: 48.48%
Validation Loss: 0.6411, Validation Accuracy: 52.94%
Epoch [6/10], Loss: 0.7237, Accuracy: 50.76%
Validation Loss: 0.6644, Validation Accuracy: 52.94%
Epoch [7/10], Loss: 0.7138, Accuracy: 50.76%
Validation Loss: 0.6643, Validation Accuracy: 52.94%
Epoch [8/10], Loss: 0.6756, Accuracy: 57.58%
Validation Loss: 0.6729, Validation Accuracy: 52.94%
Epoch [9/10], Loss: 0.6911, Accuracy: 53.79%
Validation Loss: 0.6802, Validation Accuracy: 52.94%
Epoch [10/10], Loss: 0.7114, Accuracy: 53.79%
Validation Loss: 0.6961, Validation Accuracy