## Understanding Task 5

- Protiens have totally 4 structures, but our focus is on Protein Secondary structure:
  1. Secondary (Our focus in this project) - Contains many patterns but focus on 2 states only for this project.
  2. But dividing them into only 2 pattern Alpha Helix and Others (include rest of the patterns).
  3. Focus specifically on predicting the alpha helix patterns using a 2 state HMM.
  4. The 2 states are "alpha helix" and "other".

In [1]:
# -*- coding: utf-8 -*-
import os
import warnings
from typing import Dict, List, Tuple, Optional

# warnings.filterwarnings('ignore')
os.environ['KERAS_BACKEND'] = 'tensorflow'

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import keras
import tensorflow as tf

import bayesflow as bf
from bayesflow.adapters import Adapter
from bayesflow.datasets import OnlineDataset
from bayesflow.workflows import BasicWorkflow
from bayesflow.adapters.transforms import OneHot
from bayesflow.approximators import ContinuousApproximator
from bayesflow.adapters.transforms import MapTransform, ExpandDims

from hmmlearn import hmm

from sklearn.preprocessing import LabelEncoder

current_backend = tf.keras.backend.backend()
print(f"tf.keras is using the '{current_backend}' backend.")

2025-07-12 17:20:11.885169: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2025-07-12 17:20:11.885205: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-07-12 17:20:11.885213: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
I0000 00:00:1752333611.885226 4489514 pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
I0000 00:00:1752333611.885252 4489514 pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
INFO:bayesflow:Using backend 'tensorflow'


tf.keras is using the 'tensorflow' backend.


In [2]:
class ProteinHMMSimulator:
    """
    A simulator for protein secondary structure prediction using HMM.
    
    This class implements a two-state HMM (alpha-helix vs other) with fixed
    emission and transition probabilities
    """
    
    def __init__(self):
        # Define amino acids
        self.amino_acids = ['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I',
                           'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
        
        # State definitions: 0 = "other", 1 = "alpha-helix"
        self.states = ['other', 'alpha-helix']
        self.n_states = 2
        self.n_features = len(self.amino_acids)
        
        # Initialize label encoder for amino acids
        self.aa_encoder = LabelEncoder()
        self.aa_encoder.fit(self.amino_acids)
        
        # Define emission probabilities
        self.emission_probs = self._setup_emission_probabilities()
        
        # Define transition probabilities
        self.transition_probs = self._setup_transition_probabilities()
        
        # Initialize start probabilities (always start in "other" state)
        self.start_probs = np.array([1.0, 0.0])  # [other, alpha-helix]
        
        # Setup HMM model
        self.hmm_model = self._setup_hmm_model()
        
    def _setup_emission_probabilities(self) -> np.ndarray:
        """Setup emission probability matrix."""
        
        # Alpha-helix probabilities (state 1)
        alpha_helix_probs = [
            12, 6, 3, 5, 1, 9, 5, 4, 2, 7,  # A, R, N, D, C, E, Q, G, H, I
            12, 6, 3, 4, 2, 5, 4, 1, 3, 6   # L, K, M, F, P, S, T, W, Y, V
        ]
        
        # Other probabilities (state 0)
        other_probs = [
            6, 5, 5, 6, 2, 5, 3, 9, 3, 5,   # A, R, N, D, C, E, Q, G, H, I
            8, 6, 2, 4, 6, 7, 6, 1, 4, 7    # L, K, M, F, P, S, T, W, Y, V
        ]
        
        # Convert percentages to probabilities
        alpha_helix_probs = np.array(alpha_helix_probs) / 100.0
        other_probs = np.array(other_probs) / 100.0
        
        # Ensure probabilities sum to 1 (normalize if needed)
        alpha_helix_probs = alpha_helix_probs / np.sum(alpha_helix_probs)
        other_probs = other_probs / np.sum(other_probs)
        
        # Create emission matrix: [n_states, n_features]
        emission_matrix = np.array([other_probs, alpha_helix_probs])
        
        return emission_matrix
    
    def _setup_transition_probabilities(self) -> np.ndarray:
        """Setup transition probability matrix."""
        
        # Transition probabilities:
        # From "other" (state 0): 95% stay, 5% to alpha-helix
        # From "alpha-helix" (state 1): 10% to other, 90% stay
        
        transition_matrix = np.array([
            [0.95, 0.05],  # From "other" to ["other", "alpha-helix"]
            [0.10, 0.90]   # From "alpha-helix" to ["other", "alpha-helix"]
        ])
        
        return transition_matrix
    
    def _setup_hmm_model(self) -> hmm.CategoricalHMM:
        """Initialize and configure the HMM model."""
        
        model = hmm.CategoricalHMM(
            n_components=self.n_states,
            random_state=42,
            init_params="",  # Don't initialize parameters randomly
            params=""        # Don't update parameters during fitting
        )
        
        # Set the fixed probabilities
        model.startprob_ = self.start_probs
        model.transmat_ = self.transition_probs
        model.emissionprob_ = self.emission_probs
        
        return model
    
    def generate_sequence(self, length: int) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate a single amino acid sequence with corresponding states.
        
        Args:
            length: Length of the sequence to generate
            
        Returns:
            Tuple of (amino_acid_sequence, state_sequence)
        """
        # Generate sequence using HMM
        amino_acid_indices, state_sequence = self.hmm_model.sample(length)
        
        # Convert indices back to amino acid letters
        amino_acid_sequence = self.aa_encoder.inverse_transform(amino_acid_indices.flatten())
        
        return amino_acid_sequence, state_sequence.flatten()
    
    def get_state_probabilities(self, amino_acid_sequence: np.ndarray) -> np.ndarray:
        """
        Get state probabilities for a given amino acid sequence using Forward-Backward algorithm.
        
        Args:
            amino_acid_sequence: Array of amino acid characters
            
        Returns:
            Array of state probabilities [n_positions, n_states]
        """
        # Convert amino acids to indices
        aa_indices = self.aa_encoder.transform(amino_acid_sequence).reshape(-1, 1)
        
        # Use predict_proba to get state probabilities (Forward-Backward algorithm)
        state_probs = self.hmm_model.predict_proba(aa_indices)
        
        return state_probs
    
    def sample_batch(self, batch_size: int, min_length: int = 50, max_length: int = 300) -> Dict:
        """
        Generate a batch of sequences for training.
        
        Args:
            batch_size: Number of sequences to generate
            min_length: Minimum sequence length
            max_length: Maximum sequence length
            
        Returns:
            Dictionary with amino acid sequences and state probabilities
        """
        sequences = []
        state_probs_batch = []
        sequence_lengths = []
        
        for _ in range(batch_size):
            # Random sequence length
            length = np.random.randint(min_length, max_length + 1)
            
            # Generate sequence
            aa_seq, true_states = self.generate_sequence(length)
            
            # Get state probabilities using Forward-Backward
            state_probs = self.get_state_probabilities(aa_seq)
            
            sequences.append(aa_seq)
            state_probs_batch.append(state_probs)
            sequence_lengths.append(length)
        
        return {
            'amino_acid_sequences': sequences,
            'state_probabilities': state_probs_batch,
            'sequence_lengths': sequence_lengths
        }

In [3]:
class BayesFlowProteinSimulator(bf.simulators.Simulator):
    """
    BayesFlow-compatible simulator for protein secondary structure prediction.
    """
    
    def __init__(self, max_length: int = 300, min_length: int = 50):
        super().__init__()
        self.protein_hmm = ProteinHMMSimulator()
        self.max_length = max_length
        self.min_length = min_length
        
        # For padding sequences to fixed length
        self.pad_token = 20  # Use index 20 for padding (beyond 20 amino acids)
        
    def _pad_sequence(self, sequence: np.ndarray, target_length: int) -> np.ndarray:
        """Pad sequence to target length."""
        if len(sequence) >= target_length:
            return sequence[:target_length]
        else:
            # Pad with pad_token
            padded = np.full(target_length, self.pad_token)
            padded[:len(sequence)] = sequence
            return padded
    
    def _encode_amino_acids(self, aa_sequence: np.ndarray) -> np.ndarray:
        """Convert amino acid letters to indices."""
        return self.protein_hmm.aa_encoder.transform(aa_sequence)
    
    def sample(self, batch_size: int) -> Dict:
        """
        Generate a batch of protein sequences and their state probabilities.
        
        Args:
            batch_size: Number of sequences to generate
            
        Returns:
            Dictionary with encoded sequences and state probabilities
        """
        if isinstance(batch_size, tuple):
            batch_size = int(np.prod(batch_size))
            
        # Generate sequences
        batch_data = self.protein_hmm.sample_batch(
            batch_size=batch_size,
            min_length=self.min_length,
            max_length=self.max_length
        )
        
        # Process sequences for fixed-length input
        encoded_sequences = []
        state_probs_padded = []
        valid_lengths = []
        masks = []
        
        for i in range(batch_size):
            aa_seq = batch_data['amino_acid_sequences'][i]
            state_probs = batch_data['state_probabilities'][i]
            seq_len = len(aa_seq)
            
            # Encode amino acids
            encoded_aa = self._encode_amino_acids(aa_seq)
            
            # Pad sequences to max_length
            padded_sequence = self._pad_sequence(encoded_aa, self.max_length)
            
            # Pad state probabilities (pad with zeros)
            padded_state_probs = np.zeros((self.max_length, 2))
            padded_state_probs[:seq_len] = state_probs
            
            # Create mask: 1 for valid positions, 0 for padded positions
            mask = np.zeros(self.max_length)
            mask[:seq_len] = 1
            
            encoded_sequences.append(padded_sequence)
            state_probs_padded.append(padded_state_probs)
            valid_lengths.append(seq_len)
            masks.append(mask)
        
        return {
            'amino_acid_sequences': np.array(encoded_sequences),
            'state_probabilities': np.array(state_probs_padded),
            'sequence_lengths': np.array(valid_lengths),
            'alpha_helix_probs': np.array(state_probs_padded)[:, :, 1],  # Extract alpha-helix probabilities
            'mask': np.array(masks)
        }

In [10]:
simulator = BayesFlowProteinSimulator(max_length=1632, min_length=20)

def _flat(x):
    x = tf.cast(x, tf.float32)
    return tf.reshape(x, (tf.shape(x)[0], -1))

adapter = ContinuousApproximator.build_adapter(
    inference_variables={"state_probabilities" : MapTransform(_flat)},
    inference_conditions="mask",
    summary_variables="amino_acid_sequences",
)

dataset = OnlineDataset(
    simulator=simulator,
    adapter=adapter,
    batch_size=64,
    num_batches=1000
)

test_sample_data = dataset.simulator.sample(30)
print("Shapes of sample data items:")
for key, value in test_sample_data.items():
    print(f"{key}: {value.shape} (dtype: {value.dtype})")

Shapes of sample data items:
amino_acid_sequences: (30, 1632) (dtype: int64)
state_probabilities: (30, 1632, 2) (dtype: float64)
sequence_lengths: (30,) (dtype: int64)
alpha_helix_probs: (30, 1632) (dtype: float64)
mask: (30, 1632) (dtype: float64)


In [11]:
adapter

Adapter([0: ToArray -> 1: ConvertDType -> 2: Concatenate(['state_probabilities'] -> 'inference_variables') -> 3: Rename('mask' -> 'inference_conditions') -> 4: AsSet -> 5: Rename('amino_acid_sequences' -> 'summary_variables') -> 6: Keep(['inference_variables', 'inference_conditions', 'summary_variables', 'sample_weight'])])

In [12]:
from bayesflow.networks import TimeSeriesNetwork

# Summary network: LSTNet-like for sequences of length 1632 (single channel)
summary_network = TimeSeriesNetwork(
    summary_dim=32,           # embed into 32-dimensional summary space
    filters=[64, 64],         # two conv layers with 64 filters each
    kernel_sizes=[5, 3],      # first conv kernel=5, second=3
    strides=[1, 1],
    activation="relu",
    recurrent_type="lstm",    # capture dependencies via LSTM
    recurrent_dim=64,
    bidirectional=True,
    dropout=0.1,
    skip_steps=4
)

In [13]:
from bayesflow.networks import CouplingFlow, MLP

# Inference network: invertible flow for 2-dimensional state probs per position
inference_network = CouplingFlow(
    subnet=MLP(
        widths=[512, 512, 256],  # Increased capacity and correct output dim
        activation="mish",
        dropout=0.05
    ),
    depth=4,                          # fewer coupling layers
    subnet_kwargs={"widths": [64, 64], "activation": "relu"},
    transform="affine",
    permutation="random",
    use_actnorm=True,
    base_distribution="normal"
)

In [14]:
from bayesflow.workflows import BasicWorkflow

workflow = BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    inference_network=inference_network,
    summary_network=summary_network,
)

In [15]:
history = workflow.fit_online(
    epochs=50,                  # total passes through the data
    num_batches_per_epoch=200,  # batches generated per epoch
    batch_size=64
)

INFO:bayesflow:Fitting on dataset instance of OnlineDataset.
INFO:bayesflow:Building on a test batch.
2025-07-12 17:23:36.174115: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: INVALID_ARGUMENT: ConcatOp : Ranks of all input tensors should match: shape[0] = [64,1632,1] vs. shape[1] = [64,1664]


InvalidArgumentError: {{function_node __wrapped__ConcatV2_N_2_device_/job:localhost/replica:0/task:0/device:GPU:0}} ConcatOp : Ranks of all input tensors should match: shape[0] = [64,1632,1] vs. shape[1] = [64,1664] [Op:ConcatV2] name: concat

In [None]:
# First, let's test the basic ProteinHMMSimulator
def test_basic_simulator():
    print("=" * 60)
    print("TESTING BASIC PROTEIN HMM SIMULATOR")
    print("=" * 60)
    
    # Initialize the simulator
    simulator = ProteinHMMSimulator()
    
    # Print basic information about the simulator
    print("Amino acids:", simulator.amino_acids)
    print("States:", simulator.states)
    print("Number of states:", simulator.n_states)
    print("Number of features (amino acids):", simulator.n_features)
    print()
    
    # Show emission probabilities
    print("Emission Probabilities:")
    print("Shape:", simulator.emission_probs.shape)
    emission_df = pd.DataFrame(
        simulator.emission_probs.T,
        index=simulator.amino_acids,
        columns=simulator.states
    )
    print(emission_df.round(4))
    print()
    
    # Show transition probabilities
    print("Transition Probabilities:")
    transition_df = pd.DataFrame(
        simulator.transition_probs,
        index=simulator.states,
        columns=simulator.states
    )
    print(transition_df)
    print()
    
    # Show start probabilities
    print("Start Probabilities:")
    start_df = pd.DataFrame(
        simulator.start_probs.reshape(1, -1),
        columns=simulator.states
    )
    print(start_df)
    print()

def test_sequence_generation():
    print("=" * 60)
    print("TESTING SEQUENCE GENERATION")
    print("=" * 60)
    
    simulator = ProteinHMMSimulator()
    
    # Generate a few sequences of different lengths
    lengths = [20, 50, 100]
    
    for length in lengths:
        print(f"\nGenerating sequence of length {length}:")
        aa_seq, state_seq = simulator.generate_sequence(length)
        
        print(f"Amino acid sequence: {''.join(aa_seq)}")
        print(f"State sequence: {state_seq}")
        print(f"State names: {[simulator.states[s] for s in state_seq]}")
        
        # Count states
        other_count = np.sum(state_seq == 0)
        alpha_count = np.sum(state_seq == 1)
        print(f"State counts - Other: {other_count}, Alpha-helix: {alpha_count}")
        print(f"Alpha-helix percentage: {alpha_count/length*100:.1f}%")
        print("-" * 40)

def test_state_probabilities():
    print("=" * 60)
    print("TESTING STATE PROBABILITY CALCULATION")
    print("=" * 60)
    
    simulator = ProteinHMMSimulator()
    
    # Generate a sequence
    aa_seq, true_states = simulator.generate_sequence(30)
    
    print(f"Generated sequence: {''.join(aa_seq)}")
    print(f"True states: {true_states}")
    print(f"True state names: {[simulator.states[s] for s in true_states]}")
    print()
    
    # Get state probabilities using Forward-Backward algorithm
    state_probs = simulator.get_state_probabilities(aa_seq)
    
    print("State Probabilities (Forward-Backward):")
    print("Shape:", state_probs.shape)
    
    # Create a detailed view
    prob_df = pd.DataFrame({
        'Position': range(len(aa_seq)),
        'AA': aa_seq,
        'True_State': [simulator.states[s] for s in true_states],
        'P(Other)': state_probs[:, 0],
        'P(Alpha-helix)': state_probs[:, 1],
        'Predicted_State': [simulator.states[np.argmax(probs)] for probs in state_probs]
    })
    
    print(prob_df.round(4))
    print()
    
    # Calculate accuracy
    predicted_states = np.argmax(state_probs, axis=1)
    accuracy = np.mean(predicted_states == true_states)
    print(f"Prediction accuracy: {accuracy:.2f}")

def test_batch_generation():
    print("=" * 60)
    print("TESTING BATCH GENERATION")
    print("=" * 60)
    
    simulator = ProteinHMMSimulator()
    
    # Generate a batch
    batch_size = 5
    batch_data = simulator.sample_batch(batch_size, min_length=20, max_length=50)
    
    print(f"Generated batch with {batch_size} sequences:")
    print(f"Keys in batch data: {list(batch_data.keys())}")
    print()
    
    for i in range(batch_size):
        aa_seq = batch_data['amino_acid_sequences'][i]
        state_probs = batch_data['state_probabilities'][i]
        seq_len = batch_data['sequence_lengths'][i]
        
        print(f"Sequence {i+1}:")
        print(f"  Length: {seq_len}")
        print(f"  Amino acids: {''.join(aa_seq)}")
        print(f"  State probabilities shape: {state_probs.shape}")
        print(f"  Average alpha-helix probability: {np.mean(state_probs[:, 1]):.3f}")
        print()

def test_bayesflow_simulator():
    print("=" * 60)
    print("TESTING BAYESFLOW SIMULATOR")
    print("=" * 60)
    
    # Initialize BayesFlow simulator
    bf_simulator = BayesFlowProteinSimulator(max_length=100, min_length=20)
    
    print("BayesFlow Simulator initialized:")
    print(f"  Max length: {bf_simulator.max_length}")
    print(f"  Min length: {bf_simulator.min_length}")
    print(f"  Pad token: {bf_simulator.pad_token}")
    print()
    
    # Generate a small batch
    batch_size = 3
    batch_data = bf_simulator.sample(batch_size)
    
    print(f"Generated BayesFlow batch with {batch_size} sequences:")
    print(f"Keys in batch data: {list(batch_data.keys())}")
    print()
    
    for key, value in batch_data.items():
        if isinstance(value, np.ndarray):
            print(f"  {key}: shape = {value.shape}, dtype = {value.dtype}")
        else:
            print(f"  {key}: {type(value)}")
    print()
    
    # Show details for first sequence
    print("Details for first sequence:")
    seq_idx = 0
    encoded_seq = batch_data['amino_acid_sequences'][seq_idx]
    state_probs = batch_data['state_probabilities'][seq_idx]
    seq_len = batch_data['sequence_lengths'][seq_idx]
    alpha_probs = batch_data['alpha_helix_probs'][seq_idx]
    mask = batch_data['mask'][seq_idx]
    
    print(f"  Sequence length: {seq_len}")
    print(f"  Encoded sequence (first 20): {encoded_seq[:20]}")
    print(f"  Valid positions (mask sum): {np.sum(mask)}")
    print(f"  State probabilities shape: {state_probs.shape}")
    print(f"  Alpha-helix probabilities shape: {alpha_probs.shape}")
    
    # Show valid part of sequence
    valid_encoded = encoded_seq[:seq_len]
    valid_alpha_probs = alpha_probs[:seq_len]
    
    # Convert back to amino acids (excluding padding)
    valid_amino_acids = []
    for idx in valid_encoded:
        if idx < len(bf_simulator.protein_hmm.amino_acids):
            valid_amino_acids.append(bf_simulator.protein_hmm.amino_acids[idx])
        else:
            valid_amino_acids.append('PAD')
    
    print(f"  Valid amino acids: {''.join(valid_amino_acids)}")
    print(f"  Valid alpha-helix probs (first 10): {valid_alpha_probs[:10].round(3)}")
    print(f"  Average alpha-helix probability: {np.mean(valid_alpha_probs):.3f}")

if __name__ == "__main__":
    # Run all tests
    test_basic_simulator()
    test_sequence_generation()
    test_state_probabilities()
    test_batch_generation()
    test_bayesflow_simulator()
    
    print("\n" + "=" * 60)
    print("ALL TESTS COMPLETED!")
    print("=" * 60)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

def visualize_emission_probabilities():
    """Visualize the emission probabilities for both states."""
    simulator = ProteinHMMSimulator()
    
    # Create a heatmap of emission probabilities
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Emission probabilities as DataFrame
    emission_df = pd.DataFrame(
        simulator.emission_probs.T,
        index=simulator.amino_acids,
        columns=simulator.states
    )
    
    # Heatmap
    sns.heatmap(emission_df, annot=True, fmt='.3f', cmap='Blues', ax=ax1)
    ax1.set_title('Emission Probabilities by State')
    ax1.set_xlabel('State')
    ax1.set_ylabel('Amino Acid')
    
    # Bar plot comparison
    x = np.arange(len(simulator.amino_acids))
    width = 0.35
    
    ax2.bar(x - width/2, emission_df['other'], width, label='Other', alpha=0.7)
    ax2.bar(x + width/2, emission_df['alpha-helix'], width, label='Alpha-helix', alpha=0.7)
    
    ax2.set_xlabel('Amino Acid')
    ax2.set_ylabel('Emission Probability')
    ax2.set_title('Emission Probabilities Comparison')
    ax2.set_xticks(x)
    ax2.set_xticklabels(simulator.amino_acids)
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def visualize_sequence_example():
    """Visualize a generated sequence and its state probabilities."""
    simulator = ProteinHMMSimulator()
    
    # Generate a sequence
    aa_seq, true_states = simulator.generate_sequence(100)
    state_probs = simulator.get_state_probabilities(aa_seq)
    
    # Create visualization
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(15, 10))
    
    # Plot 1: True states
    positions = np.arange(len(aa_seq))
    colors = ['lightblue' if s == 0 else 'lightcoral' for s in true_states]
    ax1.bar(positions, np.ones(len(positions)), color=colors, alpha=0.7)
    ax1.set_title('True Secondary Structure States')
    ax1.set_ylabel('State')
    ax1.set_yticks([0, 1])
    ax1.set_yticklabels(['', 'Other/Alpha-helix'])
    ax1.grid(True, alpha=0.3)
    
    # Add amino acid labels
    for i, aa in enumerate(aa_seq):
        ax1.text(i, 0.5, aa, ha='center', va='center', fontsize=8)
    
    # Plot 2: State probabilities
    ax2.plot(positions, state_probs[:, 0], label='P(Other)', color='blue', linewidth=2)
    ax2.plot(positions, state_probs[:, 1], label='P(Alpha-helix)', color='red', linewidth=2)
    ax2.fill_between(positions, state_probs[:, 0], alpha=0.3, color='blue')
    ax2.fill_between(positions, state_probs[:, 1], alpha=0.3, color='red')
    ax2.set_title('State Probabilities (Forward-Backward Algorithm)')
    ax2.set_ylabel('Probability')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 1)
    
    # Plot 3: Prediction accuracy
    predicted_states = np.argmax(state_probs, axis=1)
    correct = (predicted_states == true_states).astype(int)
    ax3.bar(positions, correct, color=['green' if c else 'red' for c in correct], alpha=0.7)
    ax3.set_title('Prediction Accuracy (Green=Correct, Red=Incorrect)')
    ax3.set_xlabel('Position')
    ax3.set_ylabel('Correct')
    ax3.grid(True, alpha=0.3)
    
    overall_accuracy = np.mean(correct)
    ax3.text(0.02, 0.95, f'Overall Accuracy: {overall_accuracy:.2%}', 
             transform=ax3.transAxes, fontsize=12, 
             bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7))
    
    plt.tight_layout()
    plt.show()
    
    return aa_seq, true_states, state_probs, overall_accuracy

def visualize_batch_statistics():
    """Visualize statistics from a batch of sequences."""
    simulator = ProteinHMMSimulator()
    
    # Generate a larger batch
    batch_size = 50
    batch_data = simulator.sample_batch(batch_size, min_length=50, max_length=200)
    
    # Extract statistics
    sequence_lengths = batch_data['sequence_lengths']
    alpha_helix_percentages = []
    avg_alpha_probs = []
    
    for i in range(batch_size):
        state_probs = batch_data['state_probabilities'][i]
        alpha_probs = state_probs[:, 1]
        
        # Calculate percentage of positions with high alpha-helix probability
        alpha_helix_pct = np.mean(alpha_probs > 0.5) * 100
        alpha_helix_percentages.append(alpha_helix_pct)
        
        # Average alpha-helix probability
        avg_alpha_probs.append(np.mean(alpha_probs))
    
    # Create visualizations
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # Plot 1: Sequence lengths
    ax1.hist(sequence_lengths, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
    ax1.set_title('Distribution of Sequence Lengths')
    ax1.set_xlabel('Sequence Length')
    ax1.set_ylabel('Frequency')
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Alpha-helix percentages
    ax2.hist(alpha_helix_percentages, bins=20, alpha=0.7, color='lightcoral', edgecolor='black')
    ax2.set_title('Distribution of Alpha-helix Percentages')
    ax2.set_xlabel('Alpha-helix Percentage (%)')
    ax2.set_ylabel('Frequency')
    ax2.grid(True, alpha=0.3)
    
    # Plot 3: Average alpha-helix probabilities
    ax3.hist(avg_alpha_probs, bins=20, alpha=0.7, color='lightgreen', edgecolor='black')
    ax3.set_title('Distribution of Average Alpha-helix Probabilities')
    ax3.set_xlabel('Average Alpha-helix Probability')
    ax3.set_ylabel('Frequency')
    ax3.grid(True, alpha=0.3)
    
    # Plot 4: Relationship between length and alpha-helix content
    ax4.scatter(sequence_lengths, alpha_helix_percentages, alpha=0.6, color='purple')
    ax4.set_title('Sequence Length vs Alpha-helix Content')
    ax4.set_xlabel('Sequence Length')
    ax4.set_ylabel('Alpha-helix Percentage (%)')
    ax4.grid(True, alpha=0.3)
    
    # Add correlation coefficient
    correlation = np.corrcoef(sequence_lengths, alpha_helix_percentages)[0, 1]
    ax4.text(0.02, 0.95, f'Correlation: {correlation:.3f}', 
             transform=ax4.transAxes, fontsize=12,
             bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7))
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print("BATCH STATISTICS SUMMARY:")
    print(f"Number of sequences: {batch_size}")
    print(f"Sequence lengths - Mean: {np.mean(sequence_lengths):.1f}, Std: {np.std(sequence_lengths):.1f}")
    print(f"Alpha-helix percentages - Mean: {np.mean(alpha_helix_percentages):.1f}%, Std: {np.std(alpha_helix_percentages):.1f}%")
    print(f"Average alpha-helix probabilities - Mean: {np.mean(avg_alpha_probs):.3f}, Std: {np.std(avg_alpha_probs):.3f}")
    print(f"Correlation (length vs alpha-helix): {correlation:.3f}")

def compare_amino_acid_distributions():
    """Compare amino acid distributions between different states."""
    simulator = ProteinHMMSimulator()
    
    # Generate a large batch to get good statistics
    batch_size = 100
    batch_data = simulator.sample_batch(batch_size, min_length=100, max_length=200)
    
    # Collect amino acids by predicted state
    other_amino_acids = []
    alpha_amino_acids = []
    
    for i in range(batch_size):
        aa_seq = batch_data['amino_acid_sequences'][i]
        state_probs = batch_data['state_probabilities'][i]
        
        # Classify positions based on highest probability
        predicted_states = np.argmax(state_probs, axis=1)
        
        for j, aa in enumerate(aa_seq):
            if predicted_states[j] == 0:
                other_amino_acids.append(aa)
            else:
                alpha_amino_acids.append(aa)
    
    # Count amino acids
    other_counts = pd.Series(other_amino_acids).value_counts()
    alpha_counts = pd.Series(alpha_amino_acids).value_counts()
    
    # Normalize to percentages
    other_pct = other_counts / len(other_amino_acids) * 100
    alpha_pct = alpha_counts / len(alpha_amino_acids) * 100
    
    # Create comparison plot
    fig, ax = plt.subplots(figsize=(15, 8))
    
    amino_acids = simulator.amino_acids
    x = np.arange(len(amino_acids))
    width = 0.35
    
    other_values = [other_pct.get(aa, 0) for aa in amino_acids]
    alpha_values = [alpha_pct.get(aa, 0) for aa in amino_acids]
    
    ax.bar(x - width/2, other_values, width, label='Other State', alpha=0.7, color='blue')
    ax.bar(x + width/2, alpha_values, width, label='Alpha-helix State', alpha=0.7, color='red')
    
    ax.set_xlabel('Amino Acid')
    ax.set_ylabel('Percentage (%)')
    ax.set_title('Amino Acid Distribution by Predicted State')
    ax.set_xticks(x)
    ax.set_xticklabels(amino_acids)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Add difference annotation
    for i, aa in enumerate(amino_acids):
        diff = alpha_values[i] - other_values[i]
        if abs(diff) > 0.5:  # Only show significant differences
            ax.annotate(f'{diff:+.1f}', xy=(i, max(other_values[i], alpha_values[i]) + 0.2),
                       ha='center', va='bottom', fontsize=8, 
                       color='green' if diff > 0 else 'red')
    
    plt.tight_layout()
    plt.show()
    
    print("AMINO ACID DISTRIBUTION ANALYSIS:")
    print(f"Total positions analyzed: {len(other_amino_acids) + len(alpha_amino_acids)}")
    print(f"Other state positions: {len(other_amino_acids)}")
    print(f"Alpha-helix state positions: {len(alpha_amino_acids)}")
    print("\nTop 5 amino acids in each state:")
    print("Other state:", other_pct.head().round(2).to_dict())
    print("Alpha-helix state:", alpha_pct.head().round(2).to_dict())

if __name__ == "__main__":
    print("Starting protein HMM data visualization...")
    
    # Run visualizations
    print("\n1. Visualizing emission probabilities...")
    visualize_emission_probabilities()
    
    print("\n2. Visualizing sequence example...")
    aa_seq, true_states, state_probs, accuracy = visualize_sequence_example()
    print(f"Example sequence accuracy: {accuracy:.2%}")
    
    print("\n3. Visualizing batch statistics...")
    visualize_batch_statistics()
    
    print("\n4. Comparing amino acid distributions...")
    compare_amino_acid_distributions()
    
    print("\nAll visualizations completed!")