## **Dementia Classification with CNN+BiLSTM**

### **Imports**


In [1]:
import torch
import torchaudio
import librosa
import matplotlib.pyplot as plt
import random
import os
import yaml
import torch.nn as nn
import scipy
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import parselmouth
from parselmouth.praat import call
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift, Gain
import wandb



### **Configurations**


In [21]:
%%writefile config.yaml

data_dir: "D:/2025/ADReSS-2020/New folder"
sr: 16000
chunk_length: 5.0
chunk_overlap: 2
n_mels: 70
augment_prob: 0.8
num_workers: 0
optimizer: "adamw"
weight_decay: 0.01
lr: 0.001
batch_size: 4
lr_scheduler: "cosine_w/restart"  #["cosine_w/restart", "ReduceOnPlateau"]
use_clinical_features: True
epochs: 20
features: "melbanks and clinical features"


Overwriting config.yaml


In [22]:
with open('config.yaml') as f:
    config = yaml.safe_load(f)

### **Time-Based Train/Val Split (80/20)**

In [4]:
def time_based_split(data_dir, train_ratio=0.8, seed=42):
    """
    Split audio files into training and validation sets based on total duration, preserving class balance.

    This function ensures that the split is not just by file count, but by total audio duration for each class,
    which is important for speech datasets where file lengths can vary significantly.

    Args:
        data_dir (str): Path to the root directory containing 'control' and 'dementia' subfolders with .wav files.
        train_ratio (float): Proportion of total duration to allocate to the training set (default: 0.8).
        seed (int): Random seed for reproducibility (default: 42).

    Returns:
        tuple: (train_set, val_set)
            train_set (list): List of file paths for training.
            val_set (list): List of file paths for validation.

    Example:
        train_files, val_files = time_based_split('data/', train_ratio=0.8)
    """
    
    
    random.seed(seed)
    
    # Collect files with durations per class
    class_files = {'control': [], 'dementia': []}
    for class_name in class_files:
        class_dir = os.path.join(data_dir, class_name)
        for file in os.listdir(class_dir):
            if file.endswith('.wav'):
                path = os.path.join(class_dir, file)
                duration = librosa.get_duration(path=path)
                class_files[class_name].append((path, duration))
    
    # Split each class separately
    train_set = []
    val_set = []
    
    for class_name, files in class_files.items():
        # Shuffle files while keeping duration info
        random.shuffle(files)
        
        # Calculate split points based on total duration
        total_duration = sum(d for _, d in files)
        train_duration = total_duration * train_ratio
        current_duration = 0
        
        # Distribute files to sets
        class_train = []
        class_val = []
        for path, duration in files:
            if current_duration + duration <= train_duration:
                class_train.append(path)
                current_duration += duration
            else:
                class_val.append(path)
        
        train_set.extend(class_train)
        val_set.extend(class_val)
    
    print(f"Train: {len(train_set)} files")
    print(f"Val: {len(val_set)} files")
    return train_set, val_set

In [5]:
print("Splitting Dataset.....")
train_files, val_files = time_based_split(config['data_dir'])

Splitting Dataset.....
Train: 85 files
Val: 23 files
Train: 85 files
Val: 23 files


In [6]:
# count files for dementia and control
dementia_files = [f for f in train_files if 'dementia' in f]
control_files = [f for f in train_files if 'control' in f]
total = len(dementia_files) + len(control_files)
print("Dataset Class Distribution........ ")
print(f"Dementia: {len(dementia_files)} ({len(dementia_files)/total:.2%})")
print(f"Control: {len(control_files)} ({len(control_files)/total:.2%})")


Dataset Class Distribution........ 
Dementia: 43 (50.59%)
Control: 42 (49.41%)


## **Clinical Feature Extraction**

