<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/Detecting_dementia_from_speech_and_transcripts_using_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Step 1: Set Up Google Colab Environment

In [None]:
import os
from google.colab import drive
import tarfile

# Mount Google Drive
drive.mount('/content/drive')

# Install required libraries
!pip install torch torchvision torchaudio
!pip install transformers
!pip install librosa
!pip install numpy pandas scikit-learn
!pip install matplotlib

# Extract datasets
data_dir = '/content/drive/MyDrive/Voice/'
extract_dir = '/content/ADReSSo21/'

os.makedirs(extract_dir, exist_ok=True)

datasets = [
    'ADReSSo21-diagnosis-train.tgz',
    'ADReSSo21-progression-test.tgz',
    'ADReSSo21-progression-train.tgz'
]

for dataset in datasets:
    tar_path = os.path.join(data_dir, dataset)
    with tarfile.open(tar_path, 'r:gz') as tar:
        tar.extractall(extract_dir)
    print(f"Extracted {dataset}")

# Verify GPU availability
import torch
print("GPU Available:", torch.cuda.is_available())

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Extracted ADReSSo21-diagnosis-train.tgz
Extracted ADReSSo21-progression-test.tgz
Extracted ADReSSo21-progression-train.tgz
GPU Available: True


# Step 2: Prepare the Dataset

In [22]:
import librosa
import numpy as np
import pandas as pd
import os
from sklearn.model_selection import train_test_split
import glob

# Define base path
base_dir = '/content/ADReSSo21/ADReSSo21/'
train_base_dir = os.path.join(base_dir, 'diagnosis/train')
test_base_dir = os.path.join(base_dir, 'progression/test-dist')
output_dir = '/content/ADReSSo21/'  # Directory to save pickle files

# Function to extract log-Mel spectrogram and MFCCs with delta and delta-delta
def extract_audio_features(audio_path, sr=16000, n_mels=128, n_mfcc=13):
    # Load audio
    y, sr = librosa.load(audio_path, sr=sr)

    # Log-Mel spectrogram
    mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels)
    log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

    # MFCCs
    mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=n_mfcc)

    # Delta and delta-delta
    delta_mfcc = librosa.feature.delta(mfcc)
    delta_delta_mfcc = librosa.feature.delta(mfcc, order=2)

    # Stack features as 3-channel image
    log_mel_image = np.stack([log_mel_spec, librosa.feature.delta(log_mel_spec), librosa.feature.delta(log_mel_spec, order=2)], axis=-1)
    mfcc_image = np.stack([mfcc, delta_mfcc, delta_delta_mfcc], axis=-1)

    return log_mel_image, mfcc_image

