## **Inference Pipeline**

In [None]:
import torch
import torchaudio
import librosa
import matplotlib.pyplot as plt
import random
import os
import yaml
import torch.nn as nn
import scipy
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import parselmouth
from parselmouth.praat import call
from sklearn.metrics import confusion_matrix, classification_report
import wandb
import pandas as pd
# utils 
from utils import *



### **Configuration and Helper functions**

In [3]:
# Configuration dictionary matching training parameters
config = {
    'sr': 16000,
    'n_mels': 70,
    'chunk_length': 5.0,    # in seconds
    'chunk_overlap': 2.0    # in seconds
}

def load_audio(file_path, target_sr=config['sr']):
    """Load and preprocess audio file"""
    waveform, sample_rate = torchaudio.load(file_path)
    waveform = torch.mean(waveform, dim=0)
    if sample_rate != target_sr:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)
        waveform = resampler(waveform)
    return waveform

def extract_fbank(waveform, config):
    """Extract acoustic features (FBank)"""
    fbank = torchaudio.compliance.kaldi.fbank(
        waveform.unsqueeze(0),
        num_mel_bins=config['n_mels'],
        sample_frequency=config['sr']
    )
    fbank = (fbank - fbank.mean(dim=0)) / (fbank.std(dim=0) + 1e-6)
    return fbank

def extract_clinical_features(waveform):
    """Extract clinical features from waveform"""
    clinical_extractor = ClinicalFeatureExtractor()
    clinical_features_dict = clinical_extractor.extract_all_features(waveform)
    
    # Convert to tensor and handle any inf/nan values
    clinical_values = []
    for key in sorted(clinical_features_dict.keys()):  # Ensure consistent ordering
        value = clinical_features_dict[key]
        if np.isinf(value) or np.isnan(value):
            value = 0.0
        clinical_values.append(value)
    
    clinical_features = torch.tensor(clinical_values, dtype=torch.float32)
    return clinical_features

def visualize_fbank(fbank, title="FBank Visualization"):
    """Visualize acoustic features"""
    fbank_np = fbank.cpu().numpy().T
    plt.figure(figsize=(10, 4))
    plt.imshow(fbank_np, aspect='auto', origin='lower', interpolation='nearest')
    plt.title(title)
    plt.xlabel("Time Frame")
    plt.ylabel("Mel Bin")
    plt.colorbar(label='Amplitude')
    plt.tight_layout()
    plt.show()
    
def chunk_fbank(fbank, config):
    """Split acoustic features into chunks"""
    chunk_frames = int(config['chunk_length'] * (config['sr'] / 160))
    overlap_frames = int(config['chunk_overlap'] * (config['sr'] / 160))
    stride = chunk_frames - overlap_frames
    chunks = []
    n_frames = fbank.shape[0]
    for start in range(0, n_frames, stride):
        end = start + chunk_frames
        chunk = fbank[start:end]
        if chunk.shape[0] < chunk_frames:
            pad_size = chunk_frames - chunk.shape[0]
            chunk = F.pad(chunk, (0, 0, 0, pad_size))
        chunks.append(chunk)
    return torch.stack(chunks)

def run_enhanced_inference(model, chunks, clinical_features, device):
    """Run inference with both acoustic and clinical features"""
    model.eval()
    chunks = chunks.to(device)
    
    # Repeat clinical features for each chunk
    num_chunks = len(chunks)
    clinical_features_repeated = clinical_features.unsqueeze(0).repeat(num_chunks, 1).to(device)
    
    with torch.no_grad():
        outputs = model(chunks, clinical_features_repeated)
    
    avg_output = outputs.mean(dim=0)
    probability = torch.sigmoid(avg_output)
    return probability.item()

def enhanced_inference_pipeline(file_path, model, device, config, visualize=False):
    """Complete enhanced inference pipeline"""
    print(f"Processing: {file_path}")
    
    # Load audio
    waveform = load_audio(file_path, target_sr=config['sr'])
    print(f"✓ Audio loaded: {waveform.shape}")
    
    # Extract acoustic features
    fbank = extract_fbank(waveform, config)
    print(f"✓ Acoustic features extracted: {fbank.shape}")
    
    # Extract clinical features
    clinical_features = extract_clinical_features(waveform)
    print(f"✓ Clinical features extracted: {clinical_features.shape}")
    
    # Visualize if requested
    if visualize:
        visualize_fbank(fbank, title="Mel Spectrogram")
    
    # Chunk acoustic features
    chunks = chunk_fbank(fbank, config)
    print(f"✓ Audio chunked: {chunks.shape}")
    
    # Run enhanced inference
    prediction = run_enhanced_inference(model, chunks, clinical_features, device)
    print(f"✓ Prediction: {prediction*100:.2f}% dementia risk")
    
    return prediction



### **Load Model and Inference end to end**

In [4]:
def load_model(model_path, device, clinical_feature_dim=18):
    """
    Load the enhanced model with clinical features.
    Returns the model and a flag indicating if clinical features are enabled.
    """
    try:
        model = EnhancedDementiaCNNBiLSTM(use_clinical_features=True, clinical_feature_dim=clinical_feature_dim).to(device)
        checkpoint = torch.load(model_path, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint)
        print("✓ Enhanced model with clinical features loaded successfully")
        use_clinical_features = True
    except FileNotFoundError:
        print("⚠ Model not found")        
    return model, use_clinical_features

def print_clinical_features(file_path, config):
    """
    Print extracted clinical features for a given audio file.
    """
    waveform = load_audio(file_path, target_sr=config['sr'])
    clinical_extractor = ClinicalFeatureExtractor()
    clinical_dict = clinical_extractor.extract_all_features(waveform)
    print("\nClinical Features:")
    for feature_name, value in clinical_dict.items():
        print(f"  {feature_name}: {value:.4f}")

def interpret_prediction(prediction):
    """
    Print risk level interpretation based on prediction probability.
    """
    if prediction > 0.7:
        risk_level = "HIGH"
        color = "🔴"
    elif prediction > 0.5:
        risk_level = "MODERATE"
        color = "🟡"
    else:
        risk_level = "LOW"
        color = "🟢"
    print(f"{color} Risk Level: {risk_level}")

def run_inference_example(file_path, model, device, config, show_features=False):
    """
    Run inference with the loaded model and print results.
    """
    print(f"\n{'='*60}")
    print(f"Running inference on: {file_path}")
    print(f"{'='*60}")

    prediction = enhanced_inference_pipeline(file_path, model, device, config, visualize=False)

    if show_features:
        print_clinical_features(file_path, config)

    print(f"\n🔍 Final Prediction: {prediction*100:.2f}% dementia risk")
    interpret_prediction(prediction)
    print(f"{'='*60}")
    return prediction


In [5]:

# --- Main Inference Flow ---
file_path = "D:/2025/ADReSS-2020/kin-keeper-audios/audio_20241215-185828.wav"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model_path = 'mybest_model.pth'
model, use_clinical_features = load_model(model_path, device, clinical_feature_dim=18)

print("✓ Model loading complete!")
run_inference_example(file_path, model, device, config, show_features=False)

Using device: cpu
✓ Enhanced model with clinical features loaded successfully
✓ Model loading complete!

Running inference on: D:/2025/ADReSS-2020/kin-keeper-audios/audio_20241215-185828.wav
Processing: D:/2025/ADReSS-2020/kin-keeper-audios/audio_20241215-185828.wav
✓ Audio loaded: torch.Size([799680])
✓ Acoustic features extracted: torch.Size([4996, 70])
✓ Clinical features extracted: torch.Size([18])
✓ Audio chunked: torch.Size([17, 500, 70])
✓ Prediction: 34.02% dementia risk

🔍 Final Prediction: 34.02% dementia risk
🟢 Risk Level: LOW


0.3401888608932495

In [6]:
file_path = "D:/2025/ADReSS-2020/kin-keeper-audios/audio_20250207-074612.wav"
run_inference_example(file_path, model, device, config, show_features=False)


Running inference on: D:/2025/ADReSS-2020/kin-keeper-audios/audio_20250207-074612.wav
Processing: D:/2025/ADReSS-2020/kin-keeper-audios/audio_20250207-074612.wav
✓ Audio loaded: torch.Size([581632])
✓ Acoustic features extracted: torch.Size([3633, 70])
✓ Clinical features extracted: torch.Size([18])
✓ Audio chunked: torch.Size([13, 500, 70])
✓ Prediction: 46.01% dementia risk

🔍 Final Prediction: 46.01% dementia risk
🟢 Risk Level: LOW


0.46014609932899475

### **Testing on ADReSSo Test Set**

### **BiLSTM with mel-banks + extracted feautures**

In [17]:
## downaload test data from the link 
!curl -o test_results.txt https://luzs.gitlab.io/adress/meta_data_test.txt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100  1134  100  1134    0     0    558      0  0:00:02  0:00:02 --:--:--   558
100  1134  100  1134    0     0    558      0  0:00:02  0:00:02 --:--:--   558


In [7]:
import pandas as pd
df = pd.read_csv('test_results.txt', sep=';')
df.columns = [col.strip() for col in df.columns]
df.head()


Unnamed: 0,ID,age,gender,Label,mmse
0,S160,63,1,0,28
1,S161,55,1,0,29
2,S162,67,1,1,24
3,S163,71,0,0,30
4,S164,73,1,1,21


In [39]:
import os
import torch
import torchaudio
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt

# ----- Enhanced Inference Function -----
def run_enhanced_inference(model, chunks, clinical_features, device, threshold=0.5):
    model.eval()
    chunks = chunks.to(device)
    num_chunks = len(chunks)
    clinical_features_repeated = clinical_features.unsqueeze(0).repeat(num_chunks, 1).to(device)
    with torch.no_grad():
        outputs = model(chunks, clinical_features_repeated)  # shape: (num_chunks, 1)
    avg_output = outputs.mean(dim=0)
    probability = torch.sigmoid(avg_output).item()
    pred_label = 1 if probability >= threshold else 0
    return probability, pred_label

# ----- Ground Truth Loading -----
def load_ground_truth(gt_path):
    df = pd.read_csv(gt_path, sep=';', engine='python')
    df.columns = [col.strip() for col in df.columns]
    return df

# ----- Pipeline for a Single Audio File -----
def process_file(model, file_path, device, config):
    waveform = load_audio(file_path, target_sr=config['sr'])
    fbank = extract_fbank(waveform, config)
    clinical_features = extract_clinical_features(waveform)
    chunks = chunk_fbank(fbank, config)
    probability, pred_label = run_enhanced_inference(model, chunks, clinical_features, device)
    return probability, pred_label

# ----- Main Pipeline -----
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Load enhanced model with clinical features
    model = EnhancedDementiaCNNBiLSTM(use_clinical_features=True, clinical_feature_dim=18).to(device)
    checkpoint = torch.load('mybest_model.pth', map_location=device, weights_only=False)
    model.load_state_dict(checkpoint)
    model.eval()

    gt_file = 'test_results.txt'
    gt_df = load_ground_truth(gt_file)
    test_audio_dir = "D:/2025/ADReSS-2020/ADReSS-IS2020-test/ADReSS-IS2020-data/test/Full_wave_enhanced_audio"
    results = []

    for idx, row in gt_df.iterrows():
        file_id = row['ID'].strip()
        gt_label = int(row['Label'])
        audio_file = os.path.join(test_audio_dir, f"{file_id}.wav")
        if not os.path.exists(audio_file):
            print(f"File {audio_file} not found, skipping.")
            continue
        prob, pred_label = process_file(model, audio_file, device, config)
        results.append({
            'ID': file_id,
            'GroundTruth': gt_label,
            'PredictedLabel': pred_label,
            'PredictedProbability': prob
        })
        #print(f"Processed {file_id}: GT={gt_label} Pred={pred_label} (prob={prob:.4f})")

    results_df = pd.DataFrame(results)
    results_df.to_csv("bilstm-predictions.csv", index=False)
    print("Predictions saved to predictions.csv")

    accuracy = (results_df['GroundTruth'] == results_df['PredictedLabel']).mean()
    print(f"Overall accuracy: {accuracy*100:.2f}%")

Predictions saved to predictions.csv
Overall accuracy: 77.08%


### **Andy Model**

In [15]:
import os
import math
import librosa
import numpy as np
import tensorflow as tf
import logging
import pandas as pd

logger = logging.getLogger(__name__)

def extract_mfcc_from_file(audio_file_path, num_mfcc=13, n_fft=2048, hop_length=512, num_segments=10):
    """
    Extracts MFCC segments from a single audio file.
    Assumes the audio file is 30 seconds long.
    
    :param audio_file_path: Path to the audio file.
    :param num_mfcc: Number of MFCC coefficients.
    :param n_fft: FFT window size.
    :param hop_length: Hop length for the FFT.
    :param num_segments: Number of segments to divide the track into.
    :return: List of MFCC segments.
    """
    SAMPLE_RATE = 22050
    TRACK_DURATION = 30  # seconds
    SAMPLES_PER_TRACK = SAMPLE_RATE * TRACK_DURATION

    # Load the audio file
    signal, sample_rate = librosa.load(audio_file_path, sr=SAMPLE_RATE)
    
    # Pad or trim the signal to exactly 30 seconds
    if len(signal) < SAMPLES_PER_TRACK:
        signal = np.pad(signal, (0, SAMPLES_PER_TRACK - len(signal)), mode='constant')
    else:
        signal = signal[:SAMPLES_PER_TRACK]

    samples_per_segment = int(SAMPLES_PER_TRACK / num_segments)
    num_mfcc_vectors_per_segment = math.ceil(samples_per_segment / hop_length)
    
    mfcc_segments = []
    
    # Process each segment
    for d in range(num_segments):
        start = samples_per_segment * d
        finish = start + samples_per_segment

        mfcc = librosa.feature.mfcc(y=signal[start:finish],
                                    sr=sample_rate,
                                    n_mfcc=num_mfcc,
                                    n_fft=n_fft,
                                    hop_length=hop_length)
        mfcc = mfcc.T  # shape: (time, num_mfcc)
        if len(mfcc) == num_mfcc_vectors_per_segment:
            mfcc_segments.append(mfcc)
    
    return mfcc_segments

def predict(interpreter, X, input_details, output_details):
    """
    Runs inference on a single MFCC segment and returns the prediction probability vector.
    """
    X = X[np.newaxis, ...].astype(np.float32)
    interpreter.set_tensor(input_details[0]['index'], X)
    interpreter.invoke()
    prediction = interpreter.get_tensor(output_details[0]['index'])
    return prediction[0]

def inference_on_file(audio_file_path, model_path, num_segments=10):
    """
    Performs inference on a single audio file by:
      - Extracting MFCC segments,
      - Running each through the model,
      - Averaging the prediction probabilities, and
      - Returning the predicted class and its probability.
      
    :return: Tuple (predicted_class, predicted_probability)
    """
    mfcc_segments = extract_mfcc_from_file(audio_file_path, num_segments=num_segments)
    if not mfcc_segments:
        logger.error("No MFCC data extracted from the audio file: {}".format(audio_file_path))
        return None

    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    predictions_list = []
    for segment in mfcc_segments:
        proba = predict(interpreter, segment, input_details, output_details)
        predictions_list.append(proba)

    avg_prediction = np.mean(predictions_list, axis=0)
    predicted_index = int(np.argmax(avg_prediction))
    mapping = {0: "dementia", 1: "control"}
    predicted_class = mapping.get(predicted_index, "Unknown")
    
    # Calculate the probability for the predicted class (in percentage)
    predicted_probability = avg_prediction[predicted_index] * 100
    
    return predicted_class, predicted_probability

def evaluate_folder(test_folder, model_path):
    """
    Loops over all .mp3 files in the test folder, performs inference on each,
    and saves a CSV file with the filename, predicted class, and predicted probability.
    
    :param test_folder: Folder containing the test audio files.
    :param model_path: Path to the TFLite model.
    :return: DataFrame with columns: filename, predicted_class, predicted_probability.
    """
    results = []
    
    for file in os.listdir(test_folder):
        if file.endswith('.wav'):
            file_path = os.path.join(test_folder, file)
            prediction = inference_on_file(file_path, model_path, num_segments=10)
            if prediction is None:
                continue
            predicted_class, predicted_probability = prediction
            results.append({
                'filename': file,
                'predicted_class': predicted_class,
                'predicted_probability': round(predicted_probability, 2)
            })
    
    df_results = pd.DataFrame(results)
    df_results.to_csv("andy-predictions.csv", index=False)
    return df_results

def main():
    test_folder = "D:/2025/ADReSS-2020/ADReSS-IS2020-test/ADReSS-IS2020-data/test/Full_wave_enhanced_audio"  # Update with your folder containing test files
    model_path = "models/model.tflite"  #
    
    df_results = evaluate_folder(test_folder, model_path)
    #print(df_results)
    


    # Load test results and preprocess
    df = pd.read_csv('test_results.txt', sep=';')
    df.columns = df.columns.str.strip()
    df['Label'] = df['Label'].replace({0: 'control', 1: 'dementia'})
    df['ID'] = df['ID'].str.strip()

    # Load predictions and preprocess
    df_predictions = pd.read_csv('andy-predictions.csv')
    df_predictions.rename(columns={'filename': 'ID'}, inplace=True)
    df_predictions['ID'] = df_predictions['ID'].str.replace('.wav', '', regex=False).str.strip()

    # Merge DataFrames on ID
    df_merged = pd.merge(df, df_predictions, on='ID', how='inner')

    # Compute and print accuracy
    accuracy = (df_merged["Label"] == df_merged["predicted_class"]).mean() * 100
    print("="*60)
    print("Model Accuracy: {:.2f}%".format(accuracy))




if __name__ == "__main__":
    main()


    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


Model Accuracy: 62.50%


### **Error Analysis**


In [50]:
df_test = df_test = pd.read_csv('test_results.txt', sep=';')
df_test.columns
df_test['Label '].value_counts()

Label 
0    24
1    24
Name: count, dtype: int64

#### BiLSTM Model

In [53]:
df_pred = pd.read_csv("bilstm-predictions.csv")
y_true = df_pred['GroundTruth']
y_pred = df_pred['PredictedLabel']

print("="*60)

print("Confusion Matrix:")
print(confusion_matrix(y_true, y_pred))

print("\nClassification Report:")
print(classification_report(y_true, y_pred))



Confusion Matrix:
[[19  5]
 [ 6 18]]

Classification Report:
              precision    recall  f1-score   support

           0       0.76      0.79      0.78        24
           1       0.78      0.75      0.77        24

    accuracy                           0.77        48
   macro avg       0.77      0.77      0.77        48
weighted avg       0.77      0.77      0.77        48



#### Andy Model

In [54]:
print("="*60)
# Load and preprocess test results
df_test = pd.read_csv('D:/2025/ADReSS-2020/test_results.txt', sep=';')
df_test.columns = df_test.columns.str.strip()
df_test['Label'] = df_test['Label'].replace({0: 'control', 1: 'dementia'})
df_test['ID'] = df_test['ID'].str.strip()

# Load and preprocess predictions
df_predictions = pd.read_csv('andy-predictions.csv')
df_predictions.rename(columns={'filename': 'ID'}, inplace=True)
df_predictions['ID'] = df_predictions['ID'].str.replace('.wav', '', regex=False).str.strip()

# Merge on the ID column
df_merged = pd.merge(df_test, df_predictions, on='ID', how='inner')

# Compute the confusion matrix and classification report
y_true = df_merged['Label']
y_pred = df_merged['predicted_class']

print("Confusion Matrix:")
print(confusion_matrix(y_true, y_pred))

print("\nClassification Report:")
print(classification_report(y_true, y_pred))



Confusion Matrix:
[[23  1]
 [17  7]]

Classification Report:
              precision    recall  f1-score   support

     control       0.57      0.96      0.72        24
    dementia       0.88      0.29      0.44        24

    accuracy                           0.62        48
   macro avg       0.72      0.62      0.58        48
weighted avg       0.72      0.62      0.58        48

