<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/DetectAlzFromSpeechGBA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os
from google.colab import drive

# Mount Google Drive
try:
    drive.mount('/content/drive', force_remount=False)
except:
    drive.mount('/content/drive')

# Define the base directory
base_directory = '/content/drive/MyDrive/Speech'

# Check if the directory exists
if os.path.exists(base_directory):
    # Separate lists for files and directories
    file_paths = []
    directory_paths = []

    for root, directories, files in os.walk(base_directory):
        # Add directory paths
        for directory in directories:
            dir_path = os.path.join(root, directory)
            directory_paths.append(dir_path)

        # Add file paths
        for file in files:
            file_path = os.path.join(root, file)
            file_paths.append(file_path)

    print("Directories:")
    for directory in directory_paths:
        print(f"  {directory}")

    print(f"\nFiles:")
    for file in file_paths:
        print(f"  {file}")

    print(f"\nSummary:")
    print(f"  Total directories: {len(directory_paths)}")
    print(f"  Total files: {len(file_paths)}")
    print(f"  Total paths: {len(file_paths) + len(directory_paths)}")

else:
    print(f"Directory {base_directory} does not exist!")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Directories:
  /content/drive/MyDrive/Speech/extracted_diagnosis_train
  /content/drive/MyDrive/Speech/extracted_progression_train
  /content/drive/MyDrive/Speech/extracted_progression_test
  /content/drive/MyDrive/Speech/linguistic_features
  /content/drive/MyDrive/Speech/lightweight_features
  /content/drive/MyDrive/Speech/transcripts
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/segmentation
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/audio
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/segmentation/cn
  /content/driv

In [None]:
import os
import pandas as pd
import numpy as np
import librosa
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import pickle
import json
from sklearn.preprocessing import StandardScaler, LabelEncoder
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

def extract_labels_from_directory_structure():
    base_paths = {
        'diagnosis_train': '/content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/audio',
        'progression_train': '/content/drive/MyDrive/Speech/extracted_progression_train/ADReSSo21/progression/train/audio',
        'progression_test': '/content/drive/MyDrive/Speech/extracted_progression_test/ADReSSo21/progression/test-dist/audio'
    }

    labels = {}

    for dataset_type, base_path in base_paths.items():
        if dataset_type == 'progression_test':
            files = os.listdir(base_path)
            wav_files = [f for f in files if f.endswith('.wav')]
            for wav_file in wav_files:
                file_id = wav_file.replace('.wav', '')
                labels[file_id] = {
                    'dataset': 'progression_test',
                    'label': 'unknown',
                    'file_path': os.path.join(base_path, wav_file)
                }
        else:
            subdirs = os.listdir(base_path)
            for subdir in subdirs:
                subdir_path = os.path.join(base_path, subdir)
                if os.path.isdir(subdir_path):
                    wav_files = os.listdir(subdir_path)
                    for wav_file in wav_files:
                        if wav_file.endswith('.wav'):
                            file_id = wav_file.replace('.wav', '')
                            if dataset_type == 'diagnosis_train':
                                label = 'ad' if subdir == 'ad' else 'cn'
                            else:
                                label = 'decline' if subdir == 'decline' else 'no_decline'

                            labels[file_id] = {
                                'dataset': dataset_type,
                                'label': label,
                                'file_path': os.path.join(subdir_path, wav_file)
                            }

    return labels

def extract_mel_spectrogram(audio_path, n_mels=128, hop_length=512, n_fft=2048):
    try:
        y, sr = librosa.load(audio_path, sr=22050)
        mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels,
                                                hop_length=hop_length, n_fft=n_fft)
        log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

        return {
            'mel_spectrogram': mel_spec,
            'log_mel_spectrogram': log_mel_spec,
            'mel_mean': np.mean(mel_spec, axis=1),
            'mel_std': np.std(mel_spec, axis=1),
            'log_mel_mean': np.mean(log_mel_spec, axis=1),
            'log_mel_std': np.std(log_mel_spec, axis=1)
        }
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
        return None

def extract_wav2vec2_features(audio_path, model_name="facebook/wav2vec2-base-960h"):
    try:
        processor = Wav2Vec2Processor.from_pretrained(model_name)
        model = Wav2Vec2Model.from_pretrained(model_name)

        y, sr = librosa.load(audio_path, sr=16000)

        inputs = processor(y, sampling_rate=16000, return_tensors="pt", padding=True)

        with torch.no_grad():
            outputs = model(**inputs)
            last_hidden_states = outputs.last_hidden_state

        features = last_hidden_states.squeeze().numpy()

        return {
            'wav2vec2_features': features,
            'wav2vec2_mean': np.mean(features, axis=0),
            'wav2vec2_std': np.std(features, axis=0),
            'wav2vec2_max': np.max(features, axis=0),
            'wav2vec2_min': np.min(features, axis=0)
        }
    except Exception as e:
        print(f"Error processing wav2vec2 for {audio_path}: {e}")
        return None