In [7]:
class ClinicalFeatureExtractor:
    """
    Extracts clinical speech features from audio waveforms for dementia detection.

    This class provides methods to extract prosodic, voice quality, and temporal/fluency features using Parselmouth (Praat) and numpy.
    Features include F0 statistics, intensity, jitter, shimmer, HNR, spectral centroid, speech rate, pause statistics, and more.
    """
    def __init__(self, sr=16000):
        self.sr = sr
        
    def extract_prosodic_features(self, waveform):
        """
        Extract prosodic features (F0, intensity, etc.) from a waveform using Parselmouth.

        Args:
            waveform (np.ndarray or torch.Tensor): Audio waveform (mono).
        Returns:
            dict: Prosodic features including F0 mean, std, range, CV, slope, voiced frame ratio, intensity mean/std/range.
        """
        try:
            # Convert torch tensor to numpy if needed
            if hasattr(waveform, 'numpy'):
                audio_np = waveform.numpy()
            else:
                audio_np = waveform
                
            # Create Parselmouth Sound object
            sound = parselmouth.Sound(audio_np, sampling_frequency=self.sr)
            
            # F0 extraction
            pitch = sound.to_pitch(time_step=0.01, pitch_floor=75, pitch_ceiling=300)  # 75-300 Hz range for speech
            f0_values = pitch.selected_array['frequency']
            f0_values = f0_values[f0_values != 0] # Remove unvoiced frames
            #print(f0_values)
            
            # Intensity extraction
            intensity = sound.to_intensity(time_step=0.01, minimum_pitch=75.0)
            intensity_values = intensity.values[0]
            #print(f"intensity {intensity_values}")
            
            prosodic_features = {}
            
            if len(f0_values) > 0:
                prosodic_features.update({
                    'f0_mean': np.mean(f0_values),
                    'f0_std': np.std(f0_values),
                    'f0_range': np.max(f0_values) - np.min(f0_values),
                    'f0_cv': np.std(f0_values) / np.mean(f0_values) if np.mean(f0_values) > 0 else 0,
                    'f0_slope': self._calculate_f0_slope(f0_values),
                    'voiced_frames_ratio': len(f0_values) /pitch.get_number_of_frames()

                })
            else:
                prosodic_features.update({
                    'f0_mean': 0, 'f0_std': 0, 'f0_range': 0, 
                    'f0_cv': 0, 'f0_slope': 0, 'voiced_frames_ratio': 0
                })
            
            if len(intensity_values) > 0:
                prosodic_features.update({
                    'intensity_mean': np.mean(intensity_values),
                    'intensity_std': np.std(intensity_values),
                    'intensity_range': np.max(intensity_values) - np.min(intensity_values)
                })
            else:
                prosodic_features.update({
                    'intensity_mean': 0, 'intensity_std': 0, 'intensity_range': 0
                })
                
            return prosodic_features
            
        except Exception as e:
            print(f"Error in prosodic feature extraction: {e}")
            # Return zero features if extraction fails
            return {
                'f0_mean': 0, 'f0_std': 0, 'f0_range': 0, 'f0_cv': 0, 'f0_slope': 0,
                'voiced_frames_ratio': 0, 'intensity_mean': 0, 'intensity_std': 0, 'intensity_range': 0
            }
    
    def _calculate_f0_slope(self, f0_values):
        """
        Calculate the slope of F0 values using linear regression.

        Args:
            f0_values (np.ndarray): Array of F0 values.
        Returns:
            float: Slope of F0 contour.
        """
        if len(f0_values) < 2:
            return 0
        x = np.arange(len(f0_values))
        slope, _, _, _, _ = scipy.stats.linregress(x, f0_values)
        return slope
    
    def extract_voice_quality_features(self, waveform):
        """
        Extract voice quality features (jitter, shimmer, HNR, spectral centroid) from a waveform.

        Args:
            waveform (np.ndarray or torch.Tensor): Audio waveform (mono).
        Returns:
            dict: Voice quality features.
        """
        try:
            if hasattr(waveform, 'numpy'):
                audio_np = waveform.numpy()
            else:
                audio_np = waveform
                
            sound = parselmouth.Sound(audio_np, sampling_frequency=self.sr)
            
            # Jitter and Shimmer
            pointprocess = call(sound, "To PointProcess (periodic, cc)", 75, 300)
            jitter = call(pointprocess, "Get jitter (local)", 0, 0, 0.0001, 0.02, 1.3)
            shimmer = call([sound, pointprocess], "Get shimmer (local)", 0, 0, 0.0001, 0.02, 1.3, 1.6)
            
            # Harmonics-to-Noise Ratio
            harmonicity = call(sound, "To Harmonicity (cc)", 0.01, 75, 0.1, 1.0)
            hnr_mean = call(harmonicity, "Get mean", 0, 0)
            
            # Spectral measures
            spectrum = call(sound, "To Spectrum", "yes")
            spectral_centroid = call(spectrum, "Get centre of gravity", 2)
            
            return {
                'jitter': jitter if not np.isnan(jitter) else 0,
                'shimmer': shimmer if not np.isnan(shimmer) else 0,
                'hnr_mean': hnr_mean if not np.isnan(hnr_mean) else 0,
                'spectral_centroid': spectral_centroid if not np.isnan(spectral_centroid) else 0
            }
            
        except Exception as e:
            print(f"Error in voice quality feature extraction: {e}")
            return {'jitter': 0, 'shimmer': 0, 'hnr_mean': 0, 'spectral_centroid': 0}
    
    def extract_temporal_features(self, waveform):
        """
        Extract temporal and fluency features (speech rate, pause rate, etc.) from a waveform.

        Args:
            waveform (np.ndarray or torch.Tensor): Audio waveform (mono).
        Returns:
            dict: Temporal and fluency features.
        """
        try:
            if hasattr(waveform, 'numpy'):
                audio_np = waveform.numpy()
            else:
                audio_np = waveform
            
            # Voice activity detection (simple energy-based)
            frame_length = int(0.025 * self.sr)  # 25ms frames
            hop_length = int(0.01 * self.sr)    # 10ms hop
            
            # Calculate energy
            energy = []
            for i in range(0, len(audio_np) - frame_length, hop_length):
                frame = audio_np[i:i + frame_length]
                energy.append(np.sum(frame ** 2))
            
            energy = np.array(energy)
            threshold = np.percentile(energy, 30)  # Adaptive threshold
            voiced_frames = energy > threshold
            
            # Calculate pause statistics
            speech_segments = self._get_speech_segments(voiced_frames, hop_length)
            pause_segments = self._get_pause_segments(voiced_frames, hop_length)
            
            total_duration = len(audio_np) / self.sr
            total_speech_time = sum([seg[1] - seg[0] for seg in speech_segments])
            total_pause_time = sum([seg[1] - seg[0] for seg in pause_segments])
            
            return {
                'speech_rate': total_speech_time / total_duration if total_duration > 0 else 0,
                'pause_rate': len(pause_segments) / total_duration if total_duration > 0 else 0,
                'mean_pause_duration': np.mean([seg[1] - seg[0] for seg in pause_segments]) if pause_segments else 0,
                'speech_to_pause_ratio': total_speech_time / total_pause_time if total_pause_time > 0 else np.inf,
                'voiced_frame_ratio': np.sum(voiced_frames) / len(voiced_frames) if len(voiced_frames) > 0 else 0
            }
            
        except Exception as e:
            print(f"Error in temporal feature extraction: {e}")
            return {
                'speech_rate': 0, 'pause_rate': 0, 'mean_pause_duration': 0, 
                'speech_to_pause_ratio': 0, 'voiced_frame_ratio': 0
            }
    
    def _get_speech_segments(self, voiced_frames, hop_length):
        """
        Get continuous speech segments from voiced frame mask.

        Args:
            voiced_frames (np.ndarray): Boolean array indicating voiced frames.
            hop_length (int): Hop length in samples.
        Returns:
            list: List of (start, end) tuples for speech segments (in seconds).
        """
        segments = []
        start = None
        
        for i, is_voiced in enumerate(voiced_frames):
            if is_voiced and start is None:
                start = i * hop_length / self.sr
            elif not is_voiced and start is not None:
                end = i * hop_length / self.sr
                segments.append((start, end))
                start = None
                
        if start is not None:
            segments.append((start, len(voiced_frames) * hop_length / self.sr))
            
        return segments
    
    def _get_pause_segments(self, voiced_frames, hop_length):
        """
        Get pause segments from unvoiced frame mask.

        Args:
            voiced_frames (np.ndarray): Boolean array indicating voiced frames.
            hop_length (int): Hop length in samples.
        Returns:
            list: List of (start, end) tuples for pause segments (in seconds).
        """
        segments = []
        start = None
        
        for i, is_voiced in enumerate(voiced_frames):
            if not is_voiced and start is None:
                start = i * hop_length / self.sr
            elif is_voiced and start is not None:
                end = i * hop_length / self.sr
                if end - start > 0.1:  # Only count pauses longer than 100ms
                    segments.append((start, end))
                start = None
                
        if start is not None:
            end = len(voiced_frames) * hop_length / self.sr
            if end - start > 0.1:
                segments.append((start, end))
                
        return segments
    
    def extract_all_features(self, waveform):
        """
        Extract all clinical features (prosodic, voice quality, temporal) from a waveform and combine them into a single dictionary.

        Args:
            waveform (np.ndarray or torch.Tensor): Audio waveform (mono).
        Returns:
            dict: All extracted clinical features.
        """
        prosodic = self.extract_prosodic_features(waveform)
        voice_quality = self.extract_voice_quality_features(waveform)
        temporal = self.extract_temporal_features(waveform)
        
        # Combine all features
        all_features = {
            **prosodic, **voice_quality, **temporal
                        }
        return all_features

