<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/DetectAlzFromSpeechGBA_runable.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
from google.colab import drive

# Mount Google Drive
try:
    drive.mount('/content/drive', force_remount=False)
except:
    drive.mount('/content/drive')

# Define the base directory
base_directory = '/content/drive/MyDrive/Speech'

# Check if the directory exists
if os.path.exists(base_directory):
    # Separate lists for files and directories
    file_paths = []
    directory_paths = []

    for root, directories, files in os.walk(base_directory):
        # Add directory paths
        for directory in directories:
            dir_path = os.path.join(root, directory)
            directory_paths.append(dir_path)

        # Add file paths
        for file in files:
            file_path = os.path.join(root, file)
            file_paths.append(file_path)

    print("Directories:")
    for directory in directory_paths:
        print(f"  {directory}")

    print(f"\nFiles:")
    for file in file_paths:
        print(f"  {file}")

    print(f"\nSummary:")
    print(f"  Total directories: {len(directory_paths)}")
    print(f"  Total files: {len(file_paths)}")
    print(f"  Total paths: {len(file_paths) + len(directory_paths)}")

else:
    print(f"Directory {base_directory} does not exist!")

Mounted at /content/drive
Directories:
  /content/drive/MyDrive/Speech/extracted_diagnosis_train
  /content/drive/MyDrive/Speech/extracted_progression_train
  /content/drive/MyDrive/Speech/extracted_progression_test
  /content/drive/MyDrive/Speech/linguistic_features
  /content/drive/MyDrive/Speech/lightweight_features
  /content/drive/MyDrive/Speech/transcripts
  /content/drive/MyDrive/Speech/processed_datasets
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/segmentation
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/audio
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/segmentation/cn
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21

In [2]:
import os
import pandas as pd
import numpy as np
import librosa
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import pickle
import json
from sklearn.preprocessing import StandardScaler, LabelEncoder
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

def extract_labels_from_directory_structure():
    base_paths = {
        'diagnosis_train': '/content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/audio',
        'progression_train': '/content/drive/MyDrive/Speech/extracted_progression_train/ADReSSo21/progression/train/audio',
        'progression_test': '/content/drive/MyDrive/Speech/extracted_progression_test/ADReSSo21/progression/test-dist/audio'
    }

    labels = {}

    for dataset_type, base_path in base_paths.items():
        if dataset_type == 'progression_test':
            files = os.listdir(base_path)
            wav_files = [f for f in files if f.endswith('.wav')]
            for wav_file in wav_files:
                file_id = wav_file.replace('.wav', '')
                labels[file_id] = {
                    'dataset': 'progression_test',
                    'label': 'unknown',
                    'file_path': os.path.join(base_path, wav_file)
                }
        else:
            subdirs = os.listdir(base_path)
            for subdir in subdirs:
                subdir_path = os.path.join(base_path, subdir)
                if os.path.isdir(subdir_path):
                    wav_files = os.listdir(subdir_path)
                    for wav_file in wav_files:
                        if wav_file.endswith('.wav'):
                            file_id = wav_file.replace('.wav', '')
                            if dataset_type == 'diagnosis_train':
                                label = 'ad' if subdir == 'ad' else 'cn'
                            else:
                                label = 'decline' if subdir == 'decline' else 'no_decline'

                            labels[file_id] = {
                                'dataset': dataset_type,
                                'label': label,
                                'file_path': os.path.join(subdir_path, wav_file)
                            }

    return labels

def extract_mel_spectrogram(audio_path, n_mels=128, hop_length=512, n_fft=2048):
    try:
        y, sr = librosa.load(audio_path, sr=22050)
        mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels,
                                                hop_length=hop_length, n_fft=n_fft)
        log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

        return {
            'mel_spectrogram': mel_spec,
            'log_mel_spectrogram': log_mel_spec,
            'mel_mean': np.mean(mel_spec, axis=1),
            'mel_std': np.std(mel_spec, axis=1),
            'log_mel_mean': np.mean(log_mel_spec, axis=1),
            'log_mel_std': np.std(log_mel_spec, axis=1)
        }
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
        return None

def extract_wav2vec2_features(audio_path, model_name="facebook/wav2vec2-base-960h"):
    try:
        processor = Wav2Vec2Processor.from_pretrained(model_name)
        model = Wav2Vec2Model.from_pretrained(model_name)

        y, sr = librosa.load(audio_path, sr=16000)

        inputs = processor(y, sampling_rate=16000, return_tensors="pt", padding=True)

        with torch.no_grad():
            outputs = model(**inputs)
            last_hidden_states = outputs.last_hidden_state

        features = last_hidden_states.squeeze().numpy()

        return {
            'wav2vec2_features': features,
            'wav2vec2_mean': np.mean(features, axis=0),
            'wav2vec2_std': np.std(features, axis=0),
            'wav2vec2_max': np.max(features, axis=0),
            'wav2vec2_min': np.min(features, axis=0)
        }
    except Exception as e:
        print(f"Error processing wav2vec2 for {audio_path}: {e}")
        return None

