<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/Detecting_dementia_from_speech_and_transcripts_using_transformers_May24.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

from transformers import AutoTokenizer, AutoModel, ViTFeatureExtractor, ViTModel

import librosa
import librosa.display
import numpy as np
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import warnings

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

print("All necessary libraries imported successfully.")


All necessary libraries imported successfully.


In [2]:
# --- Configuration ---
# IMPORTANT: Before running this code, you must manually extract your .tgz archives.
# For example, extract ADReSSo21-progression-train.tgz to a directory like 'ADReSSo21_extracted/progression/train/'
# And ADReSSo21-diagnosis-train.tgz to 'ADReSSo21_extracted/diagnosis/train/' etc.

# Define the base directory where you extracted your ADReSSo dataset
BASE_EXTRACTED_DIR = "ADReSSo21_extracted" # Adjust this path as per your extraction location

# --- Data Collection Function ---
def collect_adress_data(base_path, task_type, split_type):
    """
    Collects audio and transcript file paths and their corresponding labels
    based on the ADReSSo dataset structure.

    Args:
        base_path (str): The root directory of the extracted ADReSSo dataset.
        task_type (str): 'progression' or 'diagnosis'.
        split_type (str): 'train' or 'test'.

    Returns:
        pd.DataFrame: A DataFrame with 'audio_path', 'transcript_path', and 'label'.
    """
    data = []

    # Define the base path for the current task and split
    current_base_path = os.path.join(base_path, task_type, split_type)

    if task_type == 'progression':
        if split_type == 'train':
            audio_subdirs = ['no_decline', 'decline']
            segmentation_subdirs = ['no_decline', 'decline']
        elif split_type == 'test':
            # For test-dist, the structure is slightly different as per your description
            audio_subdirs = [''] # Audio files directly under audio/
            segmentation_subdirs = [''] # CSV files directly under segmentation/
            current_base_path = os.path.join(base_path, task_type, 'test-dist') # Adjust path for test-dist
        else:
            raise ValueError(f"Invalid split_type '{split_type}' for 'progression' task.")

        label_mapping = {'no_decline': 0, 'decline': 1} # Example mapping

    elif task_type == 'diagnosis':
        audio_subdirs = ['cn', 'ad']
        segmentation_subdirs = ['cn', 'ad']
        label_mapping = {'cn': 0, 'ad': 1} # CN: Control, AD: Alzheimer's Disease
    else:
        raise ValueError(f"Invalid task_type: {task_type}. Must be 'progression' or 'diagnosis'.")

    for i, audio_subdir in enumerate(audio_subdirs):
        audio_dir = os.path.join(current_base_path, 'audio', audio_subdir)
        segmentation_dir = os.path.join(current_base_path, 'segmentation', segmentation_subdirs[i])

        # Get all .wav files in the current audio directory
        audio_files = [f for f in os.listdir(audio_dir) if f.endswith('.wav')]

        for audio_filename in audio_files:
            patient_id = os.path.splitext(audio_filename)[0] # e.g., 'patient_001'
            audio_path = os.path.join(audio_dir, audio_filename)
            transcript_path = os.path.join(segmentation_dir, patient_id + '.csv') # Assuming .csv for transcripts

            # Assign label based on subdirectory
            if split_type == 'test' and task_type == 'progression':
                # For test-dist, labels might not be explicitly in subfolders.
                # You might need a manifest file or infer from filenames if available.
                # For now, we'll assume a dummy label if no specific subfolder indicates it.
                # In a real scenario, you'd load a manifest or infer from file names.
                # The ADReSSo test set usually has a separate manifest for labels.
                # For this example, we'll assign a placeholder label for test.
                current_label = -1 # Placeholder for test set, will be ignored during evaluation
            else:
                current_label = label_mapping[audio_subdir]

            # Check if transcript file exists (important for real data)
            if not os.path.exists(transcript_path):
                print(f"Warning: Transcript file not found for {audio_path} at {transcript_path}. Skipping.")
                continue

            data.append({
                "patient_id": patient_id,
                "audio_path": audio_path,
                "transcript_path": transcript_path,
                "label": current_label
            })

    return pd.DataFrame(data)