# Example usage
def test_feature_extraction():
    """Test the feature extractor with a sample"""
    extractor = ClinicalFeatureExtractor()
    
    # Load a sample file
    sample_file = random.choice(train_files)
    print(sample_file)
    waveform, sr = torchaudio.load(sample_file)
    waveform = torch.mean(waveform, dim=0)  # Convert to mono
    
    features = extractor.extract_all_features(waveform)
    
    print("Extracted clinical features:")
    for feature_name, value in features.items():
        print(f"{feature_name}: {value:.4f}")
    
    return features

# Uncomment to test
test_features = test_feature_extraction()

D:/2025/ADReSS-2020/New folder\control\S067.wav
Extracted clinical features:
f0_mean: 135.5611
f0_std: 28.5041
f0_range: 220.4983
f0_cv: 0.2103
f0_slope: -0.0304
voiced_frames_ratio: 0.3914
intensity_mean: 29.7116
intensity_std: 92.1761
intensity_range: 391.6471
jitter: 0.0198
shimmer: 0.1055
hnr_mean: 13.0375
spectral_centroid: 654.6001
speech_rate: 0.6996
pause_rate: 0.5857
mean_pause_duration: 0.4515
speech_to_pause_ratio: 2.6456
voiced_frame_ratio: 0.7000
Extracted clinical features:
f0_mean: 135.5611
f0_std: 28.5041
f0_range: 220.4983
f0_cv: 0.2103
f0_slope: -0.0304
voiced_frames_ratio: 0.3914
intensity_mean: 29.7116
intensity_std: 92.1761
intensity_range: 391.6471
jitter: 0.0198
shimmer: 0.1055
hnr_mean: 13.0375
spectral_centroid: 654.6001
speech_rate: 0.6996
pause_rate: 0.5857
mean_pause_duration: 0.4515
speech_to_pause_ratio: 2.6456
voiced_frame_ratio: 0.7000


