In [1]:
# -*- coding: utf-8 -*-
import os
import warnings
from typing import Dict, List, Tuple, Optional

# warnings.filterwarnings('ignore')
os.environ['KERAS_BACKEND'] = 'tensorflow'

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import keras
import tensorflow as tf

import bayesflow as bf

from hmmlearn import hmm
from hmmlearn.hmm import CategoricalHMM

from sklearn.preprocessing import LabelEncoder

current_backend = tf.keras.backend.backend()
print(f"tf.keras is using the '{current_backend}' backend.")

2025-07-13 12:48:19.733365: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2025-07-13 12:48:19.733402: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-07-13 12:48:19.733408: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
I0000 00:00:1752403699.733424 6220395 pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
I0000 00:00:1752403699.733443 6220395 pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
INFO:bayesflow:Using backend 'tensorflow'


tf.keras is using the 'tensorflow' backend.


In [2]:
# HMM PARAMETERS FROM TASK DESCRIPTION

# 20 amino acids in standard order
AMINO_ACIDS = ['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I', 
               'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']

# Emission probabilities from task tables
# Alpha-helix state (state 0)
EMISSION_ALPHA = [0.12, 0.06, 0.03, 0.05, 0.01, 0.09, 0.05, 0.04, 0.02, 0.07,
                  0.12, 0.06, 0.03, 0.04, 0.02, 0.05, 0.04, 0.01, 0.03, 0.06]

# Other state (state 1) 
EMISSION_OTHER = [0.06, 0.05, 0.05, 0.06, 0.02, 0.05, 0.03, 0.09, 0.03, 0.05,
                  0.08, 0.06, 0.02, 0.04, 0.06, 0.07, 0.06, 0.01, 0.04, 0.07]

# Transition probabilities from task description
# [alpha->alpha, alpha->other]
TRANS_FROM_ALPHA = [0.90, 0.10]
# [other->alpha, other->other]  
TRANS_FROM_OTHER = [0.05, 0.95]

# Initial state probabilities (always starts in "other" state)
INITIAL_PROBS = [0.0, 1.0]  # [alpha-helix, other]

# Validation
print("PARAMETER VALIDATION:")
print(f"Amino acids: {len(AMINO_ACIDS)} types")
print(f"Alpha emission sum: {sum(EMISSION_ALPHA):.3f}")
print(f"Other emission sum: {sum(EMISSION_OTHER):.3f}")
print(f"Alpha transitions sum: {sum(TRANS_FROM_ALPHA):.3f}")
print(f"Other transitions sum: {sum(TRANS_FROM_OTHER):.3f}")
print(f"Initial probs sum: {sum(INITIAL_PROBS):.3f}")
print("\n✓ All probabilities are valid!")

PARAMETER VALIDATION:
Amino acids: 20 types
Alpha emission sum: 1.000
Other emission sum: 1.000
Alpha transitions sum: 1.000
Other transitions sum: 1.000
Initial probs sum: 1.000

✓ All probabilities are valid!


In [3]:
# FIXED HMM MODEL CREATION

def create_fixed_hmm():
    """
    Create HMM with fixed parameters from task description.
    
    States: 0=alpha-helix, 1=other
    Features: 20 amino acids (0-19 indices)
    
    Returns:
        CategoricalHMM with fixed empirical parameters
    """
    # Create model with fixed parameters (no learning)
    model = hmm.CategoricalHMM(
        n_components=2,        # 2 states: alpha-helix, other
        n_features=20,         # 20 amino acids
        params="",             # Don't update any parameters
        init_params="",        # Don't initialize any parameters
        algorithm="viterbi",   # Use Viterbi algorithm for decoding
        verbose=True
    )
    
    # Set fixed parameters from task description
    model.startprob_ = np.array(INITIAL_PROBS)
    model.transmat_ = np.array([TRANS_FROM_ALPHA, TRANS_FROM_OTHER])
    model.emissionprob_ = np.array([EMISSION_ALPHA, EMISSION_OTHER])
    
    return model

# Test HMM creation
print("TESTING HMM CREATION:\n")
hmm_model = create_fixed_hmm()

print(f"States: {hmm_model.n_components}")
print(f"Features: {hmm_model.n_features}")
print(f"Start probabilities: {hmm_model.startprob_}")
print(f"Transition matrix shape: {hmm_model.transmat_.shape}")
print(f"Emission matrix shape: {hmm_model.emissionprob_.shape}")

print("\nTransition probabilities:")
print("From alpha-helix:", hmm_model.transmat_[0])
print("From other:     ", hmm_model.transmat_[1])

print("\nEmission probabilities (first 5 amino acids):")
print("Alpha-helix:", hmm_model.emissionprob_[0][:5])
print("Other:      ", hmm_model.emissionprob_[1][:5])
print("\n✓ HMM model created successfully!")

TESTING HMM CREATION:

States: 2
Features: 20
Start probabilities: [0. 1.]
Transition matrix shape: (2, 2)
Emission matrix shape: (2, 20)

Transition probabilities:
From alpha-helix: [0.9 0.1]
From other:      [0.05 0.95]

Emission probabilities (first 5 amino acids):
Alpha-helix: [0.12 0.06 0.03 0.05 0.01]
Other:       [0.06 0.05 0.05 0.06 0.02]

✓ HMM model created successfully!


In [4]:
# HMM DATA GENERATION AND SIMULATOR FUNCTIONS

def generate_amino_acid_sequence(n_samples=50, random_state=None):
    """
    Generate amino acid sequences from the fixed HMM.
    
    Args:
        n_samples: Number of amino acids to generate
        random_state: Random state for reproducibility
        
    Returns:
        dict with 'amino_acids', 'true_states', and 'state_probs'
    """
    # Create the fixed HMM model
    model = create_fixed_hmm()
    
    # Generate sequence from HMM
    X, Z = model.sample(n_samples, random_state=random_state)
    
    # X is shape (n_samples, 1) - amino acid indices
    # Z is shape (n_samples,) - true hidden states
    amino_acids = X.flatten()  # Convert to 1D array of amino acid indices
    
    # Get state membership probabilities using Forward-Backward algorithm
    # Need to reshape X for predict_proba (expects (n_samples, 1))
    state_probs = model.predict_proba(X)  # Shape: (n_samples, n_states)
    
    return {
        'amino_acids': amino_acids,       # Shape: (n_samples,) - amino acid indices (0-19)
        'true_states': Z,                 # Shape: (n_samples,) - true hidden states (0=alpha, 1=other) 
        'state_probs': state_probs        # Shape: (n_samples, 2) - state membership probabilities
    }

# Test the data generation
print("TESTING HMM DATA GENERATION:\n")
test_data = generate_amino_acid_sequence(n_samples=20, random_state=42)

print(f"Amino acids shape: {test_data['amino_acids'].shape}")
print(f"True states shape: {test_data['true_states'].shape}")
print(f"State probabilities shape: {test_data['state_probs'].shape}")