# Load dataset
def load_dataset(train_base_dir, test_base_dir=None, audio_only=False):
    data = []
    train_audio_dir = os.path.join(train_base_dir, 'audio')
    train_transcript_dir = os.path.join(train_base_dir, 'segmentation')

    # Debug: List directory contents
    print("Checking train base directory:", train_base_dir)
    if os.path.exists(train_base_dir):
        print("Files in train base:", os.listdir(train_base_dir))
        print("Audio subfolders:", os.listdir(train_audio_dir))
        print("Segmentation subfolders:", os.listdir(train_transcript_dir))
    else:
        print("Train base directory does not exist:", train_base_dir)

    # List potential transcript files
    transcript_candidates = []
    for ext in ['*.cha', '*.txt', '*.transcript', '*.par', '*.TextGrid']:
        transcript_candidates.extend(glob.glob(os.path.join(train_base_dir, '**', ext), recursive=True))
    print("Potential transcript files:", transcript_candidates[:10])

    # Check for metadata file
    metadata_file = None
    for fname in ['diagnosis.csv', 'metadata.csv', 'labels.csv', 'adresso-train-mmse-scores.csv']:
        if os.path.exists(os.path.join(train_base_dir, fname)):
            metadata_file = os.path.join(train_base_dir, fname)
            break

    if metadata_file:
        print("Found metadata file:", metadata_file)
        metadata = pd.read_csv(metadata_file)
        print("Metadata columns:", metadata.columns.tolist())
        print("Unique dx values:", metadata['dx'].unique())
        for _, row in metadata.iterrows():
            audio_id = str(row.get('adressfname', ''))
            if not audio_id:
                continue
            dx = str(row.get('dx', '')).lower()
            subfolder = 'ad' if dx == 'ad' else 'cn'
            audio_file = os.path.join(train_audio_dir, subfolder, f"{audio_id}.wav")
            # Try multiple transcript naming conventions
            transcript_files = [
                os.path.join(train_transcript_dir, subfolder, f"{audio_id}.cha"),
                os.path.join(train_transcript_dir, subfolder, f"{audio_id}.txt"),
                os.path.join(train_transcript_dir, subfolder, f"S{audio_id[-3:]}.cha"),
                os.path.join(train_transcript_dir, subfolder, f"{audio_id}.transcript"),
                os.path.join(train_transcript_dir, subfolder, f"{audio_id}.par"),
                os.path.join(train_transcript_dir, subfolder, f"{audio_id}.TextGrid"),
                os.path.join(train_audio_dir, subfolder, f"{audio_id}.cha"),
                os.path.join(train_audio_dir, subfolder, f"{audio_id}.txt"),
                os.path.join(train_transcript_dir, f"{audio_id}.cha"),
                os.path.join(train_audio_dir, f"{audio_id}.cha")
            ]
            transcript_file = None
            for tf in transcript_files:
                if os.path.exists(tf):
                    transcript_file = tf
                    break
            if os.path.exists(audio_file) and (transcript_file or audio_only):
                label = 1 if dx == 'ad' else 0
                data.append({
                    'audio_path': audio_file,
                    'transcript_path': transcript_file if transcript_file else None,
                    'label': label
                })
            else:
                print(f"Missing pair for ID {audio_id} (dx={dx}): Audio exists={os.path.exists(audio_file)}, Transcript exists={transcript_file is not None}")
        print(f"Loaded {len(data)} samples from metadata")

    # Fallback: Pair audio files without metadata
    else:
        print("No metadata file found, pairing audio files")
        for subfolder in ['ad', 'cn']:
            audio_files = glob.glob(os.path.join(train_audio_dir, subfolder, '*.wav'))
            print(f"Found {len(audio_files)} audio files in {train_audio_dir}/{subfolder}")
            print("Sample audio files:", audio_files[:5])
            for audio_file in audio_files:
                audio_id = os.path.basename(audio_file).replace('.wav', '')
                transcript_files = [
                    os.path.join(train_transcript_dir, subfolder, f"{audio_id}.cha"),
                    os.path.join(train_transcript_dir, subfolder, f"{audio_id}.txt"),
                    os.path.join(train_transcript_dir, subfolder, f"S{audio_id[-3:]}.cha"),
                    os.path.join(train_transcript_dir, subfolder, f"{audio_id}.transcript"),
                    os.path.join(train_transcript_dir, subfolder, f"{audio_id}.par"),
                    os.path.join(train_transcript_dir, subfolder, f"{audio_id}.TextGrid"),
                    os.path.join(train_audio_dir, subfolder, f"{audio_id}.cha"),
                    os.path.join(train_audio_dir, subfolder, f"{audio_id}.txt")
                ]
                transcript_file = None
                for tf in transcript_files:
                    if os.path.exists(tf):
                        transcript_file = tf
                        break
                if transcript_file or audio_only:
                    label = 1 if subfolder == 'ad' else 0
                    data.append({
                        'audio_path': audio_file,
                        'transcript_path': transcript_file if transcript_file else None,
                        'label': label
                    })
                else:
                    print(f"No transcript found for {audio_id} in {subfolder}. Checked: {transcript_files}")
        print(f"Loaded {len(data)} samples from file pairing")

    train_df = pd.DataFrame(data)
    print(f"Total samples loaded: {len(train_df)}")

    # Debug: Show sample data
    if not train_df.empty:
        print("Sample data:", train_df.head().to_dict())

    # Check if train_df is empty
    if train_df.empty:
        audio_samples = glob.glob(os.path.join(train_audio_dir, '**/*.wav'), recursive=True)
        raise ValueError(f"No valid audio-transcript pairs found. Check directories:\n- Audio: {train_audio_dir}\n- Transcripts: {train_transcript_dir}\nRun '!ls -R /content/ADReSSo21/ADReSSo21/diagnosis/train/' to inspect.\nSample audio files: {audio_samples[:5]}\nPotential transcripts: {transcript_candidates[:5]}")

    # Split train and validation (65%-35%)
    train_df, val_df = train_test_split(train_df, test_size=0.35, random_state=42)

    # Load test data
    test_df = pd.DataFrame()
    if test_base_dir and os.path.exists(test_base_dir):
        data = []
        test_audio_dir = os.path.join(test_base_dir, 'audio')
        test_transcript_dir = os.path.join(test_base_dir, 'segmentation')
        audio_files = glob.glob(os.path.join(test_audio_dir, '**/*.wav'), recursive=True)
        print(f"Found {len(audio_files)} test audio files in {test_audio_dir}")
        for audio_file in audio_files:
            audio_id = os.path.basename(audio_file).replace('.wav', '')
            transcript_files = [
                os.path.join(test_transcript_dir, f"{audio_id}.cha"),
                os.path.join(test_transcript_dir, f"{audio_id}.txt"),
                os.path.join(test_audio_dir, f"{audio_id}.cha"),
                os.path.join(test_audio_dir, f"{audio_id}.txt")
            ]
            transcript_file = None
            for tf in transcript_files:
                if os.path.exists(tf):
                    transcript_file = tf
                    break
            if transcript_file or audio_only:
                label = 0  # Placeholder
                data.append({
                    'audio_path': audio_file,
                    'transcript_path': transcript_file if transcript_file else None,
                    'label': label
                })
        test_df = pd.DataFrame(data)
        print(f"Test samples loaded: {len(test_df)}")

    # Save dataframes as pickle files
    os.makedirs(output_dir, exist_ok=True)
    train_df.to_pickle(os.path.join(output_dir, 'train_df.pkl'))
    val_df.to_pickle(os.path.join(output_dir, 'val_df.pkl'))
    test_df.to_pickle(os.path.join(output_dir, 'test_df.pkl'))
    print(f"Saved dataframes to {output_dir}: train_df.pkl, val_df.pkl, test_df.pkl")

    return train_df, val_df, test_df

