In [1]:
# Step 1: Environment Setup and Import Libraries
# Task 5: Inference of protein secondary structure motifs using BayesFlow

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from typing import Dict, List, Tuple, Optional, Union
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)

# BayesFlow imports
try:
    import bayesflow as bf
    print(f"BayesFlow version: {bf.__version__}")
except ImportError:
    print("BayesFlow not installed. Please install with: pip install bayesflow")
    raise

# hmmlearn imports
try:
    from hmmlearn import hmm
    import hmmlearn
    print(f"hmmlearn version: {hmmlearn.__version__}")
except ImportError:
    print("hmmlearn not installed. Please install with: pip install hmmlearn")
    raise

# Additional scientific computing libraries
try:
    import scipy.stats as stats
    from scipy.special import logsumexp
    import sklearn.metrics as metrics
    print("Scientific computing libraries loaded successfully")
except ImportError as e:
    print(f"Missing scientific computing library: {e}")
    raise

# Set TensorFlow backend for BayesFlow
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

# Configure TensorFlow for optimal performance
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("GPU memory growth enabled")
    except RuntimeError as e:
        print(f"GPU configuration error: {e}")

# Set up plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Global configuration
CONFIG = {
    'random_seed': RANDOM_SEED,
    'n_states': 2,  # alpha-helix and other
    'n_amino_acids': 20,
    'sequence_length_range': (50, 200),  # Variable length sequences
    'batch_size': 64,
    'n_train_samples': 10000,
    'n_val_samples': 2000,
    'n_test_samples': 1000,
    'verbose': True
}

print("=" * 50)
print("TASK 5: PROTEIN SECONDARY STRUCTURE INFERENCE")
print("=" * 50)
print("Environment setup completed successfully!")
print(f"Configuration: {CONFIG}")
print("=" * 50)

2025-07-08 23:04:15.518696: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2025-07-08 23:04:15.518721: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-07-08 23:04:15.518726: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
I0000 00:00:1752008655.518737 5236854 pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
I0000 00:00:1752008655.518755 5236854 pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
INFO:bayesflow:Using backend 'tensorflow'


BayesFlow version: 2.0.5
hmmlearn version: 0.3.3
Scientific computing libraries loaded successfully
TensorFlow version: 2.19.0
GPU available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
GPU configuration error: Physical devices cannot be modified after being initialized
TASK 5: PROTEIN SECONDARY STRUCTURE INFERENCE
Environment setup completed successfully!
Configuration: {'random_seed': 42, 'n_states': 2, 'n_amino_acids': 20, 'sequence_length_range': (50, 200), 'batch_size': 64, 'n_train_samples': 10000, 'n_val_samples': 2000, 'n_test_samples': 1000, 'verbose': True}


In [2]:
# Step 2: Data Preparation and Constants
# Define amino acid alphabet, emission probabilities, and transition probabilities

# Standard 20 amino acids (single letter codes)
AMINO_ACIDS = ['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I', 
               'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']

# Create mapping dictionaries
AA_TO_INDEX = {aa: idx for idx, aa in enumerate(AMINO_ACIDS)}
INDEX_TO_AA = {idx: aa for idx, aa in enumerate(AMINO_ACIDS)}

print("Amino Acid Mapping:")
print("AA_TO_INDEX:", AA_TO_INDEX)
print("INDEX_TO_AA:", INDEX_TO_AA)

# Define emission probabilities as specified in the task
# Alpha-helix emission probabilities (state 0)
ALPHA_HELIX_EMISSIONS = {
    'A': 0.12, 'R': 0.06, 'N': 0.03, 'D': 0.05, 'C': 0.01,
    'E': 0.09, 'Q': 0.05, 'G': 0.04, 'H': 0.02, 'I': 0.07,
    'L': 0.12, 'K': 0.06, 'M': 0.03, 'F': 0.04, 'P': 0.02,
    'S': 0.05, 'T': 0.04, 'W': 0.01, 'Y': 0.03, 'V': 0.06
}