print(f"\nFirst 10 amino acids (indices): {test_data['amino_acids'][:10]}")
print(f"First 10 true states: {test_data['true_states'][:10]}")
print(f"First 5 state probabilities:\n{test_data['state_probs'][:5]}")

# Verify state probabilities sum to 1
print(f"\nState probabilities sum check: {np.allclose(test_data['state_probs'].sum(axis=1), 1.0)}")

# Convert amino acid indices to actual amino acid letters for readability
amino_acid_letters = [AMINO_ACIDS[idx] for idx in test_data['amino_acids'][:10]]
print(f"First 10 amino acids (letters): {amino_acid_letters}")
print("\n✓ HMM data generation working correctly!")

TESTING HMM DATA GENERATION:

Amino acids shape: (20,)
True states shape: (20,)
State probabilities shape: (20, 2)

First 10 amino acids (indices): [19 11  2 16 14 19  3  2  9  5]
First 10 true states: [1 1 1 1 1 0 0 0 0 0]
First 5 state probabilities:
[[0.         1.        ]
 [0.01768884 0.98231116]
 [0.0253218  0.9746782 ]
 [0.03656372 0.96343628]
 [0.05153765 0.94846235]]

State probabilities sum check: True
First 10 amino acids (letters): ['V', 'K', 'N', 'T', 'P', 'V', 'D', 'N', 'I', 'E']

✓ HMM data generation working correctly!


In [5]:
# BAYESFLOW SIMULATOR IMPLEMENTATION

def hmm_simulator_function(batch_shape, sequence_length=50, **kwargs):
    """
    Simulator function for BayesFlow that generates HMM data.
    
    This function will be wrapped by BayesFlow's LambdaSimulator.
    
    Args:
        batch_shape: Shape of the batch to generate (from BayesFlow)
        sequence_length: Length of amino acid sequences to generate
        **kwargs: Additional keyword arguments
        
    Returns:
        dict: Dictionary with simulation outputs for BayesFlow
    """
    # Handle both int and tuple batch_shape
    if isinstance(batch_shape, int):
        batch_size = batch_shape
    else:
        batch_size = batch_shape[0] if len(batch_shape) > 0 else 1
    
    # Generate multiple sequences
    amino_acids_batch = []
    true_states_batch = []
    state_probs_batch = []
    
    for i in range(batch_size):
        # Generate one sequence with different random state for each
        data = generate_amino_acid_sequence(
            n_samples=sequence_length, 
            random_state=np.random.randint(0, 10000)
        )
        
        amino_acids_batch.append(data['amino_acids'])
        true_states_batch.append(data['true_states'])
        state_probs_batch.append(data['state_probs'])
    
    # Stack into batch format
    return {
        'amino_acids': np.array(amino_acids_batch),      # Shape: (batch_size, sequence_length)
        'true_states': np.array(true_states_batch),      # Shape: (batch_size, sequence_length)
        'state_probs': np.array(state_probs_batch),      # Shape: (batch_size, sequence_length, 2)
    }

# Create BayesFlow simulator
print("CREATING BAYESFLOW SIMULATOR:\n")
hmm_simulator = bf.simulators.LambdaSimulator(
    sample_fn=hmm_simulator_function,
    is_batched=True  # Our function handles batching internally
)

print("✓ BayesFlow LambdaSimulator created successfully!")

# Test the BayesFlow simulator
print("\nTESTING BAYESFLOW SIMULATOR:")
batch_size = 3
sequence_length = 15

# Sample from the simulator
simulation_data = hmm_simulator.sample(
    batch_shape=(batch_size,), 
    sequence_length=sequence_length
)

print(f"Simulation data keys: {list(simulation_data.keys())}")
print(f"Amino acids batch shape: {simulation_data['amino_acids'].shape}")
print(f"True states batch shape: {simulation_data['true_states'].shape}")
print(f"State probabilities batch shape: {simulation_data['state_probs'].shape}")

# Show multiple sequences
num_seq = 2
print(f"\nFirst {num_seq} sequences:")
for i in range(num_seq):
    amino_acids = simulation_data['amino_acids'][i]
    true_states = simulation_data['true_states'][i]
    state_probs = simulation_data['state_probs'][i]
    
    print(f"\nSequence {i}:")
    print(f"Amino acids: {amino_acids}")
    print(f"True states: {true_states}")
    print(f"State probabilities shape: {state_probs.shape}")
    print(f"State probabilities sum check: {np.allclose(state_probs.sum(axis=1), 1.0)}")
    print(f"Sequnce length: {len(amino_acids)}")

# Convert first sequence to amino acid letters
example_letters = [AMINO_ACIDS[idx] for idx in simulation_data['amino_acids'][0]]
print(f"Amino acid letters: {example_letters}")

print("\n✓ BayesFlow simulator working correctly!")

CREATING BAYESFLOW SIMULATOR:

✓ BayesFlow LambdaSimulator created successfully!

TESTING BAYESFLOW SIMULATOR:
Simulation data keys: ['amino_acids', 'true_states', 'state_probs']
Amino acids batch shape: (3, 15)
True states batch shape: (3, 15)
State probabilities batch shape: (3, 15, 2)

First 2 sequences:

Sequence 0:
Amino acids: [16  1 10 13 15  0 10  2 19  6 16  0 19 19  4]
True states: [1 1 1 1 1 1 0 0 0 0 0 0 1 1 1]
State probabilities shape: (15, 2)
State probabilities sum check: True
Sequnce length: 15

Sequence 1:
Amino acids: [10  2  7 10 14 17 19 12 14 18 15 19  9  9 14]
True states: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
State probabilities shape: (15, 2)
State probabilities sum check: True
Sequnce length: 15
Amino acid letters: ['T', 'R', 'L', 'F', 'S', 'A', 'L', 'N', 'V', 'Q', 'T', 'A', 'V', 'V', 'C']

✓ BayesFlow simulator working correctly!


In [20]:
# BAYESFLOW WORKFLOW SETUP

# Define the data variables for the adapter
INFERENCE_VARIABLES = ['state_probs']  # What we want to infer (state membership probabilities)
SUMMARY_VARIABLES = ['amino_acids']     # What the summary network processes (amino acid sequences)

print("SETTING UP BAYESFLOW WORKFLOW COMPONENTS:")
print(f"Inference variables: {INFERENCE_VARIABLES}")
print(f"Summary variables: {SUMMARY_VARIABLES}")

# Create the adapter for data processing
adapter = bf.adapters.Adapter([
    # First, convert specific arrays to float32
    bf.adapters.transforms.MapTransform({
        'state_probs': bf.adapters.transforms.ConvertDType(from_dtype='float64', to_dtype='float32'),
        'amino_acids': bf.adapters.transforms.ConvertDType(from_dtype='int64', to_dtype='float32'),
        'true_states': bf.adapters.transforms.ConvertDType(from_dtype='int64', to_dtype='float32'),
    }),
    
    # Flatten amino acids to 1D for MLP processing: (batch_size, seq_len) -> (batch_size, seq_len*1)
    bf.adapters.transforms.MapTransform({
        'amino_acids': bf.adapters.transforms.NumpyTransform(lambda x: x.reshape(x.shape[0], -1)),
    }),
    
    # Flatten state probabilities for inference: (batch_size, seq_len, 2) -> (batch_size, seq_len*2)  
    bf.adapters.transforms.MapTransform({
        'state_probs': bf.adapters.transforms.NumpyTransform(lambda x: x.reshape(x.shape[0], -1)),
    }),
    
    # Rename variables to match BayesFlow conventions
    bf.adapters.transforms.Rename(from_key='state_probs', to_key='inference_variables'),
    bf.adapters.transforms.Rename(from_key='amino_acids', to_key='summary_variables'),
    
    # Drop true_states as we don't need it for training
    bf.adapters.transforms.Drop(keys=['true_states']),
])