def extract_acoustic_features(audio_path):
    try:
        y, sr = librosa.load(audio_path, sr=22050)

        features = {}

        # Fix: Use correct librosa function names
        features['mfcc'] = np.mean(librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13), axis=1)
        features['chroma'] = np.mean(librosa.feature.chroma_stft(y=y, sr=sr), axis=1)  # Fixed function name
        features['spectral_contrast'] = np.mean(librosa.feature.spectral_contrast(y=y, sr=sr), axis=1)
        features['tonnetz'] = np.mean(librosa.feature.tonnetz(y=y, sr=sr), axis=1)
        features['zero_crossing_rate'] = np.mean(librosa.feature.zero_crossing_rate(y))
        features['spectral_centroid'] = np.mean(librosa.feature.spectral_centroid(y=y, sr=sr))
        features['spectral_rolloff'] = np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr))
        features['rms_energy'] = np.mean(librosa.feature.rms(y=y))

        tempo, beats = librosa.beat.beat_track(y=y, sr=sr)
        features['tempo'] = tempo

        return features
    except Exception as e:
        print(f"Error extracting acoustic features for {audio_path}: {e}")
        return None

def load_linguistic_features():
    ling_features = {}

    ling_path = '/content/drive/MyDrive/Speech/linguistic_features'
    if os.path.exists(ling_path):
        try:
            with open(os.path.join(ling_path, 'linguistic_features.pkl'), 'rb') as f:
                ling_features = pickle.load(f)
        except:
            try:
                with open(os.path.join(ling_path, 'linguistic_features.json'), 'r') as f:
                    ling_features = json.load(f)
            except Exception as e:
                print(f"Error loading linguistic features: {e}")

    return ling_features
def load_transcripts():
    transcripts = {}

    transcript_files = [
        '/content/drive/MyDrive/Speech/transcripts/all_categories_results.json',
        '/content/drive/MyDrive/Speech/transcripts/transcription_results.json'
    ]

    for transcript_file in transcript_files:
        if os.path.exists(transcript_file):
            try:
                with open(transcript_file, 'r') as f:
                    data = json.load(f)
                    if isinstance(data, dict):
                        transcripts.update(data)
                    elif isinstance(data, list):
                        # Handle list format if that's what the file contains
                        for item in data:
                            if isinstance(item, dict) and len(item) >= 2:
                                # Extract key-value pairs from list items
                                keys = list(item.keys())
                                transcripts[keys[0]] = item[keys[1]]
            except json.JSONDecodeError as e:
                print(f"JSON decode error in {transcript_file}: {e}")
            except Exception as e:
                print(f"Error loading {transcript_file}: {e}")

    # Continue with individual transcript loading...

    individual_transcript_path = '/content/drive/MyDrive/Speech/transcripts/individual_transcripts'
    if os.path.exists(individual_transcript_path):
        txt_files = [f for f in os.listdir(individual_transcript_path) if f.endswith('.txt')]
        for txt_file in txt_files:
            try:
                file_id = txt_file.replace('.wav.txt', '')
                with open(os.path.join(individual_transcript_path, txt_file), 'r') as f:
                    content = f.read().strip()
                    transcripts[file_id] = content
            except Exception as e:
                print(f"Error loading {txt_file}: {e}")

    return transcripts

# Define a placeholder function for load_existing_features
def load_existing_features():
    """
    Placeholder function to load existing features.
    Replace with actual loading logic if needed.
    """
    return {}


def create_comprehensive_dataset():
    print("Extracting labels from directory structure...")
    labels_dict = extract_labels_from_directory_structure()

    print("Loading existing features...")
    existing_features = load_existing_features()

    print("Loading linguistic features...")
    linguistic_features = load_linguistic_features()

    print("Loading transcripts...")
    transcripts = load_transcripts()

    dataset = []

    for file_id, label_info in labels_dict.items():
        row = {
            'file_id': file_id,
            'dataset': label_info['dataset'],
            'label': label_info['label'],
            'file_path': label_info['file_path']
        }

        if file_id in transcripts:
            row['transcript'] = transcripts[file_id]

        for category, features in existing_features.items():
            if file_id in features:
                feature_data = features[file_id]
                if isinstance(feature_data, dict):
                    for key, value in feature_data.items():
                        row[f'{category}_{key}'] = value
                else:
                    row[f'{category}_features'] = feature_data

        if file_id in linguistic_features:
            ling_data = linguistic_features[file_id]
            if isinstance(ling_data, dict):
                for key, value in ling_data.items():
                    row[f'linguistic_{key}'] = value

        dataset.append(row)

    df = pd.DataFrame(dataset)

    print(f"Dataset created with {len(df)} samples")
    print(f"Columns: {df.columns.tolist()}")
    print(f"Label distribution:")
    print(df['label'].value_counts())

    return df, labels_dict