# Other (beta-sheet/coil) emission probabilities (state 1)
OTHER_EMISSIONS = {
    'A': 0.06, 'R': 0.05, 'N': 0.05, 'D': 0.06, 'C': 0.02,
    'E': 0.05, 'Q': 0.03, 'G': 0.09, 'H': 0.03, 'I': 0.05,
    'L': 0.08, 'K': 0.06, 'M': 0.02, 'F': 0.04, 'P': 0.06,
    'S': 0.07, 'T': 0.06, 'W': 0.01, 'Y': 0.04, 'V': 0.07
}

# Convert to numpy arrays in the correct order
alpha_helix_probs = np.array([ALPHA_HELIX_EMISSIONS[aa] for aa in AMINO_ACIDS])
other_probs = np.array([OTHER_EMISSIONS[aa] for aa in AMINO_ACIDS])

# Normalize to ensure they sum to 1 (handle any rounding errors)
alpha_helix_probs = alpha_helix_probs / np.sum(alpha_helix_probs)
other_probs = other_probs / np.sum(other_probs)

# Create emission probability matrix (2 states x 20 amino acids)
EMISSION_PROBS = np.array([alpha_helix_probs, other_probs])

print("\nEmission Probabilities:")
print("Alpha-helix state probabilities sum:", np.sum(alpha_helix_probs))
print("Other state probabilities sum:", np.sum(other_probs))
print("Emission matrix shape:", EMISSION_PROBS.shape)

# Define transition probabilities as specified in the task
# States: 0 = alpha-helix, 1 = other
# Transition matrix: rows = from state, columns = to state
TRANSITION_PROBS = np.array([
    [0.90, 0.10],  # From alpha-helix: 90% stay, 10% to other
    [0.05, 0.95]   # From other: 5% to alpha-helix, 95% stay
])

print("\nTransition Probabilities:")
print("Transition matrix:")
print("From alpha-helix: alpha-helix=90%, other=10%")
print("From other: alpha-helix=5%, other=95%")
print("Transition matrix shape:", TRANSITION_PROBS.shape)
print("Row sums (should be 1.0):", np.sum(TRANSITION_PROBS, axis=1))

# Initial state probabilities (always start in "other" state)
INITIAL_PROBS = np.array([0.0, 1.0])  # [alpha-helix, other]

print("\nInitial State Probabilities:")
print("Always start in 'other' state:", INITIAL_PROBS)

# State names for interpretability
STATE_NAMES = ['alpha-helix', 'other']
STATE_TO_INDEX = {'alpha-helix': 0, 'other': 1}
INDEX_TO_STATE = {0: 'alpha-helix', 1: 'other'}

print("\nState Mapping:")
print("STATE_TO_INDEX:", STATE_TO_INDEX)
print("INDEX_TO_STATE:", INDEX_TO_STATE)

# Create a summary dataframe for emission probabilities
emission_df = pd.DataFrame({
    'Amino_Acid': AMINO_ACIDS,
    'Alpha_Helix': alpha_helix_probs,
    'Other': other_probs
})

print("\nEmission Probabilities Summary:")
print(emission_df.round(3))

# Verify probabilities are valid
def verify_probabilities():
    """Verify that all probability matrices are valid"""
    print("\n" + "="*50)
    print("PROBABILITY VERIFICATION")
    print("="*50)
    
    # Check emission probabilities
    alpha_sum = np.sum(alpha_helix_probs)
    other_sum = np.sum(other_probs)
    print(f"Alpha-helix emissions sum: {alpha_sum:.6f}")
    print(f"Other emissions sum: {other_sum:.6f}")
    
    # Check transition probabilities
    trans_sums = np.sum(TRANSITION_PROBS, axis=1)
    print(f"Transition matrix row sums: {trans_sums}")
    
    # Check initial probabilities
    init_sum = np.sum(INITIAL_PROBS)
    print(f"Initial probabilities sum: {init_sum:.6f}")
    
    # Verify all are valid
    valid = (abs(alpha_sum - 1.0) < 1e-10 and 
             abs(other_sum - 1.0) < 1e-10 and 
             abs(init_sum - 1.0) < 1e-10 and 
             all(abs(s - 1.0) < 1e-10 for s in trans_sums))
    
    print(f"All probabilities valid: {valid}")
    print("="*50)
    
    return valid

# Run verification
probability_validation = verify_probabilities()