print("✓ Adapter created with data transforms")

# Create summary network for processing amino acid sequences
# Using MLP instead of DeepSet for simpler shape handling
summary_network = bf.networks.MLP(
    widths=[128, 128, 64],       # Hidden layer widths
    activation='silu',           # Activation function
    dropout=0.1,                 # Dropout rate
    kernel_initializer='he_normal'
)

print("✓ Summary network (DeepSet) created")

# Create inference network for posterior approximation
# CouplingFlow is excellent for complex posterior distributions
inference_network = bf.networks.CouplingFlow(
    depth=8,                     # Number of coupling layers
    transform='affine',          # Type of coupling transformation
    permutation='random',        # Permutation between layers
    use_actnorm=True,           # Use activation normalization
    base_distribution='normal',  # Base distribution 
    subnet='mlp',               # Subnet architecture
    subnet_kwargs={             # Subnet configuration
        'widths': [128, 128, 128],   # Hidden layer widths
        'activation': 'silu',        # Activation function
        'dropout': 0.1               # Dropout rate
    }
)

print("✓ Inference network (CouplingFlow) created")

# Create the BasicWorkflow
workflow = bf.workflows.BasicWorkflow(
    simulator=hmm_simulator,
    adapter=adapter,
    inference_network=inference_network,
    summary_network=summary_network,
    initial_learning_rate=5e-4,
    inference_variables='inference_variables',
    summary_variables='summary_variables'
)

print("✓ BasicWorkflow created successfully!")

# Test the complete workflow with a small sample
print("\nTESTING COMPLETE WORKFLOW:")
test_batch_size = 2
test_sequence_length = 20

# Generate test data using the workflow
test_simulation = workflow.simulate(
    batch_shape=(test_batch_size,),
    sequence_length=test_sequence_length
)

print(f"Simulated data keys: {list(test_simulation.keys())}")
for key, value in test_simulation.items():
    print(f"  {key}: shape {value.shape}, dtype {value.dtype}")

# Apply adapter to the simulated data
adapted_data = workflow.adapter(test_simulation)
print(f"\nAdapted data keys: {list(adapted_data.keys())}")
for key, value in adapted_data.items():
    print(f"  {key}: shape {value.shape}, dtype {value.dtype}")

print("\n✓ Complete workflow pipeline tested successfully!")
print("Ready for training!")

SETTING UP BAYESFLOW WORKFLOW COMPONENTS:
Inference variables: ['state_probs']
Summary variables: ['amino_acids']


ValueError: Forward transformation must be a NumPy Universal Function (ufunc).

In [6]:
# CUSTOM PROTEIN SUMMARY NETWORK

class ProteinSummaryNetwork(bf.networks.SummaryNetwork):
    """
    Custom summary network for protein amino acid sequences.
    
    This network is specifically designed for the protein secondary structure task:
    - Embeds amino acid indices into dense representations
    - Uses bidirectional LSTM to capture sequential dependencies
    - Applies attention mechanism to focus on important positions
    - Outputs summary statistics for the entire sequence
    """
    
    def __init__(self, 
                 vocab_size=20,              # Number of amino acids
                 embedding_dim=32,           # Amino acid embedding dimension
                 lstm_units=64,              # LSTM hidden units
                 attention_dim=32,           # Attention mechanism dimension
                 summary_dim=64,             # Output summary dimension
                 dropout_rate=0.1,           # Dropout rate
                 **kwargs):
        super().__init__(**kwargs)
        
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.lstm_units = lstm_units
        self.attention_dim = attention_dim
        self.summary_dim = summary_dim
        self.dropout_rate = dropout_rate
        
        # Amino acid embedding layer
        self.embedding = tf.keras.layers.Embedding(
            input_dim=vocab_size,
            output_dim=embedding_dim,
            mask_zero=False,  # Don't mask zero values as amino acid 'A' has index 0
            name='amino_acid_embedding'
        )
        
        # Bidirectional LSTM for sequence processing
        self.lstm = tf.keras.layers.Bidirectional(
            tf.keras.layers.LSTM(
                lstm_units,
                return_sequences=True,  # Return full sequence for attention
                dropout=dropout_rate,
                recurrent_dropout=dropout_rate,
                name='sequence_lstm'
            ),
            name='bidirectional_lstm'
        )
        
        # Attention mechanism layers
        self.attention_dense = tf.keras.layers.Dense(
            attention_dim, 
            activation='tanh',
            name='attention_dense'
        )
        self.attention_weights = tf.keras.layers.Dense(
            1, 
            activation=None,  # Don't use softmax here, apply it later
            name='attention_weights'
        )
        
        # Final summary layers
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.summary_dense1 = tf.keras.layers.Dense(
            summary_dim * 2,
            activation='silu',
            name='summary_dense1'
        )
        self.summary_dense2 = tf.keras.layers.Dense(
            summary_dim,
            activation='silu', 
            name='summary_dense2'
        )
        
    def call(self, x, training=False, **kwargs):
        """
        Forward pass of the protein summary network.
        
        Args:
            x: Input tensor of shape (batch_size, sequence_length, 1) containing amino acid indices
            training: Whether in training mode
            
        Returns:
            Summary tensor of shape (batch_size, summary_dim)
        """
        # Remove the last dimension if present: (batch_size, seq_len, 1) -> (batch_size, seq_len)
        if x.shape[-1] == 1:
            x = tf.squeeze(x, axis=-1)
            
        # Convert to integer indices for embedding
        x = tf.cast(x, tf.int32)
        
        # Embed amino acid indices: (batch_size, seq_len) -> (batch_size, seq_len, embedding_dim)
        embedded = self.embedding(x)
        
        # Process with bidirectional LSTM: (batch_size, seq_len, embedding_dim) -> (batch_size, seq_len, 2*lstm_units)
        lstm_output = self.lstm(embedded, training=training)
        
        # Apply attention mechanism
        # Compute attention scores: (batch_size, seq_len, 2*lstm_units) -> (batch_size, seq_len, attention_dim)
        attention_scores = self.attention_dense(lstm_output)
        
        # Compute attention weights: (batch_size, seq_len, attention_dim) -> (batch_size, seq_len, 1)
        attention_logits = self.attention_weights(attention_scores)
        
        # Apply softmax along the sequence dimension to get proper attention weights
        attention_weights = tf.nn.softmax(attention_logits, axis=1)  # Softmax over sequence dimension
        
        # Apply attention: weighted sum of LSTM outputs
        # (batch_size, seq_len, 2*lstm_units) * (batch_size, seq_len, 1) -> (batch_size, 2*lstm_units)
        attended_output = tf.reduce_sum(lstm_output * attention_weights, axis=1)
        
        # Apply dropout
        attended_output = self.dropout(attended_output, training=training)
        
        # Generate final summary through dense layers
        summary = self.summary_dense1(attended_output)
        summary = self.dropout(summary, training=training)
        summary = self.summary_dense2(summary)
        
        return summary
    
    def get_config(self):
        """Return the configuration of the layer."""
        config = super().get_config()
        config.update({
            'vocab_size': self.vocab_size,
            'embedding_dim': self.embedding_dim,
            'lstm_units': self.lstm_units,
            'attention_dim': self.attention_dim,
            'summary_dim': self.summary_dim,
            'dropout_rate': self.dropout_rate,
        })
        return config
    
    @classmethod
    def from_config(cls, config):
        """Create layer from configuration."""
        return cls(**config)