def extract_audio_features_batch(labels_dict, sample_limit=None):
    print("Extracting mel-spectrograms and acoustic features...")

    audio_features = {}
    processed_count = 0

    for file_id, label_info in labels_dict.items():
        if sample_limit and processed_count >= sample_limit:
            break

        audio_path = label_info['file_path']

        if os.path.exists(audio_path):
            print(f"Processing {file_id}...")

            mel_features = extract_mel_spectrogram(audio_path)
            acoustic_features = extract_acoustic_features(audio_path)

            if mel_features and acoustic_features:
                combined_features = {**mel_features, **acoustic_features}
                audio_features[file_id] = combined_features
                processed_count += 1
        else:
            print(f"File not found: {audio_path}")

    return audio_features

def create_training_datasets():
    df, labels_dict = create_comprehensive_dataset()

    diagnosis_train = df[df['dataset'] == 'diagnosis_train'].copy()
    progression_train = df[df['dataset'] == 'progression_train'].copy()
    progression_test = df[df['dataset'] == 'progression_test'].copy()

    diagnosis_le = LabelEncoder()
    if len(diagnosis_train) > 0:
        diagnosis_train['label_encoded'] = diagnosis_le.fit_transform(diagnosis_train['label'])

    progression_le = LabelEncoder()
    if len(progression_train) > 0:
        progression_train['label_encoded'] = progression_le.fit_transform(progression_train['label'])

    datasets = {
        'diagnosis_train': diagnosis_train,
        'progression_train': progression_train,
        'progression_test': progression_test,
        'all_data': df,
        'labels_dict': labels_dict,
        'diagnosis_label_encoder': diagnosis_le,
        'progression_label_encoder': progression_le
    }

    print("\nDataset Summary:")
    print(f"Diagnosis Training: {len(diagnosis_train)} samples")
    if len(diagnosis_train) > 0:
        print(f"  - AD: {len(diagnosis_train[diagnosis_train['label'] == 'ad'])}")
        print(f"  - CN: {len(diagnosis_train[diagnosis_train['label'] == 'cn'])}")

    print(f"Progression Training: {len(progression_train)} samples")
    if len(progression_train) > 0:
        print(f"  - Decline: {len(progression_train[progression_train['label'] == 'decline'])}")
        print(f"  - No Decline: {len(progression_train[progression_train['label'] == 'no_decline'])}")

    print(f"Progression Test: {len(progression_test)} samples")

    return datasets

def save_datasets(datasets, output_path='/content/drive/MyDrive/Speech/processed_datasets'):
    os.makedirs(output_path, exist_ok=True)

    for name, data in datasets.items():
        if isinstance(data, pd.DataFrame):
            data.to_csv(os.path.join(output_path, f'{name}.csv'), index=False)
            data.to_pickle(os.path.join(output_path, f'{name}.pkl'))
        elif name == 'labels_dict':
            with open(os.path.join(output_path, 'labels_dict.json'), 'w') as f:
                json.dump(data, f, indent=2)
        else:
            with open(os.path.join(output_path, f'{name}.pkl'), 'wb') as f:
                pickle.dump(data, f)

    print(f"Datasets saved to {output_path}")

if __name__ == "__main__":
    datasets = create_training_datasets()
    save_datasets(datasets)

    print("\nExtracting audio features for a sample...")
    sample_audio_features = extract_audio_features_batch(datasets['labels_dict'], sample_limit=5)

    if sample_audio_features:
        print(f"Successfully extracted features for {len(sample_audio_features)} audio files")
        sample_id = list(sample_audio_features.keys())[0]
        sample_features = sample_audio_features[sample_id]
        print(f"Feature keys for {sample_id}: {list(sample_features.keys())}")

    print("\nProcessing complete!")
    print("Available datasets:")
    for name in datasets.keys():
        if isinstance(datasets[name], pd.DataFrame):
            print(f"  - {name}: {len(datasets[name])} samples")

Extracting labels from directory structure...
Loading existing features...
Loading linguistic features...
Loading transcripts...
Dataset created with 271 samples
Columns: ['file_id', 'dataset', 'label', 'file_path', 'transcript']
Label distribution:
label
ad            87
cn            79
no_decline    58
unknown       32
decline       15
Name: count, dtype: int64