# Store all constants in a configuration dictionary
DATA_CONFIG = {
    'amino_acids': AMINO_ACIDS,
    'aa_to_index': AA_TO_INDEX,
    'index_to_aa': INDEX_TO_AA,
    'emission_probs': EMISSION_PROBS,
    'transition_probs': TRANSITION_PROBS,
    'initial_probs': INITIAL_PROBS,
    'state_names': STATE_NAMES,
    'state_to_index': STATE_TO_INDEX,
    'index_to_state': INDEX_TO_STATE,
    'n_states': len(STATE_NAMES),
    'n_amino_acids': len(AMINO_ACIDS),
    'valid_probabilities': probability_validation
}

print("Data preparation completed successfully!")
print(f"Data configuration keys: {list(DATA_CONFIG.keys())}")

Amino Acid Mapping:
AA_TO_INDEX: {'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 'E': 5, 'Q': 6, 'G': 7, 'H': 8, 'I': 9, 'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14, 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19}
INDEX_TO_AA: {0: 'A', 1: 'R', 2: 'N', 3: 'D', 4: 'C', 5: 'E', 6: 'Q', 7: 'G', 8: 'H', 9: 'I', 10: 'L', 11: 'K', 12: 'M', 13: 'F', 14: 'P', 15: 'S', 16: 'T', 17: 'W', 18: 'Y', 19: 'V'}

Emission Probabilities:
Alpha-helix state probabilities sum: 0.9999999999999998
Other state probabilities sum: 0.9999999999999998
Emission matrix shape: (2, 20)

Transition Probabilities:
Transition matrix:
From alpha-helix: alpha-helix=90%, other=10%
From other: alpha-helix=5%, other=95%
Transition matrix shape: (2, 2)
Row sums (should be 1.0): [1. 1.]

Initial State Probabilities:
Always start in 'other' state: [0. 1.]

State Mapping:
STATE_TO_INDEX: {'alpha-helix': 0, 'other': 1}
INDEX_TO_STATE: {0: 'alpha-helix', 1: 'other'}

Emission Probabilities Summary:
   Amino_Acid  Alpha_Helix  Other
0        

In [3]:
# Step 3: HMM Model Setup with hmmlearn
# Initialize CategoricalHMM and implement helper functions

from hmmlearn.hmm import CategoricalHMM
import numpy as np