print("✓ Custom ProteinSummaryNetwork class defined")

✓ Custom ProteinSummaryNetwork class defined


In [27]:
# UPDATED BAYESFLOW WORKFLOW WITH CUSTOM SUMMARY NETWORK

def create_protein_bayesflow_workflow(
    param_dim=4,           # Transition probabilities (2x2)
    seq_len=20,           # Sequence length
    vocab_size=20,        # Number of amino acids
    embedding_dim=32,     # Amino acid embedding dimension
    lstm_units=64,        # LSTM hidden units
    attention_dim=32,     # Attention dimension
    summary_dim=64,       # Summary network output dimension
    coupling_layers=8,    # Number of coupling layers
    hidden_units=[128, 128],  # Hidden units for coupling networks
    seed=42
):
    """
    Create BayesFlow workflow with custom protein summary network.
    
    Returns:
        Configured BasicWorkflow ready for training
    """
    
    # 1. USE EXISTING SIMULATOR
    # The hmm_simulator is already created and available
    simulator = hmm_simulator
    
    # 2. CUSTOM SUMMARY NETWORK
    protein_summary_net = ProteinSummaryNetwork(
        vocab_size=vocab_size,
        embedding_dim=embedding_dim,
        lstm_units=lstm_units,
        attention_dim=attention_dim,
        summary_dim=summary_dim,
        name='ProteinSummaryNetwork'
    )
    
    # 3. INFERENCE NETWORK (unchanged)
    inference_net = bf.networks.CouplingFlow(
        num_params=param_dim,
        num_coupling_layers=coupling_layers,
        coupling_settings={'units': hidden_units, 'activation': 'silu'},
        name='ProteinInferenceNetwork'
    )
    
    # 4. PROPER ADAPTER WITH TRANSFORMS
    # Create adapter that handles the HMM simulation data properly
    adapter_transforms = [
        # Rename the data to match BayesFlow conventions
        bf.adapters.transforms.Rename(from_key='amino_acids', to_key='summary_variables'),
        bf.adapters.transforms.Rename(from_key='state_probs', to_key='inference_variables'),
        
        # Drop true_states as we don't need it for inference
        bf.adapters.transforms.Drop(keys=['true_states']),
        
        # Convert dtypes to float32 for neural networks using MapTransform
        bf.adapters.transforms.MapTransform({
            'summary_variables': bf.adapters.transforms.ConvertDType(from_dtype='int64', to_dtype='float32'),
            'inference_variables': bf.adapters.transforms.ConvertDType(from_dtype='float64', to_dtype='float32'),
        }),
    ]
    
    custom_adapter = bf.Adapter(transforms=adapter_transforms)
        workflow = bf.BasicWorkflow(
    # 5. CREATE WORKFLOW
        adapter=custom_adapter,
        simulator=simulator,
        adapter=simple_adapter,
        inference_network=inference_net,
        summary_network=protein_summary_net
    )
    
    print(f"✓ Created BayesFlow workflow with custom protein summary network")
    print(f"  - Summary network output dim: {summary_dim}")
    print(f"  - Inference network params: {param_dim}")
    print(f"  - Embedding dimension: {embedding_dim}")
    print(f"  - LSTM units: {lstm_units}")
    
    return workflow

# Create the workflow with custom summary network
protein_workflow = create_protein_bayesflow_workflow()

print("\n✓ Protein BayesFlow workflow created successfully with custom summary network!")

IndentationError: unexpected indent (2365904385.py, line 62)

In [23]:
# TEST CUSTOM SUMMARY NETWORK WITH SAMPLE DATA

print("Testing custom protein summary network...")

# Generate some test data using the existing simulator
test_data = protein_workflow.simulate(2)  # Generate 2 samples
print("✓ Test data generated")

# Check the data structure
print("\nTest data structure:")
for key, value in test_data.items():
    print(f"  {key}: shape {value.shape}, dtype {value.dtype}")

# Extract the amino acid sequences for testing
if 'summary_variables' in test_data:
    test_sequences = test_data['summary_variables']
elif 'amino_acids' in test_data:
    test_sequences = test_data['amino_acids']
else:
    # Find the right key for amino acid sequences
    for key, value in test_data.items():
        if 'amino' in key.lower() or len(value.shape) == 2:
            test_sequences = value
            print(f"Using {key} as amino acid sequences")
            break

print(f"\nTest sequences shape: {test_sequences.shape}")
print(f"Test sequences dtype: {test_sequences.dtype}")

# Test the custom summary network directly
print("\nTesting ProteinSummaryNetwork...")
try:
    # Ensure sequences are the right shape and type for our network
    if len(test_sequences.shape) == 2:
        # Add the feature dimension if needed: (batch, seq_len) -> (batch, seq_len, 1)
        test_sequences_formatted = test_sequences[..., np.newaxis]
    else:
        test_sequences_formatted = test_sequences
    
    print(f"Formatted sequences shape: {test_sequences_formatted.shape}")
    
    # Create a standalone instance of our summary network to test
    test_summary_net = ProteinSummaryNetwork(
        vocab_size=20,
        embedding_dim=32,
        lstm_units=64,
        summary_dim=64
    )
    
    # Test the forward pass
    summary_output = test_summary_net(test_sequences_formatted)
    print(f"✓ Summary network output shape: {summary_output.shape}")
    print(f"✓ Summary network output dtype: {summary_output.dtype}")
    print(f"✓ Expected shape: (batch_size={test_sequences.shape[0]}, summary_dim=64)")
    
    if summary_output.shape == (test_sequences.shape[0], 64):
        print("✅ Custom summary network working correctly!")
    else:
        print("❌ Shape mismatch - need to debug")
        
except Exception as e:
    print(f"❌ Error testing summary network: {e}")
    import traceback
    traceback.print_exc()

Testing custom protein summary network...
✓ Test data generated

Test data structure:
  amino_acids: shape (2, 50), dtype int64
  true_states: shape (2, 50), dtype int64
  state_probs: shape (2, 50, 2), dtype float64