# Preprocess dataset
try:
    # Keep audio_only=True since transcripts are missing
    train_df, val_df, test_df = load_dataset(train_base_dir, test_base_dir, audio_only=True)
    print("Training samples:", len(train_df))
    print("Validation samples:", len(val_df))
    print("Test samples:", len(test_df))
except ValueError as e:
    print("Error:", e)

Checking train base directory: /content/ADReSSo21/ADReSSo21/diagnosis/train
Files in train base: ['segmentation', 'audio', 'adresso-train-mmse-scores.csv']
Audio subfolders: ['ad', 'cn']
Segmentation subfolders: ['ad', 'cn']
Potential transcript files: []
Found metadata file: /content/ADReSSo21/ADReSSo21/diagnosis/train/adresso-train-mmse-scores.csv
Metadata columns: ['Unnamed: 0', 'adressfname', 'mmse', 'dx']
Unique dx values: ['ad' 'cn']
Loaded 166 samples from metadata
Total samples loaded: 166
Sample data: {'audio_path': {0: '/content/ADReSSo21/ADReSSo21/diagnosis/train/audio/ad/adrso024.wav', 1: '/content/ADReSSo21/ADReSSo21/diagnosis/train/audio/ad/adrso025.wav', 2: '/content/ADReSSo21/ADReSSo21/diagnosis/train/audio/ad/adrso027.wav', 3: '/content/ADReSSo21/ADReSSo21/diagnosis/train/audio/ad/adrso028.wav', 4: '/content/ADReSSo21/ADReSSo21/diagnosis/train/audio/ad/adrso031.wav'}, 'transcript_path': {0: None, 1: None, 2: None, 3: None, 4: None}, 'label': {0: 1, 1: 1, 2: 1, 3: 1, 4:

In [25]:
import torch
import torch.nn as nn
from transformers import ViTModel, BertModel, BertTokenizer
import librosa
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import os

# Define paths
data_dir = '/content/ADReSSo21/'

class ADReSSoDataset(Dataset):
    def __init__(self, dataframe, vit_model, bert_model, tokenizer):
        self.dataframe = dataframe
        self.vit_model = vit_model
        self.bert_model = bert_model
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        item = self.dataframe.iloc[idx]
        try:
            log_mel_image, mfcc_image = extract_audio_features(item['audio_path'])
            text_features = torch.zeros(768)
            if item['transcript_path'] is not None:
                text = clean_cha_file(item['transcript_path'])
                encoding = self.tokenizer(text, return_tensors='pt', max_length=512, truncation=True, padding=True)
                with torch.no_grad():
                    outputs = self.bert_model(**encoding)
                    text_features = outputs.last_hidden_state[:, 0, :]
            return {
                'log_mel_image': torch.tensor(log_mel_image).permute(2, 0, 1),  # [C, H, W]
                'mfcc_image': torch.tensor(mfcc_image).permute(2, 0, 1),
                'text_features': text_features,
                'label': torch.tensor(item['label'], dtype=torch.long)
            }
        except Exception as e:
            print(f"Error processing {item['audio_path']}: {e}")
            return None

# Custom collate function to filter None items
def custom_collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if not batch:
        raise ValueError("Empty batch after filtering")
    return torch.utils.data.dataloader.default_collate(batch)

# Placeholder for clean_cha_file
def clean_cha_file(file_path):
    with open(file_path, 'r') as f:
        text = f.read()
    return text

# Updated extract_audio_features with fixed length
def extract_audio_features(audio_path, sr=16000, n_mels=128, n_mfcc=13, max_frames=1000):
    try:
        y, sr = librosa.load(audio_path, sr=sr)
        # Log-Mel spectrogram
        mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels)
        log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
        # MFCCs
        mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=n_mfcc)
        delta_mfcc = librosa.feature.delta(mfcc)
        delta_delta_mfcc = librosa.feature.delta(mfcc, order=2)
        # Pad or truncate
        def pad_or_truncate(feature, max_frames):
            h, w = feature.shape
            if w < max_frames:
                padded = np.zeros((h, max_frames))
                padded[:, :w] = feature
                return padded
            return feature[:, :max_frames]
        log_mel_spec = pad_or_truncate(log_mel_spec, max_frames)
        delta_mel = pad_or_truncate(librosa.feature.delta(log_mel_spec), max_frames)
        delta_delta_mel = pad_or_truncate(librosa.feature.delta(log_mel_spec, order=2), max_frames)
        mfcc = pad_or_truncate(mfcc, max_frames)
        delta_mfcc = pad_or_truncate(delta_mfcc, max_frames)
        delta_delta_mfcc = pad_or_truncate(delta_delta_mfcc, max_frames)
        # Stack features
        log_mel_image = np.stack([log_mel_spec, delta_mel, delta_delta_mel], axis=-1)
        mfcc_image = np.stack([mfcc, delta_mfcc, delta_delta_mfcc], axis=-1)
        return log_mel_image, mfcc_image
    except Exception as e:
        print(f"Error loading audio {audio_path}: {e}")
        raise