class ProteinHMM:
    """
    Protein secondary structure HMM using hmmlearn's CategoricalHMM
    """
    
    def __init__(self, emission_probs, transition_probs, initial_probs, 
                 aa_to_index, index_to_aa, state_names):
        """
        Initialize the HMM with fixed probabilities
        
        Parameters:
        -----------
        emission_probs : np.ndarray
            Emission probability matrix (n_states x n_features)
        transition_probs : np.ndarray
            Transition probability matrix (n_states x n_states)
        initial_probs : np.ndarray
            Initial state probabilities
        aa_to_index : dict
            Mapping from amino acid to index
        index_to_aa : dict
            Mapping from index to amino acid
        state_names : list
            Names of the states
        """
        self.n_states = len(state_names)
        self.n_features = len(aa_to_index)
        self.aa_to_index = aa_to_index
        self.index_to_aa = index_to_aa
        self.state_names = state_names
        
        # Initialize CategoricalHMM
        self.model = CategoricalHMM(
            n_components=self.n_states,
            random_state=CONFIG['random_seed'],
            init_params="",  # Don't initialize parameters, we'll set them manually
            params=""        # Don't update parameters during fitting
        )
        
        # Set the probabilities
        self.model.startprob_ = initial_probs.copy()
        self.model.transmat_ = transition_probs.copy()
        self.model.emissionprob_ = emission_probs.copy()
        
        print(f"Initialized ProteinHMM with {self.n_states} states and {self.n_features} features")
        print(f"State names: {self.state_names}")
        
    def sequence_to_indices(self, sequence):
        """Convert amino acid sequence to indices"""
        try:
            return np.array([self.aa_to_index[aa] for aa in sequence])
        except KeyError as e:
            raise ValueError(f"Unknown amino acid: {e}")
    
    def indices_to_sequence(self, indices):
        """Convert indices back to amino acid sequence"""
        return ''.join([self.index_to_aa[idx] for idx in indices])
    
    def generate_sequence(self, length):
        """
        Generate a random amino acid sequence of given length
        
        Parameters:
        -----------
        length : int
            Length of sequence to generate
            
        Returns:
        --------
        sequence : str
            Generated amino acid sequence
        states : np.ndarray
            True hidden states for the sequence
        """
        # Generate sequence using the HMM
        sequence_indices, states = self.model.sample(length)
        
        # Convert to amino acid sequence
        sequence = self.indices_to_sequence(sequence_indices.flatten())
        
        return sequence, states.flatten()
    
    def predict_states(self, sequence):
        """
        Predict the most likely state sequence using Viterbi algorithm
        
        Parameters:
        -----------
        sequence : str
            Amino acid sequence
            
        Returns:
        --------
        states : np.ndarray
            Predicted state sequence
        """
        # Convert sequence to indices
        sequence_indices = self.sequence_to_indices(sequence)
        
        # Reshape for hmmlearn (n_samples, n_features) format
        X = sequence_indices.reshape(-1, 1)
        
        # Predict states using Viterbi algorithm
        states = self.model.predict(X)
        
        return states
    
    def predict_proba(self, sequence):
        """
        Predict state probabilities using Forward-Backward algorithm
        
        Parameters:
        -----------
        sequence : str
            Amino acid sequence
            
        Returns:
        --------
        state_probs : np.ndarray
            State probabilities for each position (n_positions x n_states)
        """
        # Convert sequence to indices
        sequence_indices = self.sequence_to_indices(sequence)
        
        # Reshape for hmmlearn (n_samples, n_features) format
        X = sequence_indices.reshape(-1, 1)
        
        # Predict probabilities using Forward-Backward algorithm
        state_probs = self.model.predict_proba(X)
        
        return state_probs
    
    def score_sequence(self, sequence):
        """
        Calculate log-likelihood of a sequence
        
        Parameters:
        -----------
        sequence : str
            Amino acid sequence
            
        Returns:
        --------
        log_likelihood : float
            Log-likelihood of the sequence
        """
        # Convert sequence to indices
        sequence_indices = self.sequence_to_indices(sequence)
        
        # Reshape for hmmlearn format
        X = sequence_indices.reshape(-1, 1)
        
        # Calculate log-likelihood
        log_likelihood = self.model.score(X)
        
        return log_likelihood
    
    def get_model_info(self):
        """Return model parameters for inspection"""
        return {
            'n_states': self.n_states,
            'n_features': self.n_features,
            'startprob': self.model.startprob_,
            'transmat': self.model.transmat_,
            'emissionprob': self.model.emissionprob_,
            'state_names': self.state_names
        }

# Initialize the protein HMM
protein_hmm = ProteinHMM(
    emission_probs=DATA_CONFIG['emission_probs'],
    transition_probs=DATA_CONFIG['transition_probs'],
    initial_probs=DATA_CONFIG['initial_probs'],
    aa_to_index=DATA_CONFIG['aa_to_index'],
    index_to_aa=DATA_CONFIG['index_to_aa'],
    state_names=DATA_CONFIG['state_names']
)

# Test the HMM functionality
print("\n" + "="*50)
print("TESTING HMM FUNCTIONALITY")
print("="*50)

# Test 1: Generate a sample sequence
test_length = 50
test_sequence, true_states = protein_hmm.generate_sequence(test_length)
print(f"Generated sequence (length {len(test_sequence)}):")
print(f"Sequence: {test_sequence}")
print(f"True states: {true_states}")
print(f"State names: {[protein_hmm.state_names[s] for s in true_states]}")

# Test 2: Predict states using Viterbi
predicted_states = protein_hmm.predict_states(test_sequence)
print(f"\nPredicted states (Viterbi): {predicted_states}")
print(f"Predicted state names: {[protein_hmm.state_names[s] for s in predicted_states]}")

# Test 3: Get state probabilities using Forward-Backward
state_probs = protein_hmm.predict_proba(test_sequence)
print(f"\nState probabilities shape: {state_probs.shape}")
print(f"First 5 positions state probabilities:")
for i in range(min(5, len(state_probs))):
    print(f"Position {i}: alpha-helix={state_probs[i,0]:.3f}, other={state_probs[i,1]:.3f}")

# Test 4: Calculate sequence likelihood
log_likelihood = protein_hmm.score_sequence(test_sequence)
print(f"\nSequence log-likelihood: {log_likelihood:.3f}")