# --- Collect data for a specific task (e.g., 'diagnosis') ---
# You can change 'diagnosis' to 'progression' if you want to train on that task.
# The paper focuses on diagnosis, so we'll use that as the primary example.
TASK_TO_USE = 'diagnosis'

# Collect training data
train_df = collect_adress_data(BASE_EXTRACTED_DIR, TASK_TO_USE, 'train')
# Collect test data
test_df = collect_adress_data(BASE_EXTRACTED_DIR, TASK_TO_USE, 'test')

# For the ADReSSo dataset, the paper splits the *provided* train set into
# train and validation. The test set is separate.
# So, we split the collected 'train_df' further.
train_df, val_df = train_test_split(train_df, test_size=0.35, stratify=train_df['label'], random_state=42)

print(f"--- Data Collection Summary for Task: {TASK_TO_USE} ---")
print(f"Train samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")
print(f"Test samples: {len(test_df)}")
print("Labels in training set:", train_df['label'].value_counts())
print("Labels in validation set:", val_df['label'].value_counts())
print("Labels in test set:", test_df['label'].value_counts())


class ADRESSODataset(Dataset):
    def __init__(self, dataframe, tokenizer, feature_extractor, max_seq_len=512, sr=16000):
        """
        Initializes the dataset.
        Args:
            dataframe (pd.DataFrame): DataFrame containing 'audio_path', 'transcript_path', and 'label'.
            tokenizer (transformers.PreTrainedTokenizer): Tokenizer for text (e.g., BertTokenizer).
            feature_extractor (transformers.ViTFeatureExtractor): Feature extractor for Vision Transformer.
            max_seq_len (int): Maximum sequence length for BERT tokenizer.
            sr (int): Sampling rate for audio processing.
        """
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.max_seq_len = max_seq_len
        self.sr = sr
        self.n_mels = 224 # As per paper for log-Mel spectrograms
        self.n_mfcc = 40 # As per paper for MFCCs
        self.n_fft = 2048
        self.hop_length_mel = 1024
        self.hop_length_mfcc = 512

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        audio_path = row['audio_path']
        transcript_path = row['transcript_path']
        label = row['label']

        # --- Audio Feature Extraction (Log-Mel Spectrograms and MFCCs) ---
        try:
            # Load the actual audio file
            audio, sr = librosa.load(audio_path, sr=self.sr)

            # Log-Mel Spectrogram
            mel_spectrogram = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=self.n_fft,
                                                              hop_length=self.hop_length_mel, n_mels=self.n_mels)
            log_mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)

            # MFCCs (not used for the ViT input in this example, but extracted as per paper)
            # mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_fft=self.n_fft,
            #                              hop_length=self.hop_length_mfcc, n_mfcc=self.n_mfcc)

            # Delta and Delta-Delta features for Log-Mel Spectrogram
            delta_log_mel = librosa.feature.delta(log_mel_spectrogram)
            delta_delta_log_mel = librosa.feature.delta(log_mel_spectrogram, order=2)

            # Pad or truncate to a fixed size for consistent image dimensions (224x224 for ViT)
            target_width = 224 # For ViT input width (frames)

            # Ensure all features have the same number of frames
            min_frames = min(log_mel_spectrogram.shape[1], delta_log_mel.shape[1], delta_delta_log_mel.shape[1])

            log_mel_spectrogram = log_mel_spectrogram[:, :min_frames]
            delta_log_mel = delta_log_mel[:, :min_frames]
            delta_delta_log_mel = delta_delta_log_mel[:, :min_frames]

            # Pad or truncate to target_width (224 frames)
            if log_mel_spectrogram.shape[1] < target_width:
                pad_width = target_width - log_mel_spectrogram.shape[1]
                log_mel_spectrogram = np.pad(log_mel_spectrogram, ((0, 0), (0, pad_width)), mode='constant')
                delta_log_mel = np.pad(delta_log_mel, ((0, 0), (0, pad_width)), mode='constant')
                delta_delta_log_mel = np.pad(delta_delta_log_mel, ((0, 0), (0, pad_width)), mode='constant')
            elif log_mel_spectrogram.shape[1] > target_width:
                log_mel_spectrogram = log_mel_spectrogram[:, :target_width]
                delta_log_mel = delta_log_mel[:, :target_width]
                delta_delta_log_mel = delta_delta_log_mel[:, :target_width]

            # Ensure the height matches ViT's expected input height (224)
            # This is crucial if n_mels is not 224. The paper uses n_mels=224, so this might not be strictly needed.
            # If your n_mels is different, you'd need resizing here.
            if log_mel_spectrogram.shape[0] != self.feature_extractor.size:
                # Simple resizing by repeating rows or interpolation.
                # For real implementation, consider `torchvision.transforms.Resize` or `cv2.resize`
                log_mel_spectrogram = np.resize(log_mel_spectrogram, (self.feature_extractor.size, target_width))
                delta_log_mel = np.resize(delta_log_mel, (self.feature_extractor.size, target_width))
                delta_delta_log_mel = np.resize(delta_delta_log_mel, (self.feature_extractor.size, target_width))

            # Stack the 3 channels (C, H, W) for ViT input
            speech_image = np.stack([log_mel_spectrogram, delta_log_mel, delta_delta_log_mel], axis=0)
            speech_image = torch.tensor(speech_image, dtype=torch.float32)

            # Prepare image for ViT using the feature extractor (handles normalization and final resizing)
            pixel_values = self.feature_extractor(images=speech_image, return_tensors="pt").pixel_values.squeeze(0)

        except Exception as e:
            print(f"Error processing audio file {audio_path}: {e}")
            # Return dummy tensor if audio processing fails
            pixel_values = torch.zeros(3, self.feature_extractor.size, self.feature_extractor.size)


        # --- Text Feature Extraction (BERT) ---
        transcript_text = ""
        try:
            # The .csv files in 'segmentation' directories might contain speaker turns and timestamps.
            # You will need to parse this CSV to extract the actual spoken text.
            # Example: If the CSV has a column named 'text' with the transcript.
            # For demonstration, we'll assume the CSV is simple and just read its content.
            # You might need to use pandas: pd.read_csv(transcript_path) and extract the relevant column.
            with open(transcript_path, 'r', encoding='utf-8') as f:
                # Read the CSV content. This is a placeholder.
                # You'll likely need more sophisticated CSV parsing here.
                # E.g., df_transcript = pd.read_csv(transcript_path)
                # transcript_text = " ".join(df_transcript['transcript_column'].tolist())
                transcript_text = f.read() # Simple read, assuming raw text or single line

            if not transcript_text.strip(): # Check if transcript is empty after stripping whitespace
                print(f"Warning: Empty transcript for {transcript_path}. Using placeholder.")
                transcript_text = "[CLS] [SEP]" # Placeholder for empty transcript

        except Exception as e:
            print(f"Error reading transcript file {transcript_path}: {e}")
            transcript_text = "[CLS] [SEP]" # Fallback for corrupted/missing transcript

        text_inputs = self.tokenizer(transcript_text, return_tensors="pt",
                                     max_length=self.max_seq_len, truncation=True, padding="max_length")

        # Squeeze the batch dimension added by return_tensors="pt"
        input_ids = text_inputs['input_ids'].squeeze(0)
        attention_mask = text_inputs['attention_mask'].squeeze(0)
        token_type_ids = text_inputs['token_type_ids'].squeeze(0) if 'token_type_ids' in text_inputs else None

        return {
            'pixel_values': pixel_values,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids,
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Initialize tokenizer and feature extractor
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

# Create dataset instances
train_dataset = ADRESSODataset(train_df, tokenizer, feature_extractor)
val_dataset = ADRESSODataset(val_df, tokenizer, feature_extractor)
test_dataset = ADRESSODataset(test_df, tokenizer, feature_extractor)

# Create DataLoaders
batch_size = 8 # Adjust based on your GPU memory
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print("\nDataset and DataLoaders created.")


FileNotFoundError: [Errno 2] No such file or directory: 'ADReSSo21_extracted/diagnosis/train/audio/cn'