In [None]:
import os
import pickle
import random
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import (
    roc_curve, auc, accuracy_score, precision_score, recall_score,
    f1_score, balanced_accuracy_score, confusion_matrix
)
from transformers import ASTFeatureExtractor, ASTForAudioClassification
from collections import defaultdict
from google.colab import drive
import seaborn as sns
import pandas as pd

# Mount Google Drive
drive.mount('/content/drive')

tesi_path = '/content/drive/My Drive/TESI'
os.chdir(tesi_path)
print("Current working directory:", os.getcwd())

file_path = '/content/drive/MyDrive/TESI/newdata_updated.pkl'
with open(file_path, 'rb') as f:
    newdata = pickle.load(f)

print("Pickle file loaded successfully.")
print(f"Total samples in dataset: {len(newdata)}")

sampling_rate = 44100

audio_lengths_sec = [len(item['audio']) / sampling_rate for item in newdata]

plt.figure(figsize=(10, 6))
plt.hist(audio_lengths_sec, bins=50, color='blue', alpha=0.7)
plt.title('Distribution of Audio Lengths (Before Processing)')
plt.xlabel('Audio Length (seconds)')
plt.ylabel('Count')
plt.show()

print(f"Number of audios: {len(audio_lengths_sec)}")
print(f"Minimum audio length: {min(audio_lengths_sec):.2f} seconds")
print(f"Maximum audio length: {max(audio_lengths_sec):.2f} seconds")

def print_updrs_distribution(data, label_key='label', updrs_keys=['updrs', 'UPDRS']):
    """
    Prints the distribution of data based on UPDRS levels and labels.
    """
    updrs_counts = defaultdict(int)
    control_count, parkinsonian_count = 0, 0

    for item in data:
        if label_key in item:
            label = item[label_key]
            if label == 0:
                control_count += 1
            elif label == 1:
                parkinsonian_count += 1

        updrs_value = None
        for key in updrs_keys:
            if key in item:
                updrs_value = item[key]
                break

        if updrs_value is not None:
            updrs_counts[updrs_value] += 1

    print(f"Number of controls: {control_count}")
    print(f"Number of Parkinsonians: {parkinsonian_count}")
    print("UPDRS Distribution:")
    for updrs_value, count in sorted(updrs_counts.items()):
        print(f"  UPDRS {updrs_value}: {count}")




In [None]:
# Remove outliers based on audio length using IQR
q1 = np.percentile(audio_lengths, 25)
q3 = np.percentile(audio_lengths, 75)
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr

filtered_data = [item for item in newdata if lower_bound <= len(item['audio']) <= upper_bound]
filtered_audio_lengths_sec = [len(item['audio']) / sampling_rate for item in filtered_data]

plt.figure(figsize=(10, 6))
plt.hist(filtered_audio_lengths_sec, bins=50, color='green', alpha=0.7)
plt.title('Distribution of Audio Lengths (After Removing Outliers)')
plt.xlabel('Audio Length (seconds)')
plt.ylabel('Count')
plt.show()

print(f"Number of audios after removing outliers: {len(filtered_audio_lengths_sec)}")
print(f"Min audio length after filtering: {min(filtered_audio_lengths_sec):.2f} seconds")
print(f"Max audio length after filtering: {max(filtered_audio_lengths_sec):.2f} seconds")

min_length = min([len(item['audio']) for item in filtered_data])
print(f"Shortest audio length: {min_length / sampling_rate:.2f} seconds ({min_length} samples)")

print("Distribution BEFORE outlier removal:")
print_updrs_distribution(newdata)
print("Distribution AFTER outlier removal:")
print_updrs_distribution(filtered_data)

newdata = filtered_data

min_length = min([len(item['audio']) for item in newdata])
for item in newdata:
    start = (len(item['audio']) - min_length) // 2
    end = start + min_length
    item['audio'] = item['audio'][start:end]

for item in newdata:
    audio = np.array(item['audio'])
    item['audio'] = audio / np.max(np.abs(audio))

print(f"Example audio shape: {newdata[0]['audio'].shape}")

feature_extractor = ASTFeatureExtractor.from_pretrained(
    'MIT/ast-finetuned-audioset-10-10-0.4593',
    sampling_rate=16000,
    return_attention_mask=False
)

class AudioDataset(Dataset):
    def __init__(self, data, extractor, dropout_rate=0.15):
        self.data = data
        self.extractor = extractor
        self.dropout = nn.Dropout(p=dropout_rate)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        audio = sample['audio']
        label = sample['label']

        updrs_value = sample.get('updrs', sample.get('UPDRS', -1))
        metadata = {'updrs': updrs_value}

        inputs = self.extractor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
        input_values = self.dropout(inputs['input_values'].squeeze(0))

        return input_values, label, metadata

def stratified_group_split(all_ids, grouped_by_id, label_key='label'):
    controls = [id_ for id_ in all_ids if grouped_by_id[id_][0][label_key] == 0]
    parkinsons = [id_ for id_ in all_ids if grouped_by_id[id_][0][label_key] == 1]

    random.shuffle(controls)
    random.shuffle(parkinsons)

    split_controls = len(controls) // 5
    split_parkinsons = len(parkinsons) // 5

    folds = []
    for i in range(5):
        fold_controls = controls[i * split_controls:(i + 1) * split_controls]
        fold_parkinsons = parkinsons[i * split_parkinsons:(i + 1) * split_parkinsons]
        folds.append(fold_controls + fold_parkinsons)

    return folds

def analyze_updrs(val_loader, model, device):
    updrs_results = {0: [], 1: [], 2: [], 3: [], 4: []}

    with torch.no_grad():
        for inputs, labels, metadata in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs).logits

            probabilities = torch.softmax(outputs, dim=1)[:, 1]
            logits = outputs[:, 1]

            for i, prob in enumerate(probabilities.cpu().numpy()):
                logit = logits[i].item()
                updrs_value = metadata.get('updrs', None)
                if updrs_value is None:
                    continue
                if isinstance(updrs_value, torch.Tensor):
                    updrs_value = updrs_value[i].item()
                else:
                    updrs_value = int(updrs_value[i])

                if updrs_value == -1:
                    continue

                updrs_results[updrs_value].append((logit, prob, labels[i].item()))

    return updrs_results


In [None]:
def aggregate_updrs_results(fold_results):
    """
    Aggregates UPDRS results across all folds.
    """
    aggregated_results = {0: [], 1: [], 2: [], 3: [], 4: []}
    probabilities_by_updrs = {0: [], 1: [], 2: [], 3: [], 4: []}

    for fold in fold_results:
        updrs_results = fold['updrs_results']
        for level, values in updrs_results.items():
            aggregated_results[level].extend(values)
            probabilities_by_updrs[level].extend([prob for _, prob, _ in values])

    metrics = {}
    for level, results in aggregated_results.items():
        if results:
            logits, probs, true_labels = zip(*results)
            mean_prob = np.mean(probs)
            mean_logit = np.mean(logits)
            total_count = len(results)
            classified_as_parkinsonian = sum(1 for prob, label in zip(probs, true_labels) if prob >= 0.5 and label == 1)
            percentage_classified_as_parkinsonian = (classified_as_parkinsonian / total_count) * 100

            metrics[level] = {
                'total_count': total_count,
                'mean_probability': mean_prob,
                'mean_logit': mean_logit,
                'percentage_classified_as_parkinsonian': percentage_classified_as_parkinsonian,
            }
        else:
            metrics[level] = {
                'total_count': 0,
                'mean_probability': 0.0,
                'mean_logit': 0.0,
                'percentage_classified_as_parkinsonian': 0.0,
            }

    return metrics, probabilities_by_updrs


def compute_confusion_matrix_metrics(labels, preds, fold_num):
    """
    Computes confusion matrix, sensitivity, specificity, and balanced accuracy.
    """
    cm = confusion_matrix(labels, preds)
    tn, fp, fn, tp = cm.ravel()

    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    balanced_acc = (sensitivity + specificity) / 2

    print(f"\nConfusion Matrix for Fold {fold_num}:")
    print(cm)
    print(f"Sensitivity: {sensitivity:.4f}, Specificity: {specificity:.4f}, Balanced Accuracy: {balanced_acc:.4f}")

    return sensitivity, specificity, balanced_acc


In [None]:

num_epochs = 50
early_stopping_patience = 5
learning_rate = 1e-3
gamma = 0.995

fold_results = []

criterion = nn.CrossEntropyLoss()

# Group data by subject ID
grouped_by_id = defaultdict(list)
for item in newdata:
    grouped_by_id[item['id']].append(item)

all_ids = list(grouped_by_id.keys())
folds = stratified_group_split(all_ids, grouped_by_id)

for fold in range(5):
    print(f"Processing Fold {fold+1}/5...")

    train_ids = [id_ for id_ in all_ids if id_ not in folds[fold]]
    val_ids = folds[fold]

    train_samples = [item for id_s in train_ids for item in grouped_by_id[id_s]]
    val_samples = [item for id_s in val_ids for item in grouped_by_id[id_s]]

    print(f"Training set size: {len(train_samples)}, Validation set size: {len(val_samples)}")

    train_dataset = AudioDataset(train_samples, feature_extractor)
    val_dataset = AudioDataset(val_samples, feature_extractor)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

    model = ASTForAudioClassification.from_pretrained(
        'MIT/ast-finetuned-audioset-10-10-0.4593',
        num_labels=2,
        ignore_mismatched_sizes=True
    )

    optimizer = optim.AdamW(model.classifier.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

    # Freeze all layers except last 2 encoder layers and classifier
    for param in model.parameters():
        param.requires_grad = False
    for name, param in model.named_parameters():
        if 'encoder.layer.11' in name or 'encoder.layer.10' in name or 'classifier' in name:
            param.requires_grad = True

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    best_val_loss = float('inf')
    patience_counter = 0

    epoch_train_accuracies, epoch_val_accuracies = [], []
    epoch_train_precisions, epoch_val_precisions = [], []
    epoch_train_recalls, epoch_val_recalls = [], []
    epoch_train_f1s, epoch_val_f1s = [], []
    epoch_train_losses, epoch_val_losses = [], []

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs} for Fold {fold+1}")

        # Training
        model.train()
        train_loss, correct_train = 0.0, 0
        train_preds, train_true = [], []

        for inputs, labels, metadata in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * labels.size(0)
            preds = torch.argmax(outputs, dim=1)
            correct_train += torch.sum(preds == labels).item()
            train_preds.extend(preds.cpu().numpy())
            train_true.extend(labels.cpu().numpy())

        train_loss /= len(train_loader.dataset)
        train_accuracy = correct_train / len(train_loader.dataset)
        train_precision = precision_score(train_true, train_preds, zero_division=0)
        train_recall = recall_score(train_true, train_preds)
        train_f1 = f1_score(train_true, train_preds)

        # Validation
        model.eval()
        val_loss, correct_val = 0.0, 0
        val_preds, val_true = [], []

        with torch.no_grad():
            for inputs, labels, metadata in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs).logits
                loss = criterion(outputs, labels)
                val_loss += loss.item() * labels.size(0)
                preds = torch.argmax(outputs, dim=1)
                correct_val += torch.sum(preds == labels).item()
                val_preds.extend(preds.cpu().numpy())
                val_true.extend(labels.cpu().numpy())

        val_loss /= len(val_loader.dataset)
        val_accuracy = correct_val / len(val_loader.dataset)
        val_precision = precision_score(val_true, val_preds, zero_division=0)
        val_recall = recall_score(val_true, val_preds)
        val_f1 = f1_score(val_true, val_preds)

        true_negatives = np.sum((np.array(val_true) == 0) & (np.array(val_preds) == 0))
        false_positives = np.sum((np.array(val_true) == 0) & (np.array(val_preds) == 1))
        val_specificity = true_negatives / (true_negatives + false_positives)
        val_sensitivity = val_recall

        updrs_results = analyze_updrs(val_loader, model, device)

        # Store epoch metrics
        epoch_train_accuracies.append(train_accuracy)
        epoch_val_accuracies.append(val_accuracy)
        epoch_train_precisions.append(train_precision)
        epoch_val_precisions.append(val_precision)
        epoch_train_recalls.append(train_recall)
        epoch_val_recalls.append(val_recall)
        epoch_train_f1s.append(train_f1)
        epoch_val_f1s.append(val_f1)
        epoch_train_losses.append(train_loss)
        epoch_val_losses.append(val_loss)

        print(f"Fold {fold+1}, Epoch {epoch+1} - Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), f'best_model_fold_{fold+1}.pt')
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

        scheduler.step()

    # Store fold results
    fold_results.append({
        'train_accuracy': epoch_train_accuracies,
        'val_accuracy': epoch_val_accuracies,
        'train_precision': epoch_train_precisions,
        'val_precision': epoch_val_precisions,
        'train_recall': epoch_train_recalls,
        'val_recall': epoch_val_recalls,
        'train_f1': epoch_train_f1s,
        'val_f1': epoch_val_f1s,
        'train_loss': epoch_train_losses,
        'val_loss': epoch_val_losses,
        'val_sensitivity': val_sensitivity,
        'val_specificity': val_specificity,
        'val_preds': val_preds,
        'val_true': val_true,
        'updrs_results': updrs_results
    })

# Compute final averages (last epoch of each fold)
average_train_accuracy = np.mean([r['train_accuracy'][-1] for r in fold_results])
average_val_accuracy = np.mean([r['val_accuracy'][-1] for r in fold_results])
average_train_precision = np.mean([r['train_precision'][-1] for r in fold_results])
average_val_precision = np.mean([r['val_precision'][-1] for r in fold_results])
average_train_recall = np.mean([r['train_recall'][-1] for r in fold_results])
average_val_recall = np.mean([r['val_recall'][-1] for r in fold_results])
average_train_f1 = np.mean([r['train_f1'][-1] for r in fold_results])
average_val_f1 = np.mean([r['val_f1'][-1] for r in fold_results])
average_val_sensitivity = np.mean([r['val_sensitivity'] for r in fold_results])
average_val_specificity = np.mean([r['val_specificity'] for r in fold_results])

print("\n===== Average Metrics Across Folds (Last Epoch Only) =====")
print(f"Training Accuracy: {average_train_accuracy:.4f}")
print(f"Validation Accuracy: {average_val_accuracy:.4f}")
print(f"Training Precision: {average_train_precision:.4f}")
print(f"Validation Precision: {average_val_precision:.4f}")
print(f"Training Recall: {average_train_recall:.4f}")
print(f"Validation Recall: {average_val_recall:.4f}")
print(f"Training F1-Score: {average_train_f1:.4f}")
print(f"Validation F1-Score: {average_val_f1:.4f}")
print(f"Validation Sensitivity: {average_val_sensitivity:.4f}")
print(f"Validation Specificity: {average_val_specificity:.4f}")




In [None]:
aggregated_updrs_metrics, probabilities_by_updrs = aggregate_updrs_results(fold_results)

print("\n===== Aggregated UPDRS Metrics Across Folds =====")
for level, metrics in aggregated_updrs_metrics.items():
    print(f"UPDRS Level {level}:")
    print(f"  Total Count: {metrics['total_count']}")
    print(f"  Mean Probability (Parkinsonian): {metrics['mean_probability']:.4f}")
    print(f"  Percentage Classified as Parkinsonian: {metrics['percentage_classified_as_parkinsonian']:.2f}%")
    print("-" * 40)


In [None]:


def plot_updrs_metrics(aggregated_updrs_metrics, probabilities_by_updrs):
    levels = list(aggregated_updrs_metrics.keys())
    total_counts = [aggregated_updrs_metrics[level]['total_count'] for level in levels]
    mean_probs = [aggregated_updrs_metrics[level]['mean_probability'] for level in levels]
    mean_logits = [aggregated_updrs_metrics[level]['mean_logit'] for level in levels]
    percentages_classified = [aggregated_updrs_metrics[level]['percentage_classified_as_parkinsonian'] for level in levels]

    plt.figure(figsize=(10, 6))
    sns.barplot(x=levels, y=mean_probs, palette='Blues_d')
    plt.xlabel('UPDRS Level')
    plt.ylabel('Mean Probability')
    plt.title('Mean Probability of Being Classified as Parkinsonian by UPDRS Level')
    plt.show()

    plt.figure(figsize=(10, 6))
    data = [(level, prob) for level, probs in probabilities_by_updrs.items() for prob in probs]
    df = pd.DataFrame(data, columns=['UPDRS Level', 'Probability'])
    sns.boxplot(x='UPDRS Level', y='Probability', data=df, palette='Pastel1')
    plt.xlabel('UPDRS Level')
    plt.ylabel('Probability')
    plt.title('Distribution of Probabilities by UPDRS Level')
    plt.show()

    plt.figure(figsize=(10, 6))
    sns.barplot(x=levels, y=percentages_classified, palette='Oranges_d')
    plt.xlabel('UPDRS Level')
    plt.ylabel('Percentage Classified as Parkinsonian')
    plt.title('Percentage Classified as Parkinsonian by UPDRS Level')
    plt.show()

    plt.figure(figsize=(10, 6))
    sns.barplot(x=levels, y=total_counts, palette='Greens_d')
    plt.xlabel('UPDRS Level')
    plt.ylabel('Total Count')
    plt.title('Total Data Count by UPDRS Level')
    plt.show()

plot_updrs_metrics(aggregated_updrs_metrics, probabilities_by_updrs)