# Test 5: Verify accuracy of Viterbi vs true states
accuracy = np.mean(predicted_states == true_states)
print(f"Viterbi accuracy vs true states: {accuracy:.3f}")

# Test 6: Show model information
model_info = protein_hmm.get_model_info()
print(f"\nModel Information:")
print(f"Number of states: {model_info['n_states']}")
print(f"Number of features: {model_info['n_features']}")
print(f"Initial probabilities: {model_info['startprob']}")
print(f"Transition matrix diagonal: {np.diag(model_info['transmat'])}")

# Helper function for batch sequence generation
def generate_training_data(n_sequences, length_range=(50, 200)):
    """
    Generate training data for BayesFlow
    
    Parameters:
    -----------
    n_sequences : int
        Number of sequences to generate
    length_range : tuple
        Range of sequence lengths (min, max)
        
    Returns:
    --------
    sequences : list
        List of amino acid sequences
    state_probabilities : list
        List of state probability arrays
    true_states : list
        List of true state sequences
    """
    sequences = []
    state_probabilities = []
    true_states = []
    
    for i in range(n_sequences):
        # Random sequence length
        length = np.random.randint(length_range[0], length_range[1] + 1)
        
        # Generate sequence
        seq, states = protein_hmm.generate_sequence(length)
        
        # Get state probabilities using Forward-Backward
        probs = protein_hmm.predict_proba(seq)
        
        sequences.append(seq)
        state_probabilities.append(probs)
        true_states.append(states)
        
        if (i + 1) % 1000 == 0:
            print(f"Generated {i + 1}/{n_sequences} sequences")
    
    return sequences, state_probabilities, true_states

# Test batch generation
print("\n" + "="*50)
print("TESTING BATCH GENERATION")
print("="*50)

# Generate small test batch
test_sequences, test_probs, test_states = generate_training_data(5, (20, 30))
print(f"Generated {len(test_sequences)} test sequences")
print(f"Sequence lengths: {[len(s) for s in test_sequences]}")
print(f"State probability shapes: {[p.shape for p in test_probs]}")

print("\nStep 3 completed successfully!")
print("HMM model is ready for sequence generation and state inference")

Initialized ProteinHMM with 2 states and 20 features
State names: ['alpha-helix', 'other']

