<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/Detecting_dementia_from_speech_and_transcripts_using_transformers_May243.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [23]:
# ============================================================================
# STEP 1: AUTOMATIC SPEECH RECOGNITION MODULE
# ============================================================================

class SpeechTranscriber:
    """Automatic Speech Recognition for generating transcripts from audio"""

    def __init__(self, model_name="openai/whisper-base", cache_dir="./asr_cache"):
        """
        Initialize ASR model
        Options:
        - openai/whisper-base: Good balance of speed/accuracy
        - facebook/wav2vec2-base-960h: Faster but less accurate
        - openai/whisper-small: More accurate but slower
        """
        self.model_name = model_name
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)

        print(f"Loading ASR model: {model_name}")

        if "whisper" in model_name.lower():
            self.processor = WhisperProcessor.from_pretrained(model_name)
            self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
            self.asr_type = "whisper"
        else:
            self.processor = Wav2Vec2Processor.from_pretrained(model_name)
            self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
            self.asr_type = "wav2vec2"

        # Move to GPU if available
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.model.to(self.device)
        self.model.eval()

        print(f"ASR model loaded on {self.device}")

    def transcribe_audio_file(self, audio_path, use_cache=True):
        """Transcribe a single audio file"""
        audio_path = Path(audio_path)

        # Check cache first
        cache_file = self.cache_dir / f"{audio_path.stem}_transcript.pkl"
        if use_cache and cache_file.exists():
            try:
                with open(cache_file, 'rb') as f:
                    cached_result = pickle.load(f)
                return cached_result['transcript']
            except:
                pass  # Cache corrupted, proceed with transcription

        try:
            # Load audio
            audio, sr = librosa.load(str(audio_path), sr=16000)  # Whisper expects 16kHz

            # Handle empty or very short audio
            if len(audio) < 1600:  # Less than 0.1 seconds
                transcript = ""
            else:
                transcript = self._transcribe_audio_array(audio)

            # Cache result
            if use_cache:
                try:
                    with open(cache_file, 'wb') as f:
                        pickle.dump({
                            'audio_path': str(audio_path),
                            'transcript': transcript,
                            'model': self.model_name
                        }, f)
                except:
                    pass  # Caching failed, but transcription succeeded

            return transcript

        except Exception as e:
            print(f"Error transcribing {audio_path}: {e}")
            return f"[Transcription failed for {audio_path.name}]"

    def _transcribe_audio_array(self, audio_array):
        """Transcribe audio array using the loaded model"""
        try:
            if self.asr_type == "whisper":
                return self._whisper_transcribe(audio_array)
            else:
                return self._wav2vec2_transcribe(audio_array)
        except Exception as e:
            print(f"Transcription error: {e}")
            return "[Transcription error]"

    def _whisper_transcribe(self, audio_array):
        """Transcribe using Whisper model"""
        # Process audio
        inputs = self.processor(
            audio_array,
            sampling_rate=16000,
            return_tensors="pt"
        ).input_features.to(self.device)

        # Generate transcription
        with torch.no_grad():
            predicted_ids = self.model.generate(inputs, max_length=448)

        # Decode
        transcript = self.processor.batch_decode(
            predicted_ids,
            skip_special_tokens=True
        )[0]

        return transcript.strip()

    def _wav2vec2_transcribe(self, audio_array):
        """Transcribe using Wav2Vec2 model"""
        # Process audio
        inputs = self.processor(
            audio_array,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        ).input_values.to(self.device)

        # Get logits
        with torch.no_grad():
            logits = self.model(inputs).logits

        # Decode
        predicted_ids = torch.argmax(logits, dim=-1)
        transcript = self.processor.batch_decode(predicted_ids)[0]

        return transcript.strip().lower()

    def batch_transcribe(self, audio_paths, batch_size=8):
        """Transcribe multiple audio files with progress bar"""
        transcripts = {}

        print(f"Transcribing {len(audio_paths)} audio files...")

        for audio_path in tqdm(audio_paths, desc="Transcribing"):
            transcript = self.transcribe_audio_file(audio_path)
            participant_id = self._extract_participant_id(Path(audio_path).name)
            transcripts[participant_id] = transcript

        return transcripts

    def _extract_participant_id(self, filename):
        """Extract participant ID from filename"""
        patterns = [
            r'adrso?(\d{3})',         # adrs0123 or adrso123
            r'adrsp?(\d{3})',         # adrsp123
            r'adrspt?(\d{1,3})',      # adrspt1, adrspt12
            r'(\d{3})',               # 3-digit numbers
            r'([A-Z]\d{2,3})',        # Letter followed by 2-3 digits
            r'(S\d{3})',              # S followed by 3 digits
        ]

        for pattern in patterns:
            match = re.search(pattern, filename)
            if match:
                return match.group(1) if pattern.startswith(r'(\d') else match.group(0)

        return Path(filename).stem

# ============================================================================
# STEP 2: ENHANCED DATA PROCESSOR WITH ASR
# ============================================================================