### **Create Dataset**

In [8]:
class DementiaDataset(Dataset):
    """
    Unified PyTorch Dataset for dementia detection from audio, supporting acoustic and clinical features.

    This dataset loads audio files, applies waveform-level and spectrogram-level augmentations, extracts acoustic (FBank) and clinical features,
    normalizes clinical features, and splits acoustic features into overlapping chunks for training.

    Args:
        file_list (list): List of audio file paths.
        config (dict): Configuration dictionary with keys for sr, n_mels, chunk_length, chunk_overlap, etc.
        is_train (bool): If True, enables data augmentation and shuffling.
        use_clinical_features (bool): If True, extracts and returns clinical features.
        normalizer (ClinicalFeatureNormalizer or None): Normalizer for clinical features.
    """
    def __init__(self, file_list, config, is_train=False, use_clinical_features=True, normalizer=None):
        """
    Initialize the DementiaDataset.

    Args:
        file_list (list): List of audio file paths.
        config (dict): Configuration dictionary with keys for sr, n_mels, chunk_length, chunk_overlap, etc.
        is_train (bool): If True, enables data augmentation and shuffling.
        use_clinical_features (bool): If True, extracts and returns clinical features.
        normalizer (ClinicalFeatureNormalizer or None): Normalizer for clinical features.
        """
        
        self.files = file_list
        self.config = config
        self.is_train = is_train
        self.use_clinical_features = use_clinical_features
        self.sr = config['sr']
        self.n_mels = config['n_mels']
        self.chunk_frames = int(config['chunk_length'] * (self.sr / 160))
        self.overlap_frames = int(config['chunk_overlap'] * (self.sr / 160))
        self.normalizer = normalizer  # Pass an instance of ClinicalFeatureNormalizer or None

        if use_clinical_features:
            self.clinical_extractor = ClinicalFeatureExtractor(sr=self.sr)
        else:
            self.clinical_extractor = None

        # Audiomentations pipeline for waveform-level augmentation
        self.augment = Compose([
            AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
            TimeStretch(min_rate=0.8, max_rate=1.25, p=0.3),
            PitchShift(min_semitones=-2, max_semitones=2, p=0.3),
            Shift(min_shift=-0.2, max_shift=0.2, p=0.3),
            Gain(min_gain_db=-6, max_gain_db=6, p=0.3)
        ])

        # SpecAugment: time and frequency masking for spectrograms
        self.time_mask = torchaudio.transforms.TimeMasking(time_mask_param=30)
        self.freq_mask = torchaudio.transforms.FrequencyMasking(freq_mask_param=13)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        """
    Load and process a single audio file, returning acoustic and (optionally) clinical features and label.

    Args:
        idx (int): Index of the file to load.

    Returns:
        tuple: (acoustic_chunks, clinical_features, label) if clinical features enabled, else (acoustic_chunks, label)
        """
        path = self.files[idx]
        label = 0 if "control" in str(path) else 1

        # Load waveform and convert to mono
        waveform, sr = torchaudio.load(path)
        waveform = torch.mean(waveform, dim=0)
        if sr != self.sr:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)
            waveform = resampler(waveform)

        # Apply audiomentations (on numpy) if training
        if self.is_train:
            waveform_np = waveform.cpu().numpy()
            waveform_np = self.augment(samples=waveform_np, sample_rate=self.sr)
            waveform = torch.tensor(waveform_np, dtype=torch.float32)

        # Extract acoustic features (FBanks)
        fbank = torchaudio.compliance.kaldi.fbank(
            waveform.unsqueeze(0),
            num_mel_bins=self.n_mels,
            sample_frequency=self.sr
        )
        fbank = (fbank - fbank.mean(dim=0)) / (fbank.std(dim=0) + 1e-6)

        # Apply SpecAugment (time/freq masking) if training
        if self.is_train:
            fbank = self.time_mask(fbank)
            fbank = self.freq_mask(fbank)

        # Extract clinical features if enabled
        clinical_features = None
        if self.use_clinical_features:
            clinical_dict = self.clinical_extractor.extract_all_features(waveform)
            clinical_values = [float(clinical_dict[k]) if not (np.isnan(clinical_dict[k]) or np.isinf(clinical_dict[k])) else 0.0 for k in sorted(clinical_dict)]
            clinical_features = torch.tensor(clinical_values, dtype=torch.float32)
            # Normalize clinical features if normalizer is provided
            if self.normalizer is not None:
                clinical_features = self.normalizer.transform(clinical_features)

        # Split acoustic features into chunks
        chunks = []
        n_frames = fbank.shape[0]
        stride = self.chunk_frames - self.overlap_frames
        for start in range(0, n_frames, stride):
            end = start + self.chunk_frames
            chunk = fbank[start:end]
            if chunk.shape[0] < self.chunk_frames:
                pad_size = self.chunk_frames - chunk.shape[0]
                chunk = torch.nn.functional.pad(chunk, (0, 0, 0, pad_size))
            chunks.append(chunk)
        acoustic_chunks = torch.stack(chunks)

        if self.use_clinical_features:
            return acoustic_chunks, clinical_features, label
        else:
            return acoustic_chunks, label