TESTING HMM FUNCTIONALITY
Generated sequence (length 50):
Sequence: VKNTPVDNIERQFIARVSRILWKINSYYDEEELPVNSMAALAELIKLIGA
True states: [1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0]
State names: ['other', 'other', 'other', 'other', 'other', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'other', 'other', 'other', 'other', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'other', 'other', 'other', 'other', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'alpha-helix', 'a

In [4]:
# Step 4: Generative Model Implementation 
# Create simulator using BayesFlow LambdaSimulator and Adapter for protein sequences

import bayesflow as bf
from bayesflow.simulators import LambdaSimulator
from bayesflow.adapters import Adapter
import numpy as np
import tensorflow as tf

def protein_simulator_function(batch_shape):
    """
    Main simulator function for BayesFlow protein secondary structure task
    
    This function generates protein sequences with known HMM parameters
    and returns both the simulated data and the ground truth for training.
    
    Parameters:
    -----------
    batch_shape : int or tuple
        Number of protein sequences to simulate (BayesFlow passes batch_shape)
        
    Returns:
    --------
    sim_data : dict
        Dictionary containing:
        - 'parameters': HMM parameters (transition, emission, initial probs)
        - 'observables': amino acid sequences as integer indices
        - 'summary_stats': forward-backward state probabilities
        - 'sequence_lengths': actual lengths of each sequence
    """
    
    # Handle batch_shape (can be int or tuple)
    if isinstance(batch_shape, tuple):
        batch_size = batch_shape[0]
    else:
        batch_size = batch_shape
    
    # Generate variable length sequences
    min_len, max_len = CONFIG['sequence_length_range']
    sequence_lengths = np.random.randint(min_len, max_len + 1, size=batch_size)
    
    # Storage for batch
    all_sequences = []
    all_state_probs = []
    all_true_states = []
    max_length = np.max(sequence_lengths)
    
    # Generate each sequence in the batch
    for i in range(batch_size):
        length = sequence_lengths[i]
        
        # Generate sequence using our HMM
        sequence, true_states = protein_hmm.generate_sequence(length)
        
        # Get forward-backward probabilities (summary statistics)
        state_probs = protein_hmm.predict_proba(sequence)
        
        all_sequences.append(sequence)
        all_state_probs.append(state_probs)
        all_true_states.append(true_states)
    
    # Convert sequences to padded integer arrays
    sequences_int = np.full((batch_size, max_length), -1, dtype=np.int32)  # -1 for padding
    for i, seq in enumerate(all_sequences):
        seq_indices = [protein_hmm.aa_to_index[aa] for aa in seq]
        sequences_int[i, :len(seq_indices)] = seq_indices
    
    # Convert state probabilities to padded arrays
    state_probs_padded = np.zeros((batch_size, max_length, 2), dtype=np.float32)
    for i, probs in enumerate(all_state_probs):
        state_probs_padded[i, :len(probs), :] = probs
    
    # Create parameter vectors (fixed for this task but included for completeness)
    # Parameters: [transition_00, transition_01, transition_10, transition_11, initial_0, initial_1]
    param_vector = np.concatenate([
        DATA_CONFIG['transition_probs'].flatten(),
        DATA_CONFIG['initial_probs']
    ])
    
    # Replicate for batch
    parameters = np.tile(param_vector, (batch_size, 1)).astype(np.float32)
    
    # Return simulation data in BayesFlow format
    return {
        'parameters': parameters,           # Ground truth HMM parameters
        'observables': sequences_int,       # Observed amino acid sequences  
        'summary_stats': state_probs_padded, # Forward-backward probabilities
        'sequence_lengths': sequence_lengths.astype(np.int32),
        # Additional metadata for debugging/analysis
        'raw_sequences': all_sequences,
        'true_states': all_true_states
    }

# Create the BayesFlow LambdaSimulator
protein_simulator = LambdaSimulator(
    sample_fn=protein_simulator_function,
    is_batched=True  # Our function handles batching internally
)

# Test the simulator
print("=" * 50)
print("TESTING PROTEIN SIMULATOR")
print("=" * 50)

# Generate test batch
test_batch = protein_simulator.sample(batch_size=3)
print(f"Simulator output keys: {list(test_batch.keys())}")
print(f"Parameters shape: {test_batch['parameters'].shape}")
print(f"Observables shape: {test_batch['observables'].shape}")
print(f"Summary stats shape: {test_batch['summary_stats'].shape}")
print(f"Sequence lengths: {test_batch['sequence_lengths']}")

# Verify data quality
print(f"\nData validation:")
print(f"Parameter vector (first): {test_batch['parameters'][0]}")
print(f"Observable sequence (first 10 positions): {test_batch['observables'][0][:10]}")
print(f"Summary stats sum check (first seq, first 3 pos): {np.sum(test_batch['summary_stats'][0][:3], axis=1)}")

# Check for padding
print(f"Padding positions in first sequence: {np.where(test_batch['observables'][0] == -1)[0]}")

# Create adapter for data preprocessing
adapter = Adapter()

# Add transforms to handle protein-specific data
adapter.rename("observables", "sequences")  # Rename for clarity
adapter.rename("summary_stats", "target_probs")  # Target probabilities

# Simple transforms
adapter.to_array(include='target_probs')
adapter.to_array(include='parameters')
adapter.to_array(include='sequence_lengths')

# Test the adapter
print(f"\n" + "=" * 50)
print("TESTING ADAPTER")
print("=" * 50)

try:
    adapted = adapter(test_batch)
    print(f"Adapted data keys: {list(adapted.keys())}")
    print(f"Adapted sequences shape: {adapted['sequences'].shape}")
    print(f"Adapted sequences dtype: {adapted['sequences'].dtype}")
    print(f"Adapted target_probs shape: {adapted['target_probs'].shape}")
    print(f"Sample adapted sequence (first 10): {adapted['sequences'][0][:10]}")
    
    # Test inverse transform
    inverse_data = adapter(adapted, inverse=True)
    print(f"\nInverse transform successful")
    print(f"Original vs inverse sequences match: {np.array_equal(test_batch['observables'], inverse_data['observables'])}")
    
except Exception as e:
    print(f"Adapter error: {e}")
    # Use test_batch directly for now
    adapted = test_batch
    print(f"Using original data format")

# Validation function
def validate_simulation_output(sim_data):
    """Validate the simulation output format and content"""
    required_keys = ['parameters', 'observables', 'summary_stats', 'sequence_lengths']
    
    validation_results = {
        'valid': True,
        'errors': [],
        'warnings': []
    }
    
    # Check required keys
    for key in required_keys:
        if key not in sim_data:
            validation_results['errors'].append(f"Missing required key: {key}")
            validation_results['valid'] = False
    
    if not validation_results['valid']:
        return validation_results
    
    batch_size = len(sim_data['sequence_lengths'])
    
    # Check shapes
    if sim_data['parameters'].shape[0] != batch_size:
        validation_results['errors'].append("Parameters batch size mismatch")
        validation_results['valid'] = False
    
    if sim_data['observables'].shape[0] != batch_size:
        validation_results['errors'].append("Observables batch size mismatch")
        validation_results['valid'] = False
    
    # Check probability validity
    summary_stats = sim_data['summary_stats']
    for i in range(batch_size):
        seq_len = sim_data['sequence_lengths'][i]
        for j in range(seq_len):
            prob_sum = np.sum(summary_stats[i, j, :])
            if abs(prob_sum - 1.0) > 1e-4:
                validation_results['warnings'].append(
                    f"Sequence {i}, position {j}: probabilities sum to {prob_sum:.4f}"
                )
    
    return validation_results

# Test validation
print(f"\n" + "=" * 50)
print("TESTING SIMULATION VALIDATION")
print("=" * 50)

validation = validate_simulation_output(test_batch)
print(f"Validation passed: {validation['valid']}")
print(f"Errors: {len(validation['errors'])}")
print(f"Warnings: {len(validation['warnings'])}")

if validation['errors']:
    print("Errors found:")
    for error in validation['errors']:
        print(f"  - {error}")

if validation['warnings']:
    print("Warnings:")
    for warning in validation['warnings'][:3]:  # Show first 3
        print(f"  - {warning}")

# Analysis function
def analyze_simulation(sim_data):
    """Analyze simulation output for quality metrics"""
    batch_size = len(sim_data['sequence_lengths'])
    
    analysis = {
        'batch_size': batch_size,
        'avg_seq_length': np.mean(sim_data['sequence_lengths']),
        'std_seq_length': np.std(sim_data['sequence_lengths']),
        'min_seq_length': np.min(sim_data['sequence_lengths']),
        'max_seq_length': np.max(sim_data['sequence_lengths']),
        'parameter_dim': sim_data['parameters'].shape[1],
        'avg_alpha_helix_prob': np.mean(sim_data['summary_stats'][:, :, 0])
    }
    
    return analysis

analysis = analyze_simulation(test_batch)
print(f"\nSimulation Analysis:")
for key, value in analysis.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.3f}")
    else:
        print(f"  {key}: {value}")