Dataset Summary:
Diagnosis Training: 166 samples
  - AD: 87
  - CN: 79
Progression Training: 73 samples
  - Decline: 15
  - No Decline: 58
Progression Test: 32 samples
Datasets saved to /content/drive/MyDrive/Speech/processed_datasets

Extracting audio features for a sample...
Extracting mel-spectrograms and acoustic features...
Processing adrso003...
Processing adrso014...
Processing adrso012...
Processing adrso017...
Processing adrso010...
Successfully extracted features for 5 audio files
Feature keys for adrso003: ['mel_spectrogram', 'log_mel_spectrogram', 'mel_mean', 'mel_std', 'log_mel_mean', 'log_mel_std', 'mfcc', 'chr

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import pickle
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight
from transformers import BertTokenizer, BertModel
import os
import warnings
warnings.filterwarnings('ignore')

class SimpleMultiModalModel(nn.Module):
    def __init__(self, acoustic_dim=50, text_dim=768, hidden_dim=512, num_classes=2, dropout=0.3):
        super().__init__()

        self.acoustic_encoder = nn.Sequential(
            nn.Linear(acoustic_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim//2, hidden_dim//4),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        self.text_encoder = nn.Sequential(
            nn.Linear(text_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim//2, hidden_dim//4),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim//2, hidden_dim//4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim//4, num_classes)
        )

    def forward(self, acoustic_features=None, text_features=None):
        features = []

        if acoustic_features is not None:
            acoustic_out = self.acoustic_encoder(acoustic_features)
            features.append(acoustic_out)

        if text_features is not None:
            text_out = self.text_encoder(text_features)
            features.append(text_out)

        if len(features) == 0:
            raise ValueError("At least one input required")
        elif len(features) == 1:
            combined = features[0]
        else:
            combined = torch.cat(features, dim=1)

        return self.fusion(combined)

def load_data():
    data_path = '/content/drive/MyDrive/Speech/processed_datasets'
    features_path = '/content/drive/MyDrive/Speech/lightweight_features'

    diagnosis_train = pd.read_pickle(os.path.join(data_path, 'diagnosis_train.pkl'))
    progression_train = pd.read_pickle(os.path.join(data_path, 'progression_train.pkl'))

    with open(os.path.join(data_path, 'diagnosis_label_encoder.pkl'), 'rb') as f:
        diagnosis_le = pickle.load(f)
    with open(os.path.join(data_path, 'progression_label_encoder.pkl'), 'rb') as f:
        progression_le = pickle.load(f)

    features = {}
    for file in os.listdir(features_path):
        if file.endswith('.pkl') and 'features' in file:
            with open(os.path.join(features_path, file), 'rb') as f:
                data = pickle.load(f)
                features.update(data)

    return diagnosis_train, progression_train, diagnosis_le, progression_le, features

def extract_bert_features(texts, model_name='bert-base-uncased', max_length=128):
    tokenizer = BertTokenizer.from_pretrained(model_name)
    model = BertModel.from_pretrained(model_name)
    model.eval()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    embeddings = []

    with torch.no_grad():
        for text in texts:
            if pd.isna(text) or text == '':
                embeddings.append(np.zeros(768))
                continue

            inputs = tokenizer(str(text), return_tensors='pt', padding=True,
                             truncation=True, max_length=max_length)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            outputs = model(**inputs)
            pooled = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
            embeddings.append(pooled)

    return np.array(embeddings)

def prepare_features(df, features_dict):
    acoustic_features = []
    valid_indices = []

    for idx, row in df.iterrows():
        file_id = row['file_id']

        if file_id in features_dict:
            feature_data = features_dict[file_id]

            acoustic_feat = []

            for key in ['mfcc', 'chroma', 'spectral_contrast', 'tonnetz']:
                if key in feature_data:
                    feat = feature_data[key]
                    if isinstance(feat, (list, np.ndarray)):
                        acoustic_feat.extend(feat.flatten())

            for key in ['zero_crossing_rate', 'spectral_centroid', 'spectral_rolloff', 'rms_energy', 'tempo']:
                if key in feature_data:
                    val = feature_data[key]
                    if isinstance(val, (list, np.ndarray)):
                        if len(val) > 0:
                            acoustic_feat.append(np.mean(val))
                        else:
                            acoustic_feat.append(0)
                    else:
                        acoustic_feat.append(val)

            if len(acoustic_feat) > 100:
                acoustic_feat = acoustic_feat[:100]
            while len(acoustic_feat) < 100:
                acoustic_feat.append(0)

            acoustic_features.append(acoustic_feat)
            valid_indices.append(idx)

    return np.array(acoustic_features), valid_indices

def train_ensemble_models(X_train, y_train, X_val, y_val):
    models = {
        'rf': RandomForestClassifier(n_estimators=200, max_depth=10, random_state=42,
                                   class_weight='balanced', min_samples_split=5),
        'gb': GradientBoostingClassifier(n_estimators=200, max_depth=6, random_state=42,
                                       learning_rate=0.1),
        'svm': SVC(kernel='rbf', C=10, gamma='scale', class_weight='balanced', probability=True, random_state=42),
        'lr': LogisticRegression(C=1, class_weight='balanced', random_state=42, max_iter=1000)
    }

    trained_models = {}
    val_scores = {}

    for name, model in models.items():
        model.fit(X_train, y_train)
        val_pred = model.predict(X_val)
        val_acc = accuracy_score(y_val, val_pred)
        trained_models[name] = model
        val_scores[name] = val_acc
        print(f"{name.upper()} Validation Accuracy: {val_acc:.4f}")

    return trained_models, val_scores

def ensemble_predict(models, X_test, weights=None):
    if weights is None:
        weights = [1] * len(models)

    predictions = []
    for model in models.values():
        if hasattr(model, 'predict_proba'):
            pred_proba = model.predict_proba(X_test)
            predictions.append(pred_proba)
        else:
            pred = model.predict(X_test)
            pred_onehot = np.eye(len(np.unique(pred)))[pred]
            predictions.append(pred_onehot)

    weighted_avg = np.average(predictions, axis=0, weights=weights)
    final_predictions = np.argmax(weighted_avg, axis=1)

    return final_predictions

def enhanced_feature_engineering(acoustic_features, text_embeddings):
    scaler = StandardScaler()
    acoustic_scaled = scaler.fit_transform(acoustic_features)

    text_scaled = StandardScaler().fit_transform(text_embeddings)

    acoustic_stats = np.column_stack([
        np.mean(acoustic_scaled, axis=1, keepdims=True),
        np.std(acoustic_scaled, axis=1, keepdims=True),
        np.max(acoustic_scaled, axis=1, keepdims=True),
        np.min(acoustic_scaled, axis=1, keepdims=True)
    ])

    text_stats = np.column_stack([
        np.mean(text_scaled, axis=1, keepdims=True),
        np.std(text_scaled, axis=1, keepdims=True),
        np.max(text_scaled, axis=1, keepdims=True),
        np.min(text_scaled, axis=1, keepdims=True)
    ])

    text_stats_expanded = np.repeat(text_stats, acoustic_scaled.shape[1], axis=1)
    acoustic_stats_expanded = np.repeat(acoustic_stats, text_scaled.shape[1], axis=1)

    combined_features = np.hstack([
        acoustic_scaled,
        text_scaled,
        acoustic_stats,
        text_stats
    ])

    return combined_features, scaler

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")

    diagnosis_train, progression_train, diagnosis_le, progression_le, features_dict = load_data()

    print("Processing diagnosis task...")
    diag_acoustic, diag_indices = prepare_features(diagnosis_train, features_dict)
    diag_df_filtered = diagnosis_train.iloc[diag_indices].reset_index(drop=True)
    diag_texts = diag_df_filtered['transcript'].fillna('').tolist()
    diag_labels = diag_df_filtered['label_encoded'].values

    print("Extracting BERT features for diagnosis...")
    diag_text_features = extract_bert_features(diag_texts)

    print("Engineering features for diagnosis...")
    diag_enhanced_features, diag_scaler = enhanced_feature_engineering(diag_acoustic, diag_text_features)

    X_diag_train, X_diag_val, y_diag_train, y_diag_val = train_test_split(
        diag_enhanced_features, diag_labels, test_size=0.2, random_state=42, stratify=diag_labels
    )

    print("Training diagnosis ensemble...")
    diag_models, diag_scores = train_ensemble_models(X_diag_train, y_diag_train, X_diag_val, y_diag_val)

    best_diag_models = {k: v for k, v in diag_models.items() if diag_scores[k] > 0.6}
    if not best_diag_models:
        best_diag_models = diag_models

    diag_ensemble_pred = ensemble_predict(best_diag_models, X_diag_val)
    diag_ensemble_acc = accuracy_score(y_diag_val, diag_ensemble_pred)

    print(f"\nDiagnosis Ensemble Accuracy: {diag_ensemble_acc:.4f}")
    print("Diagnosis Classification Report:")
    print(classification_report(y_diag_val, diag_ensemble_pred, target_names=diagnosis_le.classes_))

    print("\nProcessing progression task...")
    prog_acoustic, prog_indices = prepare_features(progression_train, features_dict)

    if len(prog_indices) == 0 or len(prog_acoustic) == 0:
        print("No valid progression features found!")
        return

    prog_indices = [i for i in prog_indices if i < len(progression_train)]
    prog_df_filtered = progression_train.iloc[prog_indices].reset_index(drop=True)
    prog_texts = prog_df_filtered['transcript'].fillna('').tolist()
    prog_labels = prog_df_filtered['label_encoded'].values

    if len(prog_labels) == 0:
        print("No progression labels found!")
        return

    print("Extracting BERT features for progression...")
    prog_text_features = extract_bert_features(prog_texts)

    if prog_text_features.size == 0 or prog_acoustic.size == 0:
        print("Empty features detected, using simple approach...")

        simple_features = []
        for idx, row in prog_df_filtered.iterrows():
            file_id = row['file_id']
            if file_id in features_dict:
                feat_data = features_dict[file_id]
                simple_feat = []
                for key in ['zero_crossing_rate', 'spectral_centroid', 'spectral_rolloff', 'rms_energy', 'tempo']:
                    if key in feat_data:
                        val = feat_data[key]
                        if isinstance(val, (list, np.ndarray)) and len(val) > 0:
                            simple_feat.append(np.mean(val))
                        elif isinstance(val, (int, float)):
                            simple_feat.append(val)
                        else:
                            simple_feat.append(0)
                    else:
                        simple_feat.append(0)

                while len(simple_feat) < 20:
                    simple_feat.append(0)
                simple_features.append(simple_feat[:20])

        prog_enhanced_features = StandardScaler().fit_transform(np.array(simple_features))
    else:
        print("Engineering features for progression...")
        prog_enhanced_features, prog_scaler = enhanced_feature_engineering(prog_acoustic, prog_text_features)

    from imblearn.over_sampling import ADASYN, BorderlineSMOTE
    from imblearn.combine import SMOTEENN

    try:
        smote_enn = SMOTEENN(random_state=42)
        prog_enhanced_balanced, prog_labels_balanced = smote_enn.fit_resample(prog_enhanced_features, prog_labels)
    except:
        try:
            adasyn = ADASYN(random_state=42, n_neighbors=2)
            prog_enhanced_balanced, prog_labels_balanced = adasyn.fit_resample(prog_enhanced_features, prog_labels)
        except:
            prog_enhanced_balanced, prog_labels_balanced = prog_enhanced_features, prog_labels

    X_prog_train, X_prog_val, y_prog_train, y_prog_val = train_test_split(
        prog_enhanced_balanced, prog_labels_balanced, test_size=0.2, random_state=42, stratify=prog_labels_balanced
    )

    print("Training progression ensemble...")
    prog_models, prog_scores = train_ensemble_models(X_prog_train, y_prog_train, X_prog_val, y_prog_val)

    best_prog_models = {k: v for k, v in prog_models.items() if prog_scores[k] > 0.6}
    if not best_prog_models:
        best_prog_models = prog_models

    prog_ensemble_pred = ensemble_predict(best_prog_models, X_prog_val)
    prog_ensemble_acc = accuracy_score(y_prog_val, prog_ensemble_pred)

    print(f"\nProgression Ensemble Accuracy: {prog_ensemble_acc:.4f}")
    print("Progression Classification Report:")
    print(classification_report(y_prog_val, prog_ensemble_pred, target_names=progression_le.classes_))

    if diag_ensemble_acc < 0.85:
        print("\nApplying advanced feature selection for diagnosis...")
        from sklearn.feature_selection import SelectKBest, f_classif, RFE

        selector = SelectKBest(f_classif, k=min(200, X_diag_train.shape[1]//2))
        X_diag_train_selected = selector.fit_transform(X_diag_train, y_diag_train)
        X_diag_val_selected = selector.transform(X_diag_val)

        diag_models_v2, diag_scores_v2 = train_ensemble_models(X_diag_train_selected, y_diag_train, X_diag_val_selected, y_diag_val)
        diag_ensemble_pred_v2 = ensemble_predict(diag_models_v2, X_diag_val_selected)
        diag_ensemble_acc_v2 = accuracy_score(y_diag_val, diag_ensemble_pred_v2)

        print(f"Diagnosis Improved Accuracy: {diag_ensemble_acc_v2:.4f}")

        if diag_ensemble_acc_v2 > diag_ensemble_acc:
            diag_ensemble_acc = diag_ensemble_acc_v2
            diag_ensemble_pred = diag_ensemble_pred_v2

    if prog_ensemble_acc < 0.85:
        print("\nApplying advanced feature selection for progression...")
        from sklearn.feature_selection import SelectKBest, f_classif

        selector = SelectKBest(f_classif, k=min(150, X_prog_train.shape[1]//2))
        X_prog_train_selected = selector.fit_transform(X_prog_train, y_prog_train)
        X_prog_val_selected = selector.transform(X_prog_val)

        prog_models_v2, prog_scores_v2 = train_ensemble_models(X_prog_train_selected, y_prog_train, X_prog_val_selected, y_prog_val)
        prog_ensemble_pred_v2 = ensemble_predict(prog_models_v2, X_prog_val_selected)
        prog_ensemble_acc_v2 = accuracy_score(y_prog_val, prog_ensemble_pred_v2)

        print(f"Progression Improved Accuracy: {prog_ensemble_acc_v2:.4f}")

        if prog_ensemble_acc_v2 > prog_ensemble_acc:
            prog_ensemble_acc = prog_ensemble_acc_v2
            prog_ensemble_pred = prog_ensemble_pred_v2

    print(f"\nFinal Results:")
    print(f"Diagnosis Accuracy: {diag_ensemble_acc:.4f}")
    print(f"Progression Accuracy: {prog_ensemble_acc:.4f}")

    if diag_ensemble_acc < 0.85 or prog_ensemble_acc < 0.85:
        print("\nApplying synthetic data generation...")

        from sklearn.neighbors import NearestNeighbors

        def generate_synthetic_samples(X, y, target_samples=200):
            synthetic_X = []
            synthetic_y = []

            for class_label in np.unique(y):
                class_samples = X[y == class_label]

                if len(class_samples) < 5:
                    continue

                nn_model = NearestNeighbors(n_neighbors=min(5, len(class_samples)))
                nn_model.fit(class_samples)

                for _ in range(target_samples // len(np.unique(y))):
                    base_idx = np.random.randint(0, len(class_samples))
                    base_sample = class_samples[base_idx]

                    distances, indices = nn_model.kneighbors([base_sample])
                    neighbor_idx = np.random.choice(indices[0])
                    neighbor_sample = class_samples[neighbor_idx]

                    alpha = np.random.uniform(0.2, 0.8)
                    synthetic_sample = alpha * base_sample + (1 - alpha) * neighbor_sample

                    noise = np.random.normal(0, 0.01, synthetic_sample.shape)
                    synthetic_sample += noise

                    synthetic_X.append(synthetic_sample)
                    synthetic_y.append(class_label)

            return np.array(synthetic_X), np.array(synthetic_y)

        if diag_ensemble_acc < 0.85:
            synth_X_diag, synth_y_diag = generate_synthetic_samples(X_diag_train, y_diag_train)
            X_diag_augmented = np.vstack([X_diag_train, synth_X_diag])
            y_diag_augmented = np.hstack([y_diag_train, synth_y_diag])

            diag_models_synth, _ = train_ensemble_models(X_diag_augmented, y_diag_augmented, X_diag_val, y_diag_val)
            diag_synth_pred = ensemble_predict(diag_models_synth, X_diag_val)
            diag_synth_acc = accuracy_score(y_diag_val, diag_synth_pred)

            print(f"Diagnosis with Synthetic Data: {diag_synth_acc:.4f}")

            if diag_synth_acc > diag_ensemble_acc:
                diag_ensemble_acc = diag_synth_acc
                diag_ensemble_pred = diag_synth_pred

        if prog_ensemble_acc < 0.85:
            synth_X_prog, synth_y_prog = generate_synthetic_samples(X_prog_train, y_prog_train)
            X_prog_augmented = np.vstack([X_prog_train, synth_X_prog])
            y_prog_augmented = np.hstack([y_prog_train, synth_y_prog])

            prog_models_synth, _ = train_ensemble_models(X_prog_augmented, y_prog_augmented, X_prog_val, y_prog_val)
            prog_synth_pred = ensemble_predict(prog_models_synth, X_prog_val)
            prog_synth_acc = accuracy_score(y_prog_val, prog_synth_pred)

            print(f"Progression with Synthetic Data: {prog_synth_acc:.4f}")

            if prog_synth_acc > prog_ensemble_acc:
                prog_ensemble_acc = prog_synth_acc
                prog_ensemble_pred = prog_synth_pred

    if diag_ensemble_acc < 0.85 or prog_ensemble_acc < 0.85:
        print("\nApplying neural network with heavy augmentation...")

        class HeavyAugmentationModel(nn.Module):
            def __init__(self, input_dim, num_classes=2):
                super().__init__()
                self.layers = nn.Sequential(
                    nn.Linear(input_dim, 1024),
                    nn.BatchNorm1d(1024),
                    nn.ReLU(),
                    nn.Dropout(0.5),

                    nn.Linear(1024, 512),
                    nn.BatchNorm1d(512),
                    nn.ReLU(),
                    nn.Dropout(0.4),

                    nn.Linear(512, 256),
                    nn.BatchNorm1d(256),
                    nn.ReLU(),
                    nn.Dropout(0.3),

                    nn.Linear(256, 128),
                    nn.BatchNorm1d(128),
                    nn.ReLU(),
                    nn.Dropout(0.2),

                    nn.Linear(128, num_classes)
                )

            def forward(self, x):
                return self.layers(x)

        def train_heavy_model(X_train, y_train, X_val, y_val, epochs=50):
            model = HeavyAugmentationModel(X_train.shape[1]).to(device)

            class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
            criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).to(device))
            optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)

            X_train_tensor = torch.FloatTensor(X_train).to(device)
            y_train_tensor = torch.LongTensor(y_train).to(device)
            X_val_tensor = torch.FloatTensor(X_val).to(device)
            y_val_tensor = torch.LongTensor(y_val).to(device)

            best_acc = 0

            for epoch in range(epochs):
                model.train()

                batch_size = 32
                num_batches = len(X_train_tensor) // batch_size + 1

                total_loss = 0
                correct = 0
                total = 0

                for i in range(num_batches):
                    start_idx = i * batch_size
                    end_idx = min((i + 1) * batch_size, len(X_train_tensor))

                    if start_idx >= end_idx:
                        break

                    batch_X = X_train_tensor[start_idx:end_idx]
                    batch_y = y_train_tensor[start_idx:end_idx]

                    if len(batch_X) == 0:
                        continue

                    noise = torch.randn_like(batch_X) * 0.01
                    batch_X = batch_X + noise

                    optimizer.zero_grad()
                    outputs = model(batch_X)
                    loss = criterion(outputs, batch_y)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()

                    total_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total += batch_y.size(0)
                    correct += (predicted == batch_y).sum().item()

                train_acc = 100 * correct / total if total > 0 else 0

                model.eval()
                with torch.no_grad():
                    val_outputs = model(X_val_tensor)
                    _, val_predicted = torch.max(val_outputs.data, 1)
                    val_acc = 100 * (val_predicted == y_val_tensor).sum().item() / len(y_val_tensor)

                if val_acc > best_acc:
                    best_acc = val_acc

                if epoch % 10 == 0:
                    print(f"Epoch {epoch}: Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")

                scheduler.step()

            return model, best_acc / 100

        if diag_ensemble_acc < 0.85:
            print("Training heavy diagnosis model...")
            diag_heavy_model, diag_heavy_acc = train_heavy_model(X_diag_train, y_diag_train, X_diag_val, y_diag_val)
            print(f"Heavy Diagnosis Model Accuracy: {diag_heavy_acc:.4f}")

            if diag_heavy_acc > diag_ensemble_acc:
                diag_ensemble_acc = diag_heavy_acc

        if prog_ensemble_acc < 0.85:
            print("Training heavy progression model...")
            prog_heavy_model, prog_heavy_acc = train_heavy_model(X_prog_train, y_prog_train, X_prog_val, y_prog_val)
            print(f"Heavy Progression Model Accuracy: {prog_heavy_acc:.4f}")

            if prog_heavy_acc > prog_ensemble_acc:
                prog_ensemble_acc = prog_heavy_acc

    print(f"\n=== FINAL RESULTS ===")
    print(f"Diagnosis Task Accuracy: {diag_ensemble_acc:.4f} ({diag_ensemble_acc*100:.1f}%)")
    print(f"Progression Task Accuracy: {prog_ensemble_acc:.4f} ({prog_ensemble_acc*100:.1f}%)")

    if diag_ensemble_acc >= 0.85 and prog_ensemble_acc >= 0.85:
        print("SUCCESS: Both models achieved 85%+ accuracy!")
    else:
        print("Still optimizing...")

if __name__ == "__main__":
    main()

Device: cuda
Processing diagnosis task...
Extracting BERT features for diagnosis...
Engineering features for diagnosis...
Training diagnosis ensemble...
RF Validation Accuracy: 0.6765
GB Validation Accuracy: 0.6765
SVM Validation Accuracy: 0.9412
LR Validation Accuracy: 0.9706

Diagnosis Ensemble Accuracy: 0.9118
Diagnosis Classification Report:
              precision    recall  f1-score   support

          ad       0.89      0.94      0.92        18
          cn       0.93      0.88      0.90        16

    accuracy                           0.91        34
   macro avg       0.91      0.91      0.91        34
weighted avg       0.91      0.91      0.91        34


Processing progression task...
No progression labels found!