In [9]:
def collate_fn(batch):
    if len(batch[0]) == 3:  # With clinical features
        all_acoustic_chunks = []
        all_clinical_features = []
        all_labels = []
        
        for acoustic_chunks, clinical_features, label in batch:
            all_acoustic_chunks.append(acoustic_chunks)
            # Repeat clinical features for each chunk
            num_chunks = len(acoustic_chunks)
            clinical_repeated = clinical_features.unsqueeze(0).repeat(num_chunks, 1)
            all_clinical_features.append(clinical_repeated)
            all_labels.extend([label] * num_chunks)
        
        # Concatenate
        all_acoustic_chunks = torch.cat(all_acoustic_chunks, dim=0)
        all_clinical_features = torch.cat(all_clinical_features, dim=0)
        
        return all_acoustic_chunks, all_clinical_features, torch.tensor(all_labels)
    
    else:  # Without clinical features (original behavior)
        all_chunks = []
        all_labels = []
        
        for chunks, label in batch:
            all_chunks.append(chunks)
            all_labels.extend([label] * len(chunks))
            
        all_chunks = torch.cat(all_chunks, dim=0)
        return all_chunks, torch.tensor(all_labels)

In [10]:
# Fit and apply normalization to clinical features

class ClinicalFeatureNormalizer:
    def __init__(self):
        self.mean = None
        self.std = None

    def fit(self, dataset):
        # Collect all clinical features from the dataset
        feats = []
        for i in range(len(dataset)):
            item = dataset[i]
            if len(item) == 3:
                _, clinical, _ = item
                feats.append(clinical.numpy())
        feats = np.stack(feats)
        self.mean = feats.mean(axis=0)
        self.std = feats.std(axis=0) + 1e-8  # avoid division by zero

    def transform(self, clinical_tensor):
        # Normalize a single tensor (1D or batched 2D)
        return (clinical_tensor - torch.tensor(self.mean, dtype=clinical_tensor.dtype)) / torch.tensor(self.std, dtype=clinical_tensor.dtype)

    def fit_transform(self, dataset):
        self.fit(dataset)
        return [self.transform(torch.tensor(f, dtype=torch.float32)) for f in dataset]


