In [None]:
TMPDEL 2 HMM GMM 3.9




import numpy as np
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional
import logging
from pathlib import Path
import joblib
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import time
from enum import Enum, auto  
import warnings


""" --- OLD ! 
'helix': np.array([0.52, 0.30, 0.18]),  # Helix mixing proportions
'sheet': np.array([0.38, 0.34, 0.28]),  # Sheet mixing proportions
'coil': np.array([0.34, 0.33, 0.33])    # Coil mixing proportions

# # Emission configuration from structure analysis --- outdated and has mistakes.
# self.emission_config = {
#     'base_std': 0.285,        # From PSSM analysis
#     'init_noise_scale': 0.05,  # Reduce for better stability
#     'state_biases': {
#         'helix': [0.46, 0.35, 0.19],  # Match mixture weights
#         'sheet': [0.46, 0.35, 0.19],  # Keep consistent
#         'coil': [0.46, 0.35, 0.19]    # Use analyzed proportions
#     },
#     'update_clip': 0.1        # For gradient stability
# }
"""


# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s | %(levelname)s | %(message)s',
    datefmt='%H:%M:%S',
    handlers=[
        logging.FileHandler(f'protein_structure_pred_{datetime.now():%Y%m%d_%H%M}.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)




## auto() function is used in Enums to automatically assign incrementing values to enum members.

class StructureState(Enum):
    """Three-state DSSP classification"""
    HELIX = auto()  # H: Alpha helix, 3-10 helix, Pi helix
    SHEET = auto()  # E: Extended strand, Bridge
    COIL = auto()   # C: Coil, Turn, Bend




class ModelConfig:
    def __init__(self):

        # Core parameters
        self.n_states = 3    # ✓ From structure classification
        self.n_mixtures = 3  # ✓ From mixture analysis (optimal clusters)
        self.n_features = 46    # ✓ From feature space analysis
        self.random_seed = 42  # ⚠️ Still arbitrary
        
        # # Training parameters (existing)
        # self.learning_rate = 0.005 # self.lr_decay = 0.99 # self.clip_value = 2.0 # self.momentum = 0.1 # self.batch_size = 32
        
        # Training parameters (updated from analysis)
        self.learning_rate = 0.094353  # From feature variance analysis
        self.lr_decay = 0.9791        # From feature stability
        self.clip_value = 0.5871      # From gradient analysis
        self.momentum = 0.0941        # From autocorrelation analysis
        self.batch_size = 32          # Keep current as analysis suggests very small batches
    
        # State balance parameters (updated from distribution analysis)
        self.min_state_prob = 0.016   # From state distribution
        self.max_state_prob = 0.047   # From state distribution
        self.min_mixture_prob = 0.1   # Keep current as analysis supports this
        self.state_balance_weight = 0.5
    
        # Target distribution (keep current as validated by position analysis)
        self.target_state_dist = np.array([0.492, 0.162, 0.346]) # ✓ From NPY analysis # Helix, Sheet, Coil

        """ suggested set:
            'transition_params': {
            'helix_self': 0.8,      # Reduced from 0.91 for better stability
            'sheet_self': 0.6,      # Reduced from 0.67
            'coil_self': 0.4,       # Slightly increased from 0.39
            'helix_to_sheet': 0.15,
            'sheet_to_coil': 0.15,
            'coil_to_helix': 0.15   # Made mixing probabilities equal
        },
        """

        # Initialization parameters from detailed NPY analysis
        self.init_params = {
            'state_priors': {
                'helix': 0.492,  # From position-wise structure counting
                'sheet': 0.162,  # Combined E+B states frequency
                'coil': 0.346   # Combined T+S+C states frequency
            },
            'transition_params': {
                'helix_self': 0.91,   # Direct H→H transition probability
                'sheet_self': 0.67,   # Measured from G→G transitions
                'coil_self': 0.39,    # Actual C→C transition rate
                'helix_to_sheet': 0.15,  # Measured H→E transition
                'sheet_to_coil': 0.18,   # Measured E→C transition
                'coil_to_helix': 0.10    # Measured C→H transition
            },
            'emission_params': {
                'mean_scale': 0.136,    # From PSSM analysis
                'std_scale': 0.285,     # From feature std analysis
                'noise_scale': 0.1782,  # From mixture analysis
                
                'helix_boost': 1.2,      # Based on higher helix conservation # ✓ From conservation patterns
                'sheet_boost': 1.0,      # From moderate sheet conservation # ✓ From conservation patterns
                'coil_boost': 0.8        # Reflects higher coil variability # ✓ From conservation patterns
            },
            'feature_weights': {
                'one_hot': 0.42,     # Measured primary structure influence    # ✓ From feature importance analysis
                'pssm': 0.39,        # Measured evolutionary signal impact    # ✓ From feature importance analysis
                'aux': 0.19          # Measured auxiliary property contribution     # ✓ From feature importance analysis
            },
            'mixture_weights': {
                'primary': 0.46,     # Dominant signal component        # ✓ From mixture analysis
                'secondary': 0.35,   # Supporting signal strength          # ✓ From mixture analysis
                'tertiary': 0.19     # Refinement signal proportion          # ✓ From mixture analysis
            }
        }
        
        # Feature configuration
        self.feature_config = {
            'one_hot': {
                'start': 0, 'end': 21,
                'weight': self.init_params['feature_weights']['one_hot'],
                'std_scale': 1.0,  # ✓ Binary features
                'mean': 0.047      # ✓ From one-hot analysis
            },
            'pssm': {
                'start': 21, 'end': 42,
                'weight': self.init_params['feature_weights']['pssm'],
                'std_scale': 1.0,  # ✓ Base scale
                'mean': 0.136      # ✓ From PSSM analysis
            },
            'aux': {
                'start': 42, 'end': 46,
                'weight': self.init_params['feature_weights']['aux'],
                'std_scale': 0.7,  # ✓ From property analysis
                'mean': 0.268      # ✓ From property analysis
            }
        }
        
        # Stability configuration (from stability analysis)
        self.stability_config = {
            'max_lr_scale': 2.0,                # ⚠️ Could be refined
            'balance_threshold': 1.399835e-04,  # ✓ From stability analysis
            'update_scale': 1.0                 # ✓ From update scale analysis
        }

        self.emission_config = {
            'base_std': 0.089,
            'init_noise_scale': 0.098,
            'state_biases': {
                'helix': [0.030, 0.325, 0.645],
                'sheet': [0.059, 0.449, 0.492],
                'coil': [0.172, 0.765, 0.063],
            },
            'update_clip': 0.021
        }
        
        
        # Add new logging control parameters
        self.logging_config = {
            'state_collapse_window': 100,  # Number of positions to average over
            'logging_frequency': 10,        # Log every N iterations
            'min_warning_interval': 2.0    # Minimum seconds between collapse warnings
        }


        
## ProteinFeatures Class
## Previously: Used mask from seq_data[:, -1] to find valid positions, which was incorrect as actual data was in first ~67 positions
## Fixed: Now using one-hot encoding sums to find true sequence positions, ensuring we don't miss the actual data
## Key Issue: Original code was looking at positions 67+ while actual data was in 0-66 range
class ProteinFeatures:
    def __init__(self, npy_data: np.ndarray):
        """Initialize with raw NPY data"""
        self.raw_data = npy_data
        self.current_idx = 0
        logger.info(f"Initializing feature extraction for sequence shape: {npy_data.shape}")
        
    def extract_features(self) -> Tuple[np.ndarray, np.ndarray]:
        """Extract features with verified dimensions while preserving all feature information"""
        if self.current_idx >= len(self.raw_data):
            raise StopIteration("No more sequences to process")
            
        # Get current sequence and reshape
        seq_data = self.raw_data[self.current_idx].reshape(700, 57)
        one_hot_sums = np.sum(seq_data[:, :21], axis=1)
        valid_positions = np.where(one_hot_sums > 0)[0]
        
        if len(valid_positions) == 0:
            logger.warning(f"Empty sequence at index {self.current_idx}")
            self.current_idx += 1
            raise ValueError(f"Empty sequence at index {self.current_idx-1}")
        
        seq_length = len(valid_positions)
        if self.current_idx % 100 == 0 or self.current_idx == 1 or self.current_idx == len(self.raw_data) - 1:
            logger.debug(f"Processing sequence {self.current_idx}/{len(self.raw_data)}, length: {seq_length}")
        
        self.current_idx += 1
        
        # Extract all feature components (preserving existing structure)
        one_hot = seq_data[valid_positions, :21]                        # 21 features
        pssm = seq_data[valid_positions, 21:42]                        # 21 features
        ss8 = seq_data[valid_positions, 42:50]                         # Secondary structure (for labels)
        disorder = seq_data[valid_positions, 50:51]                    # 1 feature
        additional = seq_data[valid_positions, 51:]                    # Additional features
        
        # Create position features
        rel_pos = np.arange(seq_length) / seq_length
        start_dist = np.arange(seq_length) / seq_length
        end_dist = np.arange(seq_length)[::-1] / seq_length
        pos_features = np.stack([rel_pos, start_dist, end_dist], axis=1)
        
        # Store full feature information for debugging and analysis
        self.feature_info = {
            'one_hot': one_hot,
            'pssm': pssm,
            'ss8': ss8,
            'disorder': disorder,
            'additional': additional,
            'positional': pos_features
        }
        
        # Combine features with verified dimensions (21 + 21 + 1 + 3 = 46 total)
        features = np.concatenate([one_hot, pssm, disorder, pos_features], axis=1)
        
        # Convert SS8 to SS3 for labels
        ss3 = self._convert_dssp8_to_dssp3(ss8)
        
        # Validate feature dimensions while preserving debug info
        if features.shape[1] != 46:
            logger.error(f"Feature dimension mismatch: {features.shape[1]} != 46")
            logger.debug("Feature shapes:")
            for name, feat in self.feature_info.items():
                logger.debug(f"  {name}: {feat.shape}")
            raise ValueError(f"Feature dimension mismatch: {features.shape[1]} != 46")
            
        return features, ss3

    
    def _convert_dssp8_to_dssp3(self, ss8: np.ndarray) -> np.ndarray:
        """Convert 8-state DSSP to 3-state representation using dominant states"""
        dssp3 = np.zeros((len(ss8), 3))
        dominant_states = np.argmax(ss8, axis=1)
        
        # Map states: H,G,I -> Helix; E,B -> Sheet; T,S,C -> Coil
        dssp3[:, 0] = np.isin(dominant_states, [0,1,2])  # Helix
        dssp3[:, 1] = np.isin(dominant_states, [3,4])    # Sheet
        dssp3[:, 2] = np.isin(dominant_states, [5,6,7])  # Coil
        
        return dssp3
    
    def _create_position_features(self, length: int) -> np.ndarray:
        """Create position-specific features"""
        # Relative position in sequence
        rel_pos = np.arange(length) / length
        # Distance from sequence ends
        start_dist = np.arange(length) / length
        end_dist = np.arange(length)[::-1] / length
        return np.stack([rel_pos, start_dist, end_dist], axis=1)


## -----------------------------------------------------------------------------------------------------------------------------------------------------------


    
class MixtureGaussianHMM:
    """HMM with mixture of Gaussians emission for protein structure prediction"""
    
    
    def __init__(self, config: ModelConfig):
        """Initialize model with given configuration"""
        self.config = config
        np.random.seed(config.random_seed)
        self._initialize_model()
        logger.info(f"Initialized HMM with {config.n_states} states and {config.n_mixtures} mixtures per state")

    ## DO. NOT. USE. THE. MASK. its so broken!!! use this instead.
    @staticmethod
    def _get_valid_positions(x: np.ndarray) -> np.ndarray:
        """Centralized method for getting valid positions"""
        if x.ndim == 1:
            x = x.reshape(-1, 57)
        one_hot_sums = np.sum(x[:, :21], axis=1)
        return np.where(one_hot_sums > 0)[0]

    # Add Feature Validation Method to MixtureGaussianHMM
    def _validate_feature_dimensions(self, x: np.ndarray) -> None:
        """Validate feature dimensions while preserving all error tracking"""
        if x.shape[1] != self.config.n_features:
            error_msg = f"Feature dimension mismatch. Expected: {self.config.n_features}, Got: {x.shape[1]}"
            logger.error(error_msg)
            logger.debug(f"Feature shape details: {x.shape}")
            logger.debug(f"Model feature config: {self.feature_config}")
            raise ValueError(error_msg)

                        
    def _initialize_model(self):
        """Initialize model with validation"""
        # Get initial stats from config
        state_priors = np.array([self.config.init_params['state_priors'][s] for s in ['helix', 'sheet', 'coil']])
        tp = self.config.init_params['transition_params']

        # Calculate mixing with boundary check
        mixing = max(0.05, min(tp['helix_to_sheet'], tp['sheet_to_coil'], tp['coil_to_helix']))

        # Initialize transitions with guaranteed non-zero probabilities
        self.transitions = np.array([
            [tp['helix_self'], mixing, 1 - tp['helix_self'] - mixing],
            [mixing, tp['sheet_self'], 1 - tp['sheet_self'] - mixing],
            [mixing, 1 - tp['coil_self'] - mixing, tp['coil_self']]
        ])
            
        # Ensure valid probabilities
        self.transitions = np.maximum(self.transitions, 0.0)  # Ensure non-negative
        row_sums = self.transitions.sum(axis=1, keepdims=True)
        self.transitions = self.transitions / row_sums  # Normalize rows

        # Initialize state priors
        self.state_priors = state_priors / state_priors.sum()
        
        # Verify transition matrix before proceeding
        self._check_transition_matrix()
        # Initialize remaining components
        self._initialize_mixture_weights()
        self._initialize_emissions()


    def _check_transition_matrix(self):
        """Validate transition matrix properties"""
        # Check row sums
        row_sums = np.sum(self.transitions, axis=1)
        if not np.allclose(row_sums, 1.0, rtol=1e-5, atol=1e-5):
            logger.error(f"Invalid transition matrix:\nRow sums: {row_sums}\nMin value: {np.min(self.transitions)}")
            raise ValueError("Invalid transition matrix: row sums not 1 or negative values present")
        
        # Check for and fix zero probabilities
        min_prob = 1e-5
        if np.any(self.transitions < min_prob):
            logger.warning(f"Found near-zero probabilities in transition matrix. Adjusting...")
            self.transitions = np.maximum(self.transitions, min_prob)
            self.transitions /= self.transitions.sum(axis=1, keepdims=True)


    
    def _initialize_transitions(self):
        """Initialize transitions using config parameters"""
        tp = self.config.init_params['transition_params']
        
        # Calculate mixing based on transition probabilities
        ## old --> mixing = min(tp['helix_to_sheet'], tp['sheet_to_coil'], tp['coil_to_helix'])
        mixing = max(0.05, min(tp['helix_to_sheet'], tp['sheet_to_coil'], tp['coil_to_helix']))  # Ensure minimum mixing

        self.transitions = np.array([
            [tp['helix_self'], mixing, 1 - tp['helix_self'] - mixing],
            [mixing, tp['sheet_self'], 1 - tp['sheet_self'] - mixing],
            [mixing, 1 - tp['coil_self'] - mixing, tp['coil_self']]
        ])
        
        # Ensure minimum probabilities
        self.transitions = np.maximum(self.transitions, 0.05)
        self.transitions /= self.transitions.sum(axis=1, keepdims=True)
    

        
    
    def _initialize_emissions(self):
        """Initialize emissions using parameterized config"""
        n_states, n_mix, n_feat = self.config.n_states, self.config.n_mixtures, self.config.n_features
        ec = self.config.emission_config
        ep = self.config.init_params['emission_params']
        fc = self.config.feature_config
            
        # Initialize base parameters with smaller noise scale for stability
        self.emission_means = np.random.normal(0, ec['init_noise_scale'] * 0.5,  # Reduced noise scale
                                             size=(n_states, n_mix, n_feat))
        self.emission_covs = np.ones((n_states, n_mix, n_feat)) * ec['base_std'] * 1.5  # Slightly increased base variance
        
        # Define boost factors from config
        boost_factors = [ ep['helix_boost'], ep['sheet_boost'], ep['coil_boost'] ]

        for state in range(n_states):
            # Add gradual scaling across states
            state_scale = 1.0 + (state - 1) * 0.1  # Small differential between states
            
            for feat_type, conf in fc.items():
                start, end = conf['start'], conf['end']
                # Balanced feature-specific scaling
                self.emission_means[state, :, start:end] *= boost_factors[state] * state_scale
                self.emission_covs[state, :, start:end] *= conf['std_scale'] * state_scale
        
        # Square covariances for positive definiteness
        self.emission_covs = self.emission_covs ** 2
        

        
    def _initialize_mixture_weights(self):
        """Initialize mixture weights using parameterized weights"""
        # Use mixture weights from config instead of hardcoded values
        base_weights = np.array([
            self.config.init_params['mixture_weights']['primary'], self.config.init_params['mixture_weights']['secondary'],
            self.config.init_params['mixture_weights']['tertiary'] ])
        
        # Create tile and apply minimum probability constraint
        self.mixture_weights = np.tile(base_weights, (self.config.n_states, 1))
        min_prob = self.config.min_mixture_prob
        
        # Apply constraints and normalize
        self.mixture_weights = np.maximum(self.mixture_weights, min_prob)
        self.mixture_weights /= self.mixture_weights.sum(axis=1, keepdims=True)

    
    """
    Old version: Direct gradient updates with clipping
    New version: Balance-weighted updates with feature-type specific handling
    More sophisticated learning rate modulation based on state distribution
    
    -- Think of it like adding a correction term to keep the model from falling into state collapse, while preserving the underlying statistical learning mechanism.
    -- added an overlay of state balance maintenance.
    """       
    

    ## State Balance Enforcement
    ## Critical Issue: Was enforcing balance on possibly invalid state distributions
    ## Fix: Now uses valid position counts and proper weighting
    def _update_parameters(self, stats: Dict) -> None:
        """Parameter update coordinator with state balance enforcement and emission integration"""
        # Get valid state usage with stability threshold
        state_usage = stats['state_occurences'] 
        total_valid = np.sum(state_usage) + self.config.stability_config['balance_threshold']
        state_dist = state_usage / total_valid
        
        # State balance analysis with deviation tracking
        dist_deviation = self.config.target_state_dist - state_dist
        max_scale = self.config.stability_config['max_lr_scale']
        state_scales = np.clip(self.config.target_state_dist / (state_dist + 1e-10), 1.0, max_scale)
        state_lr = self.config.learning_rate * state_scales
        
        # Analyze emission patterns and component usage
        emission_stats = stats['emission_stats']
        component_balance = emission_stats['component_usage'] / (np.sum(emission_stats['component_usage'], axis=1, keepdims=True) + 1e-10)
        feature_importance = emission_stats['feature_contributions'] / (np.sum(emission_stats['feature_contributions'], axis=1, keepdims=True) + 1e-10)
        
        # Update transitions with balance enforcement and component influence
        self._update_transitions(
            transition_counts=stats['transition_counts'],
            state_dist=state_dist,
            dist_deviation=dist_deviation,
            component_usage=component_balance
        )
        
        # Update emissions with state-specific learning rates and feature importance
        self._update_emission_parameters(
            emission_stats=emission_stats,
            state_lr=state_lr,
            feature_weights=feature_importance,
            state_deviation=dist_deviation
        )
        
        # Monitor parameter updates and stability
        if np.any(dist_deviation > 0.2):
            logger.warning(f"Large state distribution deviation detected: {dist_deviation}")
            logger.debug(f"Component balance: {component_balance}")
            logger.debug(f"Feature importance: {feature_importance}")
    
        
    def _update_transitions(self, transition_counts: np.ndarray, state_dist: np.ndarray, dist_deviation: np.ndarray, component_usage: np.ndarray) -> None:
        """Update transition matrix with balance constraints and component influence"""
        # Normalize counts with stability threshold
        norm_counts = transition_counts / (transition_counts.sum(axis=1, keepdims=True) + self.config.stability_config['balance_threshold'])
        
        # Compute balance-adjusted updates with component influence
        balance_factor = self.config.state_balance_weight
        component_influence = np.mean(component_usage, axis=1)  # Average component usage per state
        updates = norm_counts - self.transitions
        
        # Apply stronger updates for underrepresented states with component weighting
        for i in range(self.config.n_states):
            balance_update = balance_factor * dist_deviation
            component_weight = np.clip(component_influence[i], 0.1, 0.9)  # Limit component influence
            updates[i] += balance_update * component_weight
        
        # Apply updates with bounds and stability constraints
        self.transitions = np.clip(
            self.transitions + updates * self.config.stability_config['update_scale'],
            self.config.min_state_prob,
            self.config.max_state_prob
        )
        
        # Renormalize while preserving minimum probability constraints
        row_sums = self.transitions.sum(axis=1, keepdims=True)
        self.transitions /= np.maximum(row_sums, 1e-10)
        
        # Monitor transition stability
        if np.any(np.abs(updates) > 0.1):
            logger.debug(f"Large transition updates detected: max={np.max(np.abs(updates)):.3f}")
            logger.debug(f"Component influence: {component_influence}")

            
    def _update_emission_parameters(self, emission_stats: Dict, state_lr: np.ndarray, feature_weights: np.ndarray, state_deviation: np.ndarray) -> Dict:
        """Update emission parameters with state balance and feature importance integration"""
        weights = emission_stats['weights_num']; means_num = emission_stats['means_num']; covs_num = emission_stats['covs_num']
        param_changes = {'mean_shifts': np.zeros((self.config.n_states, self.config.n_mixtures)), 'cov_changes': np.zeros((self.config.n_states, self.config.n_mixtures))}
        
        for state in range(self.config.n_states):
            for mix in range(self.config.n_mixtures):
                if weights[state, mix] < 1e-10: continue
                
                # New means computation with feature weighting
                new_means = means_num[state, mix] / (weights[state, mix] + 1e-10)
                new_covs = covs_num[state, mix] / (weights[state, mix] + 1e-10)
                
                # Feature-type specific updates with learning rate adjustment
                for feat_type, conf in self.config.feature_config.items():
                    start, end = conf['start'], conf['end']; feat_slice = slice(start, end)
                    feat_lr = state_lr[state] * conf['weight'] * feature_weights[state, conf['end'] - conf['start']]
                    
                    old_means = self.emission_means[state, mix, feat_slice].copy()
                    self.emission_means[state, mix, feat_slice] = (1 - feat_lr) * old_means + feat_lr * new_means[feat_slice]
                    
                    # Update covariances with bounds and stability constraints
                    new_covs_bounded = np.clip(new_covs[feat_slice], self.config.stability_config['balance_threshold'], self.config.emission_config['base_std'] * 4)
                    self.emission_covs[state, mix, feat_slice] = (1 - feat_lr) * self.emission_covs[state, mix, feat_slice] + feat_lr * new_covs_bounded
                    
                    param_changes['mean_shifts'][state, mix] += np.mean(np.abs(self.emission_means[state, mix, feat_slice] - old_means))
                
                # Ensure minimum variance with state balance influence
                min_var = self.config.stability_config['balance_threshold'] * (1 + np.abs(state_deviation[state]))
                self.emission_covs[state, mix] = np.maximum(self.emission_covs[state, mix], min_var)
        
        if np.any(param_changes['mean_shifts'] > 0.1): logger.debug(f"Large emission shifts detected: {param_changes['mean_shifts']}")
        return param_changes
            
        
        
        
    ## MixtureGaussianHMM Core Methods
    ## Previous Issue: Emission likelihood computation was using wrong positions
    ## Fixed: Now properly handling sequence positions and structure state extraction
    def _compute_emission_likelihood(self, x: np.ndarray) -> np.ndarray:
        """Compute emission likelihood with corrected feature handling and stability monitoring"""
        # Initialize output array for full sequence
        seq_len = len(x)
        emission_ll = np.zeros((seq_len, self.config.n_states))
        state_stats = {state: {'max_ll': -np.inf, 'min_ll': np.inf} for state in range(self.config.n_states)}
    
        # Get valid positions and prepare data
        valid_pos = self._get_valid_positions(x)
        x_valid = x[valid_pos]
        
        # Define feature ranges based on config - ensures consistent slicing
        feature_ranges = {'one_hot': slice(0, 21), 'pssm': slice(21, 42), 'aux': slice(42, 46)}  # one-hot: 21, PSSM: 21, aux: 4 features
        
        # Process each state
        for state in range(self.config.n_states):
            mixture_ll = np.zeros((len(valid_pos), self.config.n_mixtures))  # Setup mixture likelihoods array
           
            # Process each mixture component
            for mix in range(self.config.n_mixtures):
                group_ll = np.zeros(len(valid_pos))  # Initialize group log probabilities
               
                # Process each feature group (one-hot, PSSM, auxiliary)
                for feat_name, feat_slice in feature_ranges.items():
                    feat_ll = self._compute_feature_likelihood(x_valid[:, feat_slice], self.emission_means[state, mix, feat_slice], self.emission_covs[state, mix, feat_slice], self.config.feature_config[feat_name]['weight'])  # Compute feature likelihood with proper weight
                    group_ll += feat_ll
               
                mixture_ll[:, mix] = group_ll + np.log(self.mixture_weights[state, mix] + self.config.stability_config['balance_threshold'])  # Add mixture weight and store
           
            max_ll = mixture_ll.max(axis=1, keepdims=True)  # Compute state emission likelihood with numerical stability (log-sum-exp trick)
            emission_ll[valid_pos, state] = max_ll.squeeze() + np.log(np.sum(np.exp(mixture_ll - max_ll), axis=1))
            state_stats[state].update({'max_ll': emission_ll[valid_pos, state].max(), 'min_ll': emission_ll[valid_pos, state].min(), 'mean_ll': emission_ll[valid_pos, state].mean()})  # Update state statistics
        
        return self._normalize_probabilities(emission_ll)  # Return normalized probabilities
        
    
        
    def _compute_feature_likelihood(self, x: np.ndarray, mean: np.ndarray, 
                                  cov: np.ndarray, weight: float) -> np.ndarray:
        """Compute weighted feature likelihood with numerical stability"""
        # Input validation
        if x.ndim != 2:   x = np.atleast_2d(x)
        
        n_samples, n_features = x.shape
        if mean.size != n_features or cov.size != n_features:
            raise ValueError(f"Shape mismatch - x: {x.shape}, mean: {mean.shape}, "
                            f"cov: {cov.shape}, expected features: {n_features}")
        
        # Reshape mean and cov for broadcasting
        mean = np.atleast_2d(mean);        cov = np.atleast_2d(cov);
        
        if mean.shape[1] != n_features:            mean = mean.T
        if cov.shape[1] != n_features:            cov = cov.T
            
        # Ensure minimum variance for stability
        stable_cov = np.maximum(cov, self.config.stability_config['balance_threshold'])
        
        # Compute likelihood
        diff = x - mean
        exp_term = -0.5 * np.sum((diff ** 2) / stable_cov, axis=1)
        log_norm = -0.5 * np.sum(np.log(2 * np.pi * stable_cov))
        
        return weight * (exp_term + log_norm)
        

    
    
    def _normalize_probabilities(self, log_probs: np.ndarray) -> np.ndarray:
        """Normalize log probabilities with reduced warning frequency"""
        max_lp = log_probs.max(axis=1, keepdims=True)
        probs = np.exp(log_probs - max_lp)
        normalized = probs / (probs.sum(axis=1, keepdims=True) + self.config.stability_config['balance_threshold'])
        
        # Compute average state probabilities over window
        window_size = min(self.config.logging_config['state_collapse_window'], normalized.shape[0])
        state_probs = normalized[:window_size].mean(axis=0)
        # Check state collapse only periodically using averaged probabilities
        curr_time = time.time()
        if not hasattr(self, '_last_warning_time'):
            self._last_warning_time = 0

        # Only log warning if sufficient time has passed and probability is below threshold
        if (min(state_probs) < self.config.min_state_prob and 
            curr_time - self._last_warning_time >= self.config.logging_config['min_warning_interval']):
            logger.warning(f"State probability collapse detected:")
            logger.warning(f"State distributions: {[f'{p:.6f}' for p in state_probs]}")
            logger.warning(f"Min allowed prob: {self.config.min_state_prob}")
            logger.warning(f"Max allowed prob: {self.config.max_state_prob}")
            self._last_warning_time = curr_time
                
        return normalized



    def monitor_training_iteration(self, iteration: int, stats: Dict) -> None:
        """Monitor critical training statistics for debugging"""
        if iteration % 5 == 0:  # Log every 5 iterations
            state_usage = stats['state_occurences'] / stats['state_occurences'].sum()
            transition_diag = np.diag(self.transitions)
            
            logger.info(f"\nIteration {iteration} Statistics:")
            logger.info(f"State Distribution: {[f'{p:.3f}' for p in state_usage]}")
            logger.info(f"Self-Transition Probs: {[f'{p:.3f}' for p in transition_diag]}")
            
            # Monitor emission parameter spread
            for state in range(self.config.n_states):
                mean_range = (self.emission_means[state].min(), 
                             self.emission_means[state].max())
                logger.info(f"State {state} Emission Range: [{mean_range[0]:.3f}, "
                           f"{mean_range[1]:.3f}]")
    
    

    
    def _compute_gaussian_ll(self, x: np.ndarray, mean: np.ndarray, cov: np.ndarray, weight: float = 1.0) -> np.ndarray:
        """Compute weighted Gaussian log-likelihood with numerical stability"""
        cov_stable = np.clip(cov, self.config.min_std, self.config.max_std)
        diff = x - mean
        exp_term = -0.5 * np.sum((diff ** 2) / cov_stable, axis=1)
        log_norm = -0.5 * np.sum(np.log(2 * np.pi * cov_stable))
        return weight * (exp_term + log_norm)

    
    
    ## Mixture Component Responsibilities
    ## Critical Issue: Not handling feature ranges correctly in probability computation
    ## Fix: Now properly computes probabilities for valid features only
    def _compute_mixture_responsibilities(self, x: np.ndarray, state: int, posteriors: np.ndarray) -> Tuple[np.ndarray, Dict]:
        """Compute responsibilities with enhanced component analysis"""
        
        valid_pos = self._get_valid_positions(x)
        x_valid = x[valid_pos]
        seq_len = len(x); 
        resp = np.zeros((seq_len, self.config.n_mixtures)); 
        mixture_stats = {'component_weights': np.zeros(self.config.n_mixtures), 
                         'feature_contributions': {feat: np.zeros(self.config.n_mixtures) for feat in self.config.feature_config}, 
                         'dominance_patterns': np.zeros(self.config.n_mixtures), 'confidence_metrics': []}
        
        # Handle each mixture component with feature-specific tracking
        for mix in range(self.config.n_mixtures):
            log_probs = np.zeros(seq_len); mixture_feat_contribs = np.zeros(len(self.config.feature_config)); start_idx = 0
            
            # Process each feature group with contribution tracking
            for feat_idx, (feat_type, conf) in enumerate(self.config.feature_config.items()):
                feat_len = conf['end'] - conf['start']; feat_slice = slice(start_idx, start_idx + feat_len)
                feat_contribution = self._compute_feature_likelihood(x[:, feat_slice], self.emission_means[state, mix, feat_slice], self.emission_covs[state, mix, feat_slice], conf['weight'])
                log_probs += feat_contribution; mixture_feat_contribs[feat_idx] = np.mean(np.abs(feat_contribution)); start_idx += feat_len
            
            # Add mixture weight and track statistics
            log_probs += np.log(self.mixture_weights[state, mix] + 1e-10); resp[:, mix] = np.exp(log_probs)
            mixture_stats['component_weights'][mix] = np.mean(resp[:, mix])
            for feat_idx, feat_type in enumerate(self.config.feature_config):
                mixture_stats['feature_contributions'][feat_type][mix] = mixture_feat_contribs[feat_idx]
        
        # Normalize responsibilities and track statistics
        normalizer = resp.sum(axis=1, keepdims=True); resp /= np.maximum(normalizer, 1e-10)
        mixture_stats['dominance_patterns'] = np.mean(resp, axis=0); mixture_stats['confidence_metrics'] = np.max(resp, axis=1)
        
        # Weight by state posteriors and log significant patterns
        resp *= posteriors.reshape(-1, 1)
        if np.max(mixture_stats['dominance_patterns']) > 0.8:
            logger.debug(f"Strong component dominance in state {state}: {mixture_stats['dominance_patterns']}")
        
        return resp, mixture_stats    
    
    
    """
    Training Flow:
    1. forward() and backward() - Core probability computation
    2. _compute_mixture_responsibilities() - Component analysis
    3. _collect_statistics() - Aggregates data from 1 & 2
    4. _process_training_iteration() - Uses 3's output
    5. _update_parameters() - Uses 4's output
    6. train() - Orchestrates 4 & 5
    """

    
    ## Forward Algorithm
    ## Critical Issue: Was processing all positions including invalid ones
    ## Fix: Now properly handles only valid sequence positions
    def forward(self, x: np.ndarray) -> Tuple[np.ndarray, float, Dict]:
        """Forward algorithm with correct position handling and enhanced monitoring"""
        # Find valid positions using one-hot encoding
        one_hot = x[:, :21]
        valid_pos = np.where(np.sum(one_hot, axis=1) > 0)[0]
        seq_len = len(valid_pos)
        
        alpha = np.zeros((seq_len, self.config.n_states))
        scaling = np.zeros(seq_len)
        
        # Track state probabilities with enhanced monitoring
        state_probs = {
            'max_prob': np.zeros(self.config.n_states),
            'min_prob': np.ones(self.config.n_states),
            'mean_prob': np.zeros(self.config.n_states),
            'std_prob': [],
            'stability_metrics': [],  # Added this key
            'detailed': {
                'position_distributions': [],
                'scaling_factors': [],
                'stability_metrics': []
            }
        }
        
        # Get emission likelihoods for valid positions only
        ## old: changed:: emission_probs = self._compute_emission_likelihood(x[valid_pos])
        emission_probs = self._compute_emission_likelihood(x) 

        # Initialize forward variables with monitoring
        alpha[0] = self.state_priors * emission_probs[0]
        scaling[0] = np.sum(alpha[0]) + 1e-10
        alpha[0] /= scaling[0]
        
        # Track initial state probabilities
        for state in range(self.config.n_states):
            state_probs['max_prob'][state] = alpha[0, state]
            state_probs['min_prob'][state] = alpha[0, state]
            state_probs['mean_prob'][state] = alpha[0, state]
        state_probs['detailed']['position_distributions'].append(alpha[0].copy())
        state_probs['detailed']['scaling_factors'].append(scaling[0])
        
        # Forward recursion with enhanced monitoring
        for t in range(1, seq_len):
            for j in range(self.config.n_states):
                alpha[t, j] = np.sum(alpha[t-1] * self.transitions[:, j]) * emission_probs[t, j]
                
                # Update state probability tracking
                state_probs['max_prob'][j] = max(state_probs['max_prob'][j], alpha[t, j])
                state_probs['min_prob'][j] = min(state_probs['min_prob'][j], alpha[t, j])
                state_probs['mean_prob'][j] += alpha[t, j]
            
            scaling[t] = np.sum(alpha[t]) + 1e-10
            alpha[t] /= scaling[t]
            
            # Track distributions and stability
            state_probs['detailed']['position_distributions'].append(alpha[t].copy())
            state_probs['detailed']['scaling_factors'].append(scaling[t])
            state_probs['std_prob'].append(np.std(alpha[t]))
            
            # Monitor stability
            stability_metric = np.mean(np.abs(alpha[t]))
            state_probs['detailed']['stability_metrics'].append(stability_metric)
            
        # Finalize mean probabilities
        state_probs['mean_prob'] /= seq_len
        log_likelihood = np.sum(np.log(scaling))
        
        # Monitor for numerical stability issues
        if np.any(np.isnan(alpha)) or np.any(np.isinf(alpha)):
            logger.warning("Numerical stability issues detected in forward algorithm")
            logger.debug(f"State probability ranges: {state_probs}")


        
        return alpha, log_likelihood, state_probs


    
    ## Backward Algorithm 
    ## Critical Issue: Was using scaling values for all positions, mismatched with forward pass
    ## Fix: Now properly aligned with valid positions from forward pass
    def backward(self, x: np.ndarray, scaling: np.ndarray) -> Tuple[np.ndarray, Dict]:
        """Backward algorithm with proper position handling and stability monitoring"""
        # Get valid positions
        one_hot = x[:, :21]
        valid_pos = np.where(np.sum(one_hot, axis=1) > 0)[0]
        seq_len = len(valid_pos)
        
        beta = np.zeros((seq_len, self.config.n_states))
        emission_probs = self._compute_emission_likelihood(x[valid_pos])
        
        # Enhanced backward statistics tracking
        backward_stats = {
            'max_beta': np.zeros(self.config.n_states),
            'min_beta': np.ones(self.config.n_states),
            'mean_beta': np.zeros(self.config.n_states),
            'stability_metrics': [],
            'detailed': {
                'position_distributions': [],
                'transition_influence': np.zeros((self.config.n_states, self.config.n_states)),
                'scaling_impact': []
            }
        }
        
        # Initialize backward variables
        beta[-1] = 1.0 / scaling[-1]
        backward_stats['detailed']['position_distributions'].append(beta[-1].copy())
        
        # Track initial state
        for state in range(self.config.n_states):
            backward_stats['max_beta'][state] = beta[-1, state]
            backward_stats['min_beta'][state] = beta[-1, state]
            backward_stats['mean_beta'][state] = beta[-1, state]
        
        # Backward recursion with enhanced stability monitoring
        for t in range(seq_len-2, -1, -1):
            for i in range(self.config.n_states):
                # Compute transitions with emission probabilities
                trans_probs = self.transitions[i] * emission_probs[t+1]
                beta_contrib = trans_probs * beta[t+1]
                beta[t, i] = np.sum(beta_contrib)
                
                # Track transition influence
                backward_stats['detailed']['transition_influence'][i] += beta_contrib / (np.sum(beta_contrib) + 1e-10)
                
                # Track state probabilities
                backward_stats['max_beta'][i] = max(backward_stats['max_beta'][i], beta[t, i])
                backward_stats['min_beta'][i] = min(backward_stats['min_beta'][i], beta[t, i])
                backward_stats['mean_beta'][i] += beta[t, i]
            
            # Scale and check stability
            scaling_factor = scaling[t]
            beta[t] /= scaling_factor
            backward_stats['detailed']['scaling_impact'].append(scaling_factor)
            
            # Monitor stability
            stability_metric = np.mean(np.abs(beta[t]))
            backward_stats['stability_metrics'].append(stability_metric)
            
            # Track position distribution
            backward_stats['detailed']['position_distributions'].append(beta[t].copy())
            
            if stability_metric > 1e3 or stability_metric < 1e-3:
                logger.debug(f"Potential stability issue at position {t}: metric = {stability_metric}")
        
        # Finalize statistics
        backward_stats['mean_beta'] /= seq_len
        
        # Check for numerical issues
        if np.any(np.isnan(beta)) or np.any(np.isinf(beta)):
            logger.warning("Numerical instability detected in backward algorithm")
            logger.debug(f"Stability metrics: mean={np.mean(backward_stats['stability_metrics'])}, std={np.std(backward_stats['stability_metrics'])}")
        
        return beta, backward_stats    



    ## Viterbi Algorithm
    ## Critical Issue: Was finding paths through invalid positions
    ## Fix: Now only considers valid sequence positions
    def viterbi(self, x: np.ndarray) -> Tuple[List[int], float, Dict]:
        """Viterbi algorithm with proper position handling and enhanced path analysis"""
        # Get valid positions
        one_hot = x[:, :21]; valid_pos = np.where(np.sum(one_hot, axis=1) > 0)[0]; seq_len = len(valid_pos)
        
        # Initialize tracking structures
        viterbi_vars = np.zeros((seq_len, self.config.n_states)); backpointers = np.zeros((seq_len, self.config.n_states), dtype=np.int32)
        path_stats = {'state_transitions': np.zeros((self.config.n_states, self.config.n_states)), 'state_confidences': [], 'path_probabilities': []}
        
        # Get emission probabilities with monitoring
        emission_probs = self._compute_emission_likelihood(x[valid_pos])
        
        # Initialize with state priors and track confidence
        viterbi_vars[0] = np.log(self.state_priors + 1e-10) + np.log(emission_probs[0] + 1e-10)
        path_stats['state_confidences'].append(np.exp(viterbi_vars[0] - np.max(viterbi_vars[0])))
        
        # Recursion with enhanced monitoring
        for t in range(1, seq_len):
            for j in range(self.config.n_states):
                trans_probs = viterbi_vars[t-1] + np.log(self.transitions[:, j] + 1e-10)
                backpointers[t, j] = np.argmax(trans_probs)
                viterbi_vars[t, j] = trans_probs[backpointers[t, j]] + np.log(emission_probs[t, j] + 1e-10)
            
            # Track state confidences and probabilities
            position_probs = np.exp(viterbi_vars[t] - np.max(viterbi_vars[t]))
            path_stats['state_confidences'].append(position_probs)
            path_stats['path_probabilities'].append(np.max(position_probs))
        
        # Backtrack with transition analysis
        path = [0] * seq_len; path[-1] = np.argmax(viterbi_vars[-1]); path_score = viterbi_vars[-1, path[-1]]
        
        for t in range(seq_len-2, -1, -1):
            path[t] = backpointers[t+1, path[t+1]]
            # Track state transitions in path
            if t < seq_len-1:
                path_stats['state_transitions'][path[t], path[t+1]] += 1
        
        # Compute additional path statistics
        path_stats.update({
            'mean_confidence': np.mean(path_stats['path_probabilities']),
            'min_confidence': np.min(path_stats['path_probabilities']),
            'state_frequencies': np.bincount(path, minlength=self.config.n_states) / len(path),
            'transition_frequencies': path_stats['state_transitions'] / max(1, np.sum(path_stats['state_transitions']))
        })
        
        # Log interesting patterns
        if path_stats['mean_confidence'] < 0.5:
            logger.warning(f"Low confidence path detected: mean={path_stats['mean_confidence']:.3f}")
        
        return path, path_score, path_stats

    
    
            
    def compute_posteriors(self, alpha: np.ndarray, beta: np.ndarray) -> np.ndarray:
        """Compute state posteriors from forward-backward variables with enhanced stability"""
        # Compute raw posteriors with numerical stability handling
        posteriors = alpha * beta; scaling_factors = np.maximum(posteriors.sum(axis=1, keepdims=True), self.config.stability_config['balance_threshold'])
        
        # Normalize with stability threshold and state balance influence
        normalized_posteriors = posteriors / scaling_factors
        
        # Track extreme probability events for monitoring
        if np.any(normalized_posteriors > 0.99) or np.any(normalized_posteriors < 0.01):
            state_dist = normalized_posteriors.mean(axis=0)
            logger.debug(f"Extreme posterior probabilities detected: min={normalized_posteriors.min():.3f}, max={normalized_posteriors.max():.3f}")
            logger.debug(f"State distribution: {[f'{p:.3f}' for p in state_dist]}")
        
        return normalized_posteriors


    ## ----------------------------------------------------------------------------------------------------------------------------------------------------------


    def _debug_state_statistics(self, x: np.ndarray, state: int) -> Dict:
        """Analyze state-specific statistics for debugging"""
        # Get emission probabilities
        emission_probs = self._compute_emission_likelihood(x)
        
        # Analyze state probabilities
        state_stats = { 'max_prob': emission_probs[:, state].max(), 'min_prob': emission_probs[:, state].min(), 'mean_prob': emission_probs[:, state].mean(), 
            'std_prob': emission_probs[:, state].std() }
        
        # Analyze transitions
        state_stats.update({ 'incoming_trans': self.transitions[:, state].copy(), 
                            'outgoing_trans': self.transitions[state, :].copy(), 'self_trans': self.transitions[state, state] })
        
        # Analyze emissions
        for mix in range(self.config.n_mixtures):
            mean_stats = self.emission_means[state, mix]
            cov_stats = self.emission_covs[state, mix]
            state_stats.update({ f'mix_{mix}_mean_range': (mean_stats.min(), mean_stats.max()), f'mix_{mix}_cov_range': (cov_stats.min(), cov_stats.max()) })
        
        return state_stats

        
            



        
    def _validate_sequence(self, seq: np.ndarray, seq_idx: int) -> Optional[np.ndarray]:
        """Helper to validate and extract valid positions from a sequence"""
        one_hot = seq[:, :21]; valid_pos = np.where(np.sum(one_hot, axis=1) > 0)[0]
        return seq[valid_pos] if len(valid_pos) > 0 else None
    
    def _validate_sequences(self, sequences: List[np.ndarray]) -> List[np.ndarray]:
        """Process all sequences and return valid ones"""
        valid_seqs = [seq for idx, seq in enumerate(sequences) if (validated := self._validate_sequence(seq, idx)) is not None]
        logger.info(f"Found {len(valid_seqs)} valid sequences out of {len(sequences)}")
        if not valid_seqs: raise ValueError("No valid sequences found for training")
        return valid_seqs


    
    def train(self, sequences: List[np.ndarray], val_sequences: Optional[List[np.ndarray]] = None, val_labels: Optional[List[np.ndarray]] = None, n_iterations: int = 100, tolerance: float = 1e-4) -> Dict[str, List[float]]:
        """Training with proper sequence validation and comprehensive monitoring"""
        logger.info(f"Starting training with {len(sequences)} sequences"); 
        valid_sequences = self._validate_sequences(sequences)

        # Add diagnostic prints right after model initialization
        logger.info("\nInitial Model State:"); logger.info(f"Transition probabilities:\n{self.transitions}"); logger.info(f"Mixture weights:\n{self.mixture_weights}");
        logger.info(f"State priors:\n{self.state_priors}");
        
        # Monitor first sequence processing
        first_seq = sequences[0]
        valid_pos = self._get_valid_positions(first_seq);    logger.info(f"\nFirst sequence details:");    logger.info(f"Valid positions: {len(valid_pos)}")    
        # Check emission probabilities for first sequence
        emission_probs = self._compute_emission_likelihood(first_seq[valid_pos]);  logger.info(f"Initial emission probability ranges:");
        logger.info(f"Min: {emission_probs.min():.6f}, Max: {emission_probs.max():.6f}"); logger.info(f"Mean per state: {emission_probs.mean(axis=0)}");
    
        
        # Validate val sequences if provided
        valid_val = None if val_sequences is None else self._validate_sequences(val_sequences)
        self.history = self._initialize_history(valid_val is not None)
        
        # Training state tracking
        best_ll, best_params, no_improve, plateau_count = float('-inf'), None, 0, 0
        self._collapse_counter = 0; self._last_warning_time = time.time()
        
        # Enhanced training monitoring
        self.training_stats = {
            'component_evolution': [],  # Track mixture component changes
            'feature_importance': [],   # Track feature contributions
            'state_dynamics': [],       # Track state distribution changes
            'emission_variations': [],  # Track emission parameter changes
            'convergence_metrics': []   # Track detailed convergence metrics
        }
        
        # Main training loop with enhanced monitoring
        for iteration in range(n_iterations):
            # Process iteration and collect comprehensive stats
            iter_stats = self._process_training_iteration(iteration, valid_sequences, self.history, valid_val, val_labels)
            
            # Update enhanced monitoring
            self._update_training_stats(iter_stats)
            
            # Check convergence with modified criteria
            should_stop, best_ll, best_params = self._check_convergence(iter_stats, best_ll, best_params, no_improve, plateau_count, tolerance)
            
            # Early stopping with state distribution validation
            if should_stop and iteration >= 5:
                if best_params: 
                    logger.info("Restoring best parameters before stopping")
                    self._restore_parameters(best_params)
                break
            
            # Update parameters with monitoring
            param_changes = self._update_parameters(iter_stats['train_stats'])
            self.training_stats['emission_variations'].append(param_changes)
            
            # Learning rate adaptation
            self._adjust_learning_rate(iter_stats)
            
            # Periodic detailed analysis
            if iteration % self.config.logging_config['logging_frequency'] == 0:
                self._perform_detailed_analysis(iter_stats)
        
        # Final analysis and logging
        self._log_final_training_results(iter_stats, iteration)
        return self.history, self.training_stats


        
    def _update_training_stats(self, iter_stats: Dict) -> None:
        """Update comprehensive training statistics"""
        # Component evolution tracking
        self.training_stats['component_evolution'].append({
            'weights': [comp['component_weights'] for comp in iter_stats['mixture_stats']],
            'dominance': [comp['dominance_patterns'] for comp in iter_stats['mixture_stats']]
        })
        
        # Feature importance tracking
        self.training_stats['feature_importance'].append({
            'feature_contribs': iter_stats['feature_contributions'],
            'importance_scores': iter_stats['feature_importance']
        })
        
        # State dynamics
        self.training_stats['state_dynamics'].append({
            'distribution': iter_stats['state_dist'],
            'transition_patterns': iter_stats['transition_patterns'],
            'confidence': iter_stats['state_confidence']
        })
        
        # Convergence metrics
        self.training_stats['convergence_metrics'].append({
            'likelihood_change': iter_stats['ll_change'],
            'param_stability': iter_stats['param_stability'],
            'state_balance': iter_stats['state_balance']
        })
    
    def _perform_detailed_analysis(self, iter_stats: Dict) -> None:
        """Perform comprehensive analysis during training"""
        logger.info("\nDetailed Training Analysis:")
        
        # Analyze state dynamics
        state_dist = iter_stats['state_dist']
        logger.info("\nState Distribution Analysis:")
        for i, state in enumerate(['Helix', 'Sheet', 'Coil']):
            logger.info(f"{state}: {state_dist[i]:.3f}")
        
        # Analyze mixture components
        logger.info("\nMixture Component Analysis:")
        for state in range(self.config.n_states):
            comps = iter_stats['mixture_stats'][state]['component_weights']
            logger.info(f"State {state} components: {[f'{w:.3f}' for w in comps]}")
        
        # Analyze feature contributions
        logger.info("\nFeature Contribution Analysis:")
        for feat, score in iter_stats['feature_importance'].items():
            logger.info(f"{feat}: {score:.3f}")
        
        # Monitor emission parameters
        logger.info("\nEmission Parameter Stability:")
        logger.info(f"Mean change: {iter_stats['param_stability']['mean_change']:.3e}")
        logger.info(f"Max change: {iter_stats['param_stability']['max_change']:.3e}")


    
    def _adjust_learning_rate(self, iter_stats: Dict) -> None:
        """Adaptive learning rate adjustment with warmup and stability control"""
        # Initialize tracking if not exists
        if not hasattr(self, '_lr_tracking'):
            self._lr_tracking = { 'warmup_steps': 10, 'min_lr': 1e-5, 'max_lr': self.config.learning_rate * 2.0, 'best_ll': float('-inf'), 'no_improve_count': 0 }
        
        # Get current likelihood and state distribution
        current_ll = iter_stats['total_ll']
        state_dist = iter_stats['state_dist']
        iteration = len(self.history['train_ll']) if hasattr(self, 'history') else 0
        
        # Basic warmup phase
        if iteration < self._lr_tracking['warmup_steps']:
            warmup_factor = (iteration + 1) / self._lr_tracking['warmup_steps']
            self.config.learning_rate *= warmup_factor
            return
        
        # Compute adaptive factors
        stability_factor = 1.0
        balance_factor = 1.0
        
        # State balance based adjustment
        min_state_prob = np.min(state_dist)
        if min_state_prob < self.config.min_state_prob:
            balance_factor = min_state_prob / self.config.min_state_prob
        
        # Performance based adjustment
        if current_ll > self._lr_tracking['best_ll']:
            self._lr_tracking['best_ll'] = current_ll
            self._lr_tracking['no_improve_count'] = 0
        else:
            self._lr_tracking['no_improve_count'] += 1
            stability_factor = 0.95 ** self._lr_tracking['no_improve_count']
        
        # Parameter stability check
        if 'param_stability' in iter_stats and iter_stats['param_stability']['max_change'] > self.config.emission_config['update_clip']:
            stability_factor *= 0.9
        
        # Apply all adjustments
        self.config.learning_rate *= (
            self.config.lr_decay *      # Base decay
            stability_factor *          # Stability adjustment
            np.sqrt(balance_factor)     # State balance influence
        )
        
        # Bound the learning rate
        self.config.learning_rate = np.clip( self.config.learning_rate, self._lr_tracking['min_lr'], self._lr_tracking['max_lr'] )
        
        # Log significant changes
        if stability_factor < 0.95 or balance_factor < 0.9:
            logger.debug(f"Learning rate adjusted: {self.config.learning_rate:.6f} "
                        f"(stability: {stability_factor:.3f}, balance: {balance_factor:.3f})")


            
    

    def _log_iteration_details(self, iter_stats: Dict, iteration: int, has_validation: bool) -> None:
        """Helper to log iteration details in a clean format"""
        state_dist = iter_stats['state_dist']
        logger.info(f"\nIteration {iteration}:")
        logger.info(f"Training Log-Likelihood: {iter_stats['total_ll']:.2f}")
        logger.info(f"State Distribution: {[f'{p:.3f}' for p in state_dist]}")
        logger.info(f"Learning Rate: {self.config.learning_rate:.6f}")
        
        if has_validation:
            val_stats = iter_stats['val_stats']
            logger.info(f"Validation Log-Likelihood: {val_stats['val_ll']:.2f}")
            logger.info(f"Validation Accuracy: {val_stats['val_acc']:.3f}")


        
    def _initialize_history(self, has_validation: bool) -> Dict[str, List]:
        """Initialize history dictionary with required tracking lists"""
        return { 'train_ll': [], 'val_ll': [] if has_validation else None, 'val_accuracy': [] if has_validation else None, 'state_usage': [], 'transition_diag': [] }
    

    ## sequence validation → feature extraction → emission computation → forward/backward → statistics collection → parameter updates
    def _process_training_iteration(self, iteration: int, sequences: List[np.ndarray], history: Dict[str, List], val_sequences: Optional[List[np.ndarray]], val_labels: Optional[List[np.ndarray]]) -> Dict:
        """Process training iteration with enhanced logging and monitoring"""
        # Collect training statistics
        train_stats = self._collect_statistics(sequences)
        total_ll = train_stats['log_likelihood']
        
        # Maintain existing state monitoring with proper handling
        state_usage = train_stats['state_occurences']
        total_valid = np.sum(state_usage) + self.config.stability_config['balance_threshold']
        state_dist = state_usage / total_valid
        
        # Preserve existing history tracking
        history['train_ll'].append(total_ll)
        history['state_usage'].append(state_dist)
        history['transition_diag'].append(np.diag(self.transitions).copy())
        
        # Keep emission range tracking
        history.setdefault('emission_ranges', []).append([
            (self.emission_means[i].min(), self.emission_means[i].max())
            for i in range(self.config.n_states)
        ])
        
        # Maintain existing logging with proper frequency
        logger.info(f"\nIteration {iteration}:")
        logger.info(f"Training Log-Likelihood: {total_ll:.2f}")
        logger.info(f"State Distribution: {[f'{p:.3f}' for p in state_dist]}")
        logger.info(f"Learning Rate: {self.config.learning_rate:.6f}")
        
        # Keep validation handling with all its logic
        val_stats = None
        if val_sequences:
            val_stats = self._compute_validation_metrics(val_sequences, val_labels)
            history['val_ll'].append(val_stats['val_ll'])
            history['val_accuracy'].append(val_stats['val_acc'])
            
            logger.info(f"Validation Log-Likelihood: {val_stats['val_ll']:.2f}")
            logger.info(f"Validation Accuracy: {val_stats['val_acc']:.3f}")
        
        # Preserve detailed state analysis
        if iteration % self.config.logging_config['logging_frequency'] == 0:
            logger.info("\nDetailed State Analysis:")
            for state in range(self.config.n_states):
                stats = self._debug_state_statistics(sequences[0], state)
                logger.info(f"\nState {state}:")
                logger.info(f"  Mean Emission: {stats['mean_prob']:.3f}")
                logger.info(f"  Transitions - In: {stats['incoming_trans']}")
                logger.info(f"  Transitions - Out: {stats['outgoing_trans']}")
                logger.info(f"  Self-Transition: {stats['self_trans']:.3f}")
        
        # Keep state collapse checking with time-based warning control
        if np.any(state_dist < self.config.min_state_prob):
            curr_time = time.time()
            if not hasattr(self, '_last_collapse_warning_time'):
                self._last_collapse_warning_time = 0
                
            if curr_time - self._last_collapse_warning_time >= self.config.logging_config['min_warning_interval']:
                logger.warning(f"State collapse detected: {[f'{p:.4f}' for p in state_dist]}")
                self._log_collapse_warning(state_dist, sequences[0])
                self._last_collapse_warning_time = curr_time
        
        # Integrate mixture monitoring without losing existing stats
        mixture_monitoring = {}
        if 'mixture_monitoring' in train_stats:
            mixture_monitoring = train_stats['mixture_monitoring']
        
        # Return with all existing keys plus mixture monitoring
        return {
            'train_stats': train_stats,
            'val_stats': val_stats,
            'state_dist': state_dist,
            'total_ll': total_ll,
            'mixture_monitoring': mixture_monitoring
        }    



        
    def _compute_validation_metrics(self, val_sequences: List[np.ndarray], 
                                  val_labels: List[np.ndarray]) -> Dict:
        """Compute validation metrics during training"""
        val_ll = sum(self.score(x) for x in val_sequences)
        val_pred = self.predict_batch(val_sequences)
        val_true = [np.argmax(label, axis=1) for label in val_labels]
        val_acc = np.mean([np.mean(p == t) for p, t in zip(val_pred, val_true)])
        
        return {'val_ll': val_ll, 'val_acc': val_acc}


    def _log_iteration_stats(self, iteration: int, train_ll: float, 
                            val_stats: Dict, state_dist: np.ndarray) -> None:
        """Log training iteration statistics"""
        logger.info(f"Iteration {iteration}:")
        logger.info(f"  Train LL: {train_ll:.2f}")
        logger.info(f"  Val LL: {val_stats['val_ll']:.2f}")
        logger.info(f"  Val Acc: {val_stats['val_acc']:.3f}")
        logger.info(f"  State Distribution: {[f'{p:.3f}' for p in state_dist]}")
        
    
    def _check_convergence(self, iter_stats: Dict, best_ll: float, best_params: Optional[Dict],
                          no_improve: int, plateau_count: int, tolerance: float) -> Tuple[bool, float, Optional[Dict]]:
        """Modified convergence check with more tolerant early stopping"""
        should_stop = False
        val_stats = iter_stats['val_stats']
        current_iteration = len(self.history['train_ll']) if hasattr(self, 'history') else 0
        
        # Allow some initial iterations before checking state collapse
        min_iterations = 5  # Give at least 5 iterations to stabilize
        
        if val_stats:
            if val_stats['val_ll'] > best_ll + tolerance:
                best_ll = val_stats['val_ll']
                best_params = self._get_model_params()
                no_improve = 0
                plateau_count = 0
            else:
                no_improve += 1
                if abs(val_stats['val_ll'] - best_ll) < tolerance:
                    plateau_count += 1
                else:
                    plateau_count = 0
            
            # Check stopping criteria with more tolerance
            if current_iteration >= min_iterations:
                state_dist = iter_stats['state_dist']
                if no_improve >= 10:  # Increased from 5
                    logger.info("Early stopping: No improvement for 10 iterations")
                    should_stop = True
                elif plateau_count >= 5:  # Increased from 3
                    logger.info("Early stopping: Detected convergence plateau")
                    should_stop = True
                elif np.any(state_dist < self.config.min_state_prob) and best_params:
                    # Check if state collapse persists for multiple iterations
                    if hasattr(self, '_collapse_counter'):
                        self._collapse_counter += 1
                    else:
                        self._collapse_counter = 1
                    
                    if self._collapse_counter >= 3:  # Require persistent collapse
                        logger.info("Early stopping: Persistent state collapse detected")
                        should_stop = True
                else:
                    self._collapse_counter = 0  # Reset counter if distribution improves
        
        return should_stop, best_ll, best_params


    
        
    def _log_state_analysis(self, sample_sequence: np.ndarray) -> None:
        """Log detailed state analysis"""
        logger.info("\nDetailed State Analysis:")
        for state in range(self.config.n_states):
            stats = self._debug_state_statistics(sample_sequence, state)
            logger.info(f"\nState {state}: Emission mean={stats['mean_prob']:.3f}, "
                       f"std={stats['std_prob']:.3f}")
            logger.info(f"Transitions - In: {stats['incoming_trans']}, "
                       f"Out: {stats['outgoing_trans']}")
    
    def _log_collapse_warning(self, state_dist: np.ndarray, sample_sequence: np.ndarray) -> None:
        """Log detailed warning when state collapse is detected"""
        logger.warning(f"State collapse detected: {[f'{p:.4f}' for p in state_dist]}")
        logger.warning("State Details:")
        for state in range(self.config.n_states):
            if state_dist[state] < self.config.min_state_prob:
                stats = self._debug_state_statistics(sample_sequence, state)
                logger.warning(f"Collapsed State {state}:")
                logger.warning(f"  - Mean Emission: {stats['mean_prob']:.4f}")
                logger.warning(f"  - Incoming Trans: {stats['incoming_trans']}")
                logger.warning(f"  - Self Trans: {stats['self_trans']:.4f}")    
        
        


    
    ## Collect Statistics Method
    ## Critical Issue: Was accumulating statistics for all positions
    ## Fix: Now only processes valid sequence positions
    def _collect_statistics(self, sequences: List[np.ndarray]) -> Dict:
        """Collect sufficient statistics with enhanced monitoring and validation"""
        stats = {
            'transition_counts': np.zeros((self.config.n_states, self.config.n_states)), 
            'emission_stats': {
                'means_num': np.zeros((self.config.n_states, self.config.n_mixtures, self.config.n_features)),
                'covs_num': np.zeros((self.config.n_states, self.config.n_mixtures, self.config.n_features)),
                'weights_num': np.zeros((self.config.n_states, self.config.n_mixtures)),
                'feature_contributions': np.zeros((self.config.n_states, len(self.config.feature_config))),
                'component_usage': np.zeros((self.config.n_states, self.config.n_mixtures))
            },
            'state_occurences': np.zeros(self.config.n_states),
            'log_likelihood': 0.0,
            'sequence_stats': [],
            'stability_metrics': {'forward': [], 'backward': [], 'mixture': []}
        }
        
        for seq_idx, x in enumerate(sequences):
            ## valid: one_hot = x[:, :21]; valid_pos = np.where(np.sum(one_hot, axis=1) > 0)[0]
            valid_pos = self._get_valid_positions(x)
            if len(valid_pos) == 0: logger.warning(f"Skipping empty sequence {seq_idx}"); continue
                
            # Forward-backward pass with full statistics collection
            x_valid = x[valid_pos]; alpha, seq_ll, forward_stats = self.forward(x_valid)
            if np.isnan(seq_ll) or np.isinf(seq_ll): logger.warning(f"Invalid log-likelihood for sequence {seq_idx}, skipping"); continue
            
            stats['log_likelihood'] += seq_ll
            beta, backward_stats = self.backward(x_valid, alpha.sum(axis=1))
            posteriors = self.compute_posteriors(alpha, beta)


            # Update sequence-level statistics
            seq_stats = {'length': len(valid_pos), 'state_dist': posteriors.sum(axis=0) / len(valid_pos), 
                         'forward_stats': forward_stats, 'backward_stats': backward_stats}
            stats['sequence_stats'].append(seq_stats)

            # Properly access nested stability metrics
            stats['stability_metrics']['forward'].extend(forward_stats['detailed']['stability_metrics'])
            if 'detailed' in backward_stats and 'stability_metrics' in backward_stats['detailed']:
                stats['stability_metrics']['backward'].extend(backward_stats['detailed']['stability_metrics'])
                
            # Update transition and emission statistics with mixture responsibilities
            stats['state_occurences'] += posteriors.sum(axis=0)
            for t in range(len(valid_pos) - 1):
                for i in range(self.config.n_states):
                    for j in range(self.config.n_states):
                        trans_prob = posteriors[t, i] * self.transitions[i, j] * self._compute_emission_likelihood(x_valid[t+1:t+2])[0, j] * posteriors[t+1, j]
                        stats['transition_counts'][i, j] += trans_prob
            
            # Update emission statistics with mixture information
            self._update_emission_statistics(x_valid, posteriors, stats['emission_stats'])
            
            if seq_idx % 100 == 0: logger.debug(f"Processing sequence {seq_idx}, length={len(valid_pos)}")
        
        return stats

    

    
    def _update_sequence_statistics(self, x: np.ndarray, posteriors: np.ndarray, stats: Dict) -> None:
        """Update sufficient statistics for a single sequence"""
        seq_len = len(x)
        emission_probs = self._compute_emission_likelihood(x)
        
        # Update state occurrences
        stats['state_occurences'] += posteriors.sum(axis=0)
        
        # Update transition counts
        for t in range(seq_len - 1):
            for i in range(self.config.n_states):
                for j in range(self.config.n_states):
                    stats['transition_counts'][i, j] += (
                        posteriors[t, i] * self.transitions[i, j] * 
                        emission_probs[t+1, j] * posteriors[t+1, j]
                    )
        
        # Update emission statistics
        self._update_emission_statistics(x, posteriors, stats['emission_stats'])

    
    ## Update Emission Statistics
    ## Critical Issue: Was processing features incorrectly
    ## Fix: Now properly handles feature ranges and valid positions
    def _update_emission_statistics(self, x: np.ndarray, posteriors: np.ndarray, emission_stats: Dict) -> None:
        """Update emission statistics with proper feature handling and mixture tracking"""

        valid_pos = self._get_valid_positions(x)
        x_valid = x[valid_pos]
        x = x_valid
        for state in range(self.config.n_states):
            # Get mixture responsibilities and track usage patterns
            mix_resp, mix_stats = self._compute_mixture_responsibilities(x, state, posteriors[:, state])
            emission_stats['component_usage'][state] += mix_stats['dominance_patterns']

            feature_sizes = {
                    'one_hot': 21,  # 21 amino acids
                    'pssm': 21,    # 21 PSSM scores
                    'aux': 5       # 5 auxiliary features
                }
                
            # Update per-mixture statistics with feature-specific handling
            for mix in range(self.config.n_mixtures):
                resp = mix_resp[:, mix].reshape(-1, 1)
                emission_stats['weights_num'][state, mix] += resp.sum()
                
                # Update means and covariances for each feature type
                curr_pos = 0
                for feat_type, size in feature_sizes.items():
                    feat_slice = slice(curr_pos, curr_pos + size)
                    # Update means
                    emission_stats['means_num'][state, mix, feat_slice] += (resp * x[:, feat_slice]).sum(axis=0)
                    # Update covariances
                    diff = x[:, feat_slice] - self.emission_means[state, mix, feat_slice]
                    emission_stats['covs_num'][state, mix, feat_slice] += (resp * (diff ** 2)).sum(axis=0)
                    # Update feature contributions
                    emission_stats['feature_contributions'][state] += mix_stats['feature_contributions'][feat_type][mix]
                    
                    curr_pos += size
    

    
    def _get_model_params(self) -> Dict:
        """Get current model parameters for checkpointing"""
        return {
            'transitions': self.transitions.copy(),
            'emission_means': self.emission_means.copy(),
            'emission_covs': self.emission_covs.copy(),
            'mixture_weights': self.mixture_weights.copy()
        }

    def _restore_parameters(self, params: Dict) -> None:
        """Restore model parameters from checkpoint"""
        self.transitions = params['transitions']
        self.emission_means = params['emission_means']
        self.emission_covs = params['emission_covs']
        self.mixture_weights = params['mixture_weights']



    """Compute log-likelihood score for a sequence"""
    def score(self, x: np.ndarray) -> float:
        # Forward pass to get likelihood
        ## _, log_likelihood = ....
        
        alpha, log_likelihood, state_probs = self.forward(x)

        # Check for numerical issues
        if np.isnan(log_likelihood) or np.isinf(log_likelihood):
            logger.warning(f"Invalid log-likelihood: {log_likelihood}")
            return float('-inf')
        
        return log_likelihood




            
    def predict(self, x: np.ndarray, method: str = 'viterbi') -> np.ndarray:
        """Predict secondary structure states with confidence tracking"""
        # Validate and process sequence
        one_hot = x[:, :21]; valid_pos = np.where(np.sum(one_hot, axis=1) > 0)[0]
        if len(valid_pos) == 0: raise ValueError("Empty sequence provided")
        x_valid = x[valid_pos]; prediction_stats = {'confidence': [], 'method': method, 'sequence_length': len(valid_pos)}
        
        # Get predictions based on method
        if method == 'viterbi':
            path, score, path_stats = self.viterbi(x_valid)
            predictions = np.array(path)
            prediction_stats.update({'path_score': score, 'path_stats': path_stats, 'confidence': path_stats['path_probabilities']})
        elif method == 'posterior':
            alpha, _, forward_stats = self.forward(x_valid)
            beta, backward_stats = self.backward(x_valid, alpha.sum(axis=1))
            posteriors = self.compute_posteriors(alpha, beta)
            predictions = np.argmax(posteriors, axis=1)
            prediction_stats.update({'posteriors': posteriors, 'max_probs': np.max(posteriors, axis=1), 'confidence': np.max(posteriors, axis=1)})
        else:
            raise ValueError(f"Unknown prediction method: {method}")
        
        # Monitor prediction confidence
        mean_conf = np.mean(prediction_stats['confidence'])
        if mean_conf < 0.5: logger.warning(f"Low prediction confidence: {mean_conf:.3f}")
        
        return predictions, prediction_stats

    
        
    def predict_batch(self, sequences: List[np.ndarray], method: str = 'viterbi') -> Tuple[List[np.ndarray], Dict]:
        """Batch prediction with comprehensive statistics tracking"""
        predictions = []; batch_stats = {'confidence': [], 'lengths': [], 'method_stats': {method: {'success': 0, 'failed': 0}}}; logger.info(f"Processing batch of {len(sequences)} sequences")
        
        for i, seq in enumerate(sequences):
            try:
                pred, pred_stats = self.predict(seq, method)
                predictions.append(pred)
                batch_stats['confidence'].append(pred_stats['confidence'])
                batch_stats['lengths'].append(pred_stats['sequence_length'])
                batch_stats['method_stats'][method]['success'] += 1
                
                if i % 100 == 0: logger.debug(f"Processed {i + 1} sequences, avg confidence: {np.mean(batch_stats['confidence'][-100:]):.3f}")
                
            except Exception as e:
                logger.warning(f"Failed to process sequence {i}: {str(e)}")
                batch_stats['method_stats'][method]['failed'] += 1
                continue
        
        # Compute batch statistics
        batch_stats.update({
            'mean_confidence': np.mean([np.mean(conf) for conf in batch_stats['confidence']]),
            'sequence_length_stats': {'mean': np.mean(batch_stats['lengths']), 'std': np.std(batch_stats['lengths'])},
            'success_rate': batch_stats['method_stats'][method]['success'] / len(sequences)
        })
        
        if batch_stats['mean_confidence'] < 0.6: logger.warning(f"Low average batch confidence: {batch_stats['mean_confidence']:.3f}")
        
        return predictions, batch_stats


    




    
    def evaluate(self, sequences: List[np.ndarray], true_labels: List[np.ndarray]) -> Dict[str, float]:
        """Evaluate model with comprehensive metrics and state-wise analysis"""
        predictions, batch_stats = self.predict_batch(sequences); metrics = {}
        
        # Flatten predictions and labels with proper handling
        y_pred = np.concatenate(predictions); y_true = np.concatenate([np.argmax(label, axis=1) if label.ndim > 1 else label for label in true_labels])
        state_metrics = {state: {'tp': 0, 'fp': 0, 'fn': 0} for state in range(self.config.n_states)}
        
        # Compute metrics with state-wise tracking
        metrics['accuracy'] = np.mean(y_pred == y_true)
        for state in range(self.config.n_states):
            state_metrics[state]['tp'] = np.sum((y_pred == state) & (y_true == state))
            state_metrics[state]['fp'] = np.sum((y_pred == state) & (y_true != state))
            state_metrics[state]['fn'] = np.sum((y_pred != state) & (y_true == state))
            metrics[f'state_{state}_precision'] = state_metrics[state]['tp'] / (state_metrics[state]['tp'] + state_metrics[state]['fp'] + 1e-10)
            metrics[f'state_{state}_recall'] = state_metrics[state]['tp'] / (state_metrics[state]['tp'] + state_metrics[state]['fn'] + 1e-10)
            metrics[f'state_{state}_f1'] = 2 * metrics[f'state_{state}_precision'] * metrics[f'state_{state}_recall'] / (metrics[f'state_{state}_precision'] + metrics[f'state_{state}_recall'] + 1e-10)
        
        # Log results and warnings
        logger.info(f"\nEvaluation Results:\nAccuracy: {metrics['accuracy']:.3f}")
        for state, name in enumerate(['Helix', 'Sheet', 'Coil']): logger.info(f"{name} F1: {metrics[f'state_{state}_f1']:.3f}")
        if any(metrics[f'state_{s}_f1'] < 0.3 for s in range(self.config.n_states)): logger.warning("Low performance detected for some states")
        
        return {**metrics, 'batch_stats': batch_stats}
        

    
    
    def save_model(self, filepath: str) -> None:
        """Save model parameters and stats with version tracking"""
        save_dict = {
            'config': self.config,
            'state_priors': self.state_priors,
            'transitions': self.transitions,
            'emission_means': self.emission_means,
            'emission_covs': self.emission_covs,
            'mixture_weights': self.mixture_weights,
            'model_stats': {
                'state_distributions': np.diag(self.transitions).copy(),
                'emission_ranges': {'means': (self.emission_means.min(), self.emission_means.max()), 'covs': (self.emission_covs.min(), self.emission_covs.max())},
                'timestamp': datetime.now().strftime('%Y%m%d_%H%M')
            }
        }
        
        Path(filepath).parent.mkdir(parents=True, exist_ok=True)
        joblib.dump(save_dict, filepath)
        logger.info(f"Model saved to {filepath} with state distribution: {[f'{p:.3f}' for p in np.diag(self.transitions)]}")
    
    @classmethod
    def load_model(cls, filepath: str) -> 'MixtureGaussianHMM':
        """Load model with parameter validation"""
        logger.info(f"Loading model from {filepath}")
        saved_dict = joblib.load(filepath)
        
        # Validate loaded parameters
        required_keys = ['config', 'state_priors', 'transitions', 'emission_means', 'emission_covs', 'mixture_weights']
        if not all(key in saved_dict for key in required_keys):
            raise ValueError("Invalid model file: missing required parameters")
        
        # Initialize and restore model
        model = cls(saved_dict['config'])
        model.state_priors = saved_dict['state_priors']
        model.transitions = saved_dict['transitions']
        model.emission_means = saved_dict['emission_means']
        model.emission_covs = saved_dict['emission_covs']
        model.mixture_weights = saved_dict['mixture_weights']
        
        # Validate parameter shapes and distributions
        if np.any(model.transitions < 0) or not np.allclose(model.transitions.sum(axis=1), 1.0):
            logger.warning("Loaded transitions may be invalid")
        if np.any(model.emission_covs <= 0):
            logger.warning("Invalid emission covariances detected")
            
        stats = saved_dict.get('model_stats', {})
        logger.info(f"Model loaded with state distribution: {[f'{p:.3f}' for p in stats.get('state_distributions', np.diag(model.transitions))]}")
        
        return model

    
## ------------------------------------------------------------------------------------------------------------------------------------------------------------

def preprocess_npy_file(input_path: str, output_path: str = None) -> str:
    """Convert Python 2 NPY file to Python 3 format and save"""
    if output_path is None:
        output_path = input_path.replace('.npy', '_py3.npy')
    
    logger.info(f"Converting NPY file from Python 2 to Python 3 format")
    # Load and immediately save in new format
    data = np.load(input_path)
    np.save(output_path, data)
    logger.info(f"Saved converted file to {output_path}")
    
    return output_path



def visualize_state_predictions(model: 'MixtureGaussianHMM', sequence: np.ndarray, true_labels: np.ndarray = None, save_path: str = None) -> None:
    """Visualize predictions with state posteriors and confidence"""
    predictions, pred_stats = model.predict(sequence); posteriors = pred_stats.get('posteriors', model.compute_posteriors(*model.forward(sequence)[:2]))
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 8), gridspec_kw={'height_ratios': [1, 3]}); plt.tight_layout(pad=3.0)
    
    # Plot predictions with confidence bands
    ax1.plot(predictions, 'k-', label='Predicted', alpha=0.7)
    if true_labels is not None: ax1.plot(true_labels if true_labels.ndim == 1 else np.argmax(true_labels, axis=1), 'r--', label='True', alpha=0.5)
    ax1.fill_between(range(len(predictions)), predictions - 0.2, predictions + 0.2, alpha=0.2, color='gray', label='Confidence')
    ax1.set_title('Structure Predictions with Confidence'); ax1.set_ylabel('State'); ax1.legend()
    
    # Plot state posteriors heatmap
    im = ax2.imshow(posteriors.T, aspect='auto', cmap='viridis', interpolation='nearest')
    ax2.set_title('State Posteriors'); ax2.set_xlabel('Position'); ax2.set_ylabel('State')
    plt.colorbar(im, ax=ax2, label='Probability')
    
    if save_path: plt.savefig(save_path); logger.info(f"Saved prediction visualization to {save_path}")
    plt.show()  # Display plot
    plt.close()