def extract_acoustic_features(audio_path):
    try:
        y, sr = librosa.load(audio_path, sr=22050)

        features = {}

        features['mfcc'] = np.mean(librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13), axis=1)
        features['chroma'] = np.mean(librosa.feature.chroma(y=y, sr=sr), axis=1)
        features['spectral_contrast'] = np.mean(librosa.feature.spectral_contrast(y=y, sr=sr), axis=1)
        features['tonnetz'] = np.mean(librosa.feature.tonnetz(y=y, sr=sr), axis=1)
        features['zero_crossing_rate'] = np.mean(librosa.feature.zero_crossing_rate(y))
        features['spectral_centroid'] = np.mean(librosa.feature.spectral_centroid(y=y, sr=sr))
        features['spectral_rolloff'] = np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr))
        features['rms_energy'] = np.mean(librosa.feature.rms(y=y))

        tempo, beats = librosa.beat.beat_track(y=y, sr=sr)
        features['tempo'] = tempo

        return features
    except Exception as e:
        print(f"Error extracting acoustic features for {audio_path}: {e}")
        return None

def load_existing_features():
    features = {}

    lightweight_path = '/content/drive/MyDrive/Speech/lightweight_features'
    if os.path.exists(lightweight_path):
        feature_files = [f for f in os.listdir(lightweight_path) if f.endswith('.pkl')]
        for feature_file in feature_files:
            try:
                with open(os.path.join(lightweight_path, feature_file), 'rb') as f:
                    data = pickle.load(f)
                    category = feature_file.replace('_features.pkl', '').replace('.pkl', '')
                    features[category] = data
            except Exception as e:
                print(f"Error loading {feature_file}: {e}")

    return features

def load_linguistic_features():
    ling_features = {}

    ling_path = '/content/drive/MyDrive/Speech/linguistic_features'
    if os.path.exists(ling_path):
        try:
            with open(os.path.join(ling_path, 'linguistic_features.pkl'), 'rb') as f:
                ling_features = pickle.load(f)
        except:
            try:
                with open(os.path.join(ling_path, 'linguistic_features.json'), 'r') as f:
                    ling_features = json.load(f)
            except Exception as e:
                print(f"Error loading linguistic features: {e}")

    return ling_features

def load_transcripts():
    transcripts = {}

    transcript_files = [
        '/content/drive/MyDrive/Speech/transcripts/all_categories_results.json',
        '/content/drive/MyDrive/Speech/transcripts/transcription_results.json'
    ]

    for transcript_file in transcript_files:
        if os.path.exists(transcript_file):
            try:
                with open(transcript_file, 'r') as f:
                    data = json.load(f)
                    transcripts.update(data)
            except Exception as e:
                print(f"Error loading {transcript_file}: {e}")

    individual_transcript_path = '/content/drive/MyDrive/Speech/transcripts/individual_transcripts'
    if os.path.exists(individual_transcript_path):
        txt_files = [f for f in os.listdir(individual_transcript_path) if f.endswith('.txt')]
        for txt_file in txt_files:
            try:
                file_id = txt_file.replace('.wav.txt', '')
                with open(os.path.join(individual_transcript_path, txt_file), 'r') as f:
                    content = f.read().strip()
                    transcripts[file_id] = content
            except Exception as e:
                print(f"Error loading {txt_file}: {e}")

    return transcripts

def create_comprehensive_dataset():
    print("Extracting labels from directory structure...")
    labels_dict = extract_labels_from_directory_structure()

    print("Loading existing features...")
    existing_features = load_existing_features()

    print("Loading linguistic features...")
    linguistic_features = load_linguistic_features()

    print("Loading transcripts...")
    transcripts = load_transcripts()

    dataset = []

    for file_id, label_info in labels_dict.items():
        row = {
            'file_id': file_id,
            'dataset': label_info['dataset'],
            'label': label_info['label'],
            'file_path': label_info['file_path']
        }

        if file_id in transcripts:
            row['transcript'] = transcripts[file_id]

        for category, features in existing_features.items():
            if file_id in features:
                feature_data = features[file_id]
                if isinstance(feature_data, dict):
                    for key, value in feature_data.items():
                        row[f'{category}_{key}'] = value
                else:
                    row[f'{category}_features'] = feature_data

        if file_id in linguistic_features:
            ling_data = linguistic_features[file_id]
            if isinstance(ling_data, dict):
                for key, value in ling_data.items():
                    row[f'linguistic_{key}'] = value

        dataset.append(row)

    df = pd.DataFrame(dataset)

    print(f"Dataset created with {len(df)} samples")
    print(f"Columns: {df.columns.tolist()}")
    print(f"Label distribution:")
    print(df['label'].value_counts())

    return df, labels_dict

