# Protein Secondary Structure Inference with BayesFlow

_Authors: Bhanu Prasanna, Simulation-Based Inference Course Project_

## Introduction

This notebook demonstrates amortized Bayesian inference for protein secondary structure prediction using a two-state Hidden Markov Model (HMM) and BayesFlow. The goal is to train a neural network to predict state membership probabilities (alpha-helix vs. other) from amino acid sequences, essentially learning an amortized approximation to the Forward-Backward algorithm.

### Problem Setup

We use a two-state HMM where:
- **State 0 ("other")**: Beta-sheets and random coils
- **State 1 ("alpha-helix")**: Alpha-helix secondary structure

The HMM has fixed emission and transition probabilities based on empirical data from protein structure analysis. Given an amino acid sequence, we want to infer the probability that each position belongs to an alpha-helix or other structure.

### Approach

1. **Simulator**: Generate amino acid sequences using the HMM generative model
2. **Forward-Backward**: Compute true state probabilities for each sequence
3. **BayesFlow**: Train a neural network to map sequences → state probabilities
4. **Validation**: Compare predictions to known protein structures (e.g., human insulin)

## Import Required Libraries

In [2]:
import os
import sys

# Set backend for BayesFlow (adjust as needed)
if "KERAS_BACKEND" not in os.environ:
    os.environ["KERAS_BACKEND"] = "tensorflow"  # or "jax", "torch"
else:
    print(f"Using '{os.environ['KERAS_BACKEND']}' backend")

# Add project root to path for imports
project_root = "/Users/bhanuprasanna/Documents/TU Dortmund/SS 25 - Simulation Based Interference/final"
if project_root not in sys.path:
    sys.path.append(project_root)

Using 'tensorflow' backend


In [5]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import logging
from typing import Dict, Tuple, Any

# BayesFlow import
import bayesflow as bf

# Project-specific imports
from src.bayesflow.simulator import ProteinSimulator, create_protein_simulator
from src.hmm.protein_hmm import ProteinHMM
from src.utils.visualization import plot_state_probabilities, plot_sequence_alignment

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configure matplotlib
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("✓ All imports successful!")
print(f"✓ BayesFlow version: {bf.__version__}")
print(f"✓ Using Keras backend: {os.environ.get('KERAS_BACKEND', 'default')}")

ImportError: cannot import name 'BayesFlowModel' from 'src.bayesflow.networks' (/Users/bhanuprasanna/Documents/TU Dortmund/SS 25 - Simulation Based Interference/final/src/bayesflow/networks.py)

## Simulator: Protein Secondary Structure HMM

Our simulator generates amino acid sequences using a two-state HMM and computes state probabilities using the Forward-Backward algorithm. This creates training pairs of (sequence, state_probabilities) for the neural network to learn from.

### HMM Configuration

The HMM uses empirically-derived emission and transition probabilities:

**Transition Probabilities:**
- From "other" → "alpha-helix": 5%
- From "other" → "other": 95%  
- From "alpha-helix" → "alpha-helix": 90%
- From "alpha-helix" → "other": 10%

**Emission Probabilities:** Different probability distributions over 20 amino acids for each state, based on structural analysis of known proteins.

In [None]:
# Create the protein simulator
simulator = create_protein_simulator(
    min_length=50,
    max_length=150,
    fixed_length=False  # Variable length sequences for more realistic training
)

print("Simulator Configuration:")
info = simulator.get_simulator_info()
for key, value in info.items():
    print(f"  {key}: {value}")

print(f"\nAmino acids: {info['amino_acids']}")
print(f"States: {info['states']}")

In [None]:
# Generate a batch of simulations to see the output format
batch_size = 4
sim_batch = simulator(batch_size=batch_size)

print("Simulator Output Structure:")
print(f"Keys: {list(sim_batch.keys())}")
print()

print("Summary Conditions (sequences and metadata):")
summary_cond = sim_batch['summary_conditions']
for key, value in summary_cond.items():
    print(f"  {key}: shape {value.shape}, dtype {value.dtype}")

print(f"\nParameters (state probabilities):")
params = sim_batch['parameters']
print(f"  state_probs: shape {params.shape}, dtype {params.dtype}")

print(f"\nExample sequence lengths: {summary_cond['lengths']}")
print(f"Max sequence length in batch: {summary_cond['sequences'].shape[1]}")

In [None]:
# Examine a specific sequence from the batch
seq_idx = 0
seq_length = summary_cond['lengths'][seq_idx]
sequence_indices = summary_cond['sequences'][seq_idx, :seq_length]
state_probs = sim_batch['parameters'][seq_idx, :seq_length, :]
mask = summary_cond['masks'][seq_idx, :seq_length]

print(f"Example Sequence {seq_idx + 1} (length: {seq_length}):")
print(f"Sequence indices: {sequence_indices}")
print(f"Mask: {mask}")
print(f"State probabilities shape: {state_probs.shape}")
print(f"State probs (first 10 positions):")
print(f"  P(other): {state_probs[:10, 0]}")
print(f"  P(alpha): {state_probs[:10, 1]}")

# Convert indices back to amino acids for display
amino_acids = info['amino_acids']
sequence_str = ''.join([amino_acids[i] for i in sequence_indices])
print(f"\nAmino acid sequence: {sequence_str}")

# Quick visualization of state probabilities
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6))

# Plot state probabilities
positions = np.arange(seq_length)
ax1.plot(positions, state_probs[:, 0], 'b-', label='P(other)', linewidth=2)
ax1.plot(positions, state_probs[:, 1], 'r-', label='P(alpha-helix)', linewidth=2)
ax1.fill_between(positions, 0, state_probs[:, 1], alpha=0.3, color='red')
ax1.set_ylabel('State Probability')
ax1.set_title(f'State Probabilities for Sequence {seq_idx + 1}')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot amino acid sequence as bars colored by most likely state
most_likely_state = np.argmax(state_probs, axis=1)
colors = ['blue' if s == 0 else 'red' for s in most_likely_state]
ax2.bar(positions, np.ones(seq_length), color=colors, alpha=0.7)
ax2.set_ylabel('Amino Acid')
ax2.set_xlabel('Position')
ax2.set_title('Most Likely State Assignment (Blue=Other, Red=Alpha-helix)')
ax2.set_yticks([])

plt.tight_layout()
plt.show()

## Adapter: Data Preparation for BayesFlow

The adapter transforms simulator outputs into the format expected by BayesFlow networks. This includes:
- Converting data types to float32 for deep learning
- Standardizing sequences and masks 
- Renaming keys to match BayesFlow conventions
- Preparing variable-length sequences for batch processing

BayesFlow expects the following keys:
- `summary_conditions`: Data to be processed by summary networks (sequences, masks, metadata)
- `inference_variables`: Target variables to be predicted (state probabilities)

In [None]:
# Create the adapter for preprocessing simulator outputs
adapter = (
    bf.adapters.Adapter()
    # Convert data types to deep learning friendly formats
    .convert_dtype("int32", "float32")  # sequences as float32
    .convert_dtype("float64", "float32")  # state probabilities as float32
    
    # Rename variables to match BayesFlow conventions
    # summary_conditions contains the input data (sequences, masks, lengths)
    .rename("summary_conditions", "summary_conditions")
    # inference_variables contains the target (state probabilities)
    .rename("parameters", "inference_variables")
)

print("Adapter configuration:")
print(adapter)

# Test the adapter on our sample batch
adapted_batch = adapter(sim_batch)

print("\nAfter adaptation:")
print(f"Keys: {list(adapted_batch.keys())}")
print()

for key, value in adapted_batch.items():
    if isinstance(value, dict):
        print(f"{key}:")
        for subkey, subvalue in value.items():
            print(f"  {subkey}: shape {subvalue.shape}, dtype {subvalue.dtype}")
    else:
        print(f"{key}: shape {value.shape}, dtype {value.dtype}")

print(f"\nData ready for BayesFlow training!")

## Generate Training and Validation Data

For this example, we'll generate offline training and validation datasets. In practice, BayesFlow also supports online training where data is generated on-the-fly during training.

In [None]:
# Configuration for data generation
num_training_samples = 5000  # Moderate size for demo - increase for better performance
num_validation_samples = 500
batch_size = 64
epochs = 30  # Moderate training for demo

print("Generating training data...")
# Generate training data
training_batches = []
total_generated = 0
while total_generated < num_training_samples:
    current_batch_size = min(batch_size, num_training_samples - total_generated)
    batch = simulator(batch_size=current_batch_size)
    adapted_batch = adapter(batch)
    training_batches.append(adapted_batch)
    total_generated += current_batch_size
    if total_generated % 1000 == 0:
        print(f"  Generated {total_generated}/{num_training_samples} samples")

# Combine batches into single training dataset
def combine_batches(batches):
    """Combine list of batches into single dataset"""
    combined = {}
    for key in batches[0].keys():
        if isinstance(batches[0][key], dict):
            combined[key] = {}
            for subkey in batches[0][key].keys():
                combined[key][subkey] = np.concatenate([b[key][subkey] for b in batches], axis=0)
        else:
            combined[key] = np.concatenate([b[key] for b in batches], axis=0)
    return combined

training_data = combine_batches(training_batches)

print("Generating validation data...")
validation_data = adapter(simulator(batch_size=num_validation_samples))

print(f"✓ Training data: {training_data['inference_variables'].shape[0]} samples")
print(f"✓ Validation data: {validation_data['inference_variables'].shape[0]} samples")

# Quick statistics on sequence lengths
train_lengths = training_data['summary_conditions']['lengths']
val_lengths = validation_data['summary_conditions']['lengths']

print(f"\nSequence length statistics:")
print(f"Training - min: {train_lengths.min()}, max: {train_lengths.max()}, mean: {train_lengths.mean():.1f}")
print(f"Validation - min: {val_lengths.min()}, max: {val_lengths.max()}, mean: {val_lengths.mean():.1f}")

## Define and Configure Neural Network Approximator

We'll set up a BayesFlow architecture consisting of:
1. **Summary Network**: Processes variable-length amino acid sequences into fixed-size representations
2. **Inference Network**: Maps sequence representations to state probability predictions

For this protein sequence task, we'll use a Flow Matching network as the backbone, which can handle the complex multimodal nature of protein structure prediction.

In [None]:
# Custom summary network for protein sequences
import tensorflow as tf
from tensorflow import keras

class ProteinSummaryNetwork(bf.networks.SummaryNetwork):
    """
    Summary network for processing variable-length amino acid sequences.
    
    Architecture:
    1. Embedding layer for amino acid indices
    2. LSTM for sequence processing  
    3. Attention mechanism for important positions
    4. Dense layers for final representation
    """
    
    def __init__(self, 
                 vocab_size=20,
                 embedding_dim=32,
                 lstm_units=64,
                 attention_units=32,
                 summary_dim=64,
                 **kwargs):
        super().__init__(**kwargs)
        
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.lstm_units = lstm_units
        self.attention_units = attention_units
        self.summary_dim = summary_dim
        
        # Embedding for amino acid indices
        self.embedding = keras.layers.Embedding(
            input_dim=vocab_size,
            output_dim=embedding_dim,
            mask_zero=True,
            name="amino_acid_embedding"
        )
        
        # Bidirectional LSTM for sequence processing
        self.lstm = keras.layers.Bidirectional(
            keras.layers.LSTM(lstm_units, return_sequences=True),
            name="sequence_lstm"
        )
        
        # Attention mechanism
        self.attention = keras.layers.MultiHeadAttention(
            num_heads=4,
            key_dim=attention_units,
            name="sequence_attention"
        )
        
        # Global pooling and final layers
        self.global_pool = keras.layers.GlobalAveragePooling1D()
        self.dense1 = keras.layers.Dense(summary_dim, activation='relu')
        self.dropout = keras.layers.Dropout(0.1)
        self.dense2 = keras.layers.Dense(summary_dim, activation='relu')
        
    def call(self, summary_conditions, **kwargs):
        """
        Process protein sequences into fixed-size summaries.
        
        Args:
            summary_conditions: Dict containing 'sequences', 'masks', 'lengths'
        """
        sequences = summary_conditions['sequences']  # [batch, max_length]
        masks = summary_conditions['masks']          # [batch, max_length]
        
        # Embed amino acid sequences
        embedded = self.embedding(sequences)  # [batch, max_length, embedding_dim]
        
        # Apply sequence mask for padding
        mask_expanded = tf.expand_dims(masks, -1)
        embedded = embedded * mask_expanded
        
        # Process with LSTM
        lstm_out = self.lstm(embedded)  # [batch, max_length, 2*lstm_units]
        
        # Apply attention mechanism
        attended = self.attention(lstm_out, lstm_out, attention_mask=masks)
        
        # Global pooling to get fixed-size representation
        pooled = self.global_pool(attended)  # [batch, 2*lstm_units]
        
        # Final dense layers
        x = self.dense1(pooled)
        x = self.dropout(x, training=kwargs.get("stage") == "training")
        summary = self.dense2(x)  # [batch, summary_dim]
        
        return summary

# Create the summary network
summary_network = ProteinSummaryNetwork(
    vocab_size=20,  # 20 amino acids
    embedding_dim=32,
    lstm_units=64,
    attention_units=32,
    summary_dim=64
)

print("✓ Summary network created")
print(f"  Architecture: Embedding({20}) → LSTM({64}) → Attention → Dense({64})")

In [None]:
# Create the inference network
# Flow Matching is well-suited for protein structure prediction due to its ability
# to handle complex, multimodal posterior distributions
inference_network = bf.networks.FlowMatching(
    subnet="mlp",
    subnet_kwargs={
        "dropout": 0.1,
        "widths": (128, 128, 64)  # Network architecture for flow matching
    }
)

print("✓ Inference network created")
print(f"  Type: Flow Matching")
print(f"  Architecture: MLP with layers {(128, 128, 64)}")

# Test the networks with a small batch to ensure compatibility
print("\nTesting network compatibility...")
test_batch = adapter(simulator(batch_size=2))

# Test summary network
try:
    summary_output = summary_network(test_batch['summary_conditions'])
    print(f"✓ Summary network test: input {test_batch['summary_conditions']['sequences'].shape} → output {summary_output.shape}")
except Exception as e:
    print(f"✗ Summary network test failed: {e}")

# Get expected shapes for inference network
max_seq_len = test_batch['summary_conditions']['sequences'].shape[1]
state_prob_dim = test_batch['inference_variables'].shape[-1]  # Should be 2 (binary states)

print(f"  Expected inference input dim: {state_prob_dim * max_seq_len}")
print(f"  State probability dimensionality: {state_prob_dim}")
print(f"  Maximum sequence length: {max_seq_len}")

## Workflow: Training the Posterior Approximator

Now we'll combine all components into a BayesFlow workflow and train the network to learn the mapping from amino acid sequences to state probabilities.

In [None]:
# Create the BayesFlow workflow
workflow = bf.BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    inference_network=inference_network,
    summary_network=summary_network,
    # Standardize the state probabilities for better training stability
    standardize=["inference_variables"]
)

print("✓ Workflow created successfully")
print("  Components:")
print(f"    - Simulator: {type(simulator).__name__}")
print(f"    - Adapter: BayesFlow Adapter")
print(f"    - Summary Network: {type(summary_network).__name__}")
print(f"    - Inference Network: {type(inference_network).__name__}")

# Train the workflow
print(f"\nStarting training...")
print(f"  Training samples: {training_data['inference_variables'].shape[0]}")
print(f"  Validation samples: {validation_data['inference_variables'].shape[0]}")
print(f"  Epochs: {epochs}")
print(f"  Batch size: {batch_size}")

history = workflow.fit_offline(
    data=training_data,
    validation_data=validation_data,
    epochs=epochs,
    batch_size=batch_size,
    verbose=1
)

print("✓ Training completed!")

In [None]:
# Plot training history
f = bf.diagnostics.plots.loss(history, figsize=(12, 4))
plt.suptitle("Training Progress: Protein HMM Posterior Approximation")
plt.tight_layout()
plt.show()

# Extract some training metrics
final_train_loss = history.history['loss'][-1]
final_val_loss = history.history['val_loss'][-1]
best_val_loss = min(history.history['val_loss'])

print(f"\nTraining Summary:")
print(f"  Final training loss: {final_train_loss:.4f}")
print(f"  Final validation loss: {final_val_loss:.4f}")
print(f"  Best validation loss: {best_val_loss:.4f}")
print(f"  Training epochs: {len(history.history['loss'])}")

if final_val_loss < final_train_loss * 2:
    print("✓ Model appears well-trained (no severe overfitting)")
else:
    print("⚠ Possible overfitting detected - consider more regularization or data")

## Validation: Posterior Diagnostics and Calibration

We'll validate our trained model using simulation-based calibration (SBC) and other diagnostics to ensure the posterior approximation is accurate and well-calibrated.

In [None]:
# Generate test data for validation
num_test_samples = 300
num_posterior_samples = 1000

print("Generating test data for validation...")
test_data = adapter(simulator(batch_size=num_test_samples))

print("Sampling from approximate posterior...")
posterior_samples = workflow.sample(
    conditions=test_data, 
    num_samples=num_posterior_samples
)

print(f"✓ Generated {num_posterior_samples} posterior samples for {num_test_samples} test cases")
print(f"Posterior samples shape: {posterior_samples['inference_variables'].shape}")

# Run automated diagnostics
print("\nComputing diagnostic metrics...")
metrics = workflow.compute_default_diagnostics(test_data=num_test_samples)

print("Diagnostic Results:")
for metric_name, metric_value in metrics.items():
    if isinstance(metric_value, (int, float)):
        print(f"  {metric_name}: {metric_value:.4f}")
    else:
        print(f"  {metric_name}: {metric_value}")

# Generate diagnostic plots
print("\nGenerating diagnostic plots...")
diagnostic_figures = workflow.plot_default_diagnostics(
    test_data=num_test_samples,
    calibration_ecdf_kwargs={"difference": True, "figsize": (15, 3)},
    recovery_kwargs={"figsize": (15, 3)},
    z_score_contraction_kwargs={"figsize": (15, 3)}
)

plt.show()

In [None]:
# Manual validation: Compare predictions to ground truth for specific examples
print("Manual Validation: Detailed comparison for selected sequences")

# Select a few test cases for detailed analysis
test_indices = [0, 1, 2]
for idx in test_indices:
    print(f"\n--- Test Sequence {idx + 1} ---")
    
    # Get true state probabilities
    true_probs = test_data['inference_variables'][idx]
    seq_length = int(test_data['summary_conditions']['lengths'][idx])
    
    # Get posterior samples for this sequence
    predicted_probs = posterior_samples['inference_variables'][idx]  # [num_samples, max_len, 2]
    
    # Compute posterior statistics
    mean_probs = np.mean(predicted_probs, axis=0)[:seq_length]  # [seq_length, 2]
    std_probs = np.std(predicted_probs, axis=0)[:seq_length]   # [seq_length, 2]
    
    # Truncate true probabilities to actual sequence length
    true_probs_truncated = true_probs[:seq_length]
    
    # Compute metrics
    mse = np.mean((mean_probs - true_probs_truncated) ** 2)
    mae = np.mean(np.abs(mean_probs - true_probs_truncated))
    
    print(f"  Sequence length: {seq_length}")
    print(f"  MSE: {mse:.4f}")
    print(f"  MAE: {mae:.4f}")
    
    # Plot comparison for this sequence
    fig, axes = plt.subplots(2, 1, figsize=(12, 8))
    positions = np.arange(seq_length)
    
    # Plot state probabilities - Other state
    axes[0].plot(positions, true_probs_truncated[:, 0], 'b-', linewidth=3, label='True P(other)')
    axes[0].plot(positions, mean_probs[:, 0], 'r--', linewidth=2, label='Predicted P(other)')
    axes[0].fill_between(positions, 
                        mean_probs[:, 0] - 2*std_probs[:, 0],
                        mean_probs[:, 0] + 2*std_probs[:, 0],
                        alpha=0.2, color='red', label='95% CI')
    axes[0].set_ylabel('P(other)')
    axes[0].set_title(f'Test Sequence {idx + 1}: Other State Probabilities')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Plot state probabilities - Alpha-helix state  
    axes[1].plot(positions, true_probs_truncated[:, 1], 'b-', linewidth=3, label='True P(alpha-helix)')
    axes[1].plot(positions, mean_probs[:, 1], 'r--', linewidth=2, label='Predicted P(alpha-helix)')
    axes[1].fill_between(positions,
                        mean_probs[:, 1] - 2*std_probs[:, 1], 
                        mean_probs[:, 1] + 2*std_probs[:, 1],
                        alpha=0.2, color='red', label='95% CI')
    axes[1].set_ylabel('P(alpha-helix)')
    axes[1].set_xlabel('Sequence Position')
    axes[1].set_title(f'Test Sequence {idx + 1}: Alpha-helix State Probabilities')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

print("\n✓ Manual validation completed")