def plot_training_progress(history: Dict[str, List], save_path: str = None) -> None:
    """Plot training metrics in a concise layout"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10)); plt.tight_layout(pad=3.0)
    
    # Log likelihood progress
    ax1.plot(history['train_ll'], label='Train', color='blue', alpha=0.7)
    if history.get('val_ll'): ax1.plot(history['val_ll'], label='Validation', color='red', alpha=0.7)
    ax1.set_title('Log Likelihood'); ax1.set_xlabel('Iteration'); ax1.set_ylabel('LL'); ax1.legend()
    
    # State distribution evolution
    state_usage = np.array(history['state_usage'])
    for i, label in enumerate(['Helix', 'Sheet', 'Coil']): ax2.plot(state_usage[:, i], label=label, alpha=0.7)
    ax2.axhline(y=0.15, color='r', linestyle='--', alpha=0.3, label='Min Threshold')
    ax2.set_title('State Distribution'); ax2.set_xlabel('Iteration'); ax2.set_ylabel('Probability'); ax2.legend()
    
    # Self-transition evolution
    trans_diag = np.array(history['transition_diag'])
    for i, label in enumerate(['H→H', 'E→E', 'C→C']): ax3.plot(trans_diag[:, i], label=label, alpha=0.7)
    ax3.set_title('Self-Transitions'); ax3.set_xlabel('Iteration'); ax3.set_ylabel('Probability'); ax3.legend()
    
    # Emission parameter ranges
    emission_ranges = np.array(history['emission_ranges'])
    for i, label in enumerate(['Helix', 'Sheet', 'Coil']): 
        means = np.mean([r[i][1] - r[i][0] for r in emission_ranges])
        ax4.plot([r[i][1] - r[i][0] for r in emission_ranges], label=f'{label} (μ={means:.2f})', alpha=0.7)
    ax4.set_title('Emission Spread'); ax4.set_xlabel('Iteration'); ax4.set_ylabel('Range'); ax4.legend()
    
    if save_path: plt.savefig(save_path); logger.info(f"Saved training visualization to {save_path}")
    plt.show()  # Display plot    
    plt.close()




def train_and_evaluate_model(train_seqs: List[np.ndarray], val_seqs: List[np.ndarray], train_labels: List[np.ndarray], val_labels: List[np.ndarray], config: ModelConfig) -> Tuple[MixtureGaussianHMM, Dict]:
    """Train and evaluate model with validation"""
    model = MixtureGaussianHMM(config); logger.info("Starting model training...")
    history = model.train(train_seqs, val_sequences=val_seqs, val_labels=val_labels, n_iterations=100)
    val_metrics = model.evaluate(val_seqs, val_labels)
    timestamp = datetime.now().strftime('%Y%m%d_%H%M')
    model.save_model(f'models/hmm_model_{timestamp}.joblib')
    return model, val_metrics



def evaluate_detailed(model: MixtureGaussianHMM, sequences: List[np.ndarray], 
                     labels: List[np.ndarray]) -> Dict[str, float]:
    """Enhanced evaluation with detailed metrics and state analysis"""
    predictions = model.predict_batch(sequences)
    metrics = {}
    
    # Convert one-hot labels to indices if needed
    processed_labels = []
    for label in labels:
        if label.ndim == 2 and label.shape[1] > 1:
            processed_labels.append(np.argmax(label, axis=1))
        else:
            processed_labels.append(label)
    
    # Overall metrics
    y_pred = np.concatenate(predictions)
    y_true = np.concatenate(processed_labels)
    metrics['accuracy'] = np.mean(y_pred == y_true)
    
    # Per-state metrics (including F1)
    state_names = ['Helix', 'Sheet', 'Coil']
    for state, name in enumerate(state_names):
        true_pos = np.sum((y_pred == state) & (y_true == state))
        false_pos = np.sum((y_pred == state) & (y_true != state))
        false_neg = np.sum((y_pred != state) & (y_true == state))
        
        precision = true_pos / (true_pos + false_pos + 1e-10)
        recall = true_pos / (true_pos + false_neg + 1e-10)
        f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
        
        metrics[f'{name}_precision'] = precision
        metrics[f'{name}_recall'] = recall
        metrics[f'{name}_f1'] = f1
        
    return metrics




def save_metrics(metrics: Dict[str, float], save_path: str = 'results') -> None:
    """Save evaluation metrics to file"""
    Path(save_path).mkdir(parents=True, exist_ok=True)
    timestamp = datetime.now().strftime('%Y%m%d_%H%M')
    
    with open(f'{save_path}/metrics_{timestamp}.txt', 'w') as f:
        f.write("=== Protein Structure Prediction Results ===\n\n")
        
        # Overall metrics
        f.write(f"Overall Accuracy: {metrics['accuracy']:.3f}\n")
        f.write(f"Mean Sequence Accuracy: {metrics['mean_seq_accuracy']:.3f} (±{metrics['std_seq_accuracy']:.3f})\n")
        f.write(f"Sequence Accuracy Range: [{metrics['min_seq_accuracy']:.3f}, {metrics['max_seq_accuracy']:.3f}]\n\n")
        
        # Per-state metrics
        f.write("Per-State Metrics:\n")
        for state in ['Helix', 'Sheet', 'Coil']:
            f.write(f"\n{state}:\n")
            f.write(f"  Precision:    {metrics[f'{state}_precision']:.3f}\n")
            f.write(f"  Recall:       {metrics[f'{state}_recall']:.3f}\n")
            f.write(f"  F1 Score:     {metrics[f'{state}_f1']:.3f}\n")
            f.write(f"  Specificity:  {metrics[f'{state}_specificity']:.3f}\n")








# """
# data_path = r"C:\Users\joems\OneDrive\Desktop\MLCV Project Items\Machine Learning CS6140\dataset\CB513.npy"
# """

