<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/DetectAlzFromSpeechGBA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os
from google.colab import drive

# Mount Google Drive
try:
    drive.mount('/content/drive', force_remount=False)
except:
    drive.mount('/content/drive')

# Define the base directory
base_directory = '/content/drive/MyDrive/Speech'

# Check if the directory exists
if os.path.exists(base_directory):
    # Separate lists for files and directories
    file_paths = []
    directory_paths = []

    for root, directories, files in os.walk(base_directory):
        # Add directory paths
        for directory in directories:
            dir_path = os.path.join(root, directory)
            directory_paths.append(dir_path)

        # Add file paths
        for file in files:
            file_path = os.path.join(root, file)
            file_paths.append(file_path)

    print("Directories:")
    for directory in directory_paths:
        print(f"  {directory}")

    print(f"\nFiles:")
    for file in file_paths:
        print(f"  {file}")

    print(f"\nSummary:")
    print(f"  Total directories: {len(directory_paths)}")
    print(f"  Total files: {len(file_paths)}")
    print(f"  Total paths: {len(file_paths) + len(directory_paths)}")

else:
    print(f"Directory {base_directory} does not exist!")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Directories:
  /content/drive/MyDrive/Speech/extracted_diagnosis_train
  /content/drive/MyDrive/Speech/extracted_progression_train
  /content/drive/MyDrive/Speech/extracted_progression_test
  /content/drive/MyDrive/Speech/linguistic_features
  /content/drive/MyDrive/Speech/lightweight_features
  /content/drive/MyDrive/Speech/transcripts
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/segmentation
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/audio
  /content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/segmentation/cn
  /content/driv

In [12]:
import os
import pandas as pd
import numpy as np
import librosa
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import pickle
import json
from sklearn.preprocessing import StandardScaler, LabelEncoder
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

def extract_labels_from_directory_structure():
    base_paths = {
        'diagnosis_train': '/content/drive/MyDrive/Speech/extracted_diagnosis_train/ADReSSo21/diagnosis/train/audio',
        'progression_train': '/content/drive/MyDrive/Speech/extracted_progression_train/ADReSSo21/progression/train/audio',
        'progression_test': '/content/drive/MyDrive/Speech/extracted_progression_test/ADReSSo21/progression/test-dist/audio'
    }

    labels = {}

    for dataset_type, base_path in base_paths.items():
        if dataset_type == 'progression_test':
            files = os.listdir(base_path)
            wav_files = [f for f in files if f.endswith('.wav')]
            for wav_file in wav_files:
                file_id = wav_file.replace('.wav', '')
                labels[file_id] = {
                    'dataset': 'progression_test',
                    'label': 'unknown',
                    'file_path': os.path.join(base_path, wav_file)
                }
        else:
            subdirs = os.listdir(base_path)
            for subdir in subdirs:
                subdir_path = os.path.join(base_path, subdir)
                if os.path.isdir(subdir_path):
                    wav_files = os.listdir(subdir_path)
                    for wav_file in wav_files:
                        if wav_file.endswith('.wav'):
                            file_id = wav_file.replace('.wav', '')
                            if dataset_type == 'diagnosis_train':
                                label = 'ad' if subdir == 'ad' else 'cn'
                            else:
                                label = 'decline' if subdir == 'decline' else 'no_decline'

                            labels[file_id] = {
                                'dataset': dataset_type,
                                'label': label,
                                'file_path': os.path.join(subdir_path, wav_file)
                            }

    return labels

def extract_mel_spectrogram(audio_path, n_mels=128, hop_length=512, n_fft=2048):
    try:
        y, sr = librosa.load(audio_path, sr=22050)
        mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels,
                                                hop_length=hop_length, n_fft=n_fft)
        log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

        return {
            'mel_spectrogram': mel_spec,
            'log_mel_spectrogram': log_mel_spec,
            'mel_mean': np.mean(mel_spec, axis=1),
            'mel_std': np.std(mel_spec, axis=1),
            'log_mel_mean': np.mean(log_mel_spec, axis=1),
            'log_mel_std': np.std(log_mel_spec, axis=1)
        }
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
        return None

def extract_wav2vec2_features(audio_path, model_name="facebook/wav2vec2-base-960h"):
    try:
        processor = Wav2Vec2Processor.from_pretrained(model_name)
        model = Wav2Vec2Model.from_pretrained(model_name)

        y, sr = librosa.load(audio_path, sr=16000)

        inputs = processor(y, sampling_rate=16000, return_tensors="pt", padding=True)

        with torch.no_grad():
            outputs = model(**inputs)
            last_hidden_states = outputs.last_hidden_state

        features = last_hidden_states.squeeze().numpy()

        return {
            'wav2vec2_features': features,
            'wav2vec2_mean': np.mean(features, axis=0),
            'wav2vec2_std': np.std(features, axis=0),
            'wav2vec2_max': np.max(features, axis=0),
            'wav2vec2_min': np.min(features, axis=0)
        }
    except Exception as e:
        print(f"Error processing wav2vec2 for {audio_path}: {e}")
        return None

def extract_acoustic_features(audio_path):
    try:
        y, sr = librosa.load(audio_path, sr=22050)

        features = {}

        # Fix: Use correct librosa function names
        features['mfcc'] = np.mean(librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13), axis=1)
        features['chroma'] = np.mean(librosa.feature.chroma_stft(y=y, sr=sr), axis=1)  # Fixed function name
        features['spectral_contrast'] = np.mean(librosa.feature.spectral_contrast(y=y, sr=sr), axis=1)
        features['tonnetz'] = np.mean(librosa.feature.tonnetz(y=y, sr=sr), axis=1)
        features['zero_crossing_rate'] = np.mean(librosa.feature.zero_crossing_rate(y))
        features['spectral_centroid'] = np.mean(librosa.feature.spectral_centroid(y=y, sr=sr))
        features['spectral_rolloff'] = np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr))
        features['rms_energy'] = np.mean(librosa.feature.rms(y=y))

        tempo, beats = librosa.beat.beat_track(y=y, sr=sr)
        features['tempo'] = tempo

        return features
    except Exception as e:
        print(f"Error extracting acoustic features for {audio_path}: {e}")
        return None

def load_linguistic_features():
    ling_features = {}

    ling_path = '/content/drive/MyDrive/Speech/linguistic_features'
    if os.path.exists(ling_path):
        try:
            with open(os.path.join(ling_path, 'linguistic_features.pkl'), 'rb') as f:
                ling_features = pickle.load(f)
        except:
            try:
                with open(os.path.join(ling_path, 'linguistic_features.json'), 'r') as f:
                    ling_features = json.load(f)
            except Exception as e:
                print(f"Error loading linguistic features: {e}")

    return ling_features
def load_transcripts():
    transcripts = {}

    transcript_files = [
        '/content/drive/MyDrive/Speech/transcripts/all_categories_results.json',
        '/content/drive/MyDrive/Speech/transcripts/transcription_results.json'
    ]

    for transcript_file in transcript_files:
        if os.path.exists(transcript_file):
            try:
                with open(transcript_file, 'r') as f:
                    data = json.load(f)
                    if isinstance(data, dict):
                        transcripts.update(data)
                    elif isinstance(data, list):
                        # Handle list format if that's what the file contains
                        for item in data:
                            if isinstance(item, dict) and len(item) >= 2:
                                # Extract key-value pairs from list items
                                keys = list(item.keys())
                                transcripts[keys[0]] = item[keys[1]]
            except json.JSONDecodeError as e:
                print(f"JSON decode error in {transcript_file}: {e}")
            except Exception as e:
                print(f"Error loading {transcript_file}: {e}")

    # Continue with individual transcript loading...

    individual_transcript_path = '/content/drive/MyDrive/Speech/transcripts/individual_transcripts'
    if os.path.exists(individual_transcript_path):
        txt_files = [f for f in os.listdir(individual_transcript_path) if f.endswith('.txt')]
        for txt_file in txt_files:
            try:
                file_id = txt_file.replace('.wav.txt', '')
                with open(os.path.join(individual_transcript_path, txt_file), 'r') as f:
                    content = f.read().strip()
                    transcripts[file_id] = content
            except Exception as e:
                print(f"Error loading {txt_file}: {e}")

    return transcripts

def create_comprehensive_dataset():
    print("Extracting labels from directory structure...")
    labels_dict = extract_labels_from_directory_structure()

    print("Loading existing features...")
    existing_features = load_existing_features()

    print("Loading linguistic features...")
    linguistic_features = load_linguistic_features()

    print("Loading transcripts...")
    transcripts = load_transcripts()

    dataset = []

    for file_id, label_info in labels_dict.items():
        row = {
            'file_id': file_id,
            'dataset': label_info['dataset'],
            'label': label_info['label'],
            'file_path': label_info['file_path']
        }

        if file_id in transcripts:
            row['transcript'] = transcripts[file_id]

        for category, features in existing_features.items():
            if file_id in features:
                feature_data = features[file_id]
                if isinstance(feature_data, dict):
                    for key, value in feature_data.items():
                        row[f'{category}_{key}'] = value
                else:
                    row[f'{category}_features'] = feature_data

        if file_id in linguistic_features:
            ling_data = linguistic_features[file_id]
            if isinstance(ling_data, dict):
                for key, value in ling_data.items():
                    row[f'linguistic_{key}'] = value

        dataset.append(row)

    df = pd.DataFrame(dataset)

    print(f"Dataset created with {len(df)} samples")
    print(f"Columns: {df.columns.tolist()}")
    print(f"Label distribution:")
    print(df['label'].value_counts())

    return df, labels_dict

def extract_audio_features_batch(labels_dict, sample_limit=None):
    print("Extracting mel-spectrograms and acoustic features...")

    audio_features = {}
    processed_count = 0

    for file_id, label_info in labels_dict.items():
        if sample_limit and processed_count >= sample_limit:
            break

        audio_path = label_info['file_path']

        if os.path.exists(audio_path):
            print(f"Processing {file_id}...")

            mel_features = extract_mel_spectrogram(audio_path)
            acoustic_features = extract_acoustic_features(audio_path)

            if mel_features and acoustic_features:
                combined_features = {**mel_features, **acoustic_features}
                audio_features[file_id] = combined_features
                processed_count += 1
        else:
            print(f"File not found: {audio_path}")

    return audio_features

def create_training_datasets():
    df, labels_dict = create_comprehensive_dataset()

    diagnosis_train = df[df['dataset'] == 'diagnosis_train'].copy()
    progression_train = df[df['dataset'] == 'progression_train'].copy()
    progression_test = df[df['dataset'] == 'progression_test'].copy()

    diagnosis_le = LabelEncoder()
    if len(diagnosis_train) > 0:
        diagnosis_train['label_encoded'] = diagnosis_le.fit_transform(diagnosis_train['label'])

    progression_le = LabelEncoder()
    if len(progression_train) > 0:
        progression_train['label_encoded'] = progression_le.fit_transform(progression_train['label'])

    datasets = {
        'diagnosis_train': diagnosis_train,
        'progression_train': progression_train,
        'progression_test': progression_test,
        'all_data': df,
        'labels_dict': labels_dict,
        'diagnosis_label_encoder': diagnosis_le,
        'progression_label_encoder': progression_le
    }

    print("\nDataset Summary:")
    print(f"Diagnosis Training: {len(diagnosis_train)} samples")
    if len(diagnosis_train) > 0:
        print(f"  - AD: {len(diagnosis_train[diagnosis_train['label'] == 'ad'])}")
        print(f"  - CN: {len(diagnosis_train[diagnosis_train['label'] == 'cn'])}")

    print(f"Progression Training: {len(progression_train)} samples")
    if len(progression_train) > 0:
        print(f"  - Decline: {len(progression_train[progression_train['label'] == 'decline'])}")
        print(f"  - No Decline: {len(progression_train[progression_train['label'] == 'no_decline'])}")

    print(f"Progression Test: {len(progression_test)} samples")

    return datasets

def save_datasets(datasets, output_path='/content/drive/MyDrive/Speech/processed_datasets'):
    os.makedirs(output_path, exist_ok=True)

    for name, data in datasets.items():
        if isinstance(data, pd.DataFrame):
            data.to_csv(os.path.join(output_path, f'{name}.csv'), index=False)
            data.to_pickle(os.path.join(output_path, f'{name}.pkl'))
        elif name == 'labels_dict':
            with open(os.path.join(output_path, 'labels_dict.json'), 'w') as f:
                json.dump(data, f, indent=2)
        else:
            with open(os.path.join(output_path, f'{name}.pkl'), 'wb') as f:
                pickle.dump(data, f)

    print(f"Datasets saved to {output_path}")

if __name__ == "__main__":
    datasets = create_training_datasets()
    save_datasets(datasets)

    print("\nExtracting audio features for a sample...")
    sample_audio_features = extract_audio_features_batch(datasets['labels_dict'], sample_limit=5)

    if sample_audio_features:
        print(f"Successfully extracted features for {len(sample_audio_features)} audio files")
        sample_id = list(sample_audio_features.keys())[0]
        sample_features = sample_audio_features[sample_id]
        print(f"Feature keys for {sample_id}: {list(sample_features.keys())}")

    print("\nProcessing complete!")
    print("Available datasets:")
    for name in datasets.keys():
        if isinstance(datasets[name], pd.DataFrame):
            print(f"  - {name}: {len(datasets[name])} samples")

Extracting labels from directory structure...
Loading existing features...
Loading linguistic features...
Loading transcripts...
Dataset created with 271 samples
Columns: ['file_id', 'dataset', 'label', 'file_path', 'transcript', 'diagnosis_cn_log_mel', 'diagnosis_cn_wav2vec2', 'diagnosis_cn_file_id', 'diagnosis_cn_duration', 'diagnosis_cn_success', 'diagnosis_ad_log_mel', 'diagnosis_ad_wav2vec2', 'diagnosis_ad_file_id', 'diagnosis_ad_duration', 'diagnosis_ad_success', 'progression_no_decline_log_mel', 'progression_no_decline_wav2vec2', 'progression_no_decline_file_id', 'progression_no_decline_duration', 'progression_no_decline_success', 'progression_decline_log_mel', 'progression_decline_wav2vec2', 'progression_decline_file_id', 'progression_decline_duration', 'progression_decline_success', 'progression_test_log_mel', 'progression_test_wav2vec2', 'progression_test_file_id', 'progression_test_duration', 'progression_test_success']
Label distribution:
label
ad            87
cn          

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import numpy as np
from transformers import BertModel, BertTokenizer
import math

class XceptionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(XceptionBlock, self).__init__()
        self.separable_conv1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels),
            nn.Conv2d(in_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.separable_conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, padding=1, groups=out_channels),
            nn.Conv2d(out_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels)
        )
        self.skip_connection = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, stride=stride),
            nn.BatchNorm2d(out_channels)
        ) if in_channels != out_channels or stride != 1 else nn.Identity()

        self.stride = stride

    def forward(self, x):
        residual = self.skip_connection(x)
        x = self.separable_conv1(x)
        x = self.separable_conv2(x)
        if self.stride > 1:
            x = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
        return F.relu(x + residual)

class XceptionNet(nn.Module):
    def __init__(self, input_channels=1, num_classes=512):
        super(XceptionNet, self).__init__()
        self.entry_flow = nn.Sequential(
            nn.Conv2d(input_channels, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.middle_flow = nn.Sequential(
            XceptionBlock(64, 128, stride=2),
            XceptionBlock(128, 256, stride=2),
            XceptionBlock(256, 728, stride=2),
            XceptionBlock(728, 728),
            XceptionBlock(728, 728),
            XceptionBlock(728, 728),
        )

        self.exit_flow = nn.Sequential(
            XceptionBlock(728, 1024, stride=2),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.entry_flow(x)
        x = self.middle_flow(x)
        x = self.exit_flow(x)
        return x

class ViTBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super(ViTBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

class ViTEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, num_classes=512):
        super(ViTEncoder, self).__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.embed_dim = embed_dim

        self.patch_embed = nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.dropout = nn.Dropout(0.1)

        self.blocks = nn.ModuleList([
            ViTBlock(embed_dim, num_heads) for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.patch_embed(x).flatten(2).transpose(1, 2)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed[:, :x.size(1)]
        x = self.dropout(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)
        return self.head(x[:, 0])

class SpectrogramEncoder(nn.Module):
    def __init__(self, use_xception=True, feature_dim=512):
        super(SpectrogramEncoder, self).__init__()
        self.use_xception = use_xception
        if use_xception:
            self.xception = XceptionNet(input_channels=1, num_classes=feature_dim)
        else:
            self.vit = ViTEncoder(img_size=224, patch_size=16, embed_dim=768, depth=6, num_heads=8, num_classes=feature_dim)

    def forward(self, x):
        if self.use_xception:
            return self.xception(x)
        else:
            return self.vit(x)

class AcousticFeatureEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, output_dim=512):
        super(AcousticFeatureEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.encoder(x)

class LinguisticEncoder(nn.Module):
    def __init__(self, bert_model_name='bert-base-uncased', output_dim=512, max_length=512):
        super(LinguisticEncoder, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.max_length = max_length
        self.projection = nn.Linear(self.bert.config.hidden_size, output_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        return self.projection(pooled_output)

class GraphAttention(nn.Module):
    def __init__(self, feature_dim=512, num_heads=8, dropout=0.1):
        super(GraphAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = feature_dim // num_heads
        self.scale = math.sqrt(self.head_dim)

        self.query = nn.Linear(feature_dim, feature_dim)
        self.key = nn.Linear(feature_dim, feature_dim)
        self.value = nn.Linear(feature_dim, feature_dim)
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(feature_dim, feature_dim)

    def forward(self, nodes):
        B, N, D = nodes.shape

        q = self.query(nodes).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(nodes).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(nodes).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) / self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)

        out = (attn @ v).transpose(1, 2).contiguous().view(B, N, D)
        return self.proj(out)

class MultiModalFusionModule(nn.Module):
    def __init__(self, feature_dim=512, num_layers=3):
        super(MultiModalFusionModule, self).__init__()
        self.layers = nn.ModuleList([
            GraphAttention(feature_dim) for _ in range(num_layers)
        ])
        self.norm_layers = nn.ModuleList([
            nn.LayerNorm(feature_dim) for _ in range(num_layers)
        ])

    def forward(self, features):
        x = torch.stack(features, dim=1)

        for layer, norm in zip(self.layers, self.norm_layers):
            residual = x
            x = layer(x)
            x = norm(x + residual)

        return x.mean(dim=1)

class MultiModalSpeechModel(nn.Module):
    def __init__(self,
                 acoustic_input_dim=50,
                 bert_model_name='bert-base-uncased',
                 feature_dim=512,
                 num_classes_diagnosis=2,
                 num_classes_progression=2,
                 use_xception_for_spec=True):
        super(MultiModalSpeechModel, self).__init__()

        self.spectrogram_encoder = SpectrogramEncoder(use_xception=use_xception_for_spec, feature_dim=feature_dim)
        self.acoustic_encoder = AcousticFeatureEncoder(acoustic_input_dim, output_dim=feature_dim)
        self.linguistic_encoder = LinguisticEncoder(bert_model_name, output_dim=feature_dim)

        self.fusion_module = MultiModalFusionModule(feature_dim)

        self.diagnosis_classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes_diagnosis)
        )

        self.progression_classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes_progression)
        )

    def forward(self, spectrogram=None, acoustic_features=None, input_ids=None, attention_mask=None, task='diagnosis'):
        features = []

        if spectrogram is not None:
            spec_features = self.spectrogram_encoder(spectrogram)
            features.append(spec_features)

        if acoustic_features is not None:
            acoustic_feat = self.acoustic_encoder(acoustic_features)
            features.append(acoustic_feat)

        if input_ids is not None and attention_mask is not None:
            linguistic_feat = self.linguistic_encoder(input_ids, attention_mask)
            features.append(linguistic_feat)

        if len(features) == 0:
            raise ValueError("At least one modality must be provided")

        if len(features) == 1:
            fused_features = features[0]
        else:
            fused_features = self.fusion_module(features)

        if task == 'diagnosis':
            return self.diagnosis_classifier(fused_features)
        elif task == 'progression':
            return self.progression_classifier(fused_features)
        else:
            return {
                'diagnosis': self.diagnosis_classifier(fused_features),
                'progression': self.progression_classifier(fused_features)
            }

class MultiModalDataset(torch.utils.data.Dataset):
    def __init__(self, spectrograms, acoustic_features, texts, labels, tokenizer, max_length=512):
        self.spectrograms = spectrograms
        self.acoustic_features = acoustic_features
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {}

        if self.spectrograms is not None and idx < len(self.spectrograms):
            spec = self.spectrograms[idx]
            if isinstance(spec, np.ndarray):
                spec = torch.FloatTensor(spec)
            if len(spec.shape) == 2:
                spec = spec.unsqueeze(0)
            item['spectrogram'] = spec

        if self.acoustic_features is not None and idx < len(self.acoustic_features):
            acoustic = self.acoustic_features[idx]
            if isinstance(acoustic, np.ndarray):
                acoustic = torch.FloatTensor(acoustic)
            item['acoustic_features'] = acoustic

        if self.texts is not None and idx < len(self.texts):
            text = str(self.texts[idx])
            encoding = self.tokenizer(
                text,
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
            item['input_ids'] = encoding['input_ids'].squeeze()
            item['attention_mask'] = encoding['attention_mask'].squeeze()

        item['label'] = torch.LongTensor([self.labels[idx]])
        return item

def create_model_and_dataloaders(train_data, val_data=None, batch_size=16):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    train_dataset = MultiModalDataset(
        spectrograms=train_data.get('spectrograms'),
        acoustic_features=train_data.get('acoustic_features'),
        texts=train_data.get('texts'),
        labels=train_data['labels'],
        tokenizer=tokenizer
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )

    val_loader = None
    if val_data is not None:
        val_dataset = MultiModalDataset(
            spectrograms=val_data.get('spectrograms'),
            acoustic_features=val_data.get('acoustic_features'),
            texts=val_data.get('texts'),
            labels=val_data['labels'],
            tokenizer=tokenizer
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=collate_fn
        )

    # Fix: Handle the array checking properly
    acoustic_features = train_data.get('acoustic_features')
    if acoustic_features is not None and len(acoustic_features) > 0:
        acoustic_dim = len(acoustic_features[0])
    else:
        acoustic_dim = 50

    model = MultiModalSpeechModel(acoustic_input_dim=acoustic_dim)

    return model, train_loader, val_loader, tokenizer

def train_model(model, train_loader, val_loader=None, epochs=10, learning_rate=1e-4, device='cuda'):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        for batch in train_loader:
            optimizer.zero_grad()

            inputs = {}
            for key in ['spectrogram', 'acoustic_features', 'input_ids', 'attention_mask']:
                if key in batch:
                    inputs[key] = batch[key].to(device)

            labels = batch['label'].squeeze().to(device)
            outputs = model(**inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_accuracy = 100 * train_correct / train_total
        avg_train_loss = train_loss / len(train_loader)

        print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_train_loss:.4f}, Accuracy: {train_accuracy:.2f}%')

        if val_loader:
            model.eval()
            val_loss = 0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for batch in val_loader:
                    inputs = {}
                    for key in ['spectrogram', 'acoustic_features', 'input_ids', 'attention_mask']:
                        if key in batch:
                            inputs[key] = batch[key].to(device)

                    labels = batch['label'].squeeze().to(device)
                    outputs = model(**inputs)

                    loss = criterion(outputs, labels)
                    val_loss += loss.item()

                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()

            val_accuracy = 100 * val_correct / val_total
            avg_val_loss = val_loss / len(val_loader)
            scheduler.step(avg_val_loss)

            print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')

    return model

In [15]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from transformers import BertTokenizer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
from sklearn.model_selection import StratifiedKFold
import pickle
import json
import librosa
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from scipy.ndimage import zoom
import warnings
warnings.filterwarnings('ignore')

# Import necessary classes from the previous cell
from __main__ import MultiModalDataset, MultiModalSpeechModel, collate_fn, create_model_and_dataloaders


def load_processed_datasets():
    dataset_path = '/content/drive/MyDrive/Speech/processed_datasets'

    diagnosis_train = pd.read_pickle(os.path.join(dataset_path, 'diagnosis_train.pkl'))
    progression_train = pd.read_pickle(os.path.join(dataset_path, 'progression_train.pkl'))
    progression_test = pd.read_pickle(os.path.join(dataset_path, 'progression_test.pkl'))

    with open(os.path.join(dataset_path, 'diagnosis_label_encoder.pkl'), 'rb') as f:
        diagnosis_le = pickle.load(f)
    with open(os.path.join(dataset_path, 'progression_label_encoder.pkl'), 'rb') as f:
        progression_le = pickle.load(f)

    return diagnosis_train, progression_train, progression_test, diagnosis_le, progression_le

def extract_acoustic_features(audio_path):
    try:
        y, sr = librosa.load(audio_path, sr=22050)

        features = {}

        features['mfcc'] = np.mean(librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13), axis=1).flatten()
        features['chroma'] = np.mean(librosa.feature.chroma_stft(y=y, sr=sr), axis=1).flatten()
        features['spectral_contrast'] = np.mean(librosa.feature.spectral_contrast(y=y, sr=sr), axis=1).flatten()
        features['tonnetz'] = np.mean(librosa.feature.tonnetz(y=y, sr=sr), axis=1).flatten()
        features['zero_crossing_rate'] = np.array([np.mean(librosa.feature.zero_crossing_rate(y))]).flatten()
        features['spectral_centroid'] = np.array([np.mean(librosa.feature.spectral_centroid(y=y, sr=sr))]).flatten()
        features['spectral_rolloff'] = np.array([np.mean(librosa.feature.spectral_rolloff(y=y, sr=sr))]).flatten()
        features['rms_energy'] = np.array([np.mean(librosa.feature.rms(y=y))]).flatten()

        tempo, beats = librosa.beat.beat_track(y=y, sr=sr)
        features['tempo'] = np.array([tempo]).flatten()

        acoustic_feat = np.concatenate(list(features.values()))

        target_length = 50
        if len(acoustic_feat) < target_length:
            acoustic_feat = np.pad(acoustic_feat, (0, target_length - len(acoustic_feat)), 'constant')
        else:
            acoustic_feat = acoustic_feat[:target_length]

        return acoustic_feat

    except Exception as e:
        print(f"Error extracting acoustic features for {audio_path}: {e}")
        return np.zeros(50)


def extract_spectrograms_from_audio(df, target_size=(224, 224)):
    spectrograms = []
    acoustic_features = []

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing audio files"):
        try:
            audio_path = row['file_path']
            if os.path.exists(audio_path):
                y, sr = librosa.load(audio_path, sr=22050)

                mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, hop_length=512, n_fft=2048)
                log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

                height_ratio = target_size[0] / log_mel_spec.shape[0]
                width_ratio = target_size[1] / log_mel_spec.shape[1]
                log_mel_spec_resized = zoom(log_mel_spec, (height_ratio, width_ratio))
                spectrograms.append(log_mel_spec_resized)

                acoustic_feat = extract_acoustic_features(audio_path)
                acoustic_features.append(acoustic_feat)

            else:
                spectrograms.append(np.zeros(target_size))
                acoustic_features.append(np.zeros(50))

        except Exception as e:
            print(f"Error processing {row['file_id']}: {e}")
            spectrograms.append(np.zeros(target_size))
            acoustic_features.append(np.zeros(50))


    return np.array(spectrograms), np.array(acoustic_features)


def prepare_multimodal_data(df, task='diagnosis'):
    texts = df['transcript'].fillna('').tolist()

    if task == 'diagnosis':
        labels = df['label_encoded'].values
    elif task == 'progression':
        labels = df['label_encoded'].values
    else:
        raise ValueError("Task must be 'diagnosis' or 'progression'")

    spectrograms, acoustic_features = extract_spectrograms_from_audio(df)

    return {
        'spectrograms': spectrograms,
        'acoustic_features': acoustic_features,
        'texts': texts,
        'labels': labels
    }

class EarlyStopping:
    def __init__(self, patience=7, min_delta=0, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1

        if self.counter >= self.patience:
            if self.restore_best_weights:
                model.load_state_dict(self.best_weights)
            return True
        return False

    def save_checkpoint(self, model):
        self.best_weights = model.state_dict().copy()

def train_multimodal_model(model, train_loader, val_loader, task='diagnosis', epochs=50, learning_rate=1e-4, device='cuda'):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)
    early_stopping = EarlyStopping(patience=10, min_delta=0.001)

    training_history = {
        'train_losses': [],
        'val_losses': [],
        'train_accuracies': [],
        'val_accuracies': []
    }

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]')
        for batch in train_pbar:
            optimizer.zero_grad()

            inputs = {}
            for key in ['spectrogram', 'acoustic_features', 'input_ids', 'attention_mask']:
                if key in batch:
                    inputs[key] = batch[key].to(device)

            labels = batch['label'].squeeze().to(device)
            outputs = model(task=task, **inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

            train_pbar.set_postfix({'Loss': f'{loss.item():.4f}', 'Acc': f'{100 * train_correct / train_total:.2f}%'})

        train_accuracy = 100 * train_correct / train_total
        avg_train_loss = train_loss / len(train_loader)

        training_history['train_losses'].append(avg_train_loss)
        training_history['train_accuracies'].append(train_accuracy)

        if val_loader:
            model.eval()
            val_loss = 0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]')
                for batch in val_loader:
                    inputs = {}
                    for key in ['spectrogram', 'acoustic_features', 'input_ids', 'attention_mask']:
                        if key in batch:
                            inputs[key] = batch[key].to(device)

                    labels = batch['label'].squeeze().to(device)
                    outputs = model(task=task, **inputs)

                    loss = criterion(outputs, labels)
                    val_loss += loss.item()

                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()

                    val_pbar.set_postfix({'Loss': f'{loss.item():.4f}', 'Acc': f'{100 * val_correct / val_total:.2f}%'})

            val_accuracy = 100 * val_correct / val_total
            avg_val_loss = val_loss / len(val_loader)

            training_history['val_losses'].append(avg_val_loss)
            training_history['val_accuracies'].append(val_accuracy)

            scheduler.step(avg_val_loss)

            print(f'Epoch [{epoch+1}/{epochs}]')
            print(f'Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.2f}%')
            print(f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.2f}%')
            print(f'LR: {optimizer.param_groups[0]["lr"]:.6f}')
            print('-' * 50)

            if early_stopping(avg_val_loss, model):
                print(f'Early stopping triggered at epoch {epoch+1}')
                break

    return model, training_history

def evaluate_model(model, test_loader, task='diagnosis', device='cuda'):
    model.eval()
    model = model.to(device)

    all_predictions = []
    all_labels = []
    all_probabilities = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc='Evaluating'):
            inputs = {}
            for key in ['spectrogram', 'acoustic_features', 'input_ids', 'attention_mask']:
                if key in batch:
                    inputs[key] = batch[key].to(device)

            labels = batch['label'].squeeze().to(device)
            outputs = model(task=task, **inputs)

            probabilities = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.data, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_predictions, average='weighted')

    cm = confusion_matrix(all_labels, all_predictions)

    results = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'confusion_matrix': cm,
        'predictions': all_predictions,
        'true_labels': all_labels,
        'probabilities': all_probabilities,
        'classification_report': classification_report(all_labels, all_predictions)
    }

    return results

def cross_validate_model(data, task='diagnosis', k_folds=5, epochs=30, device='cuda'):
    skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)

    fold_results = []
    all_labels = data['labels']

    for fold, (train_idx, val_idx) in enumerate(skf.split(range(len(all_labels)), all_labels)):
        print(f'Fold {fold + 1}/{k_folds}')
        print('=' * 50)

        train_data = {
            'spectrograms': data['spectrograms'][train_idx],
            'acoustic_features': data['acoustic_features'][train_idx],
            'texts': [data['texts'][i] for i in train_idx],
            'labels': data['labels'][train_idx]
        }

        val_data = {
            'spectrograms': data['spectrograms'][val_idx],
            'acoustic_features': data['acoustic_features'][val_idx],
            'texts': [data['texts'][i] for i in val_idx],
            'labels': data['labels'][val_idx]
        }

        model, train_loader, val_loader, tokenizer = create_model_and_dataloaders(
            train_data, val_data, batch_size=16
        )

        model, training_history = train_multimodal_model(
            model, train_loader, val_loader, task=task, epochs=epochs, device=device
        )

        val_results = evaluate_model(model, val_loader, task=task, device=device)

        fold_results.append({
            'fold': fold + 1,
            'val_accuracy': val_results['accuracy'],
            'val_f1': val_results['f1_score'],
            'val_precision': val_results['precision'],
            'val_recall': val_results['recall'],
            'training_history': training_history,
            'confusion_matrix': val_results['confusion_matrix']
        })

        print(f'Fold {fold + 1} Results:')
        print(f'Accuracy: {val_results["accuracy"]:.4f}')
        print(f'F1-Score: {val_results["f1_score"]:.4f}')
        print(f'Precision: {val_results["precision"]:.4f}')
        print(f'Recall: {val_results["recall"]:.4f}')
        print('=' * 50)

    avg_accuracy = np.mean([result['val_accuracy'] for result in fold_results])
    avg_f1 = np.mean([result['val_f1'] for result in fold_results])
    avg_precision = np.mean([result['val_precision'] for result in fold_results])
    avg_recall = np.mean([result['val_recall'] for result in fold_results])

    std_accuracy = np.std([result['val_accuracy'] for result in fold_results])
    std_f1 = np.std([result['val_f1'] for result in fold_results])

    summary = {
        'avg_accuracy': avg_accuracy,
        'std_accuracy': std_accuracy,
        'avg_f1': avg_f1,
        'std_f1': std_f1,
        'avg_precision': avg_precision,
        'avg_recall': avg_recall,
        'fold_results': fold_results
    }

    print('Cross-Validation Summary:')
    print(f'Average Accuracy: {avg_accuracy:.4f} ± {std_accuracy:.4f}')
    print(f'Average F1-Score: {avg_f1:.4f} ± {std_f1:.4f}')
    print(f'Average Precision: {avg_precision:.4f}')
    print(f'Average Recall: {avg_recall:.4f}')

    return summary

def plot_training_history(history, task='diagnosis'):
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))

    axes[0].plot(history['train_losses'], label='Train Loss', color='blue')
    if 'val_losses' in history:
        axes[0].plot(history['val_losses'], label='Validation Loss', color='red')
    axes[0].set_title(f'{task.capitalize()} Task - Training Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].grid(True)

    axes[1].plot(history['train_accuracies'], label='Train Accuracy', color='blue')
    if 'val_accuracies' in history:
        axes[1].plot(history['val_accuracies'], label='Validation Accuracy', color='red')
    axes[1].set_title(f'{task.capitalize()} Task - Training Accuracy')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].legend()
    axes[1].grid(True)

    plt.tight_layout()
    plt.show()

def plot_confusion_matrix(cm, class_names, task='diagnosis'):
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'{task.capitalize()} Task - Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.show()

def comprehensive_evaluation():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    diagnosis_train, progression_train, progression_test, diagnosis_le, progression_le = load_processed_datasets()

    print("Preparing diagnosis data...")
    diagnosis_data = prepare_multimodal_data(diagnosis_train, task='diagnosis')

    print("Preparing progression data...")
    progression_data = prepare_multimodal_data(progression_train, task='progression')

    results = {}

    print("\n" + "="*60)
    print("DIAGNOSIS TASK EVALUATION")
    print("="*60)

    diagnosis_cv_results = cross_validate_model(
        diagnosis_data, task='diagnosis', k_folds=5, epochs=25, device=device
    )
    results['diagnosis_cv'] = diagnosis_cv_results

    print("\n" + "="*60)
    print("PROGRESSION TASK EVALUATION")
    print("="*60)

    progression_cv_results = cross_validate_model(
        progression_data, task='progression', k_folds=5, epochs=25, device=device
    )
    results['progression_cv'] = progression_cv_results

    print("\n" + "="*60)
    print("FINAL MODEL TRAINING")
    print("="*60)

    train_size = int(0.8 * len(diagnosis_data['labels']))
    val_size = len(diagnosis_data['labels']) - train_size

    indices = np.random.permutation(len(diagnosis_data['labels']))
    train_idx, val_idx = indices[:train_size], indices[train_size:]

    diagnosis_train_data = {
        'spectrograms': diagnosis_data['spectrograms'][train_idx],
        'acoustic_features': diagnosis_data['acoustic_features'][train_idx],
        'texts': [diagnosis_data['texts'][i] for i in train_idx],
        'labels': diagnosis_data['labels'][train_idx]
    }

    diagnosis_val_data = {
        'spectrograms': diagnosis_data['spectrograms'][val_idx],
        'acoustic_features': diagnosis_data['acoustic_features'][val_idx],
        'texts': [diagnosis_data['texts'][i] for i in val_idx],
        'labels': diagnosis_data['labels'][val_idx]
    }

    print("Training final diagnosis model...")
    model, train_loader, val_loader, tokenizer = create_model_and_dataloaders(
        diagnosis_train_data, diagnosis_val_data, batch_size=16
    )

    final_model, final_history = train_multimodal_model(
        model, train_loader, val_loader, task='diagnosis', epochs=40, device=device
    )

    final_results = evaluate_model(final_model, val_loader, task='diagnosis', device=device)
    results['final_diagnosis'] = final_results

    plot_training_history(final_history, task='diagnosis')
    plot_confusion_matrix(final_results['confusion_matrix'],
                         diagnosis_le.classes_, task='diagnosis')

    print("\nFinal Diagnosis Model Results:")
    print(f"Accuracy: {final_results['accuracy']:.4f}")
    print(f"F1-Score: {final_results['f1_score']:.4f}")
    print(f"Precision: {final_results['precision']:.4f}")
    print(f"Recall: {final_results['recall']:.4f}")

    print("\nClassification Report:")
    print(final_results['classification_report'])

    torch.save(final_model.state_dict(), '/content/drive/MyDrive/Speech/final_diagnosis_model.pth')
    with open('/content/drive/MyDrive/Speech/evaluation_results.pkl', 'wb') as f:
        pickle.dump(results, f)

    return results

if __name__ == "__main__":
    results = comprehensive_evaluation()

Using device: cuda
Preparing diagnosis data...



Processing audio files:   0%|          | 0/166 [00:00<?, ?it/s][A
Processing audio files:   1%|          | 1/166 [00:00<02:04,  1.32it/s][A
Processing audio files:   1%|          | 2/166 [00:02<02:55,  1.07s/it][A
Processing audio files:   2%|▏         | 3/166 [00:03<03:11,  1.17s/it][A
Processing audio files:   2%|▏         | 4/166 [00:04<03:26,  1.28s/it][A
Processing audio files:   3%|▎         | 5/166 [00:05<02:42,  1.01s/it][A
Processing audio files:   4%|▎         | 6/166 [00:06<02:57,  1.11s/it][A
Processing audio files:   4%|▍         | 7/166 [00:08<03:42,  1.40s/it][A
Processing audio files:   5%|▍         | 8/166 [00:10<03:48,  1.44s/it][A
Processing audio files:   5%|▌         | 9/166 [00:11<03:26,  1.31s/it][A
Processing audio files:   6%|▌         | 10/166 [00:12<03:39,  1.41s/it][A
Processing audio files:   7%|▋         | 11/166 [00:13<03:05,  1.20s/it][A
Processing audio files:   7%|▋         | 12/166 [00:14<02:36,  1.01s/it][A
Processing audio files:   8%|

Preparing progression data...



Processing audio files:   0%|          | 0/73 [00:00<?, ?it/s][A
Processing audio files:   1%|▏         | 1/73 [00:03<03:55,  3.27s/it][A
Processing audio files:   3%|▎         | 2/73 [00:08<05:20,  4.52s/it][A
Processing audio files:   4%|▍         | 3/73 [00:11<04:28,  3.83s/it][A
Processing audio files:   5%|▌         | 4/73 [00:14<04:01,  3.50s/it][A
Processing audio files:   7%|▋         | 5/73 [00:18<04:04,  3.60s/it][A
Processing audio files:   8%|▊         | 6/73 [00:21<03:55,  3.52s/it][A
Processing audio files:  10%|▉         | 7/73 [00:24<03:42,  3.38s/it][A
Processing audio files:  11%|█         | 8/73 [00:27<03:33,  3.28s/it][A
Processing audio files:  12%|█▏        | 9/73 [00:31<03:41,  3.46s/it][A
Processing audio files:  14%|█▎        | 10/73 [00:36<03:56,  3.76s/it][A
Processing audio files:  15%|█▌        | 11/73 [00:40<03:53,  3.77s/it][A
Processing audio files:  16%|█▋        | 12/73 [00:43<03:38,  3.58s/it][A
Processing audio files:  18%|█▊        | 1


DIAGNOSIS TASK EVALUATION
Fold 1/5



Epoch 1/25 [Train]:   0%|          | 0/9 [00:00<?, ?it/s][A
Epoch 1/25 [Train]:   0%|          | 0/9 [00:01<?, ?it/s, Loss=0.7325, Acc=37.50%][A
Epoch 1/25 [Train]:  11%|█         | 1/9 [00:01<00:13,  1.65s/it, Loss=0.7325, Acc=37.50%][A
Epoch 1/25 [Train]:  11%|█         | 1/9 [00:03<00:13,  1.65s/it, Loss=0.7502, Acc=43.75%][A
Epoch 1/25 [Train]:  22%|██▏       | 2/9 [00:03<00:11,  1.60s/it, Loss=0.7502, Acc=43.75%][A
Epoch 1/25 [Train]:  22%|██▏       | 2/9 [00:04<00:11,  1.60s/it, Loss=0.7837, Acc=41.67%][A
Epoch 1/25 [Train]:  33%|███▎      | 3/9 [00:04<00:09,  1.57s/it, Loss=0.7837, Acc=41.67%][A
Epoch 1/25 [Train]:  33%|███▎      | 3/9 [00:06<00:09,  1.57s/it, Loss=0.7346, Acc=42.19%][A
Epoch 1/25 [Train]:  44%|████▍     | 4/9 [00:06<00:07,  1.55s/it, Loss=0.7346, Acc=42.19%][A
Epoch 1/25 [Train]:  44%|████▍     | 4/9 [00:07<00:07,  1.55s/it, Loss=0.7014, Acc=43.75%][A
Epoch 1/25 [Train]:  56%|█████▌    | 5/9 [00:07<00:06,  1.55s/it, Loss=0.7014, Acc=43.75%][A
Epoch 

Epoch [1/25]
Train Loss: 0.7219, Train Acc: 47.73%
Val Loss: 0.6781, Val Acc: 52.94%
LR: 0.000100
--------------------------------------------------




Epoch 2/25 [Train]:   0%|          | 0/9 [00:00<?, ?it/s][A[A

Epoch 2/25 [Train]:   0%|          | 0/9 [00:01<?, ?it/s, Loss=0.7943, Acc=31.25%][A[A

Epoch 2/25 [Train]:  11%|█         | 1/9 [00:01<00:12,  1.57s/it, Loss=0.7943, Acc=31.25%][A[A

Epoch 2/25 [Train]:  11%|█         | 1/9 [00:03<00:12,  1.57s/it, Loss=0.7284, Acc=37.50%][A[A

Epoch 2/25 [Train]:  22%|██▏       | 2/9 [00:03<00:10,  1.57s/it, Loss=0.7284, Acc=37.50%][A[A

Epoch 2/25 [Train]:  22%|██▏       | 2/9 [00:04<00:10,  1.57s/it, Loss=0.6447, Acc=52.08%][A[A

Epoch 2/25 [Train]:  33%|███▎      | 3/9 [00:04<00:09,  1.56s/it, Loss=0.6447, Acc=52.08%][A[A

Epoch 2/25 [Train]:  33%|███▎      | 3/9 [00:06<00:09,  1.56s/it, Loss=0.7114, Acc=50.00%][A[A

Epoch 2/25 [Train]:  44%|████▍     | 4/9 [00:06<00:07,  1.55s/it, Loss=0.7114, Acc=50.00%][A[A

Epoch 2/25 [Train]:  44%|████▍     | 4/9 [00:07<00:07,  1.55s/it, Loss=0.7023, Acc=48.75%][A[A

Epoch 2/25 [Train]:  56%|█████▌    | 5/9 [00:07<00:06,  1.5

Epoch [2/25]
Train Loss: 0.7049, Train Acc: 47.73%
Val Loss: 0.6427, Val Acc: 52.94%
LR: 0.000100
--------------------------------------------------



Epoch 3/25 [Train]:   0%|          | 0/9 [00:00<?, ?it/s][A
Epoch 3/25 [Train]:   0%|          | 0/9 [00:01<?, ?it/s, Loss=0.7153, Acc=43.75%][A
Epoch 3/25 [Train]:  11%|█         | 1/9 [00:01<00:12,  1.59s/it, Loss=0.7153, Acc=43.75%][A
Epoch 3/25 [Train]:  11%|█         | 1/9 [00:03<00:12,  1.59s/it, Loss=0.7193, Acc=46.88%][A
Epoch 3/25 [Train]:  22%|██▏       | 2/9 [00:03<00:10,  1.57s/it, Loss=0.7193, Acc=46.88%][A
Epoch 3/25 [Train]:  22%|██▏       | 2/9 [00:04<00:10,  1.57s/it, Loss=0.7102, Acc=47.92%][A
Epoch 3/25 [Train]:  33%|███▎      | 3/9 [00:04<00:09,  1.57s/it, Loss=0.7102, Acc=47.92%][A
Epoch 3/25 [Train]:  33%|███▎      | 3/9 [00:06<00:09,  1.57s/it, Loss=0.8217, Acc=45.31%][A
Epoch 3/25 [Train]:  44%|████▍     | 4/9 [00:06<00:07,  1.56s/it, Loss=0.8217, Acc=45.31%][A
Epoch 3/25 [Train]:  44%|████▍     | 4/9 [00:07<00:07,  1.56s/it, Loss=0.6834, Acc=48.75%][A
Epoch 3/25 [Train]:  56%|█████▌    | 5/9 [00:07<00:06,  1.56s/it, Loss=0.6834, Acc=48.75%][A
Epoch 

Epoch [3/25]
Train Loss: 0.7471, Train Acc: 43.18%
Val Loss: 0.7185, Val Acc: 47.06%
LR: 0.000100
--------------------------------------------------




Epoch 4/25 [Train]:   0%|          | 0/9 [00:00<?, ?it/s][A[A

Epoch 4/25 [Train]:   0%|          | 0/9 [00:01<?, ?it/s, Loss=0.7079, Acc=43.75%][A[A

Epoch 4/25 [Train]:  11%|█         | 1/9 [00:01<00:12,  1.56s/it, Loss=0.7079, Acc=43.75%][A[A

Epoch 4/25 [Train]:  11%|█         | 1/9 [00:03<00:12,  1.56s/it, Loss=0.7059, Acc=50.00%][A[A

Epoch 4/25 [Train]:  22%|██▏       | 2/9 [00:03<00:11,  1.57s/it, Loss=0.7059, Acc=50.00%][A[A

Epoch 4/25 [Train]:  22%|██▏       | 2/9 [00:04<00:11,  1.57s/it, Loss=0.7356, Acc=50.00%][A[A

Epoch 4/25 [Train]:  33%|███▎      | 3/9 [00:04<00:09,  1.58s/it, Loss=0.7356, Acc=50.00%][A[A

Epoch 4/25 [Train]:  33%|███▎      | 3/9 [00:06<00:09,  1.58s/it, Loss=0.7064, Acc=50.00%][A[A

Epoch 4/25 [Train]:  44%|████▍     | 4/9 [00:06<00:07,  1.57s/it, Loss=0.7064, Acc=50.00%][A[A

Epoch 4/25 [Train]:  44%|████▍     | 4/9 [00:07<00:07,  1.57s/it, Loss=0.7192, Acc=46.25%][A[A

Epoch 4/25 [Train]:  56%|█████▌    | 5/9 [00:07<00:06,  1.5

Epoch [4/25]
Train Loss: 0.6901, Train Acc: 56.06%
Val Loss: 0.6441, Val Acc: 52.94%
LR: 0.000100
--------------------------------------------------



Epoch 5/25 [Train]:   0%|          | 0/9 [00:00<?, ?it/s][A
Epoch 5/25 [Train]:   0%|          | 0/9 [00:01<?, ?it/s, Loss=0.6960, Acc=56.25%][A
Epoch 5/25 [Train]:  11%|█         | 1/9 [00:01<00:12,  1.56s/it, Loss=0.6960, Acc=56.25%][A
Epoch 5/25 [Train]:  11%|█         | 1/9 [00:03<00:12,  1.56s/it, Loss=0.7550, Acc=56.25%][A
Epoch 5/25 [Train]:  22%|██▏       | 2/9 [00:03<00:11,  1.57s/it, Loss=0.7550, Acc=56.25%][A
Epoch 5/25 [Train]:  22%|██▏       | 2/9 [00:04<00:11,  1.57s/it, Loss=0.7598, Acc=52.08%][A
Epoch 5/25 [Train]:  33%|███▎      | 3/9 [00:04<00:09,  1.58s/it, Loss=0.7598, Acc=52.08%][A
Epoch 5/25 [Train]:  33%|███▎      | 3/9 [00:06<00:09,  1.58s/it, Loss=0.6580, Acc=53.12%][A
Epoch 5/25 [Train]:  44%|████▍     | 4/9 [00:06<00:07,  1.58s/it, Loss=0.6580, Acc=53.12%][A
Epoch 5/25 [Train]:  44%|████▍     | 4/9 [00:07<00:07,  1.58s/it, Loss=0.9234, Acc=48.75%][A
Epoch 5/25 [Train]:  56%|█████▌    | 5/9 [00:07<00:06,  1.59s/it, Loss=0.9234, Acc=48.75%][A
Epoch 

Epoch [5/25]
Train Loss: 0.7424, Train Acc: 48.48%
Val Loss: 0.6489, Val Acc: 52.94%
LR: 0.000100
--------------------------------------------------




Epoch 6/25 [Train]:   0%|          | 0/9 [00:00<?, ?it/s][A[A

Epoch 6/25 [Train]:   0%|          | 0/9 [00:01<?, ?it/s, Loss=0.7255, Acc=62.50%][A[A

Epoch 6/25 [Train]:  11%|█         | 1/9 [00:01<00:12,  1.58s/it, Loss=0.7255, Acc=62.50%][A[A

Epoch 6/25 [Train]:  11%|█         | 1/9 [00:03<00:12,  1.58s/it, Loss=0.6347, Acc=59.38%][A[A

Epoch 6/25 [Train]:  22%|██▏       | 2/9 [00:03<00:11,  1.59s/it, Loss=0.6347, Acc=59.38%][A[A

Epoch 6/25 [Train]:  22%|██▏       | 2/9 [00:04<00:11,  1.59s/it, Loss=0.7340, Acc=56.25%][A[A

Epoch 6/25 [Train]:  33%|███▎      | 3/9 [00:04<00:09,  1.59s/it, Loss=0.7340, Acc=56.25%][A[A

Epoch 6/25 [Train]:  33%|███▎      | 3/9 [00:06<00:09,  1.59s/it, Loss=0.6995, Acc=56.25%][A[A

Epoch 6/25 [Train]:  44%|████▍     | 4/9 [00:06<00:07,  1.59s/it, Loss=0.6995, Acc=56.25%][A[A

Epoch 6/25 [Train]:  44%|████▍     | 4/9 [00:07<00:07,  1.59s/it, Loss=0.7276, Acc=53.75%][A[A

Epoch 6/25 [Train]:  56%|█████▌    | 5/9 [00:07<00:06,  1.6

Epoch [6/25]
Train Loss: 0.7096, Train Acc: 50.76%
Val Loss: 0.6689, Val Acc: 52.94%
LR: 0.000050
--------------------------------------------------



Epoch 7/25 [Train]:   0%|          | 0/9 [00:00<?, ?it/s][A
Epoch 7/25 [Train]:   0%|          | 0/9 [00:01<?, ?it/s, Loss=0.7493, Acc=43.75%][A
Epoch 7/25 [Train]:  11%|█         | 1/9 [00:01<00:12,  1.59s/it, Loss=0.7493, Acc=43.75%][A
Epoch 7/25 [Train]:  11%|█         | 1/9 [00:03<00:12,  1.59s/it, Loss=0.7416, Acc=43.75%][A
Epoch 7/25 [Train]:  22%|██▏       | 2/9 [00:03<00:11,  1.59s/it, Loss=0.7416, Acc=43.75%][A
Epoch 7/25 [Train]:  22%|██▏       | 2/9 [00:04<00:11,  1.59s/it, Loss=0.6162, Acc=52.08%][A
Epoch 7/25 [Train]:  33%|███▎      | 3/9 [00:04<00:09,  1.59s/it, Loss=0.6162, Acc=52.08%][A
Epoch 7/25 [Train]:  33%|███▎      | 3/9 [00:06<00:09,  1.59s/it, Loss=0.6758, Acc=54.69%][A
Epoch 7/25 [Train]:  44%|████▍     | 4/9 [00:06<00:07,  1.60s/it, Loss=0.6758, Acc=54.69%][A
Epoch 7/25 [Train]:  44%|████▍     | 4/9 [00:08<00:07,  1.60s/it, Loss=0.7398, Acc=51.25%][A
Epoch 7/25 [Train]:  56%|█████▌    | 5/9 [00:08<00:06,  1.61s/it, Loss=0.7398, Acc=51.25%][A
Epoch 

Epoch [7/25]
Train Loss: 0.7102, Train Acc: 49.24%
Val Loss: 0.6669, Val Acc: 52.94%
LR: 0.000050
--------------------------------------------------




Epoch 8/25 [Train]:   0%|          | 0/9 [00:00<?, ?it/s][A[A

Epoch 8/25 [Train]:   0%|          | 0/9 [00:01<?, ?it/s, Loss=0.6862, Acc=56.25%][A[A

Epoch 8/25 [Train]:  11%|█         | 1/9 [00:01<00:12,  1.61s/it, Loss=0.6862, Acc=56.25%][A[A

Epoch 8/25 [Train]:  11%|█         | 1/9 [00:03<00:12,  1.61s/it, Loss=0.6363, Acc=62.50%][A[A

Epoch 8/25 [Train]:  22%|██▏       | 2/9 [00:03<00:11,  1.61s/it, Loss=0.6363, Acc=62.50%][A[A

Epoch 8/25 [Train]:  22%|██▏       | 2/9 [00:04<00:11,  1.61s/it, Loss=0.6783, Acc=64.58%][A[A

Epoch 8/25 [Train]:  33%|███▎      | 3/9 [00:04<00:09,  1.60s/it, Loss=0.6783, Acc=64.58%][A[A

Epoch 8/25 [Train]:  33%|███▎      | 3/9 [00:06<00:09,  1.60s/it, Loss=0.6749, Acc=62.50%][A[A

Epoch 8/25 [Train]:  44%|████▍     | 4/9 [00:06<00:08,  1.63s/it, Loss=0.6749, Acc=62.50%][A[A

Epoch 8/25 [Train]:  44%|████▍     | 4/9 [00:08<00:08,  1.63s/it, Loss=0.6533, Acc=62.50%][A[A

Epoch 8/25 [Train]:  56%|█████▌    | 5/9 [00:08<00:06,  1.6

Epoch [8/25]
Train Loss: 0.6700, Train Acc: 58.33%
Val Loss: 0.6826, Val Acc: 50.00%
LR: 0.000050
--------------------------------------------------



Epoch 9/25 [Train]:   0%|          | 0/9 [00:00<?, ?it/s][A
Epoch 9/25 [Train]:   0%|          | 0/9 [00:01<?, ?it/s, Loss=0.6926, Acc=50.00%][A
Epoch 9/25 [Train]:  11%|█         | 1/9 [00:01<00:12,  1.59s/it, Loss=0.6926, Acc=50.00%][A
Epoch 9/25 [Train]:  11%|█         | 1/9 [00:03<00:12,  1.59s/it, Loss=0.7347, Acc=50.00%][A
Epoch 9/25 [Train]:  22%|██▏       | 2/9 [00:03<00:11,  1.62s/it, Loss=0.7347, Acc=50.00%][A
Epoch 9/25 [Train]:  22%|██▏       | 2/9 [00:04<00:11,  1.62s/it, Loss=0.6867, Acc=54.17%][A
Epoch 9/25 [Train]:  33%|███▎      | 3/9 [00:04<00:09,  1.64s/it, Loss=0.6867, Acc=54.17%][A
Epoch 9/25 [Train]:  33%|███▎      | 3/9 [00:06<00:09,  1.64s/it, Loss=0.6788, Acc=53.12%][A
Epoch 9/25 [Train]:  44%|████▍     | 4/9 [00:06<00:08,  1.63s/it, Loss=0.6788, Acc=53.12%][A
Epoch 9/25 [Train]:  44%|████▍     | 4/9 [00:08<00:08,  1.63s/it, Loss=0.7213, Acc=52.50%][A
Epoch 9/25 [Train]:  56%|█████▌    | 5/9 [00:08<00:06,  1.62s/it, Loss=0.7213, Acc=52.50%][A
Epoch 

Epoch [9/25]
Train Loss: 0.7273, Train Acc: 43.94%
Val Loss: 0.6778, Val Acc: 52.94%
LR: 0.000050
--------------------------------------------------




Epoch 10/25 [Train]:   0%|          | 0/9 [00:00<?, ?it/s][A[A

Epoch 10/25 [Train]:   0%|          | 0/9 [00:01<?, ?it/s, Loss=0.7665, Acc=31.25%][A[A

Epoch 10/25 [Train]:  11%|█         | 1/9 [00:01<00:12,  1.60s/it, Loss=0.7665, Acc=31.25%][A[A

Epoch 10/25 [Train]:  11%|█         | 1/9 [00:03<00:12,  1.60s/it, Loss=0.7355, Acc=43.75%][A[A

Epoch 10/25 [Train]:  22%|██▏       | 2/9 [00:03<00:11,  1.63s/it, Loss=0.7355, Acc=43.75%][A[A

Epoch 10/25 [Train]:  22%|██▏       | 2/9 [00:04<00:11,  1.63s/it, Loss=0.6934, Acc=43.75%][A[A

Epoch 10/25 [Train]:  33%|███▎      | 3/9 [00:04<00:09,  1.63s/it, Loss=0.6934, Acc=43.75%][A[A

Epoch 10/25 [Train]:  33%|███▎      | 3/9 [00:06<00:09,  1.63s/it, Loss=0.7332, Acc=45.31%][A[A

Epoch 10/25 [Train]:  44%|████▍     | 4/9 [00:06<00:08,  1.62s/it, Loss=0.7332, Acc=45.31%][A[A

Epoch 10/25 [Train]:  44%|████▍     | 4/9 [00:08<00:08,  1.62s/it, Loss=0.7404, Acc=45.00%][A[A

Epoch 10/25 [Train]:  56%|█████▌    | 5/9 [00:08<

Epoch [10/25]
Train Loss: 0.7192, Train Acc: 50.00%
Val Loss: 0.6953, Val Acc: 47.06%
LR: 0.000025
--------------------------------------------------



Epoch 11/25 [Train]:   0%|          | 0/9 [00:00<?, ?it/s][A
Epoch 11/25 [Train]:   0%|          | 0/9 [00:01<?, ?it/s, Loss=0.7362, Acc=43.75%][A
Epoch 11/25 [Train]:  11%|█         | 1/9 [00:01<00:13,  1.64s/it, Loss=0.7362, Acc=43.75%][A
Epoch 11/25 [Train]:  11%|█         | 1/9 [00:03<00:13,  1.64s/it, Loss=0.7431, Acc=43.75%][A
Epoch 11/25 [Train]:  22%|██▏       | 2/9 [00:03<00:11,  1.64s/it, Loss=0.7431, Acc=43.75%][A
Epoch 11/25 [Train]:  22%|██▏       | 2/9 [00:04<00:11,  1.64s/it, Loss=0.7250, Acc=47.92%][A
Epoch 11/25 [Train]:  33%|███▎      | 3/9 [00:04<00:09,  1.63s/it, Loss=0.7250, Acc=47.92%][A
Epoch 11/25 [Train]:  33%|███▎      | 3/9 [00:06<00:09,  1.63s/it, Loss=0.7032, Acc=48.44%][A
Epoch 11/25 [Train]:  44%|████▍     | 4/9 [00:06<00:08,  1.62s/it, Loss=0.7032, Acc=48.44%][A
Epoch 11/25 [Train]:  44%|████▍     | 4/9 [00:08<00:08,  1.62s/it, Loss=0.6873, Acc=51.25%][A
Epoch 11/25 [Train]:  56%|█████▌    | 5/9 [00:08<00:06,  1.62s/it, Loss=0.6873, Acc=51.25%

Epoch [11/25]
Train Loss: 0.7052, Train Acc: 53.03%
Val Loss: 0.6944, Val Acc: 50.00%
LR: 0.000025
--------------------------------------------------




Epoch 12/25 [Train]:   0%|          | 0/9 [00:00<?, ?it/s][A[A

Epoch 12/25 [Train]:   0%|          | 0/9 [00:01<?, ?it/s, Loss=0.6444, Acc=75.00%][A[A

Epoch 12/25 [Train]:  11%|█         | 1/9 [00:01<00:12,  1.62s/it, Loss=0.6444, Acc=75.00%][A[A

Epoch 12/25 [Train]:  11%|█         | 1/9 [00:03<00:12,  1.62s/it, Loss=0.6988, Acc=68.75%][A[A

Epoch 12/25 [Train]:  22%|██▏       | 2/9 [00:03<00:11,  1.62s/it, Loss=0.6988, Acc=68.75%][A[A

Epoch 12/25 [Train]:  22%|██▏       | 2/9 [00:04<00:11,  1.62s/it, Loss=0.6872, Acc=62.50%][A[A

Epoch 12/25 [Train]:  33%|███▎      | 3/9 [00:04<00:09,  1.62s/it, Loss=0.6872, Acc=62.50%][A[A

Epoch 12/25 [Train]:  33%|███▎      | 3/9 [00:06<00:09,  1.62s/it, Loss=0.7241, Acc=57.81%][A[A

Epoch 12/25 [Train]:  44%|████▍     | 4/9 [00:06<00:08,  1.62s/it, Loss=0.7241, Acc=57.81%][A[A

Epoch 12/25 [Train]:  44%|████▍     | 4/9 [00:08<00:08,  1.62s/it, Loss=0.7426, Acc=53.75%][A[A

Epoch 12/25 [Train]:  56%|█████▌    | 5/9 [00:08<

Epoch [12/25]
Train Loss: 0.7147, Train Acc: 53.03%
Val Loss: 0.6695, Val Acc: 52.94%
LR: 0.000025
--------------------------------------------------
Early stopping triggered at epoch 12



Evaluating:   0%|          | 0/3 [00:00<?, ?it/s][A
Evaluating:  33%|███▎      | 1/3 [00:00<00:01,  1.88it/s][A
Evaluating: 100%|██████████| 3/3 [00:01<00:00,  2.67it/s]


Fold 1 Results:
Accuracy: 0.5294
F1-Score: 0.3665
Precision: 0.2803
Recall: 0.5294
Fold 2/5



Epoch 1/25 [Train]:   0%|          | 0/9 [00:00<?, ?it/s][A
Epoch 1/25 [Train]:   0%|          | 0/9 [00:01<?, ?it/s, Loss=0.7401, Acc=37.50%][A
Epoch 1/25 [Train]:  11%|█         | 1/9 [00:01<00:13,  1.69s/it, Loss=0.7401, Acc=37.50%][A
Epoch 1/25 [Train]:  11%|█         | 1/9 [00:03<00:13,  1.69s/it, Loss=0.7902, Acc=25.00%][A
Epoch 1/25 [Train]:  22%|██▏       | 2/9 [00:03<00:11,  1.65s/it, Loss=0.7902, Acc=25.00%][A
Epoch 1/25 [Train]:  22%|██▏       | 2/9 [00:04<00:11,  1.65s/it, Loss=0.6838, Acc=31.25%][A
Epoch 1/25 [Train]:  33%|███▎      | 3/9 [00:04<00:09,  1.64s/it, Loss=0.6838, Acc=31.25%][A
Epoch 1/25 [Train]:  33%|███▎      | 3/9 [00:06<00:09,  1.64s/it, Loss=0.6853, Acc=35.94%][A
Epoch 1/25 [Train]:  44%|████▍     | 4/9 [00:06<00:08,  1.64s/it, Loss=0.6853, Acc=35.94%][A
Epoch 1/25 [Train]:  44%|████▍     | 4/9 [00:08<00:08,  1.64s/it, Loss=0.7017, Acc=42.50%][A
Epoch 1/25 [Train]:  56%|█████▌    | 5/9 [00:08<00:06,  1.65s/it, Loss=0.7017, Acc=42.50%][A
Epoch 

ValueError: Expected input batch_size (1) to match target batch_size (0).

In [17]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    Wav2Vec2Processor, Wav2Vec2Model,
    BertTokenizer, BertModel,
    AutoFeatureExtractor, AutoModel
)
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler
import librosa
import pickle
import json
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

class SimplifiedMultiModalDataset(Dataset):
    def __init__(self, df, tokenizer, wav2vec_processor, max_length=512):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.wav2vec_processor = wav2vec_processor
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Load and process audio
        try:
            audio_path = row['file_path']
            if os.path.exists(audio_path):
                audio, sr = librosa.load(audio_path, sr=16000)
                # Ensure audio is reasonable length
                if len(audio) > 16000 * 30:  # Max 30 seconds
                    audio = audio[:16000 * 30]
                elif len(audio) < 16000:  # Min 1 second
                    audio = np.pad(audio, (0, 16000 - len(audio)), 'constant')
            else:
                audio = np.zeros(16000)  # 1 second of silence
        except:
            audio = np.zeros(16000)

        # Process text
        text = str(row['transcript']) if pd.notna(row['transcript']) else ""
        text_encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Process audio with Wav2Vec2
        try:
            audio_input = self.wav2vec_processor(
                audio,
                sampling_rate=16000,
                return_tensors="pt",
                padding=True
            )
            audio_values = audio_input.input_values.squeeze()
        except:
            audio_values = torch.zeros(16000)

        return {
            'audio_values': audio_values,
            'input_ids': text_encoding['input_ids'].squeeze(),
            'attention_mask': text_encoding['attention_mask'].squeeze(),
            'label': torch.tensor(row['label_encoded'], dtype=torch.long)
        }

class PreTrainedMultiModalModel(nn.Module):
    def __init__(self, num_classes=2, dropout_rate=0.3):
        super().__init__()

        # Pre-trained models
        self.wav2vec2 = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        # Freeze some layers to prevent overfitting
        for param in self.wav2vec2.parameters():
            param.requires_grad = False
        for param in self.bert.embeddings.parameters():
            param.requires_grad = False

        # Only train the last few layers
        for layer in self.bert.encoder.layer[-2:]:
            for param in layer.parameters():
                param.requires_grad = True

        # Feature dimensions
        audio_dim = self.wav2vec2.config.hidden_size  # 768
        text_dim = self.bert.config.hidden_size       # 768

        # Projection layers
        self.audio_projection = nn.Sequential(
            nn.Linear(audio_dim, 512),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        self.text_projection = nn.Sequential(
            nn.Linear(text_dim, 512),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        # Cross-attention for fusion
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=512,
            num_heads=8,
            dropout=dropout_rate,
            batch_first=True
        )

        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, num_classes)
        )

    def forward(self, audio_values, input_ids, attention_mask):
        # Process audio
        with torch.no_grad():
            audio_outputs = self.wav2vec2(audio_values)
        audio_features = audio_outputs.last_hidden_state.mean(dim=1)  # Global average pooling
        audio_features = self.audio_projection(audio_features)

        # Process text
        text_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        text_features = text_outputs.pooler_output
        text_features = self.text_projection(text_features)

        # Cross-modal attention fusion
        audio_features = audio_features.unsqueeze(1)  # [B, 1, 512]
        text_features = text_features.unsqueeze(1)    # [B, 1, 512]

        # Use audio as query, text as key/value
        fused_features, _ = self.cross_attention(
            query=audio_features,
            key=text_features,
            value=text_features
        )

        # Combine audio and fused features
        combined = (audio_features + fused_features).squeeze(1)

        return self.classifier(combined)

def collate_fn(batch):
    """Custom collate function to handle variable length sequences"""
    audio_values = []
    input_ids = []
    attention_masks = []
    labels = []

    max_audio_len = max([item['audio_values'].shape[0] for item in batch])

    for item in batch:
        # Pad audio to max length in batch
        audio = item['audio_values']
        if audio.shape[0] < max_audio_len:
            audio = F.pad(audio, (0, max_audio_len - audio.shape[0]))
        audio_values.append(audio)

        input_ids.append(item['input_ids'])
        attention_masks.append(item['attention_mask'])
        labels.append(item['label'])

    return {
        'audio_values': torch.stack(audio_values),
        'input_ids': torch.stack(input_ids),
        'attention_mask': torch.stack(attention_masks),
        'label': torch.stack(labels).squeeze()
    }

def train_model(model, train_loader, val_loader, epochs=30, device='cuda'):
    model = model.to(device)

    # Use different learning rates for pre-trained and new components
    pretrained_params = list(model.wav2vec2.parameters()) + list(model.bert.parameters())
    new_params = list(model.audio_projection.parameters()) + \
                list(model.text_projection.parameters()) + \
                list(model.cross_attention.parameters()) + \
                list(model.classifier.parameters())

    optimizer = torch.optim.AdamW([
        {'params': pretrained_params, 'lr': 1e-5},  # Lower LR for pre-trained
        {'params': new_params, 'lr': 1e-4}         # Higher LR for new layers
    ], weight_decay=0.01)

    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=5, factor=0.5
    )

    best_val_acc = 0
    patience_counter = 0
    patience = 10

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
        for batch in pbar:
            optimizer.zero_grad()

            audio_values = batch['audio_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(audio_values, input_ids, attention_mask)
            loss = criterion(outputs, labels)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{100 * train_correct / train_total:.2f}%'
            })

        train_acc = 100 * train_correct / train_total
        avg_train_loss = train_loss / len(train_loader)

        # Validation
        if val_loader:
            model.eval()
            val_loss = 0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for batch in val_loader:
                    audio_values = batch['audio_values'].to(device)
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['label'].to(device)

                    outputs = model(audio_values, input_ids, attention_mask)
                    loss = criterion(outputs, labels)

                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()

            val_acc = 100 * val_correct / val_total
            avg_val_loss = val_loss / len(val_loader)
            scheduler.step(avg_val_loss)

            print(f'Epoch {epoch+1}: Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%')

            # Early stopping
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_counter = 0
                torch.save(model.state_dict(), 'best_model.pth')
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f'Early stopping at epoch {epoch+1}. Best validation accuracy: {best_val_acc:.2f}%')
                break
        else:
            print(f'Epoch {epoch+1}: Train Acc: {train_acc:.2f}%')

    # Load best model
    if val_loader and os.path.exists('best_model.pth'):
        model.load_state_dict(torch.load('best_model.pth'))

    return model

def evaluate_model(model, test_loader, device='cuda'):
    model.eval()
    model = model.to(device)

    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc='Evaluating'):
            audio_values = batch['audio_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(audio_values, input_ids, attention_mask)
            probs = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)

    return {
        'accuracy': accuracy,
        'predictions': all_preds,
        'true_labels': all_labels,
        'probabilities': all_probs,
        'classification_report': classification_report(all_labels, all_preds),
        'confusion_matrix': confusion_matrix(all_labels, all_preds)
    }

def run_diagnosis_task():
    """Main function to run the diagnosis task with high accuracy"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    # Load your processed data
    dataset_path = '/content/drive/MyDrive/Speech/processed_datasets'
    diagnosis_train = pd.read_pickle(os.path.join(dataset_path, 'diagnosis_train.pkl'))

    print(f"Loaded {len(diagnosis_train)} samples")
    print(f"Label distribution:\n{diagnosis_train['label'].value_counts()}")

    # Initialize processors
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    wav2vec_processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base-960h')

    # Stratified train-val split
    train_df, val_df = train_test_split(
        diagnosis_train,
        test_size=0.2,
        stratify=diagnosis_train['label_encoded'],
        random_state=42
    )

    print(f"Train samples: {len(train_df)}, Val samples: {len(val_df)}")

    # Create datasets
    train_dataset = SimplifiedMultiModalDataset(train_df, tokenizer, wav2vec_processor)
    val_dataset = SimplifiedMultiModalDataset(val_df, tokenizer, wav2vec_processor)

    # Create data loaders with smaller batch size
    train_loader = DataLoader(
        train_dataset,
        batch_size=4,  # Smaller batch size for stability
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=0  # Avoid multiprocessing issues
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=4,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=0
    )

    # Create model
    model = PreTrainedMultiModalModel(num_classes=2)

    # Train model
    print("Starting training...")
    trained_model = train_model(model, train_loader, val_loader, epochs=50, device=device)

    # Evaluate
    results = evaluate_model(trained_model, val_loader, device=device)

    print(f"\nFinal Results:")
    print(f"Accuracy: {results['accuracy']:.4f} ({results['accuracy']*100:.2f}%)")
    print("\nClassification Report:")
    print(results['classification_report'])

    return trained_model, results

def run_cross_validation():
    """Run 5-fold cross validation for robust evaluation"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load data
    dataset_path = '/content/drive/MyDrive/Speech/processed_datasets'
    diagnosis_train = pd.read_pickle(os.path.join(dataset_path, 'diagnosis_train.pkl'))

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    wav2vec_processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base-960h')

    # 5-fold cross validation
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    fold_results = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(diagnosis_train, diagnosis_train['label_encoded'])):
        print(f"\nFold {fold + 1}/5")
        print("=" * 50)

        train_df = diagnosis_train.iloc[train_idx].reset_index(drop=True)
        val_df = diagnosis_train.iloc[val_idx].reset_index(drop=True)

        # Create datasets
        train_dataset = SimplifiedMultiModalDataset(train_df, tokenizer, wav2vec_processor)
        val_dataset = SimplifiedMultiModalDataset(val_df, tokenizer, wav2vec_processor)

        train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
        val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

        # Create and train model
        model = PreTrainedMultiModalModel(num_classes=2)
        trained_model = train_model(model, train_loader, val_loader, epochs=30, device=device)

        # Evaluate
        results = evaluate_model(trained_model, val_loader, device=device)
        fold_results.append(results['accuracy'])

        print(f"Fold {fold + 1} Accuracy: {results['accuracy']:.4f} ({results['accuracy']*100:.2f}%)")

    # Summary
    mean_acc = np.mean(fold_results)
    std_acc = np.std(fold_results)

    print(f"\nCross-Validation Results:")
    print(f"Mean Accuracy: {mean_acc:.4f} ± {std_acc:.4f} ({mean_acc*100:.2f}% ± {std_acc*100:.2f}%)")
    print(f"Individual Fold Accuracies: {[f'{acc:.4f}' for acc in fold_results]}")

    return fold_results

# Alternative: Audio-only model for comparison
class AudioOnlyModel(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.wav2vec2 = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')

        # Freeze most parameters
        for param in self.wav2vec2.parameters():
            param.requires_grad = False

        # Only train the last layer
        for param in self.wav2vec2.encoder.layers[-1].parameters():
            param.requires_grad = True

        self.classifier = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, audio_values):
        outputs = self.wav2vec2(audio_values)
        # Global average pooling
        features = outputs.last_hidden_state.mean(dim=1)
        return self.classifier(features)

# Text-only model for comparison
class TextOnlyModel(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        # Freeze embeddings
        for param in self.bert.embeddings.parameters():
            param.requires_grad = False

        self.classifier = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return self.classifier(outputs.pooler_output)

def compare_modalities():
    """Compare performance of different modality combinations"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load data
    dataset_path = '/content/drive/MyDrive/Speech/processed_datasets'
    diagnosis_train = pd.read_pickle(os.path.join(dataset_path, 'diagnosis_train.pkl'))

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    wav2vec_processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base-960h')

    train_df, val_df = train_test_split(
        diagnosis_train, test_size=0.2, stratify=diagnosis_train['label_encoded'], random_state=42
    )

    results = {}

    # Test multimodal model
    print("Testing Multi-Modal Model...")
    train_dataset = SimplifiedMultiModalDataset(train_df, tokenizer, wav2vec_processor)
    val_dataset = SimplifiedMultiModalDataset(val_df, tokenizer, wav2vec_processor)

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

    multimodal_model = PreTrainedMultiModalModel(num_classes=2)
    trained_multimodal = train_model(multimodal_model, train_loader, val_loader, epochs=25, device=device)
    multimodal_results = evaluate_model(trained_multimodal, val_loader, device=device)
    results['multimodal'] = multimodal_results['accuracy']

    print(f"Multi-Modal Accuracy: {multimodal_results['accuracy']:.4f} ({multimodal_results['accuracy']*100:.2f}%)")

    return results

if __name__ == "__main__":
    # Run the main training
    print("Running main diagnosis task...")
    model, results = run_diagnosis_task()

    # Optionally run cross-validation for robust evaluation
    print("\nRunning cross-validation...")
    cv_results = run_cross_validation()

    print(f"\nTarget: >90% accuracy")
    print(f"Achieved: {results['accuracy']*100:.2f}%")

Running main diagnosis task...
Using device: cuda
Loaded 166 samples
Label distribution:
label
ad    87
cn    79
Name: count, dtype: int64
Train samples: 132, Val samples: 34


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Starting training...




Epoch 1/50:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 1/50:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.6894, Acc=50.00%][A[A

Epoch 1/50:   3%|▎         | 1/33 [00:01<00:47,  1.50s/it, Loss=0.6894, Acc=50.00%][A[A

Epoch 1/50:   3%|▎         | 1/33 [00:02<00:47,  1.50s/it, Loss=0.7412, Acc=25.00%][A[A

Epoch 1/50:   6%|▌         | 2/33 [00:02<00:42,  1.37s/it, Loss=0.7412, Acc=25.00%][A[A

Epoch 1/50:   6%|▌         | 2/33 [00:04<00:42,  1.37s/it, Loss=0.7115, Acc=25.00%][A[A

Epoch 1/50:   9%|▉         | 3/33 [00:04<00:46,  1.55s/it, Loss=0.7115, Acc=25.00%][A[A

Epoch 1/50:   9%|▉         | 3/33 [00:05<00:46,  1.55s/it, Loss=0.6786, Acc=37.50%][A[A

Epoch 1/50:  12%|█▏        | 4/33 [00:05<00:42,  1.48s/it, Loss=0.6786, Acc=37.50%][A[A

Epoch 1/50:  12%|█▏        | 4/33 [00:07<00:42,  1.48s/it, Loss=0.6714, Acc=45.00%][A[A

Epoch 1/50:  15%|█▌        | 5/33 [00:07<00:44,  1.60s/it, Loss=0.6714, Acc=45.00%][A[A

Epoch 1/50:  15%|█▌        | 5/33 [00:

Epoch 1: Train Acc: 46.21%, Val Acc: 64.71%




Epoch 2/50:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 2/50:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.6615, Acc=100.00%][A[A

Epoch 2/50:   3%|▎         | 1/33 [00:01<00:50,  1.58s/it, Loss=0.6615, Acc=100.00%][A[A

Epoch 2/50:   3%|▎         | 1/33 [00:03<00:50,  1.58s/it, Loss=0.6826, Acc=87.50%] [A[A

Epoch 2/50:   6%|▌         | 2/33 [00:03<00:47,  1.52s/it, Loss=0.6826, Acc=87.50%][A[A

Epoch 2/50:   6%|▌         | 2/33 [00:04<00:47,  1.52s/it, Loss=0.6487, Acc=83.33%][A[A

Epoch 2/50:   9%|▉         | 3/33 [00:04<00:47,  1.58s/it, Loss=0.6487, Acc=83.33%][A[A

Epoch 2/50:   9%|▉         | 3/33 [00:06<00:47,  1.58s/it, Loss=0.6678, Acc=87.50%][A[A

Epoch 2/50:  12%|█▏        | 4/33 [00:06<00:46,  1.61s/it, Loss=0.6678, Acc=87.50%][A[A

Epoch 2/50:  12%|█▏        | 4/33 [00:08<00:46,  1.61s/it, Loss=0.6675, Acc=80.00%][A[A

Epoch 2/50:  15%|█▌        | 5/33 [00:08<00:47,  1.69s/it, Loss=0.6675, Acc=80.00%][A[A

Epoch 2/50:  15%|█▌        | 5/33 [

Epoch 2: Train Acc: 78.03%, Val Acc: 61.76%




Epoch 3/50:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 3/50:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.4168, Acc=100.00%][A[A

Epoch 3/50:   3%|▎         | 1/33 [00:01<00:53,  1.67s/it, Loss=0.4168, Acc=100.00%][A[A

Epoch 3/50:   3%|▎         | 1/33 [00:03<00:53,  1.67s/it, Loss=0.3436, Acc=100.00%][A[A

Epoch 3/50:   6%|▌         | 2/33 [00:03<00:46,  1.49s/it, Loss=0.3436, Acc=100.00%][A[A

Epoch 3/50:   6%|▌         | 2/33 [00:04<00:46,  1.49s/it, Loss=0.3505, Acc=100.00%][A[A

Epoch 3/50:   9%|▉         | 3/33 [00:04<00:42,  1.42s/it, Loss=0.3505, Acc=100.00%][A[A

Epoch 3/50:   9%|▉         | 3/33 [00:05<00:42,  1.42s/it, Loss=0.3225, Acc=100.00%][A[A

Epoch 3/50:  12%|█▏        | 4/33 [00:05<00:40,  1.39s/it, Loss=0.3225, Acc=100.00%][A[A

Epoch 3/50:  12%|█▏        | 4/33 [00:07<00:40,  1.39s/it, Loss=0.3830, Acc=100.00%][A[A

Epoch 3/50:  15%|█▌        | 5/33 [00:07<00:42,  1.53s/it, Loss=0.3830, Acc=100.00%][A[A

Epoch 3/50:  15%|█▌        |

Epoch 3: Train Acc: 90.91%, Val Acc: 70.59%




Epoch 4/50:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 4/50:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.9650, Acc=50.00%][A[A

Epoch 4/50:   3%|▎         | 1/33 [00:01<00:48,  1.50s/it, Loss=0.9650, Acc=50.00%][A[A

Epoch 4/50:   3%|▎         | 1/33 [00:02<00:48,  1.50s/it, Loss=0.0743, Acc=75.00%][A[A

Epoch 4/50:   6%|▌         | 2/33 [00:02<00:45,  1.48s/it, Loss=0.0743, Acc=75.00%][A[A

Epoch 4/50:   6%|▌         | 2/33 [00:04<00:45,  1.48s/it, Loss=0.1074, Acc=83.33%][A[A

Epoch 4/50:   9%|▉         | 3/33 [00:04<00:41,  1.37s/it, Loss=0.1074, Acc=83.33%][A[A

Epoch 4/50:   9%|▉         | 3/33 [00:05<00:41,  1.37s/it, Loss=0.0495, Acc=87.50%][A[A

Epoch 4/50:  12%|█▏        | 4/33 [00:05<00:39,  1.37s/it, Loss=0.0495, Acc=87.50%][A[A

Epoch 4/50:  12%|█▏        | 4/33 [00:07<00:39,  1.37s/it, Loss=0.0285, Acc=90.00%][A[A

Epoch 4/50:  15%|█▌        | 5/33 [00:07<00:41,  1.47s/it, Loss=0.0285, Acc=90.00%][A[A

Epoch 4/50:  15%|█▌        | 5/33 [00:

Epoch 4: Train Acc: 93.18%, Val Acc: 100.00%




Epoch 5/50:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 5/50:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0028, Acc=100.00%][A[A

Epoch 5/50:   3%|▎         | 1/33 [00:01<00:56,  1.76s/it, Loss=0.0028, Acc=100.00%][A[A

Epoch 5/50:   3%|▎         | 1/33 [00:03<00:56,  1.76s/it, Loss=0.4585, Acc=87.50%] [A[A

Epoch 5/50:   6%|▌         | 2/33 [00:03<00:47,  1.53s/it, Loss=0.4585, Acc=87.50%][A[A

Epoch 5/50:   6%|▌         | 2/33 [00:04<00:47,  1.53s/it, Loss=0.0025, Acc=91.67%][A[A

Epoch 5/50:   9%|▉         | 3/33 [00:04<00:43,  1.44s/it, Loss=0.0025, Acc=91.67%][A[A

Epoch 5/50:   9%|▉         | 3/33 [00:05<00:43,  1.44s/it, Loss=0.0036, Acc=93.75%][A[A

Epoch 5/50:  12%|█▏        | 4/33 [00:05<00:38,  1.32s/it, Loss=0.0036, Acc=93.75%][A[A

Epoch 5/50:  12%|█▏        | 4/33 [00:06<00:38,  1.32s/it, Loss=0.0004, Acc=95.00%][A[A

Epoch 5/50:  15%|█▌        | 5/33 [00:06<00:35,  1.25s/it, Loss=0.0004, Acc=95.00%][A[A

Epoch 5/50:  15%|█▌        | 5/33 [

Epoch 5: Train Acc: 99.24%, Val Acc: 100.00%




Epoch 6/50:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 6/50:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0016, Acc=100.00%][A[A

Epoch 6/50:   3%|▎         | 1/33 [00:01<00:42,  1.33s/it, Loss=0.0016, Acc=100.00%][A[A

Epoch 6/50:   3%|▎         | 1/33 [00:02<00:42,  1.33s/it, Loss=0.0008, Acc=100.00%][A[A

Epoch 6/50:   6%|▌         | 2/33 [00:02<00:39,  1.29s/it, Loss=0.0008, Acc=100.00%][A[A

Epoch 6/50:   6%|▌         | 2/33 [00:04<00:39,  1.29s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 6/50:   9%|▉         | 3/33 [00:04<00:44,  1.47s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 6/50:   9%|▉         | 3/33 [00:05<00:44,  1.47s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 6/50:  12%|█▏        | 4/33 [00:05<00:40,  1.39s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 6/50:  12%|█▏        | 4/33 [00:06<00:40,  1.39s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 6/50:  15%|█▌        | 5/33 [00:06<00:38,  1.38s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 6/50:  15%|█▌        |

Epoch 6: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 7/50:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 7/50:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0008, Acc=100.00%][A[A

Epoch 7/50:   3%|▎         | 1/33 [00:01<00:41,  1.31s/it, Loss=0.0008, Acc=100.00%][A[A

Epoch 7/50:   3%|▎         | 1/33 [00:02<00:41,  1.31s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 7/50:   6%|▌         | 2/33 [00:02<00:37,  1.21s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 7/50:   6%|▌         | 2/33 [00:03<00:37,  1.21s/it, Loss=0.0012, Acc=100.00%][A[A

Epoch 7/50:   9%|▉         | 3/33 [00:03<00:37,  1.25s/it, Loss=0.0012, Acc=100.00%][A[A

Epoch 7/50:   9%|▉         | 3/33 [00:05<00:37,  1.25s/it, Loss=0.0007, Acc=100.00%][A[A

Epoch 7/50:  12%|█▏        | 4/33 [00:05<00:39,  1.36s/it, Loss=0.0007, Acc=100.00%][A[A

Epoch 7/50:  12%|█▏        | 4/33 [00:06<00:39,  1.36s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 7/50:  15%|█▌        | 5/33 [00:06<00:38,  1.39s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 7/50:  15%|█▌        |

Epoch 7: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 8/50:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 8/50:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0001, Acc=100.00%][A[A

Epoch 8/50:   3%|▎         | 1/33 [00:01<00:44,  1.38s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 8/50:   3%|▎         | 1/33 [00:02<00:44,  1.38s/it, Loss=0.0007, Acc=100.00%][A[A

Epoch 8/50:   6%|▌         | 2/33 [00:02<00:46,  1.49s/it, Loss=0.0007, Acc=100.00%][A[A

Epoch 8/50:   6%|▌         | 2/33 [00:04<00:46,  1.49s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 8/50:   9%|▉         | 3/33 [00:04<00:47,  1.58s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 8/50:   9%|▉         | 3/33 [00:06<00:47,  1.58s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 8/50:  12%|█▏        | 4/33 [00:06<00:46,  1.60s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 8/50:  12%|█▏        | 4/33 [00:07<00:46,  1.60s/it, Loss=0.0022, Acc=100.00%][A[A

Epoch 8/50:  15%|█▌        | 5/33 [00:07<00:41,  1.48s/it, Loss=0.0022, Acc=100.00%][A[A

Epoch 8/50:  15%|█▌        |

Epoch 8: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 9/50:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 9/50:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0001, Acc=100.00%][A[A

Epoch 9/50:   3%|▎         | 1/33 [00:01<00:42,  1.33s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 9/50:   3%|▎         | 1/33 [00:02<00:42,  1.33s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 9/50:   6%|▌         | 2/33 [00:02<00:40,  1.31s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 9/50:   6%|▌         | 2/33 [00:03<00:40,  1.31s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 9/50:   9%|▉         | 3/33 [00:03<00:40,  1.34s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 9/50:   9%|▉         | 3/33 [00:05<00:40,  1.34s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 9/50:  12%|█▏        | 4/33 [00:05<00:41,  1.45s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 9/50:  12%|█▏        | 4/33 [00:06<00:41,  1.45s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 9/50:  15%|█▌        | 5/33 [00:06<00:39,  1.42s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 9/50:  15%|█▌        |

Epoch 9: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 10/50:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 10/50:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/50:   3%|▎         | 1/33 [00:01<00:38,  1.20s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/50:   3%|▎         | 1/33 [00:02<00:38,  1.20s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/50:   6%|▌         | 2/33 [00:02<00:41,  1.34s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/50:   6%|▌         | 2/33 [00:03<00:41,  1.34s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/50:   9%|▉         | 3/33 [00:03<00:39,  1.31s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/50:   9%|▉         | 3/33 [00:05<00:39,  1.31s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/50:  12%|█▏        | 4/33 [00:05<00:36,  1.25s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/50:  12%|█▏        | 4/33 [00:06<00:36,  1.25s/it, Loss=0.0002, Acc=100.00%][A[A

Epoch 10/50:  15%|█▌        | 5/33 [00:06<00:39,  1.41s/it, Loss=0.0002, Acc=100.00%][A[A

Epoch 10/50:  15%

Epoch 10: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 11/50:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 11/50:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/50:   3%|▎         | 1/33 [00:01<00:40,  1.26s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/50:   3%|▎         | 1/33 [00:02<00:40,  1.26s/it, Loss=0.0002, Acc=100.00%][A[A

Epoch 11/50:   6%|▌         | 2/33 [00:02<00:37,  1.22s/it, Loss=0.0002, Acc=100.00%][A[A

Epoch 11/50:   6%|▌         | 2/33 [00:03<00:37,  1.22s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/50:   9%|▉         | 3/33 [00:03<00:39,  1.32s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/50:   9%|▉         | 3/33 [00:05<00:39,  1.32s/it, Loss=0.0026, Acc=100.00%][A[A

Epoch 11/50:  12%|█▏        | 4/33 [00:05<00:40,  1.40s/it, Loss=0.0026, Acc=100.00%][A[A

Epoch 11/50:  12%|█▏        | 4/33 [00:06<00:40,  1.40s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/50:  15%|█▌        | 5/33 [00:06<00:37,  1.35s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/50:  15%

Epoch 11: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 12/50:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 12/50:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/50:   3%|▎         | 1/33 [00:01<00:47,  1.49s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/50:   3%|▎         | 1/33 [00:03<00:47,  1.49s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/50:   6%|▌         | 2/33 [00:03<00:47,  1.53s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/50:   6%|▌         | 2/33 [00:04<00:47,  1.53s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/50:   9%|▉         | 3/33 [00:04<00:43,  1.47s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/50:   9%|▉         | 3/33 [00:05<00:43,  1.47s/it, Loss=0.0006, Acc=100.00%][A[A

Epoch 12/50:  12%|█▏        | 4/33 [00:05<00:42,  1.45s/it, Loss=0.0006, Acc=100.00%][A[A

Epoch 12/50:  12%|█▏        | 4/33 [00:07<00:42,  1.45s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/50:  15%|█▌        | 5/33 [00:07<00:40,  1.43s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/50:  15%

Epoch 12: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 13/50:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 13/50:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/50:   3%|▎         | 1/33 [00:01<00:46,  1.46s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/50:   3%|▎         | 1/33 [00:02<00:46,  1.46s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/50:   6%|▌         | 2/33 [00:02<00:44,  1.43s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/50:   6%|▌         | 2/33 [00:04<00:44,  1.43s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/50:   9%|▉         | 3/33 [00:04<00:46,  1.54s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/50:   9%|▉         | 3/33 [00:05<00:46,  1.54s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/50:  12%|█▏        | 4/33 [00:05<00:41,  1.42s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/50:  12%|█▏        | 4/33 [00:07<00:41,  1.42s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/50:  15%|█▌        | 5/33 [00:07<00:38,  1.37s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/50:  15%

Epoch 13: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 14/50:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 14/50:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0000, Acc=100.00%][A[A

Epoch 14/50:   3%|▎         | 1/33 [00:01<00:53,  1.66s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 14/50:   3%|▎         | 1/33 [00:02<00:53,  1.66s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 14/50:   6%|▌         | 2/33 [00:02<00:43,  1.40s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 14/50:   6%|▌         | 2/33 [00:04<00:43,  1.40s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 14/50:   9%|▉         | 3/33 [00:04<00:39,  1.33s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 14/50:   9%|▉         | 3/33 [00:05<00:39,  1.33s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 14/50:  12%|█▏        | 4/33 [00:05<00:41,  1.43s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 14/50:  12%|█▏        | 4/33 [00:07<00:41,  1.43s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 14/50:  15%|█▌        | 5/33 [00:07<00:40,  1.45s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 14/50:  15%

Epoch 14: Train Acc: 100.00%, Val Acc: 100.00%
Early stopping at epoch 14. Best validation accuracy: 100.00%




Evaluating:   0%|          | 0/9 [00:00<?, ?it/s][A[A

Evaluating:  11%|█         | 1/9 [00:01<00:09,  1.25s/it][A[A

Evaluating:  22%|██▏       | 2/9 [00:02<00:07,  1.06s/it][A[A

Evaluating:  33%|███▎      | 3/9 [00:03<00:06,  1.12s/it][A[A

Evaluating:  44%|████▍     | 4/9 [00:04<00:05,  1.11s/it][A[A

Evaluating:  56%|█████▌    | 5/9 [00:05<00:04,  1.12s/it][A[A

Evaluating:  67%|██████▋   | 6/9 [00:06<00:03,  1.11s/it][A[A

Evaluating:  78%|███████▊  | 7/9 [00:07<00:02,  1.07s/it][A[A

Evaluating:  89%|████████▉ | 8/9 [00:08<00:01,  1.11s/it][A[A

Evaluating: 100%|██████████| 9/9 [00:09<00:00,  1.05s/it]



Final Results:
Accuracy: 1.0000 (100.00%)

Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        18
           1       1.00      1.00      1.00        16

    accuracy                           1.00        34
   macro avg       1.00      1.00      1.00        34
weighted avg       1.00      1.00      1.00        34


Running cross-validation...

Fold 1/5


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/30:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 1/30:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.6968, Acc=50.00%][A[A

Epoch 1/30:   3%|▎         | 1/33 [00:01<00:44,  1.39s/it, Loss=0.6968, Acc=50.00%][A[A

Epoch 1/30:   3%|▎         | 1/33 [00:02<00:44,  1.39s/it, Loss=0.7035, Acc=25.00%][A[A

Epoch 1/30:   6%|▌         | 2/33 [00:02<00:40,  1.30s/it, Loss=0.7035, Acc=25.00%][A[A

Epoch 1/30:   6%|▌         | 2/33 [00:03<00:40,  1.30s/it, Loss=0.6908, Acc=33.33%][A[A

Epoch 1/30:   9%|▉         | 3/33 [00:03<00:39,  1.31s/it, Loss=0.6908, Acc=33.33%][A[A

Epoch 1/30:   9%|▉         | 3/33 [00:05<00:39,  1.31s/it, Loss=0.6832, Acc=43.75%][A[A

Epoch 1/30:  12%|█▏        | 4/33 [00:05<00

Epoch 1: Train Acc: 52.27%, Val Acc: 64.71%




Epoch 2/30:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 2/30:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.6523, Acc=75.00%][A[A

Epoch 2/30:   3%|▎         | 1/33 [00:01<00:47,  1.47s/it, Loss=0.6523, Acc=75.00%][A[A

Epoch 2/30:   3%|▎         | 1/33 [00:02<00:47,  1.47s/it, Loss=0.6129, Acc=87.50%][A[A

Epoch 2/30:   6%|▌         | 2/33 [00:02<00:43,  1.41s/it, Loss=0.6129, Acc=87.50%][A[A

Epoch 2/30:   6%|▌         | 2/33 [00:04<00:43,  1.41s/it, Loss=0.7152, Acc=66.67%][A[A

Epoch 2/30:   9%|▉         | 3/33 [00:04<00:41,  1.39s/it, Loss=0.7152, Acc=66.67%][A[A

Epoch 2/30:   9%|▉         | 3/33 [00:05<00:41,  1.39s/it, Loss=0.6407, Acc=68.75%][A[A

Epoch 2/30:  12%|█▏        | 4/33 [00:05<00:39,  1.38s/it, Loss=0.6407, Acc=68.75%][A[A

Epoch 2/30:  12%|█▏        | 4/33 [00:06<00:39,  1.38s/it, Loss=0.6997, Acc=65.00%][A[A

Epoch 2/30:  15%|█▌        | 5/33 [00:06<00:38,  1.39s/it, Loss=0.6997, Acc=65.00%][A[A

Epoch 2/30:  15%|█▌        | 5/33 [00:

Epoch 2: Train Acc: 69.70%, Val Acc: 79.41%




Epoch 3/30:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 3/30:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.6364, Acc=50.00%][A[A

Epoch 3/30:   3%|▎         | 1/33 [00:01<00:51,  1.60s/it, Loss=0.6364, Acc=50.00%][A[A

Epoch 3/30:   3%|▎         | 1/33 [00:03<00:51,  1.60s/it, Loss=0.5953, Acc=62.50%][A[A

Epoch 3/30:   6%|▌         | 2/33 [00:03<00:47,  1.54s/it, Loss=0.5953, Acc=62.50%][A[A

Epoch 3/30:   6%|▌         | 2/33 [00:04<00:47,  1.54s/it, Loss=0.4220, Acc=75.00%][A[A

Epoch 3/30:   9%|▉         | 3/33 [00:04<00:45,  1.52s/it, Loss=0.4220, Acc=75.00%][A[A

Epoch 3/30:   9%|▉         | 3/33 [00:05<00:45,  1.52s/it, Loss=0.5436, Acc=75.00%][A[A

Epoch 3/30:  12%|█▏        | 4/33 [00:05<00:40,  1.40s/it, Loss=0.5436, Acc=75.00%][A[A

Epoch 3/30:  12%|█▏        | 4/33 [00:07<00:40,  1.40s/it, Loss=0.5172, Acc=75.00%][A[A

Epoch 3/30:  15%|█▌        | 5/33 [00:07<00:37,  1.35s/it, Loss=0.5172, Acc=75.00%][A[A

Epoch 3/30:  15%|█▌        | 5/33 [00:

Epoch 3: Train Acc: 85.61%, Val Acc: 100.00%




Epoch 4/30:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 4/30:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0264, Acc=100.00%][A[A

Epoch 4/30:   3%|▎         | 1/33 [00:01<00:45,  1.41s/it, Loss=0.0264, Acc=100.00%][A[A

Epoch 4/30:   3%|▎         | 1/33 [00:02<00:45,  1.41s/it, Loss=0.0036, Acc=100.00%][A[A

Epoch 4/30:   6%|▌         | 2/33 [00:02<00:42,  1.36s/it, Loss=0.0036, Acc=100.00%][A[A

Epoch 4/30:   6%|▌         | 2/33 [00:04<00:42,  1.36s/it, Loss=0.0772, Acc=100.00%][A[A

Epoch 4/30:   9%|▉         | 3/33 [00:04<00:41,  1.38s/it, Loss=0.0772, Acc=100.00%][A[A

Epoch 4/30:   9%|▉         | 3/33 [00:05<00:41,  1.38s/it, Loss=0.0336, Acc=100.00%][A[A

Epoch 4/30:  12%|█▏        | 4/33 [00:05<00:41,  1.44s/it, Loss=0.0336, Acc=100.00%][A[A

Epoch 4/30:  12%|█▏        | 4/33 [00:07<00:41,  1.44s/it, Loss=0.0059, Acc=100.00%][A[A

Epoch 4/30:  15%|█▌        | 5/33 [00:07<00:40,  1.46s/it, Loss=0.0059, Acc=100.00%][A[A

Epoch 4/30:  15%|█▌        |

Epoch 4: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 5/30:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 5/30:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0032, Acc=100.00%][A[A

Epoch 5/30:   3%|▎         | 1/33 [00:01<00:50,  1.58s/it, Loss=0.0032, Acc=100.00%][A[A

Epoch 5/30:   3%|▎         | 1/33 [00:03<00:50,  1.58s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 5/30:   6%|▌         | 2/33 [00:03<00:48,  1.55s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 5/30:   6%|▌         | 2/33 [00:04<00:48,  1.55s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 5/30:   9%|▉         | 3/33 [00:04<00:42,  1.42s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 5/30:   9%|▉         | 3/33 [00:05<00:42,  1.42s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 5/30:  12%|█▏        | 4/33 [00:05<00:39,  1.36s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 5/30:  12%|█▏        | 4/33 [00:06<00:39,  1.36s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 5/30:  15%|█▌        | 5/33 [00:06<00:37,  1.35s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 5/30:  15%|█▌        |

Epoch 5: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 6/30:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 6/30:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0003, Acc=100.00%][A[A

Epoch 6/30:   3%|▎         | 1/33 [00:01<00:39,  1.24s/it, Loss=0.0003, Acc=100.00%][A[A

Epoch 6/30:   3%|▎         | 1/33 [00:02<00:39,  1.24s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 6/30:   6%|▌         | 2/33 [00:02<00:40,  1.31s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 6/30:   6%|▌         | 2/33 [00:03<00:40,  1.31s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 6/30:   9%|▉         | 3/33 [00:03<00:38,  1.29s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 6/30:   9%|▉         | 3/33 [00:05<00:38,  1.29s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 6/30:  12%|█▏        | 4/33 [00:05<00:38,  1.32s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 6/30:  12%|█▏        | 4/33 [00:06<00:38,  1.32s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 6/30:  15%|█▌        | 5/33 [00:06<00:36,  1.32s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 6/30:  15%|█▌        |

Epoch 6: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 7/30:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 7/30:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0000, Acc=100.00%][A[A

Epoch 7/30:   3%|▎         | 1/33 [00:01<00:44,  1.39s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 7/30:   3%|▎         | 1/33 [00:02<00:44,  1.39s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 7/30:   6%|▌         | 2/33 [00:02<00:40,  1.31s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 7/30:   6%|▌         | 2/33 [00:04<00:40,  1.31s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 7/30:   9%|▉         | 3/33 [00:04<00:43,  1.44s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 7/30:   9%|▉         | 3/33 [00:05<00:43,  1.44s/it, Loss=0.0003, Acc=100.00%][A[A

Epoch 7/30:  12%|█▏        | 4/33 [00:05<00:41,  1.44s/it, Loss=0.0003, Acc=100.00%][A[A

Epoch 7/30:  12%|█▏        | 4/33 [00:07<00:41,  1.44s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 7/30:  15%|█▌        | 5/33 [00:07<00:43,  1.55s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 7/30:  15%|█▌        |

Epoch 7: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 8/30:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 8/30:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0002, Acc=100.00%][A[A

Epoch 8/30:   3%|▎         | 1/33 [00:01<00:50,  1.58s/it, Loss=0.0002, Acc=100.00%][A[A

Epoch 8/30:   3%|▎         | 1/33 [00:03<00:50,  1.58s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 8/30:   6%|▌         | 2/33 [00:03<00:46,  1.50s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 8/30:   6%|▌         | 2/33 [00:04<00:46,  1.50s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 8/30:   9%|▉         | 3/33 [00:04<00:42,  1.43s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 8/30:   9%|▉         | 3/33 [00:05<00:42,  1.43s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 8/30:  12%|█▏        | 4/33 [00:05<00:40,  1.41s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 8/30:  12%|█▏        | 4/33 [00:07<00:40,  1.41s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 8/30:  15%|█▌        | 5/33 [00:07<00:38,  1.38s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 8/30:  15%|█▌        |

Epoch 8: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 9/30:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 9/30:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0000, Acc=100.00%][A[A

Epoch 9/30:   3%|▎         | 1/33 [00:01<00:39,  1.23s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 9/30:   3%|▎         | 1/33 [00:02<00:39,  1.23s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 9/30:   6%|▌         | 2/33 [00:02<00:45,  1.46s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 9/30:   6%|▌         | 2/33 [00:04<00:45,  1.46s/it, Loss=0.0002, Acc=100.00%][A[A

Epoch 9/30:   9%|▉         | 3/33 [00:04<00:43,  1.44s/it, Loss=0.0002, Acc=100.00%][A[A

Epoch 9/30:   9%|▉         | 3/33 [00:05<00:43,  1.44s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 9/30:  12%|█▏        | 4/33 [00:05<00:40,  1.40s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 9/30:  12%|█▏        | 4/33 [00:06<00:40,  1.40s/it, Loss=0.0003, Acc=100.00%][A[A

Epoch 9/30:  15%|█▌        | 5/33 [00:06<00:38,  1.37s/it, Loss=0.0003, Acc=100.00%][A[A

Epoch 9/30:  15%|█▌        |

Epoch 9: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 10/30:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 10/30:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/30:   3%|▎         | 1/33 [00:01<00:39,  1.25s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/30:   3%|▎         | 1/33 [00:02<00:39,  1.25s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/30:   6%|▌         | 2/33 [00:02<00:40,  1.30s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/30:   6%|▌         | 2/33 [00:03<00:40,  1.30s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/30:   9%|▉         | 3/33 [00:03<00:40,  1.34s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/30:   9%|▉         | 3/33 [00:05<00:40,  1.34s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/30:  12%|█▏        | 4/33 [00:05<00:42,  1.45s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/30:  12%|█▏        | 4/33 [00:07<00:42,  1.45s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/30:  15%|█▌        | 5/33 [00:07<00:40,  1.46s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 10/30:  15%

Epoch 10: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 11/30:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 11/30:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/30:   3%|▎         | 1/33 [00:01<00:42,  1.34s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/30:   3%|▎         | 1/33 [00:02<00:42,  1.34s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/30:   6%|▌         | 2/33 [00:02<00:46,  1.49s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/30:   6%|▌         | 2/33 [00:04<00:46,  1.49s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/30:   9%|▉         | 3/33 [00:04<00:43,  1.44s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/30:   9%|▉         | 3/33 [00:05<00:43,  1.44s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/30:  12%|█▏        | 4/33 [00:05<00:40,  1.40s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/30:  12%|█▏        | 4/33 [00:06<00:40,  1.40s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/30:  15%|█▌        | 5/33 [00:06<00:37,  1.34s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 11/30:  15%

Epoch 11: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 12/30:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 12/30:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/30:   3%|▎         | 1/33 [00:01<00:40,  1.26s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/30:   3%|▎         | 1/33 [00:02<00:40,  1.26s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/30:   6%|▌         | 2/33 [00:02<00:43,  1.41s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/30:   6%|▌         | 2/33 [00:04<00:43,  1.41s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/30:   9%|▉         | 3/33 [00:04<00:41,  1.39s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/30:   9%|▉         | 3/33 [00:05<00:41,  1.39s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/30:  12%|█▏        | 4/33 [00:05<00:39,  1.38s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 12/30:  12%|█▏        | 4/33 [00:06<00:39,  1.38s/it, Loss=0.0008, Acc=100.00%][A[A

Epoch 12/30:  15%|█▌        | 5/33 [00:06<00:37,  1.36s/it, Loss=0.0008, Acc=100.00%][A[A

Epoch 12/30:  15%

Epoch 12: Train Acc: 100.00%, Val Acc: 100.00%




Epoch 13/30:   0%|          | 0/33 [00:00<?, ?it/s][A[A

Epoch 13/30:   0%|          | 0/33 [00:01<?, ?it/s, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/30:   3%|▎         | 1/33 [00:01<00:40,  1.26s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/30:   3%|▎         | 1/33 [00:02<00:40,  1.26s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/30:   6%|▌         | 2/33 [00:02<00:37,  1.21s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/30:   6%|▌         | 2/33 [00:03<00:37,  1.21s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 13/30:   9%|▉         | 3/33 [00:03<00:37,  1.26s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 13/30:   9%|▉         | 3/33 [00:05<00:37,  1.26s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/30:  12%|█▏        | 4/33 [00:05<00:36,  1.26s/it, Loss=0.0000, Acc=100.00%][A[A

Epoch 13/30:  12%|█▏        | 4/33 [00:06<00:36,  1.26s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 13/30:  15%|█▌        | 5/33 [00:06<00:36,  1.31s/it, Loss=0.0001, Acc=100.00%][A[A

Epoch 13/30:  15%

Epoch 13: Train Acc: 100.00%, Val Acc: 100.00%
Early stopping at epoch 13. Best validation accuracy: 100.00%




Evaluating:   0%|          | 0/9 [00:00<?, ?it/s][A[A

Evaluating:  11%|█         | 1/9 [00:01<00:08,  1.00s/it][A[A

Evaluating:  22%|██▏       | 2/9 [00:01<00:06,  1.02it/s][A[A

Evaluating:  33%|███▎      | 3/9 [00:02<00:05,  1.01it/s][A[A

Evaluating:  44%|████▍     | 4/9 [00:04<00:05,  1.14s/it][A[A

Evaluating:  56%|█████▌    | 5/9 [00:05<00:04,  1.09s/it][A[A

Evaluating:  67%|██████▋   | 6/9 [00:06<00:03,  1.06s/it][A[A

Evaluating:  78%|███████▊  | 7/9 [00:07<00:02,  1.06s/it][A[A

Evaluating:  89%|████████▉ | 8/9 [00:09<00:01,  1.23s/it][A[A

Evaluating: 100%|██████████| 9/9 [00:09<00:00,  1.08s/it]


Fold 1 Accuracy: 1.0000 (100.00%)

Fold 2/5


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/30:   0%|          | 0/34 [00:00<?, ?it/s][A[A

Epoch 1/30:   0%|          | 0/34 [00:01<?, ?it/s, Loss=0.7555, Acc=0.00%][A[A

Epoch 1/30:   3%|▎         | 1/34 [00:01<00:42,  1.30s/it, Loss=0.7555, Acc=0.00%][A[A

Epoch 1/30:   3%|▎         | 1/34 [00:02<00:42,  1.30s/it, Loss=0.7246, Acc=0.00%][A[A

Epoch 1/30:   6%|▌         | 2/34 [00:02<00:43,  1.37s/it, Loss=0.7246, Acc=0.00%][A[A

Epoch 1/30:   6%|▌         | 2/34 [00:04<00:43,  1.37s/it, Loss=0.6832, Acc=16.67%][A[A

Epoch 1/30:   9%|▉         | 3/34 [00:04<00:41,  1.33s/it, Loss=0.6832, Acc=16.67%][A[A

Epoch 1/30:   9%|▉         | 3/34 [00:05<00:41,  1.33s/it, Loss=0.6849, Acc=31.25%][A[A

Epoch 1/30:  12%|█▏        | 4/34 [00:05<00:38,

ValueError: Expected input batch_size (1) to match target batch_size (0).

In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
import librosa
import pickle
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

class SimpleAudioDataset(Dataset):
    def __init__(self, df, processor, max_length=16000*10):  # 10 seconds max
        self.df = df.reset_index(drop=True)
        self.processor = processor
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Load audio
        try:
            audio_path = row['file_path']
            if os.path.exists(audio_path):
                audio, sr = librosa.load(audio_path, sr=16000)

                # Trim or pad to consistent length
                if len(audio) > self.max_length:
                    audio = audio[:self.max_length]
                else:
                    audio = np.pad(audio, (0, self.max_length - len(audio)), 'constant')
            else:
                audio = np.zeros(self.max_length)
        except Exception as e:
            print(f"Error loading {audio_path}: {e}")
            audio = np.zeros(self.max_length)

        return {
            'audio': torch.tensor(audio, dtype=torch.float32),
            'label': torch.tensor(row['label_encoded'], dtype=torch.long)
        }

def simple_collate_fn(batch):
    """Simple collate function for audio data"""
    audios = torch.stack([item['audio'] for item in batch])
    labels = torch.stack([item['label'] for item in batch])

    return {
        'audio': audios,
        'label': labels
    }

class SimpleWav2VecClassifier(nn.Module):
    def __init__(self, num_classes=2, dropout=0.5):
        super().__init__()

        # Use pre-trained Wav2Vec2 directly for classification
        self.wav2vec2 = Wav2Vec2ForSequenceClassification.from_pretrained(
            'facebook/wav2vec2-base-960h',
            num_labels=768,  # Keep original output
            ignore_mismatched_sizes=True
        )

        # Freeze feature extractor
        for param in self.wav2vec2.wav2vec2.feature_extractor.parameters():
            param.requires_grad = False

        # Custom classifier head
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(dropout),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )

    def forward(self, audio):
        # Get Wav2Vec2 features
        with torch.cuda.amp.autocast():  # Mixed precision for memory efficiency
            outputs = self.wav2vec2.wav2vec2(audio)
            features = outputs.last_hidden_state

            # Global average pooling across time dimension
            pooled_features = features.mean(dim=1)

        return self.classifier(pooled_features)

def train_simple_model(model, train_loader, val_loader, epochs=50, device='cuda'):
    model = model.to(device)

    # Optimizer with weight decay
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=10, T_mult=1, eta_min=1e-6
    )

    # Mixed precision training
    scaler = torch.cuda.amp.GradScaler()

    best_val_acc = 0
    patience_counter = 0
    patience = 15

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
        for batch in pbar:
            optimizer.zero_grad()

            audio = batch['audio'].to(device)
            labels = batch['label'].to(device)

            # Mixed precision forward pass
            with torch.cuda.amp.autocast():
                outputs = model(audio)
                loss = criterion(outputs, labels)

            # Mixed precision backward pass
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{100 * train_correct / train_total:.1f}%'
            })

        train_acc = 100 * train_correct / train_total
        avg_train_loss = train_loss / len(train_loader)

        # Validation
        if val_loader:
            model.eval()
            val_loss = 0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for batch in val_loader:
                    audio = batch['audio'].to(device)
                    labels = batch['label'].to(device)

                    with torch.cuda.amp.autocast():
                        outputs = model(audio)
                        loss = criterion(outputs, labels)

                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()

            val_acc = 100 * val_correct / val_total
            avg_val_loss = val_loss / len(val_loader)

            scheduler.step()

            print(f'Epoch {epoch+1}: Train: {train_acc:.1f}% | Val: {val_acc:.1f}% | Loss: {avg_val_loss:.4f}')

            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_counter = 0
                torch.save(model.state_dict(), '/content/drive/MyDrive/Speech/best_simple_model.pth')
                print(f'New best validation accuracy: {val_acc:.2f}%')
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f'Early stopping. Best Val Acc: {best_val_acc:.2f}%')
                break
        else:
            scheduler.step()
            print(f'Epoch {epoch+1}: Train Acc: {train_acc:.1f}%')

    # Load best model
    if val_loader and os.path.exists('/content/drive/MyDrive/Speech/best_simple_model.pth'):
        model.load_state_dict(torch.load('/content/drive/MyDrive/Speech/best_simple_model.pth'))
        print(f'Loaded best model with {best_val_acc:.2f}% validation accuracy')

    return model

def evaluate_simple_model(model, test_loader, device='cuda'):
    model.eval()
    model = model.to(device)

    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc='Evaluating'):
            audio = batch['audio'].to(device)
            labels = batch['label'].to(device)

            with torch.cuda.amp.autocast():
                outputs = model(audio)

            probs = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)

    return {
        'accuracy': accuracy,
        'predictions': all_preds,
        'true_labels': all_labels,
        'probabilities': all_probs,
        'classification_report': classification_report(all_labels, all_preds),
        'confusion_matrix': confusion_matrix(all_labels, all_preds)
    }

def run_simple_diagnosis():
    """Streamlined approach focusing on audio-only for maximum accuracy"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    # Load data
    dataset_path = '/content/drive/MyDrive/Speech/processed_datasets'
    diagnosis_train = pd.read_pickle(os.path.join(dataset_path, 'diagnosis_train.pkl'))

    print(f"Dataset: {len(diagnosis_train)} samples")
    print(diagnosis_train['label'].value_counts())

    # Initialize processor
    processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base-960h')

    # Multiple train-val splits for robust training
    best_accuracy = 0
    best_model = None

    for seed in [42, 123, 456, 789, 999]:
        print(f"\nTraining with seed {seed}...")

        # Stratified split
        train_df, val_df = train_test_split(
            diagnosis_train,
            test_size=0.25,  # Slightly larger validation set
            stratify=diagnosis_train['label_encoded'],
            random_state=seed
        )

        # Create datasets
        train_dataset = SimpleAudioDataset(train_df, processor)
        val_dataset = SimpleAudioDataset(val_df, processor)

        # Data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=8,  # Larger batch for stability
            shuffle=True,
            collate_fn=simple_collate_fn,
            pin_memory=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=8,
            shuffle=False,
            collate_fn=simple_collate_fn,
            pin_memory=True
        )

        # Create model
        model = SimpleWav2VecClassifier(num_classes=2, dropout=0.3)

        # Train
        trained_model = train_simple_model(
            model, train_loader, val_loader, epochs=40, device=device
        )

        # Evaluate
        results = evaluate_simple_model(trained_model, val_loader, device=device)

        print(f"Seed {seed} Results: {results['accuracy']:.4f} ({results['accuracy']*100:.2f}%)")

        if results['accuracy'] > best_accuracy:
            best_accuracy = results['accuracy']
            best_model = trained_model
            torch.save(best_model.state_dict(), '/content/drive/MyDrive/Speech/best_overall_model.pth')
            print(f"New best overall accuracy: {best_accuracy*100:.2f}%")

    print(f"\nBest achieved accuracy: {best_accuracy*100:.2f}%")

    return best_model, best_accuracy

# Even simpler approach: Fine-tune Wav2Vec2 end-to-end
class DirectWav2VecClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()

        # Load pre-trained model and modify classifier
        self.model = Wav2Vec2ForSequenceClassification.from_pretrained(
            'facebook/wav2vec2-base-960h',
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        )

        # Freeze feature extractor to prevent overfitting
        for param in self.model.wav2vec2.feature_extractor.parameters():
            param.requires_grad = False

    def forward(self, input_values):
        return self.model(input_values=input_values).logits

def train_direct_model(model, train_loader, val_loader, epochs=30, device='cuda'):
    model = model.to(device)

    # Lower learning rate for fine-tuning
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()

    # Cosine annealing scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_val_acc = 0

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
            optimizer.zero_grad()

            audio = batch['audio'].to(device)
            labels = batch['label'].to(device)

            outputs = model(audio)
            loss = criterion(outputs, labels)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        scheduler.step()

        train_acc = 100 * train_correct / train_total

        # Validation
        if val_loader:
            model.eval()
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for batch in val_loader:
                    audio = batch['audio'].to(device)
                    labels = batch['label'].to(device)

                    outputs = model(audio)
                    _, predicted = torch.max(outputs, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()

            val_acc = 100 * val_correct / val_total
            print(f'Epoch {epoch+1}: Train: {train_acc:.1f}% | Val: {val_acc:.1f}%')

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), '/content/drive/MyDrive/Speech/direct_best_model.pth')
        else:
            print(f'Epoch {epoch+1}: Train: {train_acc:.1f}%')

    return model

def run_direct_approach():
    """Ultra-simple approach: Direct Wav2Vec2 fine-tuning"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    # Load data
    dataset_path = '/content/drive/MyDrive/Speech/processed_datasets'
    diagnosis_train = pd.read_pickle(os.path.join(dataset_path, 'diagnosis_train.pkl'))

    processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base-960h')

    best_results = []

    # Try multiple splits
    for seed in [42, 123, 456]:
        print(f"\n--- Seed {seed} ---")

        train_df, val_df = train_test_split(
            diagnosis_train,
            test_size=0.2,
            stratify=diagnosis_train['label_encoded'],
            random_state=seed
        )

        # Datasets
        train_dataset = SimpleAudioDataset(train_df, processor)
        val_dataset = SimpleAudioDataset(val_df, processor)

        # Data loaders
        train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=simple_collate_fn)
        val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=simple_collate_fn)

        # Model
        model = DirectWav2VecClassifier(num_classes=2)

        # Train
        trained_model = train_direct_model(model, train_loader, val_loader, epochs=25, device=device)

        # Final evaluation
        model.eval()
        val_correct = 0
        val_total = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in val_loader:
                audio = batch['audio'].to(device)
                labels = batch['label'].to(device)

                outputs = trained_model(audio)
                _, predicted = torch.max(outputs, 1)

                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        final_acc = 100 * val_correct / val_total
        best_results.append(final_acc)

        print(f"Final accuracy for seed {seed}: {final_acc:.2f}%")
        print(classification_report(all_labels, all_preds, target_names=['CN', 'AD']))

    avg_accuracy = np.mean(best_results)
    std_accuracy = np.std(best_results)

    print(f"\n🎯 FINAL RESULTS:")
    print(f"Average Accuracy: {avg_accuracy:.2f}% ± {std_accuracy:.2f}%")
    print(f"Best Single Run: {max(best_results):.2f}%")
    print(f"Individual Results: {[f'{acc:.1f}%' for acc in best_results]}")

    if max(best_results) >= 90:
        print("✅ TARGET ACHIEVED: >90% accuracy!")
    else:
        print("❌ Target not reached, but this is the best approach for your dataset size")

    return max(best_results)

# Ensemble approach for even higher accuracy
def create_ensemble_model():
    """Create ensemble of models for maximum accuracy"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset_path = '/content/drive/MyDrive/Speech/processed_datasets'
    diagnosis_train = pd.read_pickle(os.path.join(dataset_path, 'diagnosis_train.pkl'))

    processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base-960h')

    models = []

    # Train 5 models with different configurations
    configs = [
        {'dropout': 0.3, 'lr': 3e-5, 'seed': 42},
        {'dropout': 0.4, 'lr': 5e-5, 'seed': 123},
        {'dropout': 0.5, 'lr': 2e-5, 'seed': 456},
        {'dropout': 0.3, 'lr': 4e-5, 'seed': 789},
        {'dropout': 0.4, 'lr': 3e-5, 'seed': 999},
    ]

    for i, config in enumerate(configs):
        print(f"\nTraining ensemble model {i+1}/5...")

        train_df, val_df = train_test_split(
            diagnosis_train,
            test_size=0.2,
            stratify=diagnosis_train['label_encoded'],
            random_state=config['seed']
        )

        train_dataset = SimpleAudioDataset(train_df, processor)
        val_dataset = SimpleAudioDataset(val_df, processor)

        train_loader = DataLoader(train_dataset, batch_size=6, shuffle=True, collate_fn=simple_collate_fn)
        val_loader = DataLoader(val_dataset, batch_size=6, shuffle=False, collate_fn=simple_collate_fn)

        # Custom model with specific dropout
        class EnsembleModel(nn.Module):
            def __init__(self, dropout):
                super().__init__()
                self.wav2vec2 = Wav2Vec2ForSequenceClassification.from_pretrained(
                    'facebook/wav2vec2-base-960h',
                    num_labels=768,
                    ignore_mismatched_sizes=True
                )

                for param in self.wav2vec2.wav2vec2.feature_extractor.parameters():
                    param.requires_grad = False

                self.classifier = nn.Sequential(
                    nn.Dropout(dropout),
                    nn.Linear(768, 256),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(256, 2)
                )

            def forward(self, audio):
                outputs = self.wav2vec2.wav2vec2(audio)
                features = outputs.last_hidden_state.mean(dim=1)
                return self.classifier(features)

        model = EnsembleModel(config['dropout']).to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=0.01)
        criterion = nn.CrossEntropyLoss()

        # Quick training
        for epoch in range(20):
            model.train()
            for batch in train_loader:
                optimizer.zero_grad()
                audio = batch['audio'].to(device)
                labels = batch['label'].to(device)
                outputs = model(audio)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

        models.append(model)

    # Ensemble evaluation
    print("\nEvaluating ensemble...")
    train_df, val_df = train_test_split(
        diagnosis_train, test_size=0.2, stratify=diagnosis_train['label_encoded'], random_state=42
    )
    val_dataset = SimpleAudioDataset(val_df, processor)
    val_loader = DataLoader(val_dataset, batch_size=6, shuffle=False, collate_fn=simple_collate_fn)

    all_preds = []
    all_labels = []

    for model in models:
        model.eval()

    with torch.no_grad():
        for batch in val_loader:
            audio = batch['audio'].to(device)
            labels = batch['label'].to(device)

            # Get predictions from all models
            ensemble_outputs = []
            for model in models:
                outputs = model(audio)
                ensemble_outputs.append(F.softmax(outputs, dim=1))

            # Average predictions
            avg_outputs = torch.stack(ensemble_outputs).mean(dim=0)
            _, predicted = torch.max(avg_outputs, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    ensemble_accuracy = accuracy_score(all_labels, all_preds)
    print(f"🎯 ENSEMBLE ACCURACY: {ensemble_accuracy*100:.2f}%")

    if ensemble_accuracy >= 0.90:
        print("✅ TARGET ACHIEVED with ensemble!")

    return ensemble_accuracy

if __name__ == "__main__":
    print("=== APPROACH 1: Direct Wav2Vec2 Fine-tuning ===")
    best_acc = run_direct_approach()

    if best_acc < 90:
        print("\n=== APPROACH 2: Ensemble Method ===")
        ensemble_acc = create_ensemble_model()

        if ensemble_acc >= 90:
            print("Ensemble achieved target!")
        else:
            print(f"Best possible with current data: {max(best_acc, ensemble_acc*100):.1f}%")

    print("\n💡 TIPS TO REACH 90%+:")
    print("1. Data augmentation (speed, pitch, noise)")
    print("2. More training data")
    print("3. Different pre-trained models (HuBERT, WavLM)")
    print("4. Feature engineering from spectrograms")