Test sequences shape: (2, 50)
Test sequences dtype: int64

Testing ProteinSummaryNetwork...
Formatted sequences shape: (2, 50, 1)




✓ Summary network output shape: (2, 64)
✓ Summary network output dtype: <dtype: 'float32'>
✓ Expected shape: (batch_size=2, summary_dim=64)
✅ Custom summary network working correctly!


In [24]:
# TRAINING FUNCTION FOR CUSTOM PROTEIN WORKFLOW

def train_protein_workflow(
    workflow,
    batch_size=16,
    epochs=50,
    print_every=10,
    save_path=None
):
    """
    Train the protein BayesFlow workflow with our custom summary network.
    
    Args:
        workflow: The BayesFlow workflow to train
        batch_size: Batch size for training
        epochs: Number of training epochs
        print_every: Print progress every N epochs
        save_path: Path to save the trained model (optional)
    
    Returns:
        training_history: Dictionary with training metrics
    """
    
    print(f"Starting training for {epochs} epochs with batch size {batch_size}")
    print("=" * 60)
    
    training_history = {
        'epoch': [],
        'loss': [],
        'validation_loss': []
    }
    
    try:
        # Configure the workflow for training
        config = {
            'epochs': epochs,
            'batch_size': batch_size,
            'validation_sims': 1000,  # Generate validation data
            'checkpoint_interval': max(1, epochs // 10),  # Save checkpoints
        }
        
        print("Training configuration:")
        for key, value in config.items():
            print(f"  {key}: {value}")
        print()
        
        # Start online training
        print("🚀 Starting online training...")
        training_info = workflow.fit_online(
            epochs=config['epochs'],
            batch_size=config['batch_size'],
            print_every=print_every
        )
        
        print("✅ Training completed successfully!")
        
        # Extract training history if available
        if hasattr(training_info, 'history') and training_info.history:
            history = training_info.history
            training_history['loss'] = history.get('loss', [])
            training_history['validation_loss'] = history.get('val_loss', [])
            training_history['epoch'] = list(range(1, len(training_history['loss']) + 1))
        
        # Save the model if path provided
        if save_path:
            print(f"💾 Saving model to {save_path}")
            workflow.save_model(save_path)
            
        return training_history
        
    except Exception as e:
        print(f"❌ Training failed with error: {e}")
        import traceback
        traceback.print_exc()
        return training_history

print("✓ Training function defined")

# Test training with a few epochs
print("\n🧪 Testing training with a small number of epochs...")
test_training_history = train_protein_workflow(
    workflow=protein_workflow,
    batch_size=8,
    epochs=5,
    print_every=1
)

INFO:bayesflow:Fitting on dataset instance of OnlineDataset.
INFO:bayesflow:Building on a test batch.


✓ Training function defined

🧪 Testing training with a small number of epochs...
Starting training for 5 epochs with batch size 8
Training configuration:
  epochs: 5
  batch_size: 8
  validation_sims: 1000
  checkpoint_interval: 1

🚀 Starting online training...
❌ Training failed with error: 'NoneType' object is not iterable


Traceback (most recent call last):
  File "/var/folders/1r/h80d31y92rn7dxwn7_1yfhsh0000gn/T/ipykernel_74054/3088207306.py", line 49, in train_protein_workflow
    training_info = workflow.fit_online(
                    ^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/ukk/lib/python3.12/site-packages/bayesflow/workflows/basic_workflow.py", line 784, in fit_online
    return self._fit(
           ^^^^^^^^^^
  File "/opt/anaconda3/envs/ukk/lib/python3.12/site-packages/bayesflow/workflows/basic_workflow.py", line 954, in _fit
    self.history = self.approximator.fit(
                   ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/ukk/lib/python3.12/site-packages/bayesflow/approximators/continuous_approximator.py", line 316, in fit
    return super().fit(*args, **kwargs, adapter=self.adapter)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/ukk/lib/python3.12/site-packages/bayesflow/approximators/approximator.py", line 134, in fit
    mock_dat

In [26]:
# CHECK CONVERTDTYPE SIGNATURE
print("ConvertDType signature:")
print(inspect.signature(bf.adapters.transforms.ConvertDType.__init__))

# Check what transforms are available
print("\nAvailable transforms:")
transforms_list = [attr for attr in dir(bf.adapters.transforms) if not attr.startswith('_')]
print(transforms_list)

# Check MapTransform as an alternative
if hasattr(bf.adapters.transforms, 'MapTransform'):
    print("\nMapTransform signature:")
    print(inspect.signature(bf.adapters.transforms.MapTransform.__init__))

ConvertDType signature:
(self, from_dtype: str, to_dtype: str)

Available transforms:
['AsSet', 'AsTimeSeries', 'Broadcast', 'Concatenate', 'Constrain', 'ConvertDType', 'Drop', 'ElementwiseTransform', 'ExpandDims', 'FilterTransform', 'Group', 'Keep', 'Log', 'MapTransform', 'NNPE', 'NanToNum', 'NumpyTransform', 'OneHot', 'RandomSubsample', 'Rename', 'Scale', 'SerializableCustomTransform', 'Shift', 'Split', 'Sqrt', 'Squeeze', 'Standardize', 'Take', 'ToArray', 'ToDict', 'Transform', 'Ungroup', 'as_set', 'as_time_series', 'broadcast', 'concatenate', 'constrain', 'convert_dtype', 'drop', 'elementwise_transform', 'expand_dims', 'filter_transform', 'group', 'keep', 'log', 'map_transform', 'nan_to_num', 'nnpe', 'numpy_transform', 'one_hot', 'random_subsample', 'rename', 'scale', 'serializable_custom_transform', 'shift', 'split', 'sqrt', 'squeeze', 'standardize', 'take', 'to_array', 'to_dict', 'transform', 'ungroup']

MapTransform signature:
(self, transform_map: dict[str, bayesflow.adapters.tr

In [7]:
# CORRECTED PROTEIN BAYESFLOW WORKFLOW CREATION

def create_corrected_protein_workflow():
    """
    Create BayesFlow workflow with custom protein summary network and proper adapter.
    """
    print("Creating corrected protein BayesFlow workflow...")
    
    # 1. USE EXISTING SIMULATOR
    simulator = hmm_simulator
    print("✓ Using existing HMM simulator")
    
    # 2. CUSTOM SUMMARY NETWORK
    protein_summary_net = ProteinSummaryNetwork(
        vocab_size=20,
        embedding_dim=32,
        lstm_units=64,
        attention_dim=32,
        summary_dim=64,
        name='ProteinSummaryNetwork'
    )
    print("✓ Custom summary network created")
    
    # 3. INFERENCE NETWORK
    inference_net = bf.networks.CouplingFlow(
        num_params=4,  # HMM transition probabilities
        num_coupling_layers=8,
        coupling_settings={'units': [128, 128], 'activation': 'silu'},
        name='ProteinInferenceNetwork'
    )
    print("✓ Inference network created")
    
    # 4. ADAPTER WITH CORRECT TRANSFORMS
    adapter_transforms = [
        # Rename variables to BayesFlow conventions
        bf.adapters.transforms.Rename(from_key='amino_acids', to_key='summary_variables'),
        bf.adapters.transforms.Rename(from_key='state_probs', to_key='inference_variables'),
        
        # Drop unused variables
        bf.adapters.transforms.Drop(keys=['true_states']),
        
        # Convert data types using MapTransform
        bf.adapters.transforms.MapTransform({
            'summary_variables': bf.adapters.transforms.ConvertDType(
                from_dtype='int64', to_dtype='float32'
            ),
            'inference_variables': bf.adapters.transforms.ConvertDType(
                from_dtype='float64', to_dtype='float32'
            ),
        }),
    ]
    
    adapter = bf.Adapter(transforms=adapter_transforms)
    print("✓ Adapter with transforms created")
    
    # 5. CREATE WORKFLOW
    workflow = bf.BasicWorkflow(
        simulator=simulator,
        adapter=adapter,
        inference_network=inference_net,
        summary_network=protein_summary_net
    )
    print("✓ BayesFlow workflow created")
    
    return workflow

# Create the corrected workflow
print("=" * 60)
corrected_protein_workflow = create_corrected_protein_workflow()
print("=" * 60)
print("🎉 Corrected protein BayesFlow workflow created successfully!")

Creating corrected protein BayesFlow workflow...
✓ Using existing HMM simulator
✓ Custom summary network created
✓ Inference network created
✓ Adapter with transforms created
✓ BayesFlow workflow created
🎉 Corrected protein BayesFlow workflow created successfully!


In [8]:
# TEST TRAINING WITH CORRECTED WORKFLOW

print("🧪 Testing training with corrected workflow...")
print("=" * 50)

# Test simulation first
print("1. Testing simulation...")
test_sim_data = corrected_protein_workflow.simulate(2)
print("✓ Simulation successful")
print("   Data keys:", list(test_sim_data.keys()))
for key, value in test_sim_data.items():
    print(f"   {key}: shape {value.shape}, dtype {value.dtype}")

# Test adaptation
print("\n2. Testing data adaptation...")
test_adapted = corrected_protein_workflow.adapter(test_sim_data)
print("✓ Adaptation successful")
print("   Adapted keys:", list(test_adapted.keys()))
for key, value in test_adapted.items():
    print(f"   {key}: shape {value.shape}, dtype {value.dtype}")

# Now test a few training epochs
print("\n3. Testing training...")
try:
    test_history = train_protein_workflow(
        workflow=corrected_protein_workflow,
        batch_size=4,  # Small batch for testing
        epochs=3,      # Just a few epochs for testing
        print_every=1
    )
    print("✅ Training test successful!")
    
except Exception as e:
    print(f"❌ Training test failed: {e}")
    import traceback
    traceback.print_exc()

🧪 Testing training with corrected workflow...
1. Testing simulation...
✓ Simulation successful
   Data keys: ['amino_acids', 'true_states', 'state_probs']
   amino_acids: shape (2, 50), dtype int64
   true_states: shape (2, 50), dtype int64
   state_probs: shape (2, 50, 2), dtype float64

2. Testing data adaptation...
✓ Adaptation successful
   Adapted keys: ['summary_variables', 'inference_variables']
   summary_variables: shape (2, 50), dtype float32
   inference_variables: shape (2, 50, 2), dtype float32

3. Testing training...
❌ Training test failed: name 'train_protein_workflow' is not defined


Traceback (most recent call last):
  File "/var/folders/1r/h80d31y92rn7dxwn7_1yfhsh0000gn/T/ipykernel_97426/1707173917.py", line 25, in <module>
    test_history = train_protein_workflow(
                   ^^^^^^^^^^^^^^^^^^^^^^
NameError: name 'train_protein_workflow' is not defined. Did you mean: 'corrected_protein_workflow'?


In [16]:
# CHECK BAYESFLOW MODULE STRUCTURE
print("BayesFlow module attributes:")
print([attr for attr in dir(bf) if not attr.startswith('_')])

print("\nChecking for simulation module...")
if hasattr(bf, 'simulation'):
    print("✓ bf.simulation exists")
    print("  Attributes:", [attr for attr in dir(bf.simulation) if not attr.startswith('_')])
else:
    print("✗ bf.simulation not found")

print("\nChecking for simulators module...")
if hasattr(bf, 'simulators'):
    print("✓ bf.simulators exists")
    print("  Attributes:", [attr for attr in dir(bf.simulators) if not attr.startswith('_')])
else:
    print("✗ bf.simulators not found")

print("\nChecking other potential simulation-related modules...")
for module_name in ['simulators', 'simulation', 'training', 'trainers']:
    if hasattr(bf, module_name):
        module = getattr(bf, module_name)
        print(f"✓ bf.{module_name} exists")
        attrs = [attr for attr in dir(module) if not attr.startswith('_')]
        if 'Simulator' in ' '.join(attrs) or 'Lambda' in ' '.join(attrs):
            print(f"  Relevant attributes: {[attr for attr in attrs if 'Simulator' in attr or 'Lambda' in attr]}")
    else:
        print(f"✗ bf.{module_name} not found")

BayesFlow module attributes:
['Adapter', 'BasicWorkflow', 'ContinuousApproximator', 'DiskDataset', 'OfflineDataset', 'OnlineDataset', 'PointApproximator', 'adapters', 'approximators', 'datasets', 'diagnostics', 'distributions', 'experimental', 'links', 'make_simulator', 'metrics', 'networks', 'scores', 'simulators', 'types', 'utils', 'workflows', 'wrappers']

Checking for simulation module...
✗ bf.simulation not found

Checking for simulators module...
✓ bf.simulators exists
  Attributes: ['BernoulliGLM', 'BernoulliGLMRaw', 'GaussianLinear', 'GaussianLinearUniform', 'GaussianMixture', 'HierarchicalSimulator', 'InverseKinematics', 'LambdaSimulator', 'LotkaVolterra', 'ModelComparisonSimulator', 'SIR', 'SLCP', 'SLCPDistractors', 'SequentialSimulator', 'Simulator', 'TwoMoons', 'benchmark_simulators', 'hierarchical_simulator', 'lambda_simulator', 'make_simulator', 'model_comparison_simulator', 'sequential_simulator', 'simulator']

Checking other potential simulation-related modules...
✓ bf.

In [17]:
# CHECK BASICWORKFLOW STRUCTURE
print("BasicWorkflow attributes:")
print([attr for attr in dir(bf.BasicWorkflow) if not attr.startswith('_')])

print("\nBasicWorkflow docstring:")
print(bf.BasicWorkflow.__doc__)

print("\nBasicWorkflow init signature:")
import inspect
print(inspect.signature(bf.BasicWorkflow.__init__))

BasicWorkflow attributes:
['adapter', 'build_graph', 'build_optimizer', 'compute_custom_diagnostics', 'compute_default_diagnostics', 'compute_diagnostics', 'default_adapter', 'estimate', 'fit', 'fit_disk', 'fit_offline', 'fit_online', 'log_prob', 'make_simulator', 'plot_custom_diagnostics', 'plot_default_diagnostics', 'plot_diagnostics', 'sample', 'samples_to_data_frame', 'simulate', 'simulate_adapted']

BasicWorkflow docstring:
None

BasicWorkflow init signature:
(self, simulator: bayesflow.simulators.simulator.Simulator = None, adapter: bayesflow.adapters.adapter.Adapter = None, inference_network: bayesflow.networks.inference_network.InferenceNetwork | str = 'coupling_flow', summary_network: bayesflow.networks.summary_network.SummaryNetwork | str = None, initial_learning_rate: float = 0.0005, optimizer: keras.src.optimizers.optimizer.Optimizer | type = None, checkpoint_filepath: str = None, checkpoint_name: str = 'model', save_weights_only: bool = False, save_best_only: bool = False,

In [18]:
# CHECK AVAILABLE ADAPTERS
print("BayesFlow adapters module:")
print([attr for attr in dir(bf.adapters) if not attr.startswith('_')])

print("\nAdapter class signature:")
print(inspect.signature(bf.Adapter.__init__))

BayesFlow adapters module:
['Adapter', 'adapter', 'transforms']

Adapter class signature:
(self, transforms: collections.abc.Sequence[bayesflow.adapters.transforms.transform.Transform] | None = None)


In [21]:
# CHECK EXISTING HMM FUNCTIONS
print("Checking for hmm_prior function...")
if 'hmm_prior' in globals():
    print("✓ hmm_prior found")
else:
    print("✗ hmm_prior not found")

print("\nChecking for hmm_simulator...")
if 'hmm_simulator' in globals():
    print("✓ hmm_simulator found")
    print("  Type:", type(hmm_simulator))
else:
    print("✗ hmm_simulator not found")

print("\nExisting simulator-related variables:")
for var_name in ['hmm_simulator', 'hmm_model']:
    if var_name in globals():
        print(f"  {var_name}: {type(globals()[var_name])}")
        
# Let's check if the existing hmm_simulator is a LambdaSimulator
if 'hmm_simulator' in globals():
    simulator = globals()['hmm_simulator']
    print(f"\nSimulator details:")
    print(f"  Type: {type(simulator)}")
    if hasattr(simulator, 'prior_fun'):
        print(f"  Has prior_fun: {simulator.prior_fun}")
    if hasattr(simulator, 'simulator_fun'):
        print(f"  Has simulator_fun: {simulator.simulator_fun}")

Checking for hmm_prior function...
✗ hmm_prior not found

Checking for hmm_simulator...
✓ hmm_simulator found
  Type: <class 'bayesflow.simulators.lambda_simulator.LambdaSimulator'>

Existing simulator-related variables:
  hmm_simulator: <class 'bayesflow.simulators.lambda_simulator.LambdaSimulator'>
  hmm_model: <class 'hmmlearn.hmm.CategoricalHMM'>

Simulator details:
  Type: <class 'bayesflow.simulators.lambda_simulator.LambdaSimulator'>


In [12]:
# BAYESFLOW TRAINING IMPLEMENTATION

def train_bayesflow_workflow(workflow, training_config):
    """
    Train the BayesFlow workflow with proper configuration.
    
    Args:
        workflow: BayesFlow BasicWorkflow instance
        training_config: Dictionary with training parameters
        
    Returns:
        Training history
    """
    print(f"STARTING BAYESFLOW TRAINING:")
    print(f"Training strategy: {training_config['strategy']}")
    print(f"Epochs: {training_config['epochs']}")
    print(f"Batch size: {training_config['batch_size']}")
    
    if training_config['strategy'] == 'online':
        # Online training - generate data on-the-fly
        history = workflow.fit_online(
            epochs=training_config['epochs'],
            num_batches_per_epoch=training_config['num_batches_per_epoch'],
            batch_size=training_config['batch_size'],
            validation_data=training_config.get('validation_samples', None),
            sequence_length=training_config['sequence_length']
        )
    else:
        # Offline training - pre-generate training data
        print("Generating training dataset...")
        training_data = workflow.simulate(
            batch_shape=(training_config['total_samples'],),
            sequence_length=training_config['sequence_length']
        )
        
        # Generate validation data if specified
        validation_data = None
        if training_config.get('validation_samples'):
            print("Generating validation dataset...")
            validation_data = workflow.simulate(
                batch_shape=(training_config['validation_samples'],),
                sequence_length=training_config['sequence_length']
            )
        
        # Train offline with pre-generated data
        history = workflow.fit_offline(
            data=training_data,
            epochs=training_config['epochs'],
            batch_size=training_config['batch_size'],
            validation_data=validation_data
        )
    
    print("✓ Training completed!")
    return history

# Define training configuration
training_config = {
    'strategy': 'online',           # 'online' or 'offline'
    'epochs': 50,                   # Number of training epochs
    'batch_size': 32,               # Batch size for training
    'num_batches_per_epoch': 100,   # For online training
    'sequence_length': 50,          # Length of amino acid sequences
    'total_samples': 10000,         # For offline training
    'validation_samples': 1000      # Number of validation samples
}

print("TRAINING CONFIGURATION:")
for key, value in training_config.items():
    print(f"  {key}: {value}")

# Run a quick training test with reduced parameters for demonstration
quick_config = training_config.copy()
quick_config.update({
    'epochs': 5,                    # Reduced for quick demo
    'batch_size': 16,              # Smaller batch size
    'num_batches_per_epoch': 20,   # Fewer batches per epoch
    'sequence_length': 30          # Shorter sequences
})

print(f"\nRUNNING QUICK TRAINING DEMO:")
try:
    # Train the workflow
    training_history = train_bayesflow_workflow(workflow, quick_config)
    
    # Display training results
    print(f"\nTRAINING RESULTS:")
    print(f"Training completed successfully!")
    print(f"Final loss: {training_history.history['loss'][-1]:.4f}")
    
    # Plot training loss
    plt.figure(figsize=(10, 6))
    plt.plot(training_history.history['loss'], label='Training Loss')
    if 'val_loss' in training_history.history:
        plt.plot(training_history.history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('BayesFlow Training Progress')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print("✓ Quick training demo completed successfully!")
    
except Exception as e:
    print(f"Training error: {e}")
    print("This might occur due to shape mismatches - let's debug the data flow...")
    
    # Debug data shapes
    debug_data = workflow.simulate(batch_shape=(2,), sequence_length=20)
    print(f"\nDEBUG - Raw simulation output:")
    for key, value in debug_data.items():
        print(f"  {key}: {value.shape}")
    
    debug_adapted = workflow.adapter(debug_data)
    print(f"\nDEBUG - Adapted data:")
    for key, value in debug_adapted.items():
        print(f"  {key}: {value.shape}")

TRAINING CONFIGURATION:
  strategy: online
  epochs: 50
  batch_size: 32
  num_batches_per_epoch: 100
  sequence_length: 50
  total_samples: 10000
  validation_samples: 1000

RUNNING QUICK TRAINING DEMO:
STARTING BAYESFLOW TRAINING:
Training strategy: online
Epochs: 5
Batch size: 16


INFO:bayesflow:Fitting on dataset instance of OnlineDataset.
INFO:bayesflow:Building on a test batch.


Training error: {{function_node __wrapped__ConcatV2_N_2_device_/job:localhost/replica:0/task:0/device:GPU:0}} ConcatOp : Ranks of all input tensors should match: shape[0] = [16,50,1] vs. shape[1] = [16,64] [Op:ConcatV2] name: concat
This might occur due to shape mismatches - let's debug the data flow...

DEBUG - Raw simulation output:
  amino_acids: (2, 20)
  true_states: (2, 20)
  state_probs: (2, 20, 2)

DEBUG - Adapted data:
  inference_variables: (2, 20, 2)
  summary_variables: (2, 20, 1)


2025-07-13 12:36:42.016133: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: INVALID_ARGUMENT: ConcatOp : Ranks of all input tensors should match: shape[0] = [16,50,1] vs. shape[1] = [16,64]


In [None]:
# BAYESFLOW INFERENCE AND EVALUATION

def perform_inference(workflow, amino_acid_sequence, num_samples=1000):
    """
    Perform posterior inference on a given amino acid sequence.
    
    Args:
        workflow: Trained BayesFlow workflow
        amino_acid_sequence: Array of amino acid indices
        num_samples: Number of posterior samples to draw
        
    Returns:
        Dictionary with posterior samples and statistics
    """
    # Ensure input is properly shaped (batch_size=1, sequence_length)
    if amino_acid_sequence.ndim == 1:
        amino_acid_sequence = amino_acid_sequence[np.newaxis, :]
    
    # Create conditions dictionary for inference
    conditions = {
        'amino_acids': amino_acid_sequence.astype(np.float32)
    }
    
    # Apply adapter to get proper format for inference
    adapted_conditions = workflow.adapter(conditions)
    
    # Sample from posterior
    posterior_samples = workflow.sample(
        num_samples=num_samples,
        conditions=adapted_conditions
    )
    
    # Compute posterior statistics
    inference_vars = posterior_samples['inference_variables']
    
    # Reshape to (num_samples, sequence_length, num_states)
    if inference_vars.ndim == 2:
        sequence_length = amino_acid_sequence.shape[1]
        num_states = 2  # alpha-helix, other
        inference_vars = inference_vars.reshape(num_samples, sequence_length, num_states)
    
    # Compute statistics
    posterior_mean = np.mean(inference_vars, axis=0)
    posterior_std = np.std(inference_vars, axis=0)
    posterior_quantiles = np.quantile(inference_vars, [0.025, 0.5, 0.975], axis=0)
    
    return {
        'samples': inference_vars,
        'mean': posterior_mean,
        'std': posterior_std,
        'quantiles': posterior_quantiles,
        'conditions': conditions
    }

def visualize_inference_results(inference_results, amino_acid_sequence):
    """
    Visualize the results of posterior inference.
    
    Args:
        inference_results: Output from perform_inference()
        amino_acid_sequence: Original amino acid sequence
    """
    fig, axes = plt.subplots(3, 1, figsize=(15, 10))
    
    sequence_length = len(amino_acid_sequence)
    positions = np.arange(sequence_length)
    
    # Plot 1: Amino acid sequence
    axes[0].scatter(positions, amino_acid_sequence, c='darkblue', alpha=0.7, s=50)
    axes[0].set_ylabel('Amino Acid Index')
    axes[0].set_title('Input Amino Acid Sequence')
    axes[0].grid(True, alpha=0.3)
    
    # Add amino acid letters as labels
    amino_acid_letters = [AMINO_ACIDS[int(idx)] for idx in amino_acid_sequence]
    for i, letter in enumerate(amino_acid_letters):
        axes[0].text(i, amino_acid_sequence[i] + 0.5, letter, 
                    ha='center', va='bottom', fontsize=8)
    
    # Plot 2: Posterior mean probabilities
    mean_probs = inference_results['mean']
    axes[1].plot(positions, mean_probs[:, 0], 'b-', linewidth=2, label='P(Alpha-helix)', marker='o')
    axes[1].plot(positions, mean_probs[:, 1], 'r-', linewidth=2, label='P(Other)', marker='s')
    
    # Add uncertainty bands
    std_probs = inference_results['std']
    axes[1].fill_between(positions, 
                        mean_probs[:, 0] - std_probs[:, 0],
                        mean_probs[:, 0] + std_probs[:, 0], 
                        alpha=0.3, color='blue')
    axes[1].fill_between(positions,
                        mean_probs[:, 1] - std_probs[:, 1], 
                        mean_probs[:, 1] + std_probs[:, 1],
                        alpha=0.3, color='red')
    
    axes[1].set_ylabel('Posterior Probability')
    axes[1].set_title('Posterior Mean State Probabilities (with uncertainty)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    axes[1].set_ylim(0, 1)
    
    # Plot 3: Credible intervals
    quantiles = inference_results['quantiles']
    axes[2].fill_between(positions, quantiles[0, :, 0], quantiles[2, :, 0], 
                        alpha=0.4, color='blue', label='Alpha-helix 95% CI')
    axes[2].fill_between(positions, quantiles[0, :, 1], quantiles[2, :, 1],
                        alpha=0.4, color='red', label='Other 95% CI') 
    axes[2].plot(positions, quantiles[1, :, 0], 'b-', linewidth=2, label='Alpha-helix median')
    axes[2].plot(positions, quantiles[1, :, 1], 'r-', linewidth=2, label='Other median')
    
    axes[2].set_xlabel('Position in Sequence')
    axes[2].set_ylabel('Probability')
    axes[2].set_title('Posterior Credible Intervals (95%)')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    axes[2].set_ylim(0, 1)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print("INFERENCE SUMMARY:")
    print(f"Sequence length: {sequence_length}")
    print(f"Mean alpha-helix probability: {np.mean(mean_probs[:, 0]):.3f}")
    print(f"Mean other probability: {np.mean(mean_probs[:, 1]):.3f}")
    print(f"Positions with high alpha-helix confidence (>0.7): {np.sum(mean_probs[:, 0] > 0.7)}")
    print(f"Positions with high other confidence (>0.7): {np.sum(mean_probs[:, 1] > 0.7)}")

# Example usage after training
print("INFERENCE AND EVALUATION SETUP COMPLETE!")
print("\nTo use after training:")
print("1. Generate or provide an amino acid sequence")
print("2. Run: results = perform_inference(workflow, sequence)")
print("3. Visualize: visualize_inference_results(results, sequence)")

# Prepare a test sequence for inference (when workflow is trained)
print(f"\nSample amino acid sequence for testing:")
test_sequence = np.array([0, 5, 10, 2, 15, 8, 12, 3, 17, 6])  # Mix of amino acids
test_letters = [AMINO_ACIDS[idx] for idx in test_sequence]
print(f"Amino acids: {test_letters}")
print(f"Indices: {test_sequence}")

print("\n✓ Inference and evaluation functions ready!")