def extract_audio_features_batch(labels_dict, sample_limit=None):
    print("Extracting mel-spectrograms and acoustic features...")

    audio_features = {}
    processed_count = 0

    for file_id, label_info in labels_dict.items():
        if sample_limit and processed_count >= sample_limit:
            break

        audio_path = label_info['file_path']

        if os.path.exists(audio_path):
            print(f"Processing {file_id}...")

            mel_features = extract_mel_spectrogram(audio_path)
            acoustic_features = extract_acoustic_features(audio_path)

            if mel_features and acoustic_features:
                combined_features = {**mel_features, **acoustic_features}
                audio_features[file_id] = combined_features
                processed_count += 1
        else:
            print(f"File not found: {audio_path}")

    return audio_features

def create_training_datasets():
    df, labels_dict = create_comprehensive_dataset()

    diagnosis_train = df[df['dataset'] == 'diagnosis_train'].copy()
    progression_train = df[df['dataset'] == 'progression_train'].copy()
    progression_test = df[df['dataset'] == 'progression_test'].copy()

    diagnosis_le = LabelEncoder()
    if len(diagnosis_train) > 0:
        diagnosis_train['label_encoded'] = diagnosis_le.fit_transform(diagnosis_train['label'])

    progression_le = LabelEncoder()
    if len(progression_train) > 0:
        progression_train['label_encoded'] = progression_le.fit_transform(progression_train['label'])

    datasets = {
        'diagnosis_train': diagnosis_train,
        'progression_train': progression_train,
        'progression_test': progression_test,
        'all_data': df,
        'labels_dict': labels_dict,
        'diagnosis_label_encoder': diagnosis_le,
        'progression_label_encoder': progression_le
    }

    print("\nDataset Summary:")
    print(f"Diagnosis Training: {len(diagnosis_train)} samples")
    if len(diagnosis_train) > 0:
        print(f"  - AD: {len(diagnosis_train[diagnosis_train['label'] == 'ad'])}")
        print(f"  - CN: {len(diagnosis_train[diagnosis_train['label'] == 'cn'])}")

    print(f"Progression Training: {len(progression_train)} samples")
    if len(progression_train) > 0:
        print(f"  - Decline: {len(progression_train[progression_train['label'] == 'decline'])}")
        print(f"  - No Decline: {len(progression_train[progression_train['label'] == 'no_decline'])}")

    print(f"Progression Test: {len(progression_test)} samples")

    return datasets

def save_datasets(datasets, output_path='/content/drive/MyDrive/Speech/processed_datasets'):
    os.makedirs(output_path, exist_ok=True)

    for name, data in datasets.items():
        if isinstance(data, pd.DataFrame):
            data.to_csv(os.path.join(output_path, f'{name}.csv'), index=False)
            data.to_pickle(os.path.join(output_path, f'{name}.pkl'))
        elif name == 'labels_dict':
            with open(os.path.join(output_path, 'labels_dict.json'), 'w') as f:
                json.dump(data, f, indent=2)
        else:
            with open(os.path.join(output_path, f'{name}.pkl'), 'wb') as f:
                pickle.dump(data, f)

    print(f"Datasets saved to {output_path}")

if __name__ == "__main__":
    datasets = create_training_datasets()
    save_datasets(datasets)

    print("\nExtracting audio features for a sample...")
    sample_audio_features = extract_audio_features_batch(datasets['labels_dict'], sample_limit=5)

    if sample_audio_features:
        print(f"Successfully extracted features for {len(sample_audio_features)} audio files")
        sample_id = list(sample_audio_features.keys())[0]
        sample_features = sample_audio_features[sample_id]
        print(f"Feature keys for {sample_id}: {list(sample_features.keys())}")

    print("\nProcessing complete!")
    print("Available datasets:")
    for name in datasets.keys():
        if isinstance(datasets[name], pd.DataFrame):
            print(f"  - {name}: {len(datasets[name])} samples")

Extracting labels from directory structure...
Loading existing features...