print("\nStep 4 completed successfully!")
print("Protein simulator with BayesFlow LambdaSimulator is ready!")
print("Next: Implement data adapters and neural network components")

TESTING PROTEIN SIMULATOR
Simulator output keys: ['parameters', 'observables', 'summary_stats', 'sequence_lengths', 'raw_sequences', 'true_states']
Parameters shape: (3, 6)
Observables shape: (3, 171)
Summary stats shape: (3, 171, 2)
Sequence lengths: [152 171 124]

Data validation:
Parameter vector (first): [0.9  0.1  0.05 0.95 0.   1.  ]
Observable sequence (first 10 positions): [19 11  2 16 14 19  3  2  9  5]
Summary stats sum check (first seq, first 3 pos): [1. 1. 1.]
Padding positions in first sequence: [152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
 170]

TESTING ADAPTER
Adapted data keys: ['parameters', 'sequence_lengths', 'raw_sequences', 'true_states', 'sequences', 'target_probs']
Adapted sequences shape: (3, 171)
Adapted sequences dtype: int32
Adapted target_probs shape: (3, 171, 2)
Sample adapted sequence (first 10): [19 11  2 16 14 19  3  2  9  5]

Inverse transform successful
Original vs inverse sequences match: True

TESTING SIMULATION VALIDATION