class ADReSSoModel(nn.Module):
    def __init__(self):
        super(ADReSSoModel, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(768 + 768, 2)

    def forward(self, log_mel_image, mfcc_image, text_features):
        vit_outputs = self.vit(pixel_values=log_mel_image)
        vit_features = vit_outputs.last_hidden_state[:, 0, :]
        combined_features = torch.cat([vit_features, text_features], dim=-1)
        output = self.fc(combined_features)
        return output

def train_model(train_df, val_df):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224').to(device)
    bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    train_dataset = ADReSSoDataset(train_df, vit_model, bert_model, tokenizer)
    val_dataset = ADReSSoDataset(val_df, vit_model, bert_model, tokenizer)
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=custom_collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=custom_collate_fn)

    model = ADReSSoModel().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(10):
        model.train()
        for batch in train_loader:
            log_mel_image = batch['log_mel_image'].to(device)
            mfcc_image = batch['mfcc_image'].to(device)
            text_features = batch['text_features'].to(device)
            labels = batch['label'].to(device)
            optimizer.zero_grad()
            outputs = model(log_mel_image, mfcc_image, text_features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in val_loader:
            log_mel_image = batch['log_mel_image'].to(device)
            mfcc_image = batch['mfcc_image'].to(device)
            text_features = batch['text_features'].to(device)
            labels = batch['label'].to(device)
            outputs = model(log_mel_image, mfcc_image, text_features)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f"Validation Accuracy: {accuracy}%")

    return model

try:
    train_df = pd.read_pickle(os.path.join(data_dir, 'train_df.pkl'))
    val_df = pd.read_pickle(os.path.join(data_dir, 'val_df.pkl'))
    test_df = pd.read_pickle(os.path.join(data_dir, 'test_df.pkl'))
    print("Loaded dataframes:", len(train_df), len(val_df), len(test_df))
except FileNotFoundError as e:
    print(f"Error loading pickle files: {e}")
    os.system(f"ls -l {data_dir}")
    raise

model = train_model(train_df, val_df)

Loaded dataframes: 107 59 32


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ValueError: Input image size (128*1000) doesn't match model (224*224).