def load_and_validate_data(data_path: str) -> np.ndarray:
    """Load NPY data with Python 2/3 compatibility"""
    try:
        with warnings.catch_warnings(record=True) as w:
            data = np.load(data_path)
            if len(w) > 0 and issubclass(w[-1].category, UserWarning):
                py3_path = data_path.replace('.npy', '_py3.npy'); np.save(py3_path, data); data = np.load(py3_path)
                logger.info(f"Converted and saved Python 3 format to {py3_path}")
        return data
    except Exception as e: logger.error(f"Data loading failed: {str(e)}"); return None

def preprocess_sequences(data: np.ndarray) -> Tuple[List[np.ndarray], List[np.ndarray]]:
    """Process raw data into sequences and labels"""
    sequences, labels = [], []; processor = ProteinFeatures(data)
    try:
        while True:
            try:
                features, label = processor.extract_features()
                if features.shape[0] > 0: sequences.append(features); labels.append(label)
            except StopIteration: break
            except ValueError as e: logger.warning(str(e)); continue
    except Exception as e: logger.error(f"Sequence processing failed: {str(e)}")
    logger.info(f"Successfully processed {len(sequences)} sequences")
    return sequences, labels


def split_dataset(sequences: List[np.ndarray], labels: List[np.ndarray], val_size: float = 0.3, test_size: float = 0.5, random_state: int = 42) -> Tuple[List[np.ndarray], ...]:
    """Split data into train, validation, and test sets"""
    train_seqs, temp_seqs, train_labels, temp_labels = train_test_split(sequences, labels, test_size=val_size, random_state=random_state)
    val_seqs, test_seqs, val_labels, test_labels = train_test_split(temp_seqs, temp_labels, test_size=test_size, random_state=random_state)
    return train_seqs, val_seqs, test_seqs, train_labels, val_labels, test_labels

def plot_training_metrics(history: Dict[str, List[float]], save_dir: str = 'results') -> None:
    """Plot and save training metrics visualization"""
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot likelihoods
    ax1.plot(history['train_ll'], label='Train', color='blue', alpha=0.7)
    if history['val_ll']: ax1.plot(history['val_ll'], label='Validation', color='red', alpha=0.7)
    ax1.set_title('Training Progress (Log Likelihood)'); ax1.set_xlabel('Iteration'); ax1.set_ylabel('Log Likelihood'); ax1.legend()
    
    # Plot state usage
    state_usage = np.array(history['state_usage'])
    ax2.plot(state_usage, alpha=0.7); ax2.set_title('State Usage Distribution')
    ax2.set_xlabel('Iteration'); ax2.set_ylabel('Probability'); ax2.legend(['H', 'E', 'C'])
    
    plt.tight_layout()
    plt.savefig(f"{save_dir}/training_metrics.png")
    plt.close()

def save_evaluation_results(metrics: Dict[str, float], history: Dict[str, List[float]], save_dir: str = 'results') -> None:
    """Save evaluation metrics and training history"""
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    results = { 'final_metrics': metrics, 'training_history': {k: v for k, v in history.items() if isinstance(v, list)}
    }
    np.save(f"{save_dir}/evaluation_results.npy", results)


def main() -> Tuple[MixtureGaussianHMM, Dict[str, float], Dict[str, List[float]]]:
    """Main execution pipeline with modular components"""
    start_time = time.time(); logger.info("Starting protein structure prediction pipeline")
    
    # Load and process data
    data_path = r"C:\Users\joems\OneDrive\Desktop\MLCV Project Items\Machine Learning CS6140\dataset\CB513.npy"
    data = load_and_validate_data(data_path)
    
    if data is None: return None, None, None
    sequences, labels = process_sequences(data)
    if not sequences: return None, None, None

    # Split dataset with stratification
    train_seqs, val_seqs, test_seqs, train_labels, val_labels, test_labels = split_dataset( sequences, labels, val_size=0.2, test_size=0.2 )
    
    # Initialize configuration and model
    config = ModelConfig()  # Use default initialization
    # Optional: Modify config parameters if needed
    config.learning_rate = 0.001
    config.lr_decay = 0.98
    config.min_std = 0.1
    config.max_std = 2.0
    
    model = MixtureGaussianHMM(config);         
    
    # Add verification prints here
    print("State priors:", model.state_priors); print("Transition matrix shape:", model.transitions.shape);
    print("Emission means shape:", model.emission_means.shape);  print("Mixture weights shape:", model.mixture_weights.shape);
    logger.info("Starting model training...");
    
    history = model.train( train_seqs, val_sequences=val_seqs, val_labels=val_labels,
        n_iterations=100, tolerance=1e-4 )
    
    # Evaluate and save results
    test_metrics = evaluate_detailed(model, test_seqs, test_labels)
    plot_training_metrics(history)
    save_evaluation_results(test_metrics, history)
    
    # Log results
    logger.info("\nFinal Results:")
    logger.info(f"Test Accuracy: {test_metrics['accuracy']:.3f}")
    for state in ['Helix', 'Sheet', 'Coil']:
        logger.info(f"{state} Metrics:")
        logger.info(f"  F1: {test_metrics[f'{state}_f1']:.3f}")
        logger.info(f"  Precision: {test_metrics[f'{state}_precision']:.3f}")
        logger.info(f"  Recall: {test_metrics[f'{state}_recall']:.3f}")

    # Log final results
    logger.info("\nTraining Results:")
    logger.info(f"Final Train LL: {history['train_ll'][-1]:.2f}")
    if history['val_ll']:
        logger.info(f"Final Val LL: {history['val_ll'][-1]:.2f}")
    logger.info(f"Test Accuracy: {test_metrics['accuracy']:.3f}")
    
    logger.info(f"Per-state metrics:")
    state_names = ['Helix', 'Sheet', 'Coil']
    for name in state_names:
        logger.info(f"{name} Metrics:")
        logger.info(f"  Precision: {test_metrics[f'{name}_precision']:.3f}")
        logger.info(f"  Recall: {test_metrics[f'{name}_recall']:.3f}")

    
    # Save model if performance is good enough
    if test_metrics['accuracy'] > 0.5:  # Adjust threshold as needed
        model_path = f"models/hmm_model_{time.strftime('%Y%m%d_%H%M')}.joblib"
        model.save_model(model_path)
        logger.info(f"Model saved to {model_path}")
    
    return model, test_metrics, history

if __name__ == "__main__":
    model, metrics, history = main()




09:38:39 | INFO | Starting protein structure prediction pipeline
09:38:40 | INFO | Converted and saved Python 3 format to C:\Users\joems\OneDrive\Desktop\MLCV Project Items\Machine Learning CS6140\dataset\CB513_py3.npy
09:38:40 | INFO | Initializing feature extraction for sequence shape: (514, 39900)
09:38:40 | INFO | Successfully processed 514 sequences
09:38:40 | WARNING | Found near-zero probabilities in transition matrix. Adjusting...
09:38:40 | INFO | Initialized HMM with 3 states and 3 mixtures per state
09:38:40 | INFO | Starting model training...
09:38:40 | INFO | Starting training with 411 sequences
09:38:40 | INFO | Found 411 valid sequences out of 411
09:38:40 | INFO | 
Initial Model State:
09:38:40 | INFO | Transition probabilities:
[[9.00981089e-01 9.90089109e-02 9.99990000e-06]
 [1.00000000e-01 6.70000000e-01 2.30000000e-01]
 [1.00000000e-01 5.10000000e-01 3.90000000e-01]]
09:38:40 | INFO | Mixture weights:
[[0.46 0.35 0.19]
 [0.46 0.35 0.19]
 [0.46 0.35 0.19]]
09:38:40 | INFO | State priors:
[0.492 0.162 0.346]
09:38:40 | INFO | 
First sequence details:
09:38:40 | INFO | Valid positions: 69
09:38:40 | WARNING | State probability collapse detected:
09:38:40 | WARNING | State distributions: ['0.000013', '0.011012', '0.988836']
09:38:40 | WARNING | Min allowed prob: 0.016
09:38:40 | WARNING | Max allowed prob: 0.047
09:38:40 | INFO | Initial emission probability ranges:
09:38:40 | INFO | Min: 0.000000, Max: 0.999860
09:38:40 | INFO | Mean per state: [1.29035766e-05 1.10123848e-02 9.88836291e-01]
09:38:40 | INFO | Found 82 valid sequences out of 82
State priors: [0.492 0.162 0.346]
Transition matrix shape: (3, 3)
Emission means shape: (3, 3, 46)
Mixture weights shape: (3, 3)
09:38:42 | WARNING | State probability collapse detected:
09:38:42 | WARNING | State distributions: ['0.000000', '0.000037', '0.999823']
09:38:42 | WARNING | Min allowed prob: 0.016
09:38:42 | WARNING | Max allowed prob: 0.047
09:38:44 | WARNING | State probability collapse detected:
09:38:44 | WARNING | State distributions: ['0.000000', '0.000048', '0.999812']
09:38:44 | WARNING | Min allowed prob: 0.016
09:38:44 | WARNING | Max allowed prob: 0.047
09:38:46 | WARNING | State probability collapse detected:
09:38:46 | WARNING | State distributions: ['0.000002', '0.042266', '0.957599']
09:38:46 | WARNING | Min allowed prob: 0.016
09:38:46 | WARNING | Max allowed prob: 0.047
09:38:48 | WARNING | State probability collapse detected:
09:38:48 | WARNING | State distributions: ['0.000002', '0.033224', '0.966639']
09:38:48 | WARNING | Min allowed prob: 0.016
09:38:48 | WARNING | Max allowed prob: 0.047
09:38:50 | WARNING | State probability collapse detected:
09:38:50 | WARNING | State distributions: ['0.000063', '0.091233', '0.908577']
09:38:50 | WARNING | Min allowed prob: 0.016
09:38:50 | WARNING | Max allowed prob: 0.047
09:38:52 | WARNING | State probability collapse detected:
09:38:52 | WARNING | State distributions: ['0.000000', '0.000662', '0.999198']
09:38:52 | WARNING | Min allowed prob: 0.016
09:38:52 | WARNING | Max allowed prob: 0.047
09:38:54 | WARNING | State probability collapse detected:
09:38:54 | WARNING | State distributions: ['0.000000', '0.000001', '0.999859']
09:38:54 | WARNING | Min allowed prob: 0.016
09:38:54 | WARNING | Max allowed prob: 0.047
09:38:56 | WARNING | State probability collapse detected:
09:38:56 | WARNING | State distributions: ['0.000000', '0.000329', '0.999531']
09:38:56 | WARNING | Min allowed prob: 0.016
09:38:56 | WARNING | Max allowed prob: 0.047
09:38:58 | WARNING | State probability collapse detected:
09:38:58 | WARNING | State distributions: ['0.000002', '0.001236', '0.998623']
09:38:58 | WARNING | Min allowed prob: 0.016
09:38:58 | WARNING | Max allowed prob: 0.047
09:39:00 | WARNING | State probability collapse detected:
09:39:00 | WARNING | State distributions: ['0.000000', '0.003906', '0.995954']
09:39:00 | WARNING | Min allowed prob: 0.016
09:39:00 | WARNING | Max allowed prob: 0.047
09:39:02 | WARNING | State probability collapse detected:
09:39:02 | WARNING | State distributions: ['0.000000', '0.007894', '0.991967']
09:39:02 | WARNING | Min allowed prob: 0.016
09:39:02 | WARNING | Max allowed prob: 0.047
09:39:04 | WARNING | State probability collapse detected:
09:39:04 | WARNING | State distributions: ['0.000000', '0.008381', '0.991480']
09:39:04 | WARNING | Min allowed prob: 0.016
09:39:04 | WARNING | Max allowed prob: 0.047
09:39:06 | WARNING | State probability collapse detected:
09:39:06 | WARNING | State distributions: ['0.000000', '0.011795', '0.988066']
09:39:06 | WARNING | Min allowed prob: 0.016
09:39:06 | WARNING | Max allowed prob: 0.047
09:39:08 | WARNING | State probability collapse detected:
09:39:08 | WARNING | State distributions: ['0.000023', '0.018587', '0.981253']
09:39:08 | WARNING | Min allowed prob: 0.016
09:39:08 | WARNING | Max allowed prob: 0.047
09:39:10 | WARNING | State probability collapse detected:
09:39:10 | WARNING | State distributions: ['0.000000', '0.002495', '0.997365']
09:39:10 | WARNING | Min allowed prob: 0.016
09:39:10 | WARNING | Max allowed prob: 0.047
09:39:12 | WARNING | State probability collapse detected:
09:39:12 | WARNING | State distributions: ['0.000000', '0.000006', '0.999854']
09:39:12 | WARNING | Min allowed prob: 0.016
09:39:12 | WARNING | Max allowed prob: 0.047
09:39:14 | WARNING | State probability collapse detected:
09:39:14 | WARNING | State distributions: ['0.000000', '0.000033', '0.999827']
09:39:14 | WARNING | Min allowed prob: 0.016
09:39:14 | WARNING | Max allowed prob: 0.047
09:39:16 | WARNING | State probability collapse detected:
09:39:16 | WARNING | State distributions: ['0.000029', '0.016469', '0.983364']
09:39:16 | WARNING | Min allowed prob: 0.016
09:39:16 | WARNING | Max allowed prob: 0.047
09:39:18 | WARNING | State probability collapse detected:
09:39:18 | WARNING | State distributions: ['0.000000', '0.001579', '0.998281']
09:39:18 | WARNING | Min allowed prob: 0.016
09:39:18 | WARNING | Max allowed prob: 0.047
09:39:20 | WARNING | State probability collapse detected:
09:39:20 | WARNING | State distributions: ['0.000000', '0.000001', '0.999859']
09:39:20 | WARNING | Min allowed prob: 0.016
09:39:20 | WARNING | Max allowed prob: 0.047
09:39:22 | WARNING | State probability collapse detected:
09:39:22 | WARNING | State distributions: ['0.000000', '0.000778', '0.999082']
09:39:22 | WARNING | Min allowed prob: 0.016
09:39:22 | WARNING | Max allowed prob: 0.047
09:39:24 | WARNING | State probability collapse detected:
09:39:24 | WARNING | State distributions: ['0.000001', '0.001213', '0.998646']
09:39:24 | WARNING | Min allowed prob: 0.016
09:39:24 | WARNING | Max allowed prob: 0.047
09:39:26 | WARNING | State probability collapse detected:
09:39:26 | WARNING | State distributions: ['0.000000', '0.000001', '0.999859']
09:39:26 | WARNING | Min allowed prob: 0.016
09:39:26 | WARNING | Max allowed prob: 0.047
09:39:28 | WARNING | State probability collapse detected:
09:39:28 | WARNING | State distributions: ['0.000000', '0.006286', '0.993574']
09:39:28 | WARNING | Min allowed prob: 0.016
09:39:28 | WARNING | Max allowed prob: 0.047
09:39:30 | WARNING | State probability collapse detected:
09:39:30 | WARNING | State distributions: ['0.000001', '0.019100', '0.980762']
09:39:30 | WARNING | Min allowed prob: 0.016
09:39:30 | WARNING | Max allowed prob: 0.047
09:39:32 | WARNING | State probability collapse detected:
09:39:32 | WARNING | State distributions: ['0.000017', '0.010099', '0.989745']
09:39:32 | WARNING | Min allowed prob: 0.016
09:39:32 | WARNING | Max allowed prob: 0.047
09:39:34 | WARNING | State probability collapse detected:
09:39:34 | WARNING | State distributions: ['0.000005', '0.010922', '0.988934']
09:39:34 | WARNING | Min allowed prob: 0.016
09:39:34 | WARNING | Max allowed prob: 0.047
09:39:36 | WARNING | State probability collapse detected:
09:39:36 | WARNING | State distributions: ['0.000000', '0.000010', '0.999850']
09:39:36 | WARNING | Min allowed prob: 0.016
09:39:36 | WARNING | Max allowed prob: 0.047
09:39:38 | WARNING | State probability collapse detected:
09:39:38 | WARNING | State distributions: ['0.000000', '0.000166', '0.999694']
09:39:38 | WARNING | Min allowed prob: 0.016
09:39:38 | WARNING | Max allowed prob: 0.047
09:39:40 | WARNING | State probability collapse detected:
09:39:40 | WARNING | State distributions: ['0.000000', '0.000158', '0.999702']
09:39:40 | WARNING | Min allowed prob: 0.016
09:39:40 | WARNING | Max allowed prob: 0.047
09:39:42 | WARNING | State probability collapse detected:
09:39:42 | WARNING | State distributions: ['0.000001', '0.001568', '0.998292']
09:39:42 | WARNING | Min allowed prob: 0.016
09:39:42 | WARNING | Max allowed prob: 0.047
09:39:44 | WARNING | State probability collapse detected:
09:39:44 | WARNING | State distributions: ['0.000000', '0.001183', '0.998677']
09:39:44 | WARNING | Min allowed prob: 0.016
09:39:44 | WARNING | Max allowed prob: 0.047
09:39:46 | WARNING | State probability collapse detected:
09:39:46 | WARNING | State distributions: ['0.000000', '0.000565', '0.999295']
09:39:46 | WARNING | Min allowed prob: 0.016
09:39:46 | WARNING | Max allowed prob: 0.047
09:39:48 | WARNING | State probability collapse detected:
09:39:48 | WARNING | State distributions: ['0.000004', '0.038076', '0.961786']
09:39:48 | WARNING | Min allowed prob: 0.016
09:39:48 | WARNING | Max allowed prob: 0.047
09:39:50 | WARNING | State probability collapse detected:
09:39:50 | WARNING | State distributions: ['0.000001', '0.037711', '0.962153']
09:39:50 | WARNING | Min allowed prob: 0.016
09:39:50 | WARNING | Max allowed prob: 0.047
09:39:52 | WARNING | State probability collapse detected:
09:39:52 | WARNING | State distributions: ['0.000000', '0.000820', '0.999041']
09:39:52 | WARNING | Min allowed prob: 0.016
09:39:52 | WARNING | Max allowed prob: 0.047
09:39:54 | WARNING | State probability collapse detected:
09:39:54 | WARNING | State distributions: ['0.000000', '0.004230', '0.995630']
09:39:54 | WARNING | Min allowed prob: 0.016
09:39:54 | WARNING | Max allowed prob: 0.047
09:39:56 | WARNING | State probability collapse detected:
09:39:56 | WARNING | State distributions: ['0.000000', '0.002260', '0.997600']
09:39:56 | WARNING | Min allowed prob: 0.016
09:39:56 | WARNING | Max allowed prob: 0.047
09:39:58 | WARNING | State probability collapse detected:
09:39:58 | WARNING | State distributions: ['0.000002', '0.001384', '0.998474']
09:39:58 | WARNING | Min allowed prob: 0.016
09:39:58 | WARNING | Max allowed prob: 0.047
09:40:00 | WARNING | State probability collapse detected:
09:40:00 | WARNING | State distributions: ['0.000000', '0.000166', '0.999694']
09:40:00 | WARNING | Min allowed prob: 0.016
09:40:00 | WARNING | Max allowed prob: 0.047
09:40:02 | WARNING | State probability collapse detected:
09:40:02 | WARNING | State distributions: ['0.000000', '0.000031', '0.999829']
09:40:02 | WARNING | Min allowed prob: 0.016
09:40:02 | WARNING | Max allowed prob: 0.047
09:40:04 | WARNING | State probability collapse detected:
09:40:04 | WARNING | State distributions: ['0.000000', '0.002492', '0.997368']
09:40:04 | WARNING | Min allowed prob: 0.016
09:40:04 | WARNING | Max allowed prob: 0.047
09:40:06 | WARNING | State probability collapse detected:
09:40:06 | WARNING | State distributions: ['0.000000', '0.000046', '0.999814']
09:40:06 | WARNING | Min allowed prob: 0.016
09:40:06 | WARNING | Max allowed prob: 0.047
09:40:08 | WARNING | State probability collapse detected:
09:40:08 | WARNING | State distributions: ['0.000000', '0.000032', '0.999828']
09:40:08 | WARNING | Min allowed prob: 0.016
09:40:08 | WARNING | Max allowed prob: 0.047
09:40:10 | WARNING | State probability collapse detected:
09:40:10 | WARNING | State distributions: ['0.000000', '0.000000', '0.999860']
09:40:10 | WARNING | Min allowed prob: 0.016
09:40:10 | WARNING | Max allowed prob: 0.047
09:40:12 | WARNING | State probability collapse detected:
09:40:12 | WARNING | State distributions: ['0.000000', '0.008973', '0.990888']
09:40:12 | WARNING | Min allowed prob: 0.016
09:40:12 | WARNING | Max allowed prob: 0.047
09:40:14 | WARNING | State probability collapse detected:
09:40:14 | WARNING | State distributions: ['0.000000', '0.000329', '0.999531']
09:40:14 | WARNING | Min allowed prob: 0.016
09:40:14 | WARNING | Max allowed prob: 0.047
09:40:16 | WARNING | State probability collapse detected:
09:40:16 | WARNING | State distributions: ['0.000000', '0.001816', '0.998044']
09:40:16 | WARNING | Min allowed prob: 0.016
09:40:16 | WARNING | Max allowed prob: 0.047
09:40:18 | WARNING | State probability collapse detected:
09:40:18 | WARNING | State distributions: ['0.000002', '0.001263', '0.998595']
09:40:18 | WARNING | Min allowed prob: 0.016
09:40:18 | WARNING | Max allowed prob: 0.047
09:40:20 | WARNING | State probability collapse detected:
09:40:20 | WARNING | State distributions: ['0.000000', '0.001154', '0.998707']
09:40:20 | WARNING | Min allowed prob: 0.016
09:40:20 | WARNING | Max allowed prob: 0.047
09:40:22 | WARNING | State probability collapse detected:
09:40:22 | WARNING | State distributions: ['0.000001', '0.003030', '0.996829']
09:40:22 | WARNING | Min allowed prob: 0.016
09:40:22 | WARNING | Max allowed prob: 0.047
09:40:24 | WARNING | State probability collapse detected:
09:40:24 | WARNING | State distributions: ['0.000133', '0.002382', '0.997345']
09:40:24 | WARNING | Min allowed prob: 0.016
09:40:24 | WARNING | Max allowed prob: 0.047
09:40:26 | WARNING | State probability collapse detected:
09:40:26 | WARNING | State distributions: ['0.000002', '0.001030', '0.998828']
09:40:26 | WARNING | Min allowed prob: 0.016
09:40:26 | WARNING | Max allowed prob: 0.047
09:40:28 | WARNING | State probability collapse detected:
09:40:28 | WARNING | State distributions: ['0.000000', '0.000753', '0.999108']
09:40:28 | WARNING | Min allowed prob: 0.016
09:40:28 | WARNING | Max allowed prob: 0.047
09:40:30 | WARNING | State probability collapse detected:
09:40:30 | WARNING | State distributions: ['0.000000', '0.001735', '0.998124']
09:40:30 | WARNING | Min allowed prob: 0.016
09:40:30 | WARNING | Max allowed prob: 0.047
09:40:32 | WARNING | State probability collapse detected:
09:40:32 | WARNING | State distributions: ['0.000000', '0.000268', '0.999592']
09:40:32 | WARNING | Min allowed prob: 0.016
09:40:32 | WARNING | Max allowed prob: 0.047
09:40:34 | WARNING | State probability collapse detected:
09:40:34 | WARNING | State distributions: ['0.000000', '0.004259', '0.995602']
09:40:34 | WARNING | Min allowed prob: 0.016
09:40:34 | WARNING | Max allowed prob: 0.047
09:40:36 | WARNING | State probability collapse detected:
09:40:36 | WARNING | State distributions: ['0.000000', '0.000653', '0.999207']
09:40:36 | WARNING | Min allowed prob: 0.016
09:40:36 | WARNING | Max allowed prob: 0.047
09:40:38 | WARNING | State probability collapse detected:
09:40:38 | WARNING | State distributions: ['0.000000', '0.002638', '0.997222']
09:40:38 | WARNING | Min allowed prob: 0.016
09:40:38 | WARNING | Max allowed prob: 0.047
09:40:40 | WARNING | State probability collapse detected:
09:40:40 | WARNING | State distributions: ['0.000000', '0.000182', '0.999678']
09:40:40 | WARNING | Min allowed prob: 0.016
09:40:40 | WARNING | Max allowed prob: 0.047
09:40:42 | WARNING | State probability collapse detected:
09:40:42 | WARNING | State distributions: ['0.000001', '0.000778', '0.999082']
09:40:42 | WARNING | Min allowed prob: 0.016
09:40:42 | WARNING | Max allowed prob: 0.047
09:40:44 | WARNING | State probability collapse detected:
09:40:44 | WARNING | State distributions: ['0.000000', '0.005545', '0.994316']
09:40:44 | WARNING | Min allowed prob: 0.016
09:40:44 | WARNING | Max allowed prob: 0.047
09:40:46 | WARNING | State probability collapse detected:
09:40:46 | WARNING | State distributions: ['0.000000', '0.007597', '0.992263']
09:40:46 | WARNING | Min allowed prob: 0.016
09:40:46 | WARNING | Max allowed prob: 0.047
09:40:48 | WARNING | State probability collapse detected:
09:40:48 | WARNING | State distributions: ['0.000000', '0.004769', '0.995092']
09:40:48 | WARNING | Min allowed prob: 0.016
09:40:48 | WARNING | Max allowed prob: 0.047
09:40:50 | WARNING | State probability collapse detected:
09:40:50 | WARNING | State distributions: ['0.000000', '0.000369', '0.999491']
09:40:50 | WARNING | Min allowed prob: 0.016
09:40:50 | WARNING | Max allowed prob: 0.047
09:40:52 | WARNING | State probability collapse detected:
09:40:52 | WARNING | State distributions: ['0.000000', '0.000689', '0.999171']
09:40:52 | WARNING | Min allowed prob: 0.016
09:40:52 | WARNING | Max allowed prob: 0.047
09:40:54 | WARNING | State probability collapse detected:
09:40:54 | WARNING | State distributions: ['0.000000', '0.000993', '0.998868']
09:40:54 | WARNING | Min allowed prob: 0.016
09:40:54 | WARNING | Max allowed prob: 0.047
09:40:56 | WARNING | State probability collapse detected:
09:40:56 | WARNING | State distributions: ['0.000000', '0.023470', '0.976394']
09:40:56 | WARNING | Min allowed prob: 0.016
09:40:56 | WARNING | Max allowed prob: 0.047
09:40:58 | WARNING | State probability collapse detected:
09:40:58 | WARNING | State distributions: ['0.000000', '0.014153', '0.985708']
09:40:58 | WARNING | Min allowed prob: 0.016
09:40:58 | WARNING | Max allowed prob: 0.047
09:41:00 | WARNING | State probability collapse detected:
09:41:00 | WARNING | State distributions: ['0.000000', '0.001030', '0.998830']
09:41:00 | WARNING | Min allowed prob: 0.016
09:41:00 | WARNING | Max allowed prob: 0.047
09:41:02 | WARNING | State probability collapse detected:
09:41:02 | WARNING | State distributions: ['0.000000', '0.002507', '0.997353']
09:41:02 | WARNING | Min allowed prob: 0.016
09:41:02 | WARNING | Max allowed prob: 0.047
09:41:04 | WARNING | State probability collapse detected:
09:41:04 | WARNING | State distributions: ['0.000000', '0.000031', '0.999829']
09:41:04 | WARNING | Min allowed prob: 0.016
09:41:04 | WARNING | Max allowed prob: 0.047
---------------------------------------------------------------------------