In [11]:
# 1. Create initial train and val datasets (without normalizer)
train_dataset = DementiaDataset(
    train_files, config, is_train=True, use_clinical_features=True, normalizer=None
)
val_dataset = DementiaDataset(
    val_files, config, is_train=False, use_clinical_features=True, normalizer=None
)

# 2. Fit the normalizer on the training set
normalizer = ClinicalFeatureNormalizer()
normalizer.fit(train_dataset)

# 3. Assign the fitted normalizer to both datasets
train_dataset.normalizer = normalizer
val_dataset.normalizer = normalizer

# 4. Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=config['num_workers'],
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config['num_workers'],
    collate_fn=collate_fn
)

print(f"Train batches: {len(train_loader)}")

Train batches: 22


In [12]:
train_loader = DataLoader(
            train_dataset,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=config['num_workers'],
            collate_fn=collate_fn
        )
        
val_loader = DataLoader(
            val_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=config['num_workers'],
            collate_fn=collate_fn
        )
        

In [13]:
for batch_idx, batch in enumerate(train_loader):
    acoustic_data, clinical_data, labels = batch
    print(f"Batch {batch_idx + 1}:")
    print(f"  Acoustic data shape: {acoustic_data.shape}")
    print(f"  Clinical data shape: {clinical_data.shape}")
    print(f"  Labels shape: {labels.shape}")
    print(f"  Sample clinical features: {clinical_data[0][:5]}...")  # Show first 5 features
    break

print("✓ Enhanced dataset integration successful!")

Batch 1:
  Acoustic data shape: torch.Size([54, 500, 70])
  Clinical data shape: torch.Size([54, 18])
  Labels shape: torch.Size([54])
  Sample clinical features: tensor([-0.7122, -0.2327,  0.3489,  0.4537, -0.8729])...
✓ Enhanced dataset integration successful!


### **Define Model**

In [14]:
class LockedDropout(nn.Module):
    def __init__(self, p=0.3):
        super().__init__()
        self.p = p
        self.mask = None
        
    def forward(self, x):
        if not self.training or self.p == 0:
            return x
            
        # Create mask if none exists or batch size changes
        if self.mask is None or self.mask.size(0) != x.size(0):
            # (batch_size, 1, hidden_size)
            self.mask = x.new_empty(x.size(0), 1, x.size(2), 
                          requires_grad=False).bernoulli_(1 - self.p) / (1 - self.p)
            
        return self.mask.expand_as(x) * x
    


In [15]:

# Enhanced Model Architecture with Feature Fusion
class EnhancedDementiaCNNBiLSTM(nn.Module):
    """
    Enhanced CNN-BiLSTM model for dementia detection from audio, supporting both acoustic and clinical features.

    This model processes acoustic features (e.g., Mel spectrograms) using a CNN followed by a BiLSTM,
    and optionally processes clinical features through a small MLP. The two feature types are fused
    before final classification.

    Args:
        use_clinical_features (bool): If True, expects and uses clinical features in addition to acoustic features.
        clinical_feature_dim (int): Number of clinical features per sample.

    Attributes:
        acoustic_cnn (nn.Sequential): CNN layers for acoustic feature extraction.
        acoustic_lstm (nn.LSTM): BiLSTM for temporal modeling of acoustic features.
        locked_dropout (LockedDropout): Dropout layer for regularization in LSTM.
        clinical_processor (nn.Sequential): MLP for processing clinical features (if enabled).
        fusion_layer (nn.Sequential): Layer for fusing acoustic and clinical features (if enabled).
        classifier (nn.Sequential): Final classifier for binary prediction.
    """
    def __init__(self, use_clinical_features=True, clinical_feature_dim=18):
        super(EnhancedDementiaCNNBiLSTM, self).__init__()
        self.use_clinical_features = use_clinical_features
        
        # Acoustic feature processing (CNN + BiLSTM)
        self.acoustic_cnn = nn.Sequential(
            nn.Conv1d(70, 128, kernel_size=5, padding=2),
            nn.Conv1d(128, 64, kernel_size=3, padding=1),
            nn.SiLU()
        )
        
        self.acoustic_lstm = nn.LSTM(
            input_size=64,
            hidden_size=32,
            num_layers=1,
            bidirectional=True
        )
        
        self.locked_dropout = LockedDropout(p=0.3)
        
        # Clinical feature processing
        if use_clinical_features:
            self.clinical_processor = nn.Sequential(
                nn.Linear(clinical_feature_dim, 32),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(32, 16),
                nn.ReLU()
            )
            
            # Feature fusion
            self.fusion_layer = nn.Sequential(
                nn.Linear(64 + 16, 64),  # acoustic_features + clinical_features
                nn.ReLU(),
                nn.Dropout(0.4)
            )
            
            # Final classifier
            self.classifier = nn.Sequential(
                nn.Linear(64, 32),
                nn.SiLU(),
                nn.Dropout(0.4),
                nn.Linear(32, 1)
            )
        else:
            # Original classifier for acoustic features only
            self.classifier = nn.Sequential(
                nn.Linear(64, 32),
                nn.SiLU(),
                nn.Dropout(0.4),
                nn.Linear(32, 1)
            )
    
    def forward(self, acoustic_input, clinical_input=None):
        """
        Forward pass for the model.

        Args:
            acoustic_input (torch.Tensor): Acoustic features of shape (batch_size, seq_len, n_mels).
            clinical_input (torch.Tensor, optional): Clinical features of shape (batch_size, clinical_feature_dim).

        Returns:
            torch.Tensor: Model output logits of shape (batch_size, 1).
        """
        # Process acoustic features
        # acoustic_input shape: (batch_size, seq_len, n_mels)
        batch_size = acoustic_input.size(0)
        x_acoustic = acoustic_input.permute(0, 2, 1)  # (batch, n_mels, seq_len)
        
        # CNN processing
        x_acoustic = self.acoustic_cnn(x_acoustic)  # (batch, 64, seq_len)
        x_acoustic = x_acoustic.permute(0, 2, 1)  # (batch, seq_len, 64)
        
        # BiLSTM processing
        x_acoustic = self.locked_dropout(x_acoustic)
        lstm_out, _ = self.acoustic_lstm(x_acoustic)  # (batch, seq_len, 64)
        
        # Temporal pooling for acoustic features
        acoustic_features = torch.mean(lstm_out, dim=1)  # (batch, 64)
        
        if self.use_clinical_features and clinical_input is not None:
            # Process clinical features
            clinical_features = self.clinical_processor(clinical_input)  # (batch, 16)
            
            # Feature fusion
            combined_features = torch.cat([acoustic_features, clinical_features], dim=1)  # (batch, 80)
            fused_features = self.fusion_layer(combined_features)  # (batch, 64)
            
            # Classification
            output = self.classifier(fused_features)
        else:
            # Use only acoustic features
            output = self.classifier(acoustic_features)
        
        return output

In [16]:
# Test a forward pass with one example from the train dataset
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = EnhancedDementiaCNNBiLSTM(use_clinical_features=True, clinical_feature_dim=18).to(device)
model.eval()

print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
for batch in train_loader:
    acoustic_data, clinical_data, labels = batch
    acoustic_data = acoustic_data.to(device)
    clinical_data = clinical_data.to(device)
    with torch.no_grad():
        logits = model(acoustic_data, clinical_data)
        probs = torch.sigmoid(logits)
    print("Probabilities for the positive class (dementia):")
    print(probs[:10].squeeze())  # Show first 10 probabilities
    print(f"Labels {labels[:10]}")
    break

Model parameters: 103089
Probabilities for the positive class (dementia):
tensor([0.5031, 0.5032, 0.5034, 0.5035, 0.5036, 0.5038, 0.5038, 0.5037, 0.5035,
        0.5032])
Labels tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
Probabilities for the positive class (dementia):
tensor([0.5031, 0.5032, 0.5034, 0.5035, 0.5036, 0.5038, 0.5038, 0.5037, 0.5035,
        0.5032])
Labels tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


### **Model Training**

In [17]:
#training functions
def train_epoch(model, loader, optimizer, criterion, device):
    """
    Train the model for one epoch.

    Args:
        model (nn.Module): The model to train.
        loader (DataLoader): DataLoader for the training data.
        optimizer (torch.optim.Optimizer): Optimizer for model parameters.
        criterion (nn.Module): Loss function.
        device (torch.device): Device to run the training on.

    Returns:
        tuple: Average loss and accuracy for the epoch.
    """
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    for batch in loader:
        if len(batch) == 3:  # With clinical features
            acoustic_inputs, clinical_inputs, labels = batch
            acoustic_inputs = acoustic_inputs.to(device)
            clinical_inputs = clinical_inputs.to(device)
            labels = labels.to(device).float()
            
            optimizer.zero_grad()
            outputs = model(acoustic_inputs, clinical_inputs)
        else:  # Without clinical features
            acoustic_inputs, labels = batch
            acoustic_inputs = acoustic_inputs.to(device)
            labels = labels.to(device).float()
            
            optimizer.zero_grad()
            outputs = model(acoustic_inputs)
        
        loss = criterion(outputs.squeeze(), labels)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Calculate metrics
        preds = torch.sigmoid(outputs).round().squeeze()
        correct = (preds == labels).sum().item()
        
        total_loss += loss.item() * acoustic_inputs.size(0)
        total_correct += correct
        total_samples += labels.size(0)
    
    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    return avg_loss, accuracy

def validate(model, loader, criterion, device):
    """
    Evaluate the model on the validation set.

    Args:
        model (nn.Module): The model to evaluate.
        loader (DataLoader): DataLoader for the validation data.
        criterion (nn.Module): Loss function.
        device (torch.device): Device to run the evaluation on.

    Returns:
        tuple: Average loss and accuracy for the validation set.
    """
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for batch in loader:
            if len(batch) == 3:  # With clinical features
                acoustic_inputs, clinical_inputs, labels = batch
                acoustic_inputs = acoustic_inputs.to(device)
                clinical_inputs = clinical_inputs.to(device)
                labels = labels.to(device).float()
                
                outputs = model(acoustic_inputs, clinical_inputs)
            else:  # Without clinical features
                acoustic_inputs, labels = batch
                acoustic_inputs = acoustic_inputs.to(device)
                labels = labels.to(device).float()
                
                outputs = model(acoustic_inputs)
            
            loss = criterion(outputs.squeeze(), labels)
            
            # Calculate metrics
            preds = torch.sigmoid(outputs).round().squeeze()
            correct = (preds == labels).sum().item()
            
            total_loss += loss.item() * acoustic_inputs.size(0)
            total_correct += correct
            total_samples += labels.size(0)
    
    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    return avg_loss, accuracy

In [23]:
wandb.login(key="9ad1085d3897224c96621e7681c301369498f986")
# wandb.init(
#     project="Dementia detection w CNN-BiLSTM",  # Change to your project name
#     config=config,
#     resume=True
# )

# resuming run

run = wandb.init(
    project="Dementia detection w CNN-BiLSTM",
    id="86611r5b",           # Explicit run ID
    resume="must",           # Ensures strict resume
    config=config            # Optional: only used if resuming with config
)


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\brian\_netrc
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\brian\_netrc


In [None]:
# Enhanced Training Pipeline with Clinical Features
# Updated to use the enhanced training functions that handle both acoustic and clinical features

def enhanced_train_model(model, train_loader, val_loader, optimizer, criterion, device, epochs, scheduler):
    """
    Train the enhanced model with clinical features, logging to wandb.

    Args:
        model (nn.Module): The model to train.
        train_loader (DataLoader): DataLoader for the training data.
        val_loader (DataLoader): DataLoader for the validation data.
        optimizer (torch.optim.Optimizer): Optimizer for model parameters.
        criterion (nn.Module): Loss function.
        device (torch.device): Device to run the training on.
        epochs (int): Number of training epochs.
        scheduler (torch.optim.lr_scheduler): Learning rate scheduler.

    Returns:
        float: Best validation accuracy achieved during training.
    """
    best_val_acc = 0.0
    
    for epoch in range(epochs):
        current_lr = optimizer.param_groups[0]['lr']
        
        # Training with enhanced function
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        
        # Validation with enhanced function
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        # Update learning rate
        if config["lr_scheduler"] == "ReduceOnPlateau":
            scheduler.step(val_loss)
        else: scheduler.step()
        
        # Print progress
        print(f'Epoch {epoch+1:02}/{epochs}:')
        print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}% | Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}% | LR: {current_lr:.7f}')
        print('-'*100)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"model saved with accuracy: {best_val_acc*100:.2f}%")
            
        # logging to wandb
        wandb.log({
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "val_loss": val_loss,
        "val_acc": val_acc,
        "learning_rate": current_lr
    })    
    print(f'Best Validation Accuracy: {best_val_acc*100:.2f}%')
    wandb.finish()
    return best_val_acc

# Setup enhanced training
print("Setting up training with clinical features...")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=1e-3)

if config["lr_scheduler"] == "cosine_w/restart":
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=config['epochs']//2, T_mult=2, eta_min=1e-4)
elif config["lr_scheduler"] == "ReduceOnPlateau":
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=4)
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

print("✓ Enhanced training setup complete")
print(f"✓ Model: CNN-BiLSTM with clinical features")
print(f"✓ Device: {device}")
print(f"✓ Learning rate: {config['lr']}")
print(f"✓ Batch size: {config['batch_size']}")

# Start enhanced training
print("\n" + "="*100)
print("Starting Training with Clinical Features")
print("="*100)

best_accuracy = enhanced_train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    epochs=config['epochs'],
    scheduler=scheduler
)

print(f"\n✓ Training completed with best accuracy: {best_accuracy*100:.2f}%")

#wandb.finish()