class EnhancedADReSSDataProcessor:
    """Enhanced ADReSS data processor with automatic speech recognition"""

    def __init__(self, output_dir='./extracted_data', asr_model="openai/whisper-base"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        # Initialize ASR
        self.transcriber = SpeechTranscriber(model_name=asr_model)

    def extract_adress_dataset(self, tar_path, dataset_name):
        """Extract ADReSS dataset and organize files properly"""
        extract_path = self.output_dir / dataset_name
        extract_path.mkdir(exist_ok=True)

        print(f"Extracting {tar_path} to {extract_path}")

        try:
            with tarfile.open(tar_path, 'r:gz') as tar:
                tar.extractall(path=extract_path)
            print(f"Successfully extracted {dataset_name}")

            # Find the actual dataset directory structure
            self._explore_directory_structure(extract_path)
            return extract_path
        except Exception as e:
            print(f"Error extracting {tar_path}: {e}")
            return None

    def _explore_directory_structure(self, base_path):
        """Explore and print directory structure"""
        print(f"\nDirectory structure for {base_path.name}:")
        for root, dirs, files in os.walk(base_path):
            level = root.replace(str(base_path), '').count(os.sep)
            indent = ' ' * 2 * level
            print(f"{indent}{os.path.basename(root)}/")
            subindent = ' ' * 2 * (level + 1)
            for file in files[:5]:
                print(f"{subindent}{file}")
            if len(files) > 5:
                print(f"{subindent}... and {len(files) - 5} more files")

    def process_adress_dataset_with_asr(self, extract_path):
        """Process ADReSS dataset with automatic speech recognition"""
        dataset_info = {
            'audio_files': [],
            'transcript_files': [],
            'metadata_files': [],
            'labels': {},
            'paired_data': [],
            'generated_transcripts': {}
        }

        # Look for ADReSS structure
        adress_dirs = list(extract_path.rglob("*ADReSS*"))
        if adress_dirs:
            main_dir = adress_dirs[0]
        else:
            main_dir = extract_path

        print(f"Processing from directory: {main_dir}")

        # Find audio files
        audio_patterns = ['**/*.wav', '**/*.mp3', '**/*.flac']
        for pattern in audio_patterns:
            dataset_info['audio_files'].extend(list(main_dir.glob(pattern)))

        print(f"Found {len(dataset_info['audio_files'])} audio files")

        # Generate transcripts using ASR
        print("Generating transcripts using ASR...")
        audio_paths = [str(path) for path in dataset_info['audio_files']]
        generated_transcripts = self.transcriber.batch_transcribe(audio_paths)
        dataset_info['generated_transcripts'] = generated_transcripts

        print(f"Generated {len(generated_transcripts)} transcripts")

        # Process labels from directory structure
        labels = self._extract_labels_from_structure(dataset_info['audio_files'])
        dataset_info['labels'] = labels

        # Create paired dataset with generated transcripts
        paired_data = self._create_paired_data_with_asr(dataset_info)
        dataset_info['paired_data'] = paired_data

        return dataset_info

    def _extract_labels_from_structure(self, audio_files):
        """Extract labels from file paths or directory structure"""
        labels = {}

        for audio_file in audio_files:
            # Extract participant ID
            participant_id = self._extract_participant_id(audio_file.name)

            # Determine label from path
            path_str = str(audio_file).lower()
            if '/ad/' in path_str or 'dementia' in path_str or 'alzheimer' in path_str:
                label = 1  # AD/Dementia
                class_name = 'AD'
            elif '/cn/' in path_str or 'control' in path_str or 'normal' in path_str:
                label = 0  # Control/Normal
                class_name = 'CN'
            elif 'decline' in path_str:
                label = 1  # Decline/progression
                class_name = 'AD'
            elif 'no_decline' in path_str or 'no-decline' in path_str:
                label = 0  # No decline
                class_name = 'CN'
            else:
                # Default classification based on filename patterns
                if any(marker in audio_file.name.lower() for marker in ['ad', 'dem', 'alz']):
                    label = 1
                    class_name = 'AD'
                else:
                    label = 0  # Default to control
                    class_name = 'CN'

            labels[participant_id] = {
                'label': label,
                'class_name': class_name,
                'audio_path': audio_file
            }

        return labels

    def _extract_participant_id(self, filename):
        """Extract participant ID from filename"""
        patterns = [
            r'adrso?(\d{3})',
            r'adrsp?(\d{3})',
            r'adrspt?(\d{1,3})',
            r'(\d{3})',
            r'([A-Z]\d{2,3})',
            r'(S\d{3})',
        ]

        for pattern in patterns:
            match = re.search(pattern, filename)
            if match:
                return match.group(1) if pattern.startswith(r'(\d') else match.group(0)

        return Path(filename).stem

    def _create_paired_data_with_asr(self, dataset_info):
        """Create paired audio-transcript dataset using ASR-generated transcripts"""
        paired_data = []

        # Create paired dataset using generated transcripts
        for participant_id, label_info in dataset_info['labels'].items():
            # Get generated transcript
            transcript = dataset_info['generated_transcripts'].get(participant_id, "")

            # Clean and validate transcript
            transcript = self._clean_and_validate_transcript(transcript)

            # If transcript is still empty or invalid, create a meaningful placeholder
            if not transcript or len(transcript.strip()) < 10:
                transcript = f"Audio sample from participant {participant_id}. Speech content unclear or silent."

            paired_data.append({
                'participant_id': participant_id,
                'audio_path': str(label_info['audio_path']),
                'transcript': transcript,
                'label': label_info['label'],
                'class_name': label_info['class_name'],
                'transcript_source': 'ASR_generated'
            })

        print(f"Created {len(paired_data)} paired samples with ASR transcripts")

        # Print class distribution
        labels = [item['label'] for item in paired_data]
        unique, counts = np.unique(labels, return_counts=True)
        for cls, count in zip(unique, counts):
            class_name = 'CN' if cls == 0 else 'AD'
            print(f"  {class_name}: {count} samples ({count/len(labels)*100:.1f}%)")

        # Print sample transcripts for verification
        print("\nSample generated transcripts:")
        print("-" * 50)
        for i, sample in enumerate(paired_data[:3]):
            print(f"Participant {sample['participant_id']} ({sample['class_name']}):")
            print(f"Transcript: {sample['transcript'][:100]}...")
            print()

        return paired_data

    def _clean_and_validate_transcript(self, transcript):
        """Clean and validate ASR-generated transcript"""
        if not transcript:
            return ""

        # Remove common ASR artifacts
        transcript = transcript.strip()
        transcript = re.sub(r'\[.*?\]', '', transcript)  # Remove [NOISE], [MUSIC], etc.
        transcript = re.sub(r'<.*?>', '', transcript)    # Remove <unk>, <pad>, etc.
        transcript = re.sub(r'\s+', ' ', transcript)     # Normalize whitespace

        # Remove very short or repetitive transcripts
        if len(transcript) < 5:
            return ""

        # Check for repetitive patterns (common ASR error)
        words = transcript.split()
        if len(words) > 1:
            # If more than 70% of words are the same, likely an error
            unique_words = set(words)
            if len(unique_words) / len(words) < 0.3:
                return ""

        return transcript

# ============================================================================
# STEP 3: ENHANCED MULTIMODAL DATASET
# ============================================================================

class EnhancedMultiModalDataset(Dataset):
    """Enhanced dataset with ASR-generated transcripts and linguistic features"""

    def __init__(self, data_samples, audio_processor, tokenizer,
                 max_text_length=512, audio_max_length=16*16000,
                 image_size=(224, 224)):
        self.data_samples = data_samples
        self.audio_processor = audio_processor
        self.tokenizer = tokenizer
        self.max_text_length = max_text_length
        self.audio_max_length = audio_max_length
        self.image_size = image_size

        # Precompute linguistic features for efficiency
        self._precompute_linguistic_features()

    def _precompute_linguistic_features(self):
        """Precompute linguistic features that might be important for AD detection"""
        print("Precomputing linguistic features...")

        for sample in tqdm(self.data_samples, desc="Computing linguistic features"):
            transcript = sample['transcript']

            # Basic linguistic metrics
            words = transcript.split()
            sentences = re.split(r'[.!?]+', transcript)

            linguistic_features = {
                'word_count': len(words),
                'sentence_count': len([s for s in sentences if s.strip()]),
                'avg_word_length': np.mean([len(word) for word in words]) if words else 0,
                'unique_words': len(set(words)),
                'lexical_diversity': len(set(words)) / len(words) if words else 0,
                'pause_markers': transcript.count('[pause]') + transcript.count('...'),
                'filler_words': sum(1 for word in words if word.lower() in ['um', 'uh', 'er', 'ah']),
                'transcript_length': len(transcript)
            }

            sample['linguistic_features'] = linguistic_features

    def __len__(self):
        return len(self.data_samples)

    def __getitem__(self, idx):
        sample = self.data_samples[idx]

        # Process text
        text = sample['transcript']
        if not text or text.strip() == "":
            text = "No speech content detected in audio sample"

        try:
            encoding = self.tokenizer(
                text,
                truncation=True,
                padding='max_length',
                max_length=self.max_text_length,
                return_tensors='pt'
            )
        except Exception as e:
            print(f"Error tokenizing text for {sample['participant_id']}: {e}")
            # Create dummy encoding
            encoding = {
                'input_ids': torch.zeros(self.max_text_length, dtype=torch.long),
                'attention_mask': torch.zeros(self.max_text_length, dtype=torch.long),
                'token_type_ids': torch.zeros(self.max_text_length, dtype=torch.long)
            }

        # Process audio
        try:
            audio = self.audio_processor.load_audio(
                sample['audio_path'],
                max_length=self.audio_max_length
            )

            # Extract spectrogram features
            spectrogram = self.audio_processor.extract_mel_spectrogram(audio)

            # Resize for ViT input
            audio_features = self.audio_processor.resize_spectrogram_to_image(
                spectrogram, self.image_size
            )

        except Exception as e:
            print(f"Error processing audio for {sample['participant_id']}: {e}")
            # Create dummy audio features
            audio_features = np.random.rand(3, self.image_size[0], self.image_size[1])

        # Linguistic features
        ling_features = sample.get('linguistic_features', {})
        linguistic_vector = np.array([
            ling_features.get('word_count', 0),
            ling_features.get('sentence_count', 0),
            ling_features.get('avg_word_length', 0),
            ling_features.get('unique_words', 0),
            ling_features.get('lexical_diversity', 0),
            ling_features.get('pause_markers', 0),
            ling_features.get('filler_words', 0),
            ling_features.get('transcript_length', 0)
        ], dtype=np.float32)

        return {
            'input_ids': encoding['input_ids'].squeeze() if hasattr(encoding['input_ids'], 'squeeze') else encoding['input_ids'],
            'attention_mask': encoding['attention_mask'].squeeze() if hasattr(encoding['attention_mask'], 'squeeze') else encoding['attention_mask'],
            'audio_features': torch.FloatTensor(audio_features),
            'linguistic_features': torch.FloatTensor(linguistic_vector),
            'label': torch.LongTensor([sample['label']]).squeeze(),
            'participant_id': sample['participant_id'],
            'class_name': sample['class_name'],
            'transcript_preview': text[:100] + "..." if len(text) > 100 else text
        }

# ============================================================================
# STEP 4: AUDIO PROCESSOR FOR MULTIMODAL FEATURES
# ============================================================================

class SimpleAudioProcessor:
    def __init__(self, sample_rate=16000, n_mels=128):
        self.sample_rate = sample_rate
        self.n_mels = n_mels

    def load_audio(self, audio_path, max_length=None):
        try:
            audio, sr = librosa.load(audio_path, sr=self.sample_rate)
            if audio.ndim > 1:
                audio = np.mean(audio, axis=1)
            audio, _ = librosa.effects.trim(audio, top_db=20)
            if np.max(np.abs(audio)) > 0:
                audio = librosa.util.normalize(audio)

            if max_length is not None:
                if len(audio) > max_length:
                    start = (len(audio) - max_length) // 2
                    audio = audio[start:start + max_length]
                elif len(audio) < max_length:
                    pad_length = max_length - len(audio)
                    audio = np.pad(audio, (0, pad_length), mode='constant')

            return audio
        except:
            length = max_length if max_length else self.sample_rate * 10
            return np.random.randn(length) * 0.01

    def extract_mel_spectrogram(self, audio):
        try:
            mel_spec = librosa.feature.melspectrogram(
                y=audio, sr=self.sample_rate, n_mels=self.n_mels
            )
            log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
            delta = librosa.feature.delta(log_mel_spec)
            delta2 = librosa.feature.delta(log_mel_spec, order=2)
            return np.stack([log_mel_spec, delta, delta2], axis=0)
        except:
            return np.random.randn(3, self.n_mels, 100)

    def resize_spectrogram_to_image(self, spectrogram, target_size=(224, 224)):
        try:
            if spectrogram.ndim == 3:
                resized_channels = []
                for i in range(spectrogram.shape[0]):
                    channel = spectrogram[i]
                    zoom_factors = [target_size[j] / channel.shape[j] for j in range(2)]
                    resized_channel = ndimage.zoom(channel, zoom_factors, order=1)
                    resized_channels.append(resized_channel)
                resized = np.stack(resized_channels, axis=0)
            else:
                zoom_factors = [target_size[i] / spectrogram.shape[i] for i in range(2)]
                resized = ndimage.zoom(spectrogram, zoom_factors, order=1)
                resized = np.stack([resized] * 3, axis=0)

            resized = (resized - resized.min()) / (resized.max() - resized.min() + 1e-8)
            return resized
        except:
            return np.random.rand(3, target_size[0], target_size[1])

# ============================================================================
# STEP 5: ENHANCED MULTIMODAL MODEL
# ============================================================================

class EnhancedMultiModalADClassifier(nn.Module):
    """Enhanced multimodal classifier for Alzheimer's detection"""

    def __init__(self, text_hidden_size=768, audio_hidden_size=768,
                 linguistic_feature_size=8, fusion_hidden_size=512,
                 num_classes=2, dropout=0.3):
        super().__init__()

        # Text encoder (BERT)
        self.text_encoder = BertModel.from_pretrained('bert-base-uncased')

        # Audio encoder (ViT for spectrograms)
        self.audio_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224')

        # Linguistic features processor
        self.linguistic_processor = nn.Sequential(
            nn.Linear(linguistic_feature_size, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 32),
            nn.ReLU()
        )

        # Attention mechanism for modality fusion
        self.attention = nn.MultiheadAttention(
            embed_dim=fusion_hidden_size,
            num_heads=8,
            dropout=dropout
        )

        # Feature projections to common dimensionality
        self.text_projection = nn.Linear(text_hidden_size, fusion_hidden_size)
        self.audio_projection = nn.Linear(audio_hidden_size, fusion_hidden_size)
        self.linguistic_projection = nn.Linear(32, fusion_hidden_size)

        # Fusion layers
        self.fusion_layers = nn.Sequential(
            nn.Linear(fusion_hidden_size * 3, fusion_hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(fusion_hidden_size, fusion_hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(fusion_hidden_size // 2, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes)
        )

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Initialize model weights"""
        for module in [self.linguistic_processor, self.fusion_layers, self.classifier]:
            for layer in module:
                if isinstance(layer, nn.Linear):
                    nn.init.xavier_uniform_(layer.weight)
                    nn.init.zeros_(layer.bias)

    def forward(self, input_ids, attention_mask, audio_features, linguistic_features):
        batch_size = input_ids.size(0)

        # Text encoding
        text_output = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        text_features = text_output.pooler_output  # [batch_size, 768]

        # Audio encoding
        audio_output = self.audio_encoder(pixel_values=audio_features)
        audio_features = audio_output.pooler_output  # [batch_size, 768]

        # Linguistic features processing
        linguistic_processed = self.linguistic_processor(linguistic_features)  # [batch_size, 32]

        # Project to common dimensionality
        text_projected = self.text_projection(text_features)        # [batch_size, 512]
        audio_projected = self.audio_projection(audio_features)     # [batch_size, 512]
        linguistic_projected = self.linguistic_projection(linguistic_processed)  # [batch_size, 512]

        # Prepare for attention mechanism
        # Convert to [seq_len, batch_size, embed_dim] for attention
        modality_features = torch.stack([
            text_projected,
            audio_projected,
            linguistic_projected
        ], dim=0)  # [3, batch_size, 512]

        # Apply self-attention across modalities
        attended_features, attention_weights = self.attention(
            modality_features, modality_features, modality_features
        )  # [3, batch_size, 512]

        # Convert back to [batch_size, features]
        attended_features = attended_features.transpose(0, 1)  # [batch_size, 3, 512]

        # Flatten for fusion
        fused_input = attended_features.reshape(batch_size, -1)  # [batch_size, 1536]

        # Fusion
        fused_features = self.fusion_layers(fused_input)  # [batch_size, 256]

        # Classification
        logits = self.classifier(fused_features)  # [batch_size, 2]

        return {
            'logits': logits,
            'text_features': text_features,
            'audio_features': audio_features,
            'linguistic_features': linguistic_processed,
            'attention_weights': attention_weights,
            'fused_features': fused_features
        }


# ============================================================================
# STEP 6: TRAINING AND EVALUATION (CONTINUED)
# ============================================================================

class ModelTrainer:
    """Enhanced model trainer with comprehensive evaluation"""

    def __init__(self, model, device, learning_rate=2e-5, weight_decay=0.01):
        self.model = model.to(device)
        self.device = device
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay
        )
        self.criterion = nn.CrossEntropyLoss()
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', patience=3, factor=0.5, verbose=True
        )

        # Training history
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.best_val_accuracy = 0.0
        self.best_model_state = None

    def train_epoch(self, train_loader):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0.0
        correct_predictions = 0
        total_predictions = 0

        progress_bar = tqdm(train_loader, desc="Training", leave=False)

        for batch in progress_bar:
            # Move batch to device
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            audio_features = batch['audio_features'].to(self.device)
            linguistic_features = batch['linguistic_features'].to(self.device)
            labels = batch['label'].to(self.device)

            # Zero gradients
            self.optimizer.zero_grad()

            # Forward pass
            try:
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    audio_features=audio_features,
                    linguistic_features=linguistic_features
                )

                logits = outputs['logits']
                loss = self.criterion(logits, labels)

                # Backward pass
                loss.backward()

                # Gradient clipping to prevent exploding gradients
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

                self.optimizer.step()

                # Statistics
                total_loss += loss.item()
                predictions = torch.argmax(logits, dim=1)
                correct_predictions += (predictions == labels).sum().item()
                total_predictions += labels.size(0)

                # Update progress bar
                current_accuracy = correct_predictions / total_predictions
                progress_bar.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'Acc': f'{current_accuracy:.4f}'
                })

            except Exception as e:
                print(f"Error in training batch: {e}")
                continue

        avg_loss = total_loss / len(train_loader)
        accuracy = correct_predictions / total_predictions

        return avg_loss, accuracy

    def validate_epoch(self, val_loader):
        """Validate for one epoch"""
        self.model.eval()
        total_loss = 0.0
        correct_predictions = 0
        total_predictions = 0
        all_predictions = []
        all_labels = []
        all_participant_ids = []

        with torch.no_grad():
            progress_bar = tqdm(val_loader, desc="Validation", leave=False)

            for batch in progress_bar:
                # Move batch to device
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                audio_features = batch['audio_features'].to(self.device)
                linguistic_features = batch['linguistic_features'].to(self.device)
                labels = batch['label'].to(self.device)

                try:
                    # Forward pass
                    outputs = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        audio_features=audio_features,
                        linguistic_features=linguistic_features
                    )

                    logits = outputs['logits']
                    loss = self.criterion(logits, labels)

                    # Statistics
                    total_loss += loss.item()
                    predictions = torch.argmax(logits, dim=1)
                    correct_predictions += (predictions == labels).sum().item()
                    total_predictions += labels.size(0)

                    # Store for detailed analysis
                    all_predictions.extend(predictions.cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())
                    all_participant_ids.extend(batch['participant_id'])

                    # Update progress bar
                    current_accuracy = correct_predictions / total_predictions
                    progress_bar.set_postfix({
                        'Loss': f'{loss.item():.4f}',
                        'Acc': f'{current_accuracy:.4f}'
                    })

                except Exception as e:
                    print(f"Error in validation batch: {e}")
                    continue

        avg_loss = total_loss / len(val_loader)
        accuracy = correct_predictions / total_predictions

        return avg_loss, accuracy, all_predictions, all_labels, all_participant_ids

    def train(self, train_loader, val_loader, num_epochs=10, save_path='best_model.pt'):
        """Full training loop with validation"""
        print(f"Starting training for {num_epochs} epochs...")
        print(f"Training on {len(train_loader.dataset)} samples")
        print(f"Validating on {len(val_loader.dataset)} samples")
        print(f"Device: {self.device}")

        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            print("-" * 50)

            # Training
            train_loss, train_accuracy = self.train_epoch(train_loader)
            self.train_losses.append(train_loss)
            self.train_accuracies.append(train_accuracy)

            # Validation
            val_loss, val_accuracy, val_predictions, val_labels, val_ids = self.validate_epoch(val_loader)
            self.val_losses.append(val_loss)
            self.val_accuracies.append(val_accuracy)

            # Learning rate scheduling
            self.scheduler.step(val_accuracy)

            # Save best model
            if val_accuracy > self.best_val_accuracy:
                self.best_val_accuracy = val_accuracy
                self.best_model_state = self.model.state_dict().copy()
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.best_model_state,
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'best_val_accuracy': self.best_val_accuracy,
                    'train_losses': self.train_losses,
                    'val_losses': self.val_losses,
                    'train_accuracies': self.train_accuracies,
                    'val_accuracies': self.val_accuracies
                }, save_path)
                print(f"✓ New best model saved with validation accuracy: {val_accuracy:.4f}")

            # Print epoch results
            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")

            # Detailed validation metrics for best epochs
            if val_accuracy == self.best_val_accuracy:
                self._print_detailed_metrics(val_labels, val_predictions)

        print(f"\nTraining completed!")
        print(f"Best validation accuracy: {self.best_val_accuracy:.4f}")

        # Load best model
        if self.best_model_state is not None:
            self.model.load_state_dict(self.best_model_state)
            print("✓ Best model loaded for final evaluation")

    def _print_detailed_metrics(self, true_labels, predictions):
        """Print detailed classification metrics"""
        accuracy = accuracy_score(true_labels, predictions)
        precision = precision_score(true_labels, predictions, average='weighted', zero_division=0)
        recall = recall_score(true_labels, predictions, average='weighted', zero_division=0)
        f1 = f1_score(true_labels, predictions, average='weighted', zero_division=0)

        print(f"\nDetailed Metrics:")
        print(f"  Accuracy: {accuracy:.4f}")
        print(f"  Precision: {precision:.4f}")
        print(f"  Recall: {recall:.4f}")
        print(f"  F1-Score: {f1:.4f}")

    def evaluate_model(self, test_loader, class_names=['CN', 'AD']):
        """Comprehensive model evaluation"""
        print("\n" + "="*60)
        print("COMPREHENSIVE MODEL EVALUATION")
        print("="*60)

        self.model.eval()
        all_predictions = []
        all_labels = []
        all_probabilities = []
        all_participant_ids = []
        detailed_results = []

        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Evaluating"):
                # Move batch to device
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                audio_features = batch['audio_features'].to(self.device)
                linguistic_features = batch['linguistic_features'].to(self.device)
                labels = batch['label'].to(self.device)

                try:
                    # Forward pass
                    outputs = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        audio_features=audio_features,
                        linguistic_features=linguistic_features
                    )

                    logits = outputs['logits']
                    probabilities = torch.softmax(logits, dim=1)
                    predictions = torch.argmax(logits, dim=1)

                    # Store results
                    all_predictions.extend(predictions.cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())
                    all_probabilities.extend(probabilities.cpu().numpy())
                    all_participant_ids.extend(batch['participant_id'])

                    # Store detailed results for analysis
                    for i in range(len(batch['participant_id'])):
                        detailed_results.append({
                            'participant_id': batch['participant_id'][i],
                            'true_label': labels[i].cpu().item(),
                            'predicted_label': predictions[i].cpu().item(),
                            'cn_probability': probabilities[i][0].cpu().item(),
                            'ad_probability': probabilities[i][1].cpu().item(),
                            'correct': labels[i].cpu().item() == predictions[i].cpu().item(),
                            'class_name': batch['class_name'][i],
                            'transcript_preview': batch['transcript_preview'][i]
                        })

                except Exception as e:
                    print(f"Error in evaluation batch: {e}")
                    continue

        # Calculate metrics
        accuracy = accuracy_score(all_labels, all_predictions)
        precision = precision_score(all_labels, all_predictions, average=None, zero_division=0)
        recall = recall_score(all_labels, all_predictions, average=None, zero_division=0)
        f1 = f1_score(all_labels, all_predictions, average=None, zero_division=0)

        # Print comprehensive results
        print(f"\nOVERALL PERFORMANCE:")
        print(f"Total Samples: {len(all_labels)}")
        print(f"Overall Accuracy: {accuracy:.4f} ({accuracy*100:.1f}%)")
        print(f"Overall Precision: {np.mean(precision):.4f}")
        print(f"Overall Recall: {np.mean(recall):.4f}")
        print(f"Overall F1-Score: {np.mean(f1):.4f}")

        print(f"\nPER-CLASS PERFORMANCE:")
        for i, class_name in enumerate(class_names):
            class_count = sum(1 for label in all_labels if label == i)
            print(f"{class_name}:")
            print(f"  Count: {class_count}")
            print(f"  Precision: {precision[i]:.4f}")
            print(f"  Recall: {recall[i]:.4f}")
            print(f"  F1-Score: {f1[i]:.4f}")

        # Confusion Matrix
        cm = confusion_matrix(all_labels, all_predictions)
        print(f"\nCONFUSION MATRIX:")
        print(f"        Predicted")
        print(f"        CN    AD")
        print(f"Actual CN {cm[0,0]:4d}  {cm[0,1]:4d}")
        print(f"       AD {cm[1,0]:4d}  {cm[1,1]:4d}")

        # Error Analysis
        print(f"\nERROR ANALYSIS:")
        errors = [result for result in detailed_results if not result['correct']]
        print(f"Total Errors: {len(errors)}")

        if errors:
            print(f"\nSample Errors:")
            for i, error in enumerate(errors[:5]):  # Show first 5 errors
                true_class = class_names[error['true_label']]
                pred_class = class_names[error['predicted_label']]
                confidence = max(error['cn_probability'], error['ad_probability'])
                print(f"  {i+1}. Participant {error['participant_id']}: {true_class} → {pred_class} (conf: {confidence:.3f})")
                print(f"     Transcript: {error['transcript_preview']}")

        # Save detailed results
        results_df = pd.DataFrame(detailed_results)
        results_df.to_csv('detailed_evaluation_results.csv', index=False)
        print(f"\nDetailed results saved to 'detailed_evaluation_results.csv'")

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'confusion_matrix': cm,
            'detailed_results': detailed_results,
            'predictions': all_predictions,
            'labels': all_labels,
            'probabilities': all_probabilities
        }

    def plot_training_history(self, save_path='training_history.png'):
        """Plot training history"""
        if not self.train_losses:
            print("No training history to plot")
            return

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

        # Plot losses
        epochs = range(1, len(self.train_losses) + 1)
        ax1.plot(epochs, self.train_losses, 'b-', label='Training Loss', linewidth=2)
        ax1.plot(epochs, self.val_losses, 'r-', label='Validation Loss', linewidth=2)
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training and Validation Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # Plot accuracies
        ax2.plot(epochs, self.train_accuracies, 'b-', label='Training Accuracy', linewidth=2)
        ax2.plot(epochs, self.val_accuracies, 'r-', label='Validation Accuracy', linewidth=2)
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy')
        ax2.set_title('Training and Validation Accuracy')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        ax2.set_ylim(0, 1)

        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
        print(f"Training history plot saved to {save_path}")

# ============================================================================
# STEP 7: MAIN EXECUTION AND PIPELINE
# ============================================================================

def main():
    """Main execution pipeline for Enhanced Alzheimer's Detection"""

    print("="*80)
    print("ENHANCED ALZHEIMER'S DETECTION WITH AUTOMATIC SPEECH RECOGNITION")
    print("="*80)

    # Configuration
    config = {
        'data_dir': './data',
        'output_dir': './extracted_data',
        'model_save_path': './best_ad_model.pt',
        'batch_size': 8,
        'num_epochs': 15,
        'learning_rate': 2e-5,
        'max_text_length': 512,
        'audio_max_length': 16*16000,  # 16 seconds
        'test_size': 0.2,
        'val_size': 0.15,
        'random_state': 42,
        'asr_model': 'openai/whisper-base'  # Can change to 'facebook/wav2vec2-base-960h'
    }

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    try:
        # Step 1: Initialize data processor with ASR
        print("\n1. Initializing Enhanced Data Processor with ASR...")
        processor = EnhancedADReSSDataProcessor(
            output_dir=config['output_dir'],
            asr_model=config['asr_model']
        )

        # Step 2: Look for dataset files
        print("\n2. Looking for ADReSS dataset files...")
        data_dir = Path(config['data_dir'])

        # Look for compressed dataset files
        dataset_files = []
        for pattern in ['*.tar.gz', '*.tgz', '*.zip']:
            dataset_files.extend(list(data_dir.glob(pattern)))

        if not dataset_files:
            print("⚠️  No dataset files found. Creating synthetic data for demonstration...")
            # Create synthetic dataset for demonstration
            dataset_info = create_synthetic_dataset_with_asr(processor)
        else:
            print(f"Found {len(dataset_files)} dataset files")

            # Process first dataset file
            dataset_file = dataset_files[0]
            print(f"Processing: {dataset_file}")

            # Extract dataset
            extract_path = processor.extract_adress_dataset(
                dataset_file,
                f"dataset_{dataset_file.stem}"
            )

            if extract_path is None:
                print("❌ Failed to extract dataset. Creating synthetic data...")
                dataset_info = create_synthetic_dataset_with_asr(processor)
            else:
                # Process with ASR
                dataset_info = processor.process_adress_dataset_with_asr(extract_path)

        if not dataset_info['paired_data']:
            print("❌ No valid data found. Exiting...")
            return

        print(f"✅ Successfully processed {len(dataset_info['paired_data'])} samples")

        # Step 3: Prepare datasets
        print("\n3. Preparing datasets...")

        # Split data
        train_data, temp_data = train_test_split(
            dataset_info['paired_data'],
            test_size=config['test_size'] + config['val_size'],
            random_state=config['random_state'],
            stratify=[item['label'] for item in dataset_info['paired_data']]
        )

        val_data, test_data = train_test_split(
            temp_data,
            test_size=config['test_size'] / (config['test_size'] + config['val_size']),
            random_state=config['random_state'],
            stratify=[item['label'] for item in temp_data]
        )

        print(f"Train samples: {len(train_data)}")
        print(f"Validation samples: {len(val_data)}")
        print(f"Test samples: {len(test_data)}")

        # Initialize components
        audio_processor = SimpleAudioProcessor()
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        # Create datasets
        train_dataset = EnhancedMultiModalDataset(
            train_data, audio_processor, tokenizer,
            max_text_length=config['max_text_length'],
            audio_max_length=config['audio_max_length']
        )

        val_dataset = EnhancedMultiModalDataset(
            val_data, audio_processor, tokenizer,
            max_text_length=config['max_text_length'],
            audio_max_length=config['audio_max_length']
        )

        test_dataset = EnhancedMultiModalDataset(
            test_data, audio_processor, tokenizer,
            max_text_length=config['max_text_length'],
            audio_max_length=config['audio_max_length']
        )

        # Create data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=0  # Set to 0 to avoid multiprocessing issues
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=0
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=0
        )

        # Step 4: Initialize model
        print("\n4. Initializing Enhanced Multimodal Model...")
        model = EnhancedMultiModalADClassifier()

        # Step 5: Initialize trainer
        print("\n5. Initializing Model Trainer...")
        trainer = ModelTrainer(
            model=model,
            device=device,
            learning_rate=config['learning_rate']
        )

        # Step 6: Train model
        print("\n6. Starting Model Training...")
        trainer.train(
            train_loader=train_loader,
            val_loader=val_loader,
            num_epochs=config['num_epochs'],
            save_path=config['model_save_path']
        )

        # Step 7: Evaluate model
        print("\n7. Evaluating Model...")
        evaluation_results = trainer.evaluate_model(test_loader)

        # Step 8: Plot training history
        print("\n8. Generating Training History Plot...")
        trainer.plot_training_history()

        # Step 9: Feature Analysis
        print("\n9. Performing Feature Analysis...")
        analyze_model_features(trainer.model, test_loader, device)

        print("\n" + "="*80)
        print("✅ ENHANCED ALZHEIMER'S DETECTION PIPELINE COMPLETED SUCCESSFULLY!")
        print("="*80)
        print(f"Final Test Accuracy: {evaluation_results['accuracy']:.4f}")
        print(f"Model saved to: {config['model_save_path']}")
        print("Check the generated plots and CSV files for detailed analysis.")

    except Exception as e:
        print(f"❌ Error in main pipeline: {e}")
        import traceback
        traceback.print_exc()

def create_synthetic_dataset_with_asr(processor):
    """Create synthetic dataset for demonstration purposes"""
    print("Creating synthetic dataset with ASR for demonstration...")

    # Create synthetic data structure
    synthetic_data = []

    # Generate synthetic samples
    for i in range(100):  # 100 synthetic samples
        participant_id = f"SYNTH_{i:03d}"

        # Alternate between AD and CN
        label = i % 2
        class_name = 'AD' if label == 1 else 'CN'

        # Create synthetic transcript based on class
        if label == 1:  # AD
            transcript = generate_ad_like_transcript()
        else:  # CN
            transcript = generate_cn_like_transcript()

        synthetic_data.append({
            'participant_id': participant_id,
            'audio_path': f'synthetic_audio_{participant_id}.wav',
            'transcript': transcript,
            'label': label,
            'class_name': class_name,
            'transcript_source': 'synthetic'
        })

    return {
        'audio_files': [],
        'transcript_files': [],
        'metadata_files': [],
        'labels': {},
        'paired_data': synthetic_data,
        'generated_transcripts': {}
    }

def generate_ad_like_transcript():
    """Generate AD-like transcript with typical characteristics"""
    ad_patterns = [
        "Um, let me see... the boy is... um... he's climbing on the... the thing there...",
        "There's a woman in the kitchen and she's... what is she doing... oh yes, washing dishes I think...",
        "The... the thing with water is overflowing and there's... there's problems happening...",
        "I see children playing and... um... something about cookies or... or food...",
        "The lady is trying to... to do something with the... with the sink and water is...",
        "There are people in the picture and they're... um... doing things but I can't... I can't remember..."
    ]
    return np.random.choice(ad_patterns)

def generate_cn_like_transcript():
    """Generate Control-like transcript with typical characteristics"""
    cn_patterns = [
        "In this picture, I can see a kitchen scene where a woman is washing dishes at the sink. The sink appears to be overflowing with water onto the floor.",
        "There's a boy who has climbed up on a stool to reach the cookie jar on the counter. His sister is asking him to give her a cookie.",
        "The scene shows a typical kitchen with a woman doing dishes while children are nearby. The boy is reaching for cookies while standing on a chair.",
        "I can observe a domestic scene with a mother washing dishes. There are two children in the kitchen, and one of them is trying to get cookies from a jar.",
        "The picture depicts a kitchen where a woman is at the sink with running water. There are children present, and one child is reaching up to get something from the counter.",
        "This shows a busy kitchen scene with a woman washing dishes while water overflows. Meanwhile, children are nearby, with one trying to access a cookie jar."
    ]
    return np.random.choice(cn_patterns)

def analyze_model_features(model, test_loader, device):
    """Analyze model features and attention patterns"""
    print("Analyzing model features and attention patterns...")

    model.eval()
    attention_weights_list = []
    feature_importances = []

    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            if i >= 10:  # Analyze first 10 batches
                break

            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            audio_features = batch['audio_features'].to(device)
            linguistic_features = batch['linguistic_features'].to(device)

            try:
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    audio_features=audio_features,
                    linguistic_features=linguistic_features
                )

                # Store attention weights
                attention_weights = outputs['attention_weights'].cpu().numpy()
                attention_weights_list.append(attention_weights)

            except Exception as e:
                print(f"Error in feature analysis: {e}")
                continue

    if attention_weights_list:
        # Average attention weights across batches
        avg_attention = np.mean(attention_weights_list, axis=0)

        # Plot attention patterns
        plt.figure(figsize=(10, 6))

        # Plot attention weights for each modality
        modalities = ['Text', 'Audio', 'Linguistic']
        attention_by_modality = np.mean(avg_attention, axis=(0, 1))  # Average across heads and batches

        plt.bar(modalities, attention_by_modality)
        plt.title('Average Attention Weights by Modality')
        plt.ylabel('Attention Weight')
        plt.xlabel('Modality')
        plt.grid(True, alpha=0.3)

        for i, v in enumerate(attention_by_modality):
            plt.text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom')

        plt.tight_layout()
        plt.savefig('modality_attention_analysis.png', dpi=300, bbox_inches='tight')
        plt.show()

        print("✅ Feature analysis completed. Attention analysis plot saved.")
    else:
        print("⚠️  No attention weights collected for analysis.")

if __name__ == "__main__":
    main()


SyntaxError: invalid syntax (<ipython-input-23-d9ab2c23f4da>, line 1)

In [24]:
# ============================================================================
# ENHANCED ADReSSo21 DATA PROCESSOR WITH SEGMENTATION FILES
# ============================================================================

import os
import re
import tarfile
import pickle
import pandas as pd
import numpy as np
import librosa
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from transformers import (
    BertTokenizer, BertModel,
    WhisperProcessor, WhisperForConditionalGeneration,
    Wav2Vec2Processor, Wav2Vec2ForCTC,
    ViTModel
)
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy import ndimage
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# STEP 1: ENHANCED ADRESSO21 DATA PROCESSOR
# ============================================================================

class ADReSSo21DataProcessor:
    """Enhanced ADReSSo21 data processor with segmentation file support"""

    def __init__(self, output_dir='./extracted_data', use_asr=True, asr_model="openai/whisper-base"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        # Initialize ASR if requested
        self.use_asr = use_asr
        if use_asr:
            self.transcriber = SpeechTranscriber(model_name=asr_model)
        else:
            self.transcriber = None

    def extract_adresso_dataset(self, tar_path, dataset_name):
        """Extract ADReSSo21 dataset and organize files properly"""
        extract_path = self.output_dir / dataset_name
        extract_path.mkdir(exist_ok=True)

        print(f"Extracting {tar_path} to {extract_path}")

        try:
            with tarfile.open(tar_path, 'r:gz') as tar:
                tar.extractall(path=extract_path)
            print(f"Successfully extracted {dataset_name}")

            # Find the actual dataset directory structure
            self._explore_directory_structure(extract_path)
            return extract_path
        except Exception as e:
            print(f"Error extracting {tar_path}: {e}")
            return None

    def _explore_directory_structure(self, base_path):
        """Explore and print directory structure"""
        print(f"\nDirectory structure for {base_path.name}:")
        for root, dirs, files in os.walk(base_path):
            level = root.replace(str(base_path), '').count(os.sep)
            indent = ' ' * 2 * level
            print(f"{indent}{os.path.basename(root)}/")
            subindent = ' ' * 2 * (level + 1)
            for file in files[:5]:
                print(f"{subindent}{file}")
            if len(files) > 5:
                print(f"{subindent}... and {len(files) - 5} more files")

    def process_adresso_dataset(self, extract_path, dataset_type='diagnosis'):
        """Process ADReSSo21 dataset with segmentation files and optional ASR"""
        dataset_info = {
            'audio_files': [],
            'segmentation_files': [],
            'metadata_files': [],
            'labels': {},
            'paired_data': [],
            'generated_transcripts': {},
            'dataset_type': dataset_type
        }

        # Find ADReSSo21 directory
        adresso_dirs = list(extract_path.rglob("*ADReSSo21*"))
        if adresso_dirs:
            main_dir = adresso_dirs[0]
        else:
            main_dir = extract_path

        print(f"Processing from directory: {main_dir}")

        # Find audio files
        audio_patterns = ['**/*.wav', '**/*.mp3', '**/*.flac']
        for pattern in audio_patterns:
            dataset_info['audio_files'].extend(list(main_dir.glob(pattern)))

        # Find segmentation CSV files
        csv_patterns = ['**/*.csv']
        for pattern in csv_patterns:
            csv_files = list(main_dir.glob(pattern))
            for csv_file in csv_files:
                # Filter out metadata files like MMSE scores
                if 'mmse' not in csv_file.name.lower() and 'test_results' not in csv_file.name.lower():
                    dataset_info['segmentation_files'].append(csv_file)
                else:
                    dataset_info['metadata_files'].append(csv_file)

        print(f"Found {len(dataset_info['audio_files'])} audio files")
        print(f"Found {len(dataset_info['segmentation_files'])} segmentation files")
        print(f"Found {len(dataset_info['metadata_files'])} metadata files")

        # Process segmentation files to get transcripts
        transcripts_from_segmentation = self._process_segmentation_files(dataset_info['segmentation_files'])

        # Generate additional transcripts using ASR if requested
        if self.use_asr and self.transcriber:
            print("Generating additional transcripts using ASR...")
            audio_paths = [str(path) for path in dataset_info['audio_files']]
            generated_transcripts = self.transcriber.batch_transcribe(audio_paths)
            dataset_info['generated_transcripts'] = generated_transcripts

        # Process labels from directory structure
        labels = self._extract_labels_from_structure(dataset_info['audio_files'], dataset_type)
        dataset_info['labels'] = labels

        # Create paired dataset
        paired_data = self._create_paired_data(dataset_info, transcripts_from_segmentation)
        dataset_info['paired_data'] = paired_data

        return dataset_info

    def _process_segmentation_files(self, segmentation_files):
        """Process CSV segmentation files to extract transcripts"""
        transcripts = {}

        print("Processing segmentation files...")
        for csv_file in tqdm(segmentation_files, desc="Processing CSV files"):
            try:
                # Extract participant ID from filename
                participant_id = self._extract_participant_id(csv_file.name)

                # Read CSV file
                df = pd.read_csv(csv_file)

                # The CSV typically contains columns like: start_time, end_time, speaker, transcript
                # Different datasets might have different column names
                transcript_text = self._extract_transcript_from_df(df)

                if transcript_text and len(transcript_text.strip()) > 0:
                    transcripts[participant_id] = transcript_text

            except Exception as e:
                print(f"Error processing {csv_file.name}: {e}")
                continue

        print(f"Extracted transcripts from {len(transcripts)} segmentation files")
        return transcripts

    def _extract_transcript_from_df(self, df):
        """Extract transcript text from segmentation DataFrame"""
        transcript_parts = []

        # Common column name variations for transcripts
        text_columns = []
        for col in df.columns:
            col_lower = col.lower().strip()
            if any(keyword in col_lower for keyword in ['transcript', 'text', 'word', 'utterance', 'speech']):
                text_columns.append(col)

        if not text_columns:
            # If no obvious text column, try common patterns
            possible_cols = ['transcript', 'text', 'word', 'utterance', 'speech_text', 'content']
            for col in possible_cols:
                if col in df.columns:
                    text_columns.append(col)
                    break

        # If still no text columns found, use the last column (common pattern)
        if not text_columns and len(df.columns) > 0:
            text_columns = [df.columns[-1]]

        # Extract text from identified columns
        for col in text_columns:
            try:
                # Filter out non-speech entries (like [NOISE], [MUSIC], etc.)
                valid_texts = df[col].dropna()
                valid_texts = valid_texts[~valid_texts.astype(str).str.contains(r'^\[.*\]$', na=False)]
                valid_texts = valid_texts[valid_texts.astype(str).str.len() > 1]

                transcript_parts.extend(valid_texts.astype(str).tolist())
            except Exception as e:
                continue

        # Join all transcript parts
        full_transcript = ' '.join(transcript_parts).strip()

        # Clean the transcript
        full_transcript = self._clean_transcript(full_transcript)

        return full_transcript

    def _clean_transcript(self, transcript):
        """Clean and normalize transcript text"""
        if not transcript:
            return ""

        # Remove excessive whitespace
        transcript = re.sub(r'\s+', ' ', transcript)

        # Remove common artifacts
        transcript = re.sub(r'\[.*?\]', '', transcript)  # Remove [NOISE], [MUSIC], etc.
        transcript = re.sub(r'<.*?>', '', transcript)    # Remove <unk>, etc.
        transcript = re.sub(r'\*.*?\*', '', transcript)  # Remove *action* markers

        # Normalize punctuation
        transcript = re.sub(r'[.]{2,}', '...', transcript)  # Normalize ellipses
        transcript = re.sub(r'[?!]{2,}', '!', transcript)   # Normalize repeated punctuation

        # Remove very short words that might be artifacts
        words = transcript.split()
        cleaned_words = [word for word in words if len(word) > 1 or word.lower() in ['i', 'a']]

        return ' '.join(cleaned_words).strip()

    def _extract_labels_from_structure(self, audio_files, dataset_type):
        """Extract labels from file paths based on dataset type"""
        labels = {}

        for audio_file in audio_files:
            participant_id = self._extract_participant_id(audio_file.name)
            path_str = str(audio_file).lower()

            if dataset_type == 'diagnosis':
                # For diagnosis task: AD vs CN (Control Normal)
                if '/ad/' in path_str or 'dementia' in path_str or 'alzheimer' in path_str:
                    label = 1  # AD
                    class_name = 'AD'
                elif '/cn/' in path_str or 'control' in path_str or 'normal' in path_str:
                    label = 0  # CN
                    class_name = 'CN'
                else:
                    # Default based on filename patterns
                    if any(marker in audio_file.name.lower() for marker in ['ad', 'dem', 'alz']):
                        label = 1
                        class_name = 'AD'
                    else:
                        label = 0
                        class_name = 'CN'

            elif dataset_type == 'progression':
                # For progression task: decline vs no_decline
                if '/decline/' in path_str and '/no_decline/' not in path_str:
                    label = 1  # Decline
                    class_name = 'Decline'
                elif '/no_decline/' in path_str or 'no-decline' in path_str:
                    label = 0  # No decline
                    class_name = 'No_Decline'
                else:
                    # Default classification for progression
                    if 'decline' in path_str and 'no' not in path_str:
                        label = 1
                        class_name = 'Decline'
                    else:
                        label = 0
                        class_name = 'No_Decline'
            else:
                # Unknown dataset type, default to binary classification
                label = 0
                class_name = 'Unknown'

            labels[participant_id] = {
                'label': label,
                'class_name': class_name,
                'audio_path': audio_file
            }

        return labels

    def _extract_participant_id(self, filename):
        """Extract participant ID from filename"""
        patterns = [
            r'adrso(\d{3})',          # adrso123
            r'adrsp(\d{3})',          # adrsp123
            r'adrspt(\d{1,3})',       # adrspt1, adrspt12
            r'(\d{3})',               # 3-digit numbers
            r'([A-Z]\d{2,3})',        # Letter followed by 2-3 digits
            r'(S\d{3})',              # S followed by 3 digits
        ]

        for pattern in patterns:
            match = re.search(pattern, filename)
            if match:
                return match.group(1) if pattern.startswith(r'(\d') else match.group(0)

        return Path(filename).stem

    def _create_paired_data(self, dataset_info, transcripts_from_segmentation):
        """Create paired audio-transcript dataset"""
        paired_data = []

        for participant_id, label_info in dataset_info['labels'].items():
            # Get transcript from segmentation file first
            transcript = transcripts_from_segmentation.get(participant_id, "")

            # If no segmentation transcript and ASR is available, use ASR
            if not transcript and dataset_info.get('generated_transcripts'):
                transcript = dataset_info['generated_transcripts'].get(participant_id, "")

            # Clean and validate transcript
            transcript = self._clean_transcript(transcript)

            # Create meaningful placeholder if transcript is still empty
            if not transcript or len(transcript.strip()) < 10:
                transcript = f"Audio sample from participant {participant_id}. Limited speech content available."

            # Determine transcript source
            if participant_id in transcripts_from_segmentation and transcripts_from_segmentation[participant_id]:
                transcript_source = 'segmentation_file'
            elif dataset_info.get('generated_transcripts', {}).get(participant_id):
                transcript_source = 'ASR_generated'
            else:
                transcript_source = 'placeholder'

            paired_data.append({
                'participant_id': participant_id,
                'audio_path': str(label_info['audio_path']),
                'transcript': transcript,
                'label': label_info['label'],
                'class_name': label_info['class_name'],
                'transcript_source': transcript_source,
                'dataset_type': dataset_info['dataset_type']
            })

        print(f"Created {len(paired_data)} paired samples")

        # Print class distribution
        labels = [item['label'] for item in paired_data]
        unique, counts = np.unique(labels, return_counts=True)
        for cls, count in zip(unique, counts):
            class_names = list(set([item['class_name'] for item in paired_data if item['label'] == cls]))
            class_name = class_names[0] if class_names else f'Class_{cls}'
            print(f"  {class_name}: {count} samples ({count/len(labels)*100:.1f}%)")

        # Print transcript source distribution
        sources = [item['transcript_source'] for item in paired_data]
        unique_sources, source_counts = np.unique(sources, return_counts=True)
        print(f"\nTranscript sources:")
        for source, count in zip(unique_sources, source_counts):
            print(f"  {source}: {count} samples ({count/len(sources)*100:.1f}%)")

        # Print sample transcripts for verification
        print(f"\nSample transcripts:")
        print("-" * 50)
        for i, sample in enumerate(paired_data[:3]):
            print(f"Participant {sample['participant_id']} ({sample['class_name']}) - Source: {sample['transcript_source']}")
            print(f"Transcript: {sample['transcript'][:150]}...")
            print()

        return paired_data

# ============================================================================
# STEP 2: SPEECH TRANSCRIBER (OPTIONAL ASR)
# ============================================================================

class SpeechTranscriber:
    """Automatic Speech Recognition for generating transcripts from audio"""

    def __init__(self, model_name="openai/whisper-base", cache_dir="./asr_cache"):
        """Initialize ASR model"""
        self.model_name = model_name
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)

        print(f"Loading ASR model: {model_name}")

        if "whisper" in model_name.lower():
            self.processor = WhisperProcessor.from_pretrained(model_name)
            self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
            self.asr_type = "whisper"
        else:
            self.processor = Wav2Vec2Processor.from_pretrained(model_name)
            self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
            self.asr_type = "wav2vec2"

        # Move to GPU if available
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.model.to(self.device)
        self.model.eval()

        print(f"ASR model loaded on {self.device}")

    def transcribe_audio_file(self, audio_path, use_cache=True):
        """Transcribe a single audio file"""
        audio_path = Path(audio_path)

        # Check cache first
        cache_file = self.cache_dir / f"{audio_path.stem}_transcript.pkl"
        if use_cache and cache_file.exists():
            try:
                with open(cache_file, 'rb') as f:
                    cached_result = pickle.load(f)
                return cached_result['transcript']
            except:
                pass

        try:
            # Load audio
            audio, sr = librosa.load(str(audio_path), sr=16000)

            if len(audio) < 1600:  # Less than 0.1 seconds
                transcript = ""
            else:
                transcript = self._transcribe_audio_array(audio)

            # Cache result
            if use_cache:
                try:
                    with open(cache_file, 'wb') as f:
                        pickle.dump({
                            'audio_path': str(audio_path),
                            'transcript': transcript,
                            'model': self.model_name
                        }, f)
                except:
                    pass

            return transcript

        except Exception as e:
            print(f"Error transcribing {audio_path}: {e}")
            return f"[Transcription failed for {audio_path.name}]"

    def _transcribe_audio_array(self, audio_array):
        """Transcribe audio array using the loaded model"""
        try:
            if self.asr_type == "whisper":
                return self._whisper_transcribe(audio_array)
            else:
                return self._wav2vec2_transcribe(audio_array)
        except Exception as e:
            print(f"Transcription error: {e}")
            return "[Transcription error]"

    def _whisper_transcribe(self, audio_array):
        """Transcribe using Whisper model"""
        inputs = self.processor(
            audio_array,
            sampling_rate=16000,
            return_tensors="pt"
        ).input_features.to(self.device)

        with torch.no_grad():
            predicted_ids = self.model.generate(inputs, max_length=448)

        transcript = self.processor.batch_decode(
            predicted_ids,
            skip_special_tokens=True
        )[0]

        return transcript.strip()

    def _wav2vec2_transcribe(self, audio_array):
        """Transcribe using Wav2Vec2 model"""
        inputs = self.processor(
            audio_array,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        ).input_values.to(self.device)

        with torch.no_grad():
            logits = self.model(inputs).logits

        predicted_ids = torch.argmax(logits, dim=-1)
        transcript = self.processor.batch_decode(predicted_ids)[0]

        return transcript.strip().lower()

    def batch_transcribe(self, audio_paths, batch_size=8):
        """Transcribe multiple audio files with progress bar"""
        transcripts = {}

        print(f"Transcribing {len(audio_paths)} audio files...")

        for audio_path in tqdm(audio_paths, desc="Transcribing"):
            transcript = self.transcribe_audio_file(audio_path)
            participant_id = self._extract_participant_id(Path(audio_path).name)
            transcripts[participant_id] = transcript

        return transcripts

    def _extract_participant_id(self, filename):
        """Extract participant ID from filename"""
        patterns = [
            r'adrso(\d{3})',
            r'adrsp(\d{3})',
            r'adrspt(\d{1,3})',
            r'(\d{3})',
            r'([A-Z]\d{2,3})',
            r'(S\d{3})',
        ]

        for pattern in patterns:
            match = re.search(pattern, filename)
            if match:
                return match.group(1) if pattern.startswith(r'(\d') else match.group(0)

        return Path(filename).stem

# ============================================================================
# STEP 3: ENHANCED MULTIMODAL DATASET
# ============================================================================

class EnhancedMultiModalDataset(Dataset):
    """Enhanced dataset with segmentation-based transcripts and linguistic features"""

    def __init__(self, data_samples, audio_processor, tokenizer,
                 max_text_length=512, audio_max_length=16*16000,
                 image_size=(224, 224)):
        self.data_samples = data_samples
        self.audio_processor = audio_processor
        self.tokenizer = tokenizer
        self.max_text_length = max_text_length
        self.audio_max_length = audio_max_length
        self.image_size = image_size

        # Precompute linguistic features
        self._precompute_linguistic_features()

    def _precompute_linguistic_features(self):
        """Precompute comprehensive linguistic features"""
        print("Precomputing linguistic features...")

        for sample in tqdm(self.data_samples, desc="Computing linguistic features"):
            transcript = sample['transcript']

            # Basic linguistic metrics
            words = transcript.split()
            sentences = re.split(r'[.!?]+', transcript)
            sentences = [s.strip() for s in sentences if s.strip()]

            # Advanced linguistic features
            linguistic_features = {
                # Basic counts
                'word_count': len(words),
                'sentence_count': len(sentences),
                'character_count': len(transcript),

                # Word-level features
                'avg_word_length': np.mean([len(word) for word in words]) if words else 0,
                'unique_words': len(set(words)),
                'lexical_diversity': len(set(words)) / len(words) if words else 0,
                'long_words_ratio': sum(1 for word in words if len(word) > 6) / len(words) if words else 0,

                # Sentence-level features
                'avg_sentence_length': np.mean([len(s.split()) for s in sentences]) if sentences else 0,
                'max_sentence_length': max([len(s.split()) for s in sentences]) if sentences else 0,

                # Fluency indicators
                'pause_markers': transcript.count('[pause]') + transcript.count('...') + transcript.count('..'),
                'filler_words': sum(1 for word in words if word.lower() in ['um', 'uh', 'er', 'ah', 'hmm']),
                'repetitions': self._count_repetitions(words),

                # Semantic features
                'content_words': sum(1 for word in words if len(word) > 3),
                'function_words': sum(1 for word in words if word.lower() in
                                    ['the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with']),

                # Complexity measures
                'subordinate_clauses': transcript.count('that') + transcript.count('which') + transcript.count('because'),
                'coordination': transcript.count(' and ') + transcript.count(' or ') + transcript.count(' but '),
            }

            sample['linguistic_features'] = linguistic_features

    def _count_repetitions(self, words):
        """Count word repetitions (potential indicator of cognitive issues)"""
        if len(words) < 2:
            return 0

        repetitions = 0
        for i in range(len(words) - 1):
            if words[i].lower() == words[i + 1].lower():
                repetitions += 1

        return repetitions

    def __len__(self):
        return len(self.data_samples)

    def __getitem__(self, idx):
        sample = self.data_samples[idx]

        # Process text
        text = sample['transcript']
        if not text or text.strip() == "":
            text = "No speech content detected in audio sample"

        try:
            encoding = self.tokenizer(
                text,
                truncation=True,
                padding='max_length',
                max_length=self.max_text_length,
                return_tensors='pt'
            )
        except Exception as e:
            print(f"Error tokenizing text for {sample['participant_id']}: {e}")
            # Create dummy encoding
            encoding = {
                'input_ids': torch.zeros(self.max_text_length, dtype=torch.long),
                'attention_mask': torch.zeros(self.max_text_length, dtype=torch.long),
                'token_type_ids': torch.zeros(self.max_text_length, dtype=torch.long)
            }

        # Process audio
        try:
            audio = self.audio_processor.load_audio(
                sample['audio_path'],
                max_length=self.audio_max_length
            )

            spectrogram = self.audio_processor.extract_mel_spectrogram(audio)
            audio_features = self.audio_processor.resize_spectrogram_to_image(
                spectrogram, self.image_size
            )

        except Exception as e:
            print(f"Error processing audio for {sample['participant_id']}: {e}")
            audio_features = np.random.rand(3, self.image_size[0], self.image_size[1])

        # Enhanced linguistic features vector
        ling_features = sample.get('linguistic_features', {})
        linguistic_vector = np.array([
            ling_features.get('word_count', 0),
            ling_features.get('sentence_count', 0),
            ling_features.get('avg_word_length', 0),
            ling_features.get('unique_words', 0),
            ling_features.get('lexical_diversity', 0),
            ling_features.get('pause_markers', 0),
            ling_features.get('filler_words', 0),
            ling_features.get('repetitions', 0),
            ling_features.get('avg_sentence_length', 0),
            ling_features.get('long_words_ratio', 0),
            ling_features.get('content_words', 0),
            ling_features.get('function_words', 0),
            ling_features.get('subordinate_clauses', 0),
            ling_features.get('coordination', 0),
            ling_features.get('character_count', 0)
        ], dtype=np.float32)

        return {
            'input_ids': encoding['input_ids'].squeeze() if hasattr(encoding['input_ids'], 'squeeze') else encoding['input_ids'],
            'attention_mask': encoding['attention_mask'].squeeze() if hasattr(encoding['attention_mask'], 'squeeze') else encoding['attention_mask'],
            'audio_features': torch.FloatTensor(audio_features),
            'linguistic_features': torch.FloatTensor(linguistic_vector),
            'label': torch.LongTensor([sample['label']]).squeeze(),
            'participant_id': sample['participant_id'],
            'class_name': sample['class_name'],
            'transcript_preview': text[:100] + "..." if len(text) > 100 else text,
            'transcript_source': sample.get('transcript_source', 'unknown'),
            'dataset_type': sample.get('dataset_type', 'unknown')
        }


In [27]:
# ============================================================================
# STEP 4: COMPLETE AUDIO PROCESSOR
# ============================================================================

class SimpleAudioProcessor:
    def __init__(self, sample_rate=16000, n_mels=128):
        self.sample_rate = sample_rate
        self.n_mels = n_mels

    def load_audio(self, audio_path, max_length=None):
        try:
            audio, sr = librosa.load(audio_path, sr=self.sample_rate)
            if audio.ndim > 1:
                audio = np.mean(audio, axis=1)
            audio, _ = librosa.effects.trim(audio, top_db=20)
            if np.max(np.abs(audio)) > 0:
                audio = librosa.util.normalize(audio)

            if max_length is not None:
                if len(audio) > max_length:
                    start = (len(audio) - max_length) // 2
                    audio = audio[start:start + max_length]
                elif len(audio) < max_length:
                    pad_length = max_length - len(audio)
                    audio = np.pad(audio, (0, pad_length), mode='constant')

            return audio
        except Exception as e:
            print(f"Error loading audio: {e}")
            length = max_length if max_length else self.sample_rate * 10
            return np.random.randn(length) * 0.01

    def extract_mel_spectrogram(self, audio):
        try:
            mel_spec = librosa.feature.melspectrogram(
                y=audio, sr=self.sample_rate, n_mels=self.n_mels
            )
            log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
            return log_mel_spec
        except Exception as e:
            print(f"Error extracting mel spectrogram: {e}")
            return np.random.randn(self.n_mels, 100)

    def resize_spectrogram_to_image(self, spectrogram, target_size=(224, 224)):
        """Convert spectrogram to image format for vision models"""
        try:
            # Normalize to [0, 1]
            spec_normalized = (spectrogram - np.min(spectrogram)) / (np.max(spectrogram) - np.min(spectrogram) + 1e-8)

            # Resize to target size
            spec_resized = ndimage.zoom(spec_normalized,
                                      (target_size[0] / spec_normalized.shape[0],
                                       target_size[1] / spec_normalized.shape[1]))

            # Convert to 3-channel image (RGB)
            spec_image = np.stack([spec_resized, spec_resized, spec_resized], axis=0)

            return spec_image
        except Exception as e:
            print(f"Error resizing spectrogram: {e}")
            return np.random.rand(3, target_size[0], target_size[1])

    def extract_acoustic_features(self, audio):
        """Extract comprehensive acoustic features"""
        try:
            features = {}

            # Basic spectral features
            features['spectral_centroid'] = np.mean(librosa.feature.spectral_centroid(y=audio, sr=self.sample_rate))
            features['spectral_bandwidth'] = np.mean(librosa.feature.spectral_bandwidth(y=audio, sr=self.sample_rate))
            features['spectral_rolloff'] = np.mean(librosa.feature.spectral_rolloff(y=audio, sr=self.sample_rate))
            features['zero_crossing_rate'] = np.mean(librosa.feature.zero_crossing_rate(audio))

            # MFCC features
            mfccs = librosa.feature.mfcc(y=audio, sr=self.sample_rate, n_mfcc=13)
            for i in range(13):
                features[f'mfcc_{i}'] = np.mean(mfccs[i])
                features[f'mfcc_{i}_std'] = np.std(mfccs[i])

            # Prosodic features (fundamental frequency)
            f0 = librosa.yin(audio, fmin=50, fmax=300)
            f0_valid = f0[f0 > 0]  # Remove unvoiced frames
            if len(f0_valid) > 0:
                features['f0_mean'] = np.mean(f0_valid)
                features['f0_std'] = np.std(f0_valid)
                features['f0_range'] = np.max(f0_valid) - np.min(f0_valid)
            else:
                features['f0_mean'] = features['f0_std'] = features['f0_range'] = 0

            # Energy features
            rms = librosa.feature.rms(y=audio)[0]
            features['energy_mean'] = np.mean(rms)
            features['energy_std'] = np.std(rms)

            # Temporal features
            features['duration'] = len(audio) / self.sample_rate

            return features
        except Exception as e:
            print(f"Error extracting acoustic features: {e}")
            # Return default features
            return {f'feature_{i}': 0.0 for i in range(30)}

# ============================================================================
# STEP 5: ENHANCED MULTIMODAL MODEL
# ============================================================================

class EnhancedMultiModalModel(nn.Module):
    """Advanced multimodal model combining text, audio, and linguistic features"""

    def __init__(self, num_classes=2, bert_model='bert-base-uncased',
                 linguistic_features_dim=15, dropout_rate=0.3):
        super(EnhancedMultiModalModel, self).__init__()

        # Text encoder (BERT)
        self.bert = BertModel.from_pretrained(bert_model)
        self.text_dim = self.bert.config.hidden_size

        # Audio encoder (CNN for spectrogram images)
        self.audio_encoder = nn.Sequential(
            # First conv block
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.25),

            # Second conv block
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.25),

            # Third conv block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.25),

            # Fourth conv block
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((4, 4)),
            nn.Dropout2d(0.5),
        )

        # Calculate audio feature dimension
        self.audio_dim = 256 * 4 * 4  # 4096

        # Linguistic features encoder
        self.linguistic_encoder = nn.Sequential(
            nn.Linear(linguistic_features_dim, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(64, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
        )
        self.linguistic_dim = 128

        # Feature dimension reduction layers
        self.text_projector = nn.Sequential(
            nn.Linear(self.text_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
        )

        self.audio_projector = nn.Sequential(
            nn.Linear(self.audio_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
        )

        # Cross-modal attention mechanism
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=512, num_heads=8, dropout=dropout_rate, batch_first=True
        )

        # Fusion layer
        total_dim = 512 + 512 + self.linguistic_dim  # text + audio + linguistic
        self.fusion_layer = nn.Sequential(
            nn.Linear(total_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
        )

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(256, num_classes)
        )

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize model weights"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, input_ids, attention_mask, audio_features, linguistic_features):
        batch_size = input_ids.size(0)

        # Text encoding
        text_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        text_features = text_outputs.pooler_output  # [batch_size, 768]
        text_projected = self.text_projector(text_features)  # [batch_size, 512]

        # Audio encoding
        audio_encoded = self.audio_encoder(audio_features)  # [batch_size, 256, 4, 4]
        audio_flattened = audio_encoded.view(batch_size, -1)  # [batch_size, 4096]
        audio_projected = self.audio_projector(audio_flattened)  # [batch_size, 512]

        # Linguistic features encoding
        linguistic_encoded = self.linguistic_encoder(linguistic_features)  # [batch_size, 128]

        # Cross-modal attention between text and audio
        text_expanded = text_projected.unsqueeze(1)  # [batch_size, 1, 512]
        audio_expanded = audio_projected.unsqueeze(1)  # [batch_size, 1, 512]

        # Apply cross-attention
        attended_text, _ = self.cross_attention(text_expanded, audio_expanded, audio_expanded)
        attended_audio, _ = self.cross_attention(audio_expanded, text_expanded, text_expanded)

        attended_text = attended_text.squeeze(1)  # [batch_size, 512]
        attended_audio = attended_audio.squeeze(1)  # [batch_size, 512]

        # Feature fusion
        fused_features = torch.cat([attended_text, attended_audio, linguistic_encoded], dim=1)
        fused_output = self.fusion_layer(fused_features)  # [batch_size, 512]

        # Classification
        logits = self.classifier(fused_output)  # [batch_size, num_classes]

        return {
            'logits': logits,
            'text_features': text_projected,
            'audio_features': audio_projected,
            'linguistic_features': linguistic_encoded,
            'fused_features': fused_output
        }


In [28]:
# ============================================================================
# STEP 6: TRAINING AND EVALUATION
# ============================================================================

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

class ModelTrainer:
    """Enhanced trainer with comprehensive evaluation metrics"""

    def __init__(self, model, device, num_classes=2):
        self.model = model
        self.device = device
        self.num_classes = num_classes
        self.model.to(device)

        # Initialize metrics tracking
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.learning_rates = []

    def train_model(self, train_loader, val_loader, num_epochs=50,
                   learning_rate=2e-5, weight_decay=1e-4, patience=10,
                   warmup_epochs=3):
        """Train the model with early stopping and learning rate scheduling"""

        # Optimizer with different learning rates for different components
        bert_params = list(self.model.bert.parameters())
        other_params = [p for p in self.model.parameters() if p not in bert_params]

        optimizer = optim.AdamW([
            {'params': bert_params, 'lr': learning_rate * 0.1},  # Lower LR for BERT
            {'params': other_params, 'lr': learning_rate}
        ], weight_decay=weight_decay)

        # Learning rate scheduler with warmup
        total_steps = len(train_loader) * num_epochs
        warmup_steps = len(train_loader) * warmup_epochs

        def lr_lambda(current_step):
            if current_step < warmup_steps:
                return float(current_step) / float(max(1, warmup_steps))
            return max(0.0, float(total_steps - current_step) / float(max(1, total_steps - warmup_steps)))

        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

        # Loss function with label smoothing
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

        # Early stopping
        best_val_loss = float('inf')
        best_val_f1 = 0.0
        patience_counter = 0
        best_model_state = None

        print(f"Starting training for {num_epochs} epochs...")
        print(f"Device: {self.device}")
        print(f"Training samples: {len(train_loader.dataset)}")
        print(f"Validation samples: {len(val_loader.dataset)}")
        print(f"Batch size: {train_loader.batch_size}")
        print(f"Steps per epoch: {len(train_loader)}")
        print("-" * 60)

        for epoch in range(num_epochs):
            # Training phase
            train_loss, train_acc = self._train_epoch(train_loader, optimizer, criterion, scheduler)

            # Validation phase
            val_loss, val_acc, val_metrics = self._validate_epoch(val_loader, criterion)

            # Save metrics
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.train_accuracies.append(train_acc)
            self.val_accuracies.append(val_acc)
            self.learning_rates.append(scheduler.get_last_lr()[0])

            # Print progress
            print(f"Epoch {epoch+1:3d}/{num_epochs}: "
                  f"Train[Loss: {train_loss:.4f}, Acc: {train_acc:.4f}] "
                  f"Val[Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_metrics['f1_score']:.4f}] "
                  f"LR: {scheduler.get_last_lr()[0]:.2e}")

            # Early stopping check (using F1 score as primary metric)
            if val_metrics['f1_score'] > best_val_f1:
                best_val_f1 = val_metrics['f1_score']
                best_val_loss = val_loss
                patience_counter = 0
                best_model_state = self.model.state_dict().copy()
                print(f"  → New best model! F1: {best_val_f1:.4f}")
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"\nEarly stopping at epoch {epoch+1} (patience: {patience})")
                break

            # Print detailed metrics every 5 epochs
            if (epoch + 1) % 5 == 0:
                print(f"  Detailed metrics: Precision: {val_metrics['precision']:.4f}, "
                      f"Recall: {val_metrics['recall']:.4f}")
                self._print_confusion_matrix(val_metrics['confusion_matrix'])

        # Load best model
        if best_model_state is not None:
            self.model.load_state_dict(best_model_state)
            print(f"\nLoaded best model with F1: {best_val_f1:.4f}")

        return {
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'train_accuracies': self.train_accuracies,
            'val_accuracies': self.val_accuracies,
            'learning_rates': self.learning_rates,
            'best_val_loss': best_val_loss,
            'best_val_f1': best_val_f1,
            'total_epochs': len(self.train_losses)
        }

    def _train_epoch(self, train_loader, optimizer, criterion, scheduler):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0

        progress_bar = tqdm(train_loader, desc="Training", leave=False)

        for batch_idx, batch in enumerate(progress_bar):
            # Move to device
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            audio_features = batch['audio_features'].to(self.device)
            linguistic_features = batch['linguistic_features'].to(self.device)
            labels = batch['label'].to(self.device)

            # Forward pass
            optimizer.zero_grad()
            outputs = self.model(input_ids, attention_mask, audio_features, linguistic_features)
            loss = criterion(outputs['logits'], labels)

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            # Statistics
            total_loss += loss.item()
            _, predicted = torch.max(outputs['logits'], 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Update progress bar
            current_acc = correct / total
            progress_bar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{current_acc:.4f}',
                'LR': f'{scheduler.get_last_lr()[0]:.2e}'
            })

        avg_loss = total_loss / len(train_loader)
        accuracy = correct / total

        return avg_loss, accuracy

    def _validate_epoch(self, val_loader, criterion):
        """Validate for one epoch"""
        self.model.eval()
        total_loss = 0
        all_predictions = []
        all_labels = []
        all_probabilities = []

        progress_bar = tqdm(val_loader, desc="Validation", leave=False)

        with torch.no_grad():
            for batch in progress_bar:
                # Move to device
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                audio_features = batch['audio_features'].to(self.device)
                linguistic_features = batch['linguistic_features'].to(self.device)
                labels = batch['label'].to(self.device)

                # Forward pass
                outputs = self.model(input_ids, attention_mask, audio_features, linguistic_features)
                loss = criterion(outputs['logits'], labels)

                # Get predictions and probabilities
                probabilities = torch.softmax(outputs['logits'], dim=1)
                _, predicted = torch.max(outputs['logits'], 1)

                # Statistics
                total_loss += loss.item()
                all_predictions.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probabilities.extend(probabilities.cpu().numpy())

                # Update progress bar
                progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})

        avg_loss = total_loss / len(val_loader)

        # Calculate comprehensive metrics
        accuracy = accuracy_score(all_labels, all_predictions)
        precision = precision_score(all_labels, all_predictions, average='macro', zero_division=0)
        recall = recall_score(all_labels, all_predictions, average='macro', zero_division=0)
        f1 = f1_score(all_labels, all_predictions, average='macro', zero_division=0)
        cm = confusion_matrix(all_labels, all_predictions)

        metrics = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'confusion_matrix': cm,
            'predictions': all_predictions,
            'true_labels': all_labels,
            'probabilities': all_probabilities
        }

        return avg_loss, accuracy, metrics

    def evaluate_model(self, test_loader, detailed=True):
        """Comprehensive model evaluation"""
        self.model.eval()
        all_predictions = []
        all_labels = []
        all_probabilities = []
        participant_results = {}
        feature_analysis = defaultdict(list)

        print("Evaluating model on test set...")
        progress_bar = tqdm(test_loader, desc="Testing")

        with torch.no_grad():
            for batch in progress_bar:
                # Move to device
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                audio_features = batch['audio_features'].to(self.device)
                linguistic_features = batch['linguistic_features'].to(self.device)
                labels = batch['label'].to(self.device)

                # Forward pass
                outputs = self.model(input_ids, attention_mask, audio_features, linguistic_features)
                probabilities = torch.softmax(outputs['logits'], dim=1)
                _, predicted = torch.max(outputs['logits'], 1)

                # Store results
                all_predictions.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probabilities.extend(probabilities.cpu().numpy())

                # Store per-participant results and feature analysis
                if detailed:
                    for i in range(len(labels)):
                        participant_id = batch['participant_id'][i] if 'participant_id' in batch else f"sample_{len(participant_results)}"

                        participant_results[participant_id] = {
                            'true_label': labels[i].item(),
                            'predicted_label': predicted[i].item(),
                            'probability': probabilities[i].cpu().numpy(),
                            'confidence': torch.max(probabilities[i]).item(),
                            'text_features': outputs['text_features'][i].cpu().numpy(),
                            'audio_features': outputs['audio_features'][i].cpu().numpy(),
                            'linguistic_features': outputs['linguistic_features'][i].cpu().numpy(),
                        }

                        # Feature analysis
                        label = labels[i].item()
                        feature_analysis[f'text_features_class_{label}'].append(
                            outputs['text_features'][i].cpu().numpy()
                        )
                        feature_analysis[f'audio_features_class_{label}'].append(
                            outputs['audio_features'][i].cpu().numpy()
                        )

        # Calculate metrics
        accuracy = accuracy_score(all_labels, all_predictions)
        precision = precision_score(all_labels, all_predictions, average='macro', zero_division=0)
        recall = recall_score(all_labels, all_predictions, average='macro', zero_division=0)
        f1 = f1_score(all_labels, all_predictions, average='macro', zero_division=0)
        cm = confusion_matrix(all_labels, all_predictions)

        # Detailed class-wise metrics
        class_precision = precision_score(all_labels, all_predictions, average=None, zero_division=0)
        class_recall = recall_score(all_labels, all_predictions, average=None, zero_division=0)
        class_f1 = f1_score(all_labels, all_predictions, average=None, zero_division=0)

        # Calculate confidence statistics
        confidence_stats = self._analyze_confidence(all_probabilities, all_predictions, all_labels)

        evaluation_results = {
            'overall_metrics': {
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1_score': f1,
                'confusion_matrix': cm
            },
            'class_wise_metrics': {
                'precision': class_precision,
                'recall': class_recall,
                'f1_score': class_f1
            },
            'confidence_analysis': confidence_stats,
            'participant_results': participant_results if detailed else None,
            'feature_analysis': dict(feature_analysis) if detailed else None,
            'predictions': all_predictions,
            'true_labels': all_labels,
            'probabilities': all_probabilities
        }

        return evaluation_results

    def _analyze_confidence(self, probabilities, predictions, true_labels):
        """Analyze model confidence and calibration"""
        probs_array = np.array(probabilities)
        max_probs = np.max(probs_array, axis=1)
        predictions = np.array(predictions)
        true_labels = np.array(true_labels)

        # Overall confidence stats
        confidence_stats = {
            'mean_confidence': np.mean(max_probs),
            'std_confidence': np.std(max_probs),
            'min_confidence': np.min(max_probs),
            'max_confidence': np.max(max_probs)
        }

        # Confidence for correct vs incorrect predictions
        correct_mask = predictions == true_labels
        confidence_stats['correct_confidence'] = np.mean(max_probs[correct_mask])
        confidence_stats['incorrect_confidence'] = np.mean(max_probs[~correct_mask])

        # Calibration analysis
        n_bins = 10
        bin_boundaries = np.linspace(0, 1, n_bins + 1)
        bin_lowers = bin_boundaries[:-1]
        bin_uppers = bin_boundaries[1:]

        calibration_data = []
        for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
            in_bin = (max_probs > bin_lower) & (max_probs <= bin_upper)
            prop_in_bin = in_bin.mean()

            if prop_in_bin > 0:
                accuracy_in_bin = correct_mask[in_bin].mean()
                avg_confidence_in_bin = max_probs[in_bin].mean()
                calibration_data.append({
                    'bin_lower': bin_lower,
                    'bin_upper': bin_upper,
                    'accuracy': accuracy_in_bin,
                    'confidence': avg_confidence_in_bin,
                    'count': in_bin.sum()
                })

        confidence_stats['calibration_data'] = calibration_data

        return confidence_stats

    def _print_confusion_matrix(self, cm, class_names=None):
        """Print formatted confusion matrix"""
        if class_names is None:
            class_names = ['Control', 'Dementia']

        print("  Confusion Matrix:")
        print("  " + "-" * 25)
        print(f"          Predicted")
        print(f"        {class_names[0]:<8} {class_names[1]:<8}")
        print(f"  True")
        for i, true_class in enumerate(class_names):
            print(f"  {true_class:<8} {cm[i,0]:<8} {cm[i,1]:<8}")
        print()

    def plot_training_history(self, save_path=None):
        """Plot comprehensive training history"""
        fig, axes = plt.subplots(2, 3, figsize=(18, 10))

        # Training and validation loss
        axes[0,0].plot(self.train_losses, label='Training Loss', color='blue', alpha=0.7)
        axes[0,0].plot(self.val_losses, label='Validation Loss', color='red', alpha=0.7)
        axes[0,0].set_title('Training and Validation Loss')
        axes[0,0].set_xlabel('Epoch')
        axes[0,0].set_ylabel('Loss')
        axes[0,0].legend()
        axes[0,0].grid(True, alpha=0.3)

        # Training and validation accuracy
        axes[0,1].plot(self.train_accuracies, label='Training Accuracy', color='blue', alpha=0.7)
        axes[0,1].plot(self.val_accuracies, label='Validation Accuracy', color='red', alpha=0.7)
        axes[0,1].set_title('Training and Validation Accuracy')
        axes[0,1].set_xlabel('Epoch')
        axes[0,1].set_ylabel('Accuracy')
        axes[0,1].legend()
        axes[0,1].grid(True, alpha=0.3)

        # Learning rate schedule
        axes[0,2].plot(self.learning_rates, label='Learning Rate', color='green')
        axes[0,2].set_title('Learning Rate Schedule')
        axes[0,2].set_xlabel('Epoch')
        axes[0,2].set_ylabel('Learning Rate')
        axes[0,2].set_yscale('log')
        axes[0,2].legend()
        axes[0,2].grid(True, alpha=0.3)

        # Loss difference (overfitting indicator)
        loss_diff = [abs(t - v) for t, v in zip(self.train_losses, self.val_losses)]
        axes[1,0].plot(loss_diff, label='|Train - Val| Loss', color='orange')
        axes[1,0].set_title('Training-Validation Loss Gap')
        axes[1,0].set_xlabel('Epoch')
        axes[1,0].set_ylabel('Absolute Difference')
        axes[1,0].legend()
        axes[1,0].grid(True, alpha=0.3)

        # Accuracy difference
        acc_diff = [abs(t - v) for t, v in zip(self.train_accuracies, self.val_accuracies)]
        axes[1,1].plot(acc_diff, label='|Train - Val| Accuracy', color='purple')
        axes[1,1].set_title('Training-Validation Accuracy Gap')
        axes[1,1].set_xlabel('Epoch')
        axes[1,1].set_ylabel('Absolute Difference')
        axes[1,1].legend()
        axes[1,1].grid(True, alpha=0.3)

        # Training summary
        if len(self.train_losses) > 0:
            summary_text = (
                f'Training Complete!\n\n'
                f'Total Epochs: {len(self.train_losses)}\n'
                f'Final Train Acc: {self.train_accuracies[-1]:.4f}\n'
                f'Final Val Acc: {self.val_accuracies[-1]:.4f}\n'
                f'Final Train Loss: {self.train_losses[-1]:.4f}\n'
                f'Final Val Loss: {self.val_losses[-1]:.4f}\n'
                f'Best Val Acc: {max(self.val_accuracies):.4f}\n'
                f'Min Val Loss: {min(self.val_losses):.4f}'
            )
        else:
            summary_text = 'No training data available'

        axes[1,2].text(0.5, 0.5, summary_text,
                      horizontalalignment='center', verticalalignment='center',
                      transform=axes[1,2].transAxes, fontsize=11,
                      bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.8))
        axes[1,2].set_title('Training Summary')
        axes[1,2].axis('off')

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Training history plot saved to: {save_path}")

        return fig

    def plot_evaluation_results(self, evaluation_results, save_path=None):
        """Plot comprehensive evaluation results"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))

        # Confusion Matrix
        cm = evaluation_results['overall_metrics']['confusion_matrix']
        class_names = ['Control', 'Dementia']

        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=class_names, yticklabels=class_names,
                   ax=axes[0,0])
        axes[0,0].set_title('Confusion Matrix')
        axes[0,0].set_xlabel('Predicted')
        axes[0,0].set_ylabel('True')

        # Class-wise metrics
        class_metrics = evaluation_results['class_wise_metrics']
        metrics_names = ['Precision', 'Recall', 'F1-Score']
        x = np.arange(len(class_names))
        width = 0.25

        for i, metric in enumerate(['precision', 'recall', 'f1_score']):
            axes[0,1].bar(x + i*width, class_metrics[metric], width,
                         label=metrics_names[i], alpha=0.8)

        axes[0,1].set_xlabel('Classes')
        axes[0,1].set_ylabel('Score')
        axes[0,1].set_title('Class-wise Performance Metrics')
        axes[0,1].set_xticks(x + width)
        axes[0,1].set_xticklabels(class_names)
        axes[0,1].legend()
        axes[0,1].grid(True, alpha=0.3)

        # Confidence Analysis
        if 'confidence_analysis' in evaluation_results:
            conf_stats = evaluation_results['confidence_analysis']

            # Confidence distribution
            probabilities = np.array(evaluation_results['probabilities'])
            max_probs = np.max(probabilities, axis=1)

            axes[1,0].hist(max_probs, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
            axes[1,0].axvline(conf_stats['mean_confidence'], color='red', linestyle='--',
                             label=f'Mean: {conf_stats["mean_confidence"]:.3f}')
            axes[1,0].set_xlabel('Model Confidence')
            axes[1,0].set_ylabel('Count')
            axes[1,0].set_title('Confidence Distribution')
            axes[1,0].legend()
            axes[1,0].grid(True, alpha=0.3)

        # Overall metrics summary
        overall_metrics = evaluation_results['overall_metrics']
        metrics_text = (
            f'Overall Performance\n\n'
            f'Accuracy: {overall_metrics["accuracy"]:.4f}\n'
            f'Precision: {overall_metrics["precision"]:.4f}\n'
            f'Recall: {overall_metrics["recall"]:.4f}\n'
            f'F1-Score: {overall_metrics["f1_score"]:.4f}\n\n'
        )

        if 'confidence_analysis' in evaluation_results:
            conf_stats = evaluation_results['confidence_analysis']
            metrics_text += (
                f'Confidence Stats\n'
                f'Mean Confidence: {conf_stats["mean_confidence"]:.4f}\n'
                f'Correct Predictions: {conf_stats["correct_confidence"]:.4f}\n'
                f'Incorrect Predictions: {conf_stats["incorrect_confidence"]:.4f}'
            )

        axes[1,1].text(0.1, 0.5, metrics_text,
                      horizontalalignment='left', verticalalignment='center',
                      transform=axes[1,1].transAxes, fontsize=12,
                      bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgreen", alpha=0.8))
        axes[1,1].set_title('Performance Summary')
        axes[1,1].axis('off')

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Evaluation results plot saved to: {save_path}")

        return fig

    def save_model(self, save_path, evaluation_results=None):
        """Save model and training information"""
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'model_config': {
                'num_classes': self.num_classes,
            },
            'training_history': {
                'train_losses': self.train_losses,
                'val_losses': self.val_losses,
                'train_accuracies': self.train_accuracies,
                'val_accuracies': self.val_accuracies,
                'learning_rates': self.learning_rates,
            }
        }

        if evaluation_results:
            checkpoint['evaluation_results'] = evaluation_results

        torch.save(checkpoint, save_path)
        print(f"Model checkpoint saved to: {save_path}")

    def load_model(self, load_path):
        """Load model from checkpoint"""
        checkpoint = torch.load(load_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])

        if 'training_history' in checkpoint:
            history = checkpoint['training_history']
            self.train_losses = history.get('train_losses', [])
            self.val_losses = history.get('val_losses', [])
            self.train_accuracies = history.get('train_accuracies', [])
            self.val_accuracies = history.get('val_accuracies', [])
            self.learning_rates = history.get('learning_rates', [])

        print(f"Model loaded from: {load_path}")
        return checkpoint.get('evaluation_results', None)