In [1]:
from transformers import EsmTokenizer, EsmModel
import torch
import sys
import os
import numpy as np
import pandas as pd

# 🧬 Protein Aggregation Prediction Model

## Overview
This notebook implements a deep learning model that predicts protein aggregation behavior by combining:
1. **Protein sequence information** using ESM2 transformer embeddings
2. **Environmental conditions** (temperature, pH, concentration) through neural network processing  
3. **Multimodal fusion** to create unified representations for downstream prediction

## Architecture Pipeline
```
Protein Sequence → ESM2 → Attention Pooling → [1280 dims]
                                                    ↓
Environmental Data → MLP + BatchNorm → [16 dims]   ↓
                                                    ↓
                          Fusion (Concatenation) → [1296 dims] → Prediction Head
```

## Sections
1. **Sequence Embedding**: Process protein sequences with ESM2 and attention pooling
2. **Environmental Embedding**: Neural network processing of environmental conditions
3. **Fusion**: Combine protein and environmental representations
4. **[Future]**: Classification/regression heads for aggregation prediction

# Sequence Embedding

In [2]:
protein_sequence = ["DAEFRHDSGYEVHHQKLVFFAEDVGSNKGAIIGLMVGGVV"]

In [3]:
model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
print(model)  # Prints full architecture

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 1280, padding_idx=1)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0-32): 33 x EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=1280, out_features=1280, bias=True)
            (key): Linear(in_features=1280, out_features=1280, bias=True)
            (value): Linear(in_features=1280, out_features=1280, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=1280, out_features=1280, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmIntermediate(
          (dense): Linear(in_features=1280, out_feature

In [4]:
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")

In [5]:
tokens = tokenizer(protein_sequence, return_tensors="pt", padding=True, truncation=True)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [6]:
outputs = model(**tokens)
print(f"Model output shape: {outputs.last_hidden_state.shape}")
print(f"Pooler output shape: {outputs.pooler_output.shape}")

Model output shape: torch.Size([1, 42, 1280])
Pooler output shape: torch.Size([1, 1280])


In [7]:
import torch.nn as nn
import torch.nn.functional as F

class AttentionPooling(nn.Module):
    """
    Attention pooling layer to reduce sequence embeddings to a fixed size.
    Reduces [batch_size, sequence_length, hidden_dim] -> [batch_size, hidden_dim]
    """
    def __init__(self, hidden_dim):
        super(AttentionPooling, self).__init__()
        self.attention = nn.Linear(hidden_dim, 1)
        
    def forward(self, embeddings, attention_mask=None):
        # embeddings: [batch_size, sequence_length, hidden_dim]
        # attention_mask: [batch_size, sequence_length] - 1 for valid tokens, 0 for padding
        
        # Compute attention scores
        attention_scores = self.attention(embeddings)  # [batch_size, sequence_length, 1]
        attention_scores = attention_scores.squeeze(-1)  # [batch_size, sequence_length]
        
        # Apply attention mask if provided
        if attention_mask is not None:
            attention_scores = attention_scores.masked_fill(attention_mask == 0, -1e9)
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(attention_scores, dim=1)  # [batch_size, sequence_length]
        
        # Apply attention weights to embeddings
        pooled_output = torch.sum(embeddings * attention_weights.unsqueeze(-1), dim=1)  # [batch_size, hidden_dim]
        
        return pooled_output, attention_weights

# Get the hidden states from ESM2 model output
hidden_states = outputs.last_hidden_state  # [batch_size, sequence_length, hidden_dim]
print(f"Original shape: {hidden_states.shape}")

# Initialize attention pooling layer
hidden_dim = hidden_states.shape[-1]  # Get hidden dimension from ESM2 output
attention_pooler = AttentionPooling(hidden_dim)

# Apply attention pooling
# Get attention mask from tokenizer output (to handle padding)
attention_mask = tokens.get('attention_mask', None)
pooled_embedding, attention_weights = attention_pooler(hidden_states, attention_mask)

print(f"Pooled shape: {pooled_embedding.shape}")
print(f"Attention weights shape: {attention_weights.shape}")
print(f"Attention weights sum (should be ~1.0): {attention_weights.sum(dim=1)}")

# Visualize which positions get the most attention
print("\nAttention weights for each position:")
for i, weight in enumerate(attention_weights[0]):  # Show weights for first sequence in batch
    print(f"Position {i}: {weight.item():.4f}")

Original shape: torch.Size([1, 42, 1280])
Pooled shape: torch.Size([1, 1280])
Attention weights shape: torch.Size([1, 42])
Attention weights sum (should be ~1.0): tensor([1.0000], grad_fn=<SumBackward1>)

Attention weights for each position:
Position 0: 0.0224
Position 1: 0.0219
Position 2: 0.0253
Position 3: 0.0255
Position 4: 0.0247
Position 5: 0.0228
Position 6: 0.0231
Position 7: 0.0231
Position 8: 0.0247
Position 9: 0.0248
Position 10: 0.0239
Position 11: 0.0252
Position 12: 0.0230
Position 13: 0.0224
Position 14: 0.0236
Position 15: 0.0267
Position 16: 0.0262
Position 17: 0.0255
Position 18: 0.0222
Position 19: 0.0260
Position 20: 0.0251
Position 21: 0.0242
Position 22: 0.0255
Position 23: 0.0233
Position 24: 0.0217
Position 25: 0.0230
Position 26: 0.0230
Position 27: 0.0239
Position 28: 0.0246
Position 29: 0.0249
Position 30: 0.0244
Position 31: 0.0229
Position 32: 0.0208
Position 33: 0.0238
Position 34: 0.0254
Position 35: 0.0247
Position 36: 0.0210
Position 37: 0.0232
Position

In [8]:
# Additional utility functions for attention pooling

def visualize_attention_weights(attention_weights, sequence, title="Attention Weights"):
    """
    Visualize attention weights across the protein sequence
    """
    import matplotlib.pyplot as plt
    
    # Convert sequence list to string if needed
    if isinstance(sequence, list):
        sequence = sequence[0]
    
    # Create visualization
    fig, ax = plt.subplots(figsize=(15, 6))
    positions = range(len(attention_weights[0]))
    weights = attention_weights[0].detach().numpy()
    
    bars = ax.bar(positions, weights)
    ax.set_xlabel('Sequence Position')
    ax.set_ylabel('Attention Weight')
    ax.set_title(title)
    
    # Add amino acid labels if sequence is short enough
    if len(sequence) <= 50:
        ax.set_xticks(positions[1:-1])  # Skip special tokens (CLS and SEP)
        ax.set_xticklabels([sequence[i-1] for i in positions[1:-1]], rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    return fig

def get_attention_statistics(attention_weights):
    """
    Get statistics about attention distribution
    """
    weights = attention_weights[0].detach().numpy()
    stats = {
        'mean': np.mean(weights),
        'std': np.std(weights),
        'min': np.min(weights),
        'max': np.max(weights),
        'entropy': -np.sum(weights * np.log(weights + 1e-8))  # Shannon entropy
    }
    return stats

# Apply utilities to our current results
print("Attention Statistics:")
stats = get_attention_statistics(attention_weights)
for key, value in stats.items():
    print(f"{key}: {value:.4f}")

# Store the pooled embedding for later use
sequence_embedding = pooled_embedding.detach().numpy()
print(f"\nSequence embedding shape: {sequence_embedding.shape}")
print(f"Sequence embedding sample values: {sequence_embedding[0][:10]}")  # Show first 10 values

Attention Statistics:
mean: 0.0238
std: 0.0015
min: 0.0208
max: 0.0267
entropy: 3.7357

Sequence embedding shape: (1, 1280)
Sequence embedding sample values: [-6.2805954e-03  4.5681186e-02 -1.4986906e-02  3.5999343e-05
 -1.3745151e-02 -1.2572345e-03 -7.8304186e-02  8.1339180e-02
  1.5063080e-01 -5.1894628e-02]


## 📝 Sequence Embedding Summary

### What We Accomplished:
- **ESM2 Model Loading**: Successfully loaded the `facebook/esm2_t33_650M_UR50D` pre-trained protein language model
- **Tokenization**: Converted protein sequence to tokens compatible with ESM2
- **Feature Extraction**: Generated contextualized embeddings `[1, 42, 1280]` for each amino acid position
- **Attention Pooling**: Implemented learnable attention mechanism to reduce sequence dimension from `[1, 42, 1280]` → `[1, 1280]`
- **Interpretability**: Added utilities to analyze attention weights and understand which positions are most important

### Key Benefits:
- ✅ **Contextual Understanding**: ESM2 captures protein structure and function relationships
- ✅ **Adaptive Pooling**: Attention learns task-specific sequence importance
- ✅ **Fixed Output Size**: Consistent 1280-dimensional representation regardless of sequence length
- ✅ **Gradient Flow**: End-to-end trainable for protein aggregation prediction

**Output**: Ready-to-use 1280-dimensional protein sequence embedding

# Environmental Embedding

In [9]:
environmental_conditions = {
    "Temperature (°C)": 37.0,
    "pH": 7.4,
    "Protein concentration (µM)": 50.0,
}

df = pd.DataFrame([environmental_conditions])
print(df)

   Temperature (°C)   pH  Protein concentration (µM)
0              37.0  7.4                        50.0


In [10]:
class EnvironmentalEmbedding(nn.Module):
    """
    Environmental embedding layer with batch normalization.
    Converts environmental features to a 16-dimensional embedding.
    """
    def __init__(self, input_dim=3, output_dim=16):
        super(EnvironmentalEmbedding, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.batch_norm = nn.BatchNorm1d(output_dim)
        self.activation = nn.ReLU()
        
    def forward(self, env_features):
        # env_features: [batch_size, input_dim]
        x = self.linear(env_features)  # [batch_size, output_dim]
        x = self.batch_norm(x)         # Apply batch normalization
        x = self.activation(x)         # Apply ReLU activation
        return x

# Convert environmental conditions to tensor
env_features = torch.tensor(df.values, dtype=torch.float32)
print(f"Environmental features shape: {env_features.shape}")
print(f"Environmental features: {env_features}")

# Initialize environmental embedding layer
env_embedding_layer = EnvironmentalEmbedding(input_dim=3, output_dim=16)

# Apply environmental embedding
# Set to evaluation mode for single sample inference (batch norm issue)
env_embedding_layer.eval()
env_embedding = env_embedding_layer(env_features)
print(f"Environmental embedding shape: {env_embedding.shape}")
print(f"Environmental embedding sample values: {env_embedding[0][:10]}")

# Display layer parameters
print(f"\nEnvironmental embedding layer parameters:")
print(f"Linear layer weight shape: {env_embedding_layer.linear.weight.shape}")
print(f"Linear layer bias shape: {env_embedding_layer.linear.bias.shape}")
print(f"Batch norm running mean shape: {env_embedding_layer.batch_norm.running_mean.shape}")
print(f"Batch norm running var shape: {env_embedding_layer.batch_norm.running_var.shape}")

# For training with larger batches, you would use:
print(f"\nNote: For training, use env_embedding_layer.train() and batch_size > 1")

Environmental features shape: torch.Size([1, 3])
Environmental features: tensor([[37.0000,  7.4000, 50.0000]])
Environmental embedding shape: torch.Size([1, 16])
Environmental embedding sample values: tensor([14.5434,  0.0000,  0.0000,  4.5256, 20.9454,  3.7892,  4.7910,  0.0000,
         0.0000,  5.6818], grad_fn=<SliceBackward0>)

Environmental embedding layer parameters:
Linear layer weight shape: torch.Size([16, 3])
Linear layer bias shape: torch.Size([16])
Batch norm running mean shape: torch.Size([16])
Batch norm running var shape: torch.Size([16])

Note: For training, use env_embedding_layer.train() and batch_size > 1


In [11]:
# Demonstration: Environmental embedding with multiple samples (batch processing)

# Create sample batch of environmental conditions
batch_env_conditions = [
    {"Temperature (°C)": 37.0, "pH": 7.4, "Protein concentration (µM)": 50.0},
    {"Temperature (°C)": 25.0, "pH": 6.8, "Protein concentration (µM)": 100.0},
    {"Temperature (°C)": 42.0, "pH": 8.0, "Protein concentration (µM)": 25.0},
    {"Temperature (°C)": 30.0, "pH": 7.0, "Protein concentration (µM)": 75.0}
]

batch_df = pd.DataFrame(batch_env_conditions)
batch_env_features = torch.tensor(batch_df.values, dtype=torch.float32)

print("Batch environmental conditions:")
print(batch_df)
print(f"\nBatch tensor shape: {batch_env_features.shape}")

# Create new embedding layer for training demonstration
env_embedding_layer_batch = EnvironmentalEmbedding(input_dim=3, output_dim=16)
env_embedding_layer_batch.train()  # Set to training mode

# Process batch
batch_env_embedding = env_embedding_layer_batch(batch_env_features)
print(f"\nBatch environmental embeddings shape: {batch_env_embedding.shape}")
print(f"Sample embeddings for first condition: {batch_env_embedding[0][:5]}")
print(f"Sample embeddings for second condition: {batch_env_embedding[1][:5]}")

# Store the single sample embedding for later fusion
env_embedding_single = env_embedding.detach()
print(f"\nSingle sample embedding stored for fusion: {env_embedding_single.shape}")

Batch environmental conditions:
   Temperature (°C)   pH  Protein concentration (µM)
0              37.0  7.4                        50.0
1              25.0  6.8                       100.0
2              42.0  8.0                        25.0
3              30.0  7.0                        75.0

Batch tensor shape: torch.Size([4, 3])

Batch environmental embeddings shape: torch.Size([4, 16])
Sample embeddings for first condition: tensor([0., 0., 0., 0., 0.], grad_fn=<SliceBackward0>)
Sample embeddings for second condition: tensor([1.3583, 1.3393, 1.3365, 1.3597, 1.3444], grad_fn=<SliceBackward0>)

Single sample embedding stored for fusion: torch.Size([1, 16])


## 📝 Environmental Embedding Summary

### What We Accomplished:
- **Feature Engineering**: Processed 3 key environmental factors (Temperature, pH, Protein concentration)
- **Neural Network Layer**: Implemented MLP with Linear → BatchNorm → ReLU architecture
- **Dimensionality Expansion**: Transformed 3 environmental features → 16-dimensional embedding
- **Batch Processing**: Handled both single samples (.eval() mode) and batch training (.train() mode)
- **Stability Features**: Added batch normalization for consistent training dynamics

### Key Benefits:
- ✅ **Feature Learning**: Neural network learns optimal environmental representations
- ✅ **Batch Normalization**: Improved training stability and convergence speed  
- ✅ **Non-linear Processing**: ReLU activation enables complex feature interactions
- ✅ **Scalable Design**: Easy to add more environmental factors as inputs
- ✅ **Training Ready**: Proper handling of single-sample inference and batch training

**Output**: Ready-to-use 16-dimensional environmental embedding

# Protein + Environment Fusion

In [12]:
class ProteinEnvironmentFusion(nn.Module):
    """
    Fusion module for combining protein sequence embeddings with environmental embeddings.
    Uses concatenation to combine the features.
    """
    def __init__(self, sequence_dim=1280, env_dim=16):
        super(ProteinEnvironmentFusion, self).__init__()
        self.sequence_dim = sequence_dim
        self.env_dim = env_dim
        self.fused_dim = sequence_dim + env_dim
        
    def forward(self, sequence_embedding, env_embedding):
        # sequence_embedding: [batch_size, sequence_dim]
        # env_embedding: [batch_size, env_dim]
        
        # Ensure both embeddings have the same batch size
        assert sequence_embedding.shape[0] == env_embedding.shape[0], \
            f"Batch size mismatch: sequence {sequence_embedding.shape[0]} vs env {env_embedding.shape[0]}"
        
        # Concatenate along feature dimension
        fused_embedding = torch.cat([sequence_embedding, env_embedding], dim=1)
        # Output: [batch_size, sequence_dim + env_dim]
        
        return fused_embedding

# Initialize fusion module
fusion_module = ProteinEnvironmentFusion(sequence_dim=1280, env_dim=16)

# Get the current embeddings (convert sequence embedding back to tensor)
sequence_embedding_tensor = torch.tensor(sequence_embedding, dtype=torch.float32)
env_embedding_tensor = env_embedding_single

print("=== Embedding Fusion ===")
print(f"Sequence embedding shape: {sequence_embedding_tensor.shape}")
print(f"Environmental embedding shape: {env_embedding_tensor.shape}")

# Apply fusion
fused_embedding = fusion_module(sequence_embedding_tensor, env_embedding_tensor)
print(f"Fused embedding shape: {fused_embedding.shape}")
print(f"Expected fused dimension: {fusion_module.fused_dim}")

# Show sample values from each component
print("\n=== Fusion Components ===")
print(f"Sequence portion (first 5 values): {fused_embedding[0][:5]}")
print(f"Environmental portion (last 5 values): {fused_embedding[0][-5:]}")

# Verify the concatenation is correct
print("\n=== Verification ===")
print(f"Original sequence sample: {sequence_embedding_tensor[0][:5]}")
print(f"Original env sample: {env_embedding_tensor[0][-5:]}")
print(f"Fused sequence portion matches: {torch.allclose(fused_embedding[0][:1280], sequence_embedding_tensor[0])}")
print(f"Fused env portion matches: {torch.allclose(fused_embedding[0][1280:], env_embedding_tensor[0])}")

# Store for downstream use
protein_env_fused = fused_embedding.detach()
print(f"\nFused embedding stored for downstream tasks: {protein_env_fused.shape}")

=== Embedding Fusion ===
Sequence embedding shape: torch.Size([1, 1280])
Environmental embedding shape: torch.Size([1, 16])
Fused embedding shape: torch.Size([1, 1296])
Expected fused dimension: 1296

=== Fusion Components ===
Sequence portion (first 5 values): tensor([-6.2806e-03,  4.5681e-02, -1.4987e-02,  3.5999e-05, -1.3745e-02])
Environmental portion (last 5 values): tensor([ 4.4430,  0.0000, 19.4043,  0.0000, 11.8405])

=== Verification ===
Original sequence sample: tensor([-6.2806e-03,  4.5681e-02, -1.4987e-02,  3.5999e-05, -1.3745e-02])
Original env sample: tensor([ 4.4430,  0.0000, 19.4043,  0.0000, 11.8405])
Fused sequence portion matches: True
Fused env portion matches: True

Fused embedding stored for downstream tasks: torch.Size([1, 1296])


In [13]:
# Demonstration: Batch fusion processing

# Create mock batch of sequence embeddings (normally these would come from attention pooling)
batch_size = 4
sequence_dim = 1280
mock_sequence_batch = torch.randn(batch_size, sequence_dim)

# Use the batch environmental embeddings we created earlier
# batch_env_embedding was created in the previous environmental embedding cell

print("=== Batch Fusion Demonstration ===")
print(f"Mock sequence batch shape: {mock_sequence_batch.shape}")
print(f"Environmental batch shape: {batch_env_embedding.shape}")

# Apply batch fusion
batch_fused_embedding = fusion_module(mock_sequence_batch, batch_env_embedding)
print(f"Batch fused embedding shape: {batch_fused_embedding.shape}")

# Show how each sample in the batch gets processed
print("\n=== Per-Sample Breakdown ===")
for i in range(batch_size):
    seq_portion = batch_fused_embedding[i][:5]  # First 5 values (from sequence)
    env_portion = batch_fused_embedding[i][-5:]  # Last 5 values (from environment)
    print(f"Sample {i+1} - Seq: {seq_portion.detach().numpy()}")
    print(f"Sample {i+1} - Env: {env_portion.detach().numpy()}")
    print()

# Summary statistics
print("=== Fusion Summary ===")
print(f"Input dimensions: Sequence ({sequence_dim}) + Environment ({fusion_module.env_dim})")
print(f"Output dimension: {fusion_module.fused_dim}")
print(f"Feature concatenation: [seq_features | env_features]")
print(f"Ready for downstream prediction layers!")

=== Batch Fusion Demonstration ===
Mock sequence batch shape: torch.Size([4, 1280])
Environmental batch shape: torch.Size([4, 16])
Batch fused embedding shape: torch.Size([4, 1296])

=== Per-Sample Breakdown ===
Sample 1 - Seq: [ 1.3823773  -0.02419936 -1.7813597  -1.0548543   0.82001823]
Sample 1 - Env: [0.         0.42615467 0.43299773 0.         0.51579916]

Sample 2 - Seq: [-1.0216062   0.11305421 -2.934657   -0.71725214  1.216467  ]
Sample 2 - Env: [1.3577632 0.        0.        1.3295695 0.       ]

Sample 3 - Seq: [-0.26486784  0.08872905  0.88852507  0.43666533  1.8734921 ]
Sample 3 - Env: [0.        1.3483403 1.3459644 0.        1.3108011]

Sample 4 - Seq: [0.24590443 0.86854786 0.01486731 1.3056583  2.0413456 ]
Sample 4 - Env: [0.41502786 0.         0.         0.47161758 0.        ]

=== Fusion Summary ===
Input dimensions: Sequence (1280) + Environment (16)
Output dimension: 1296
Feature concatenation: [seq_features | env_features]
Ready for downstream prediction layers!


## 📝 Protein + Environment Fusion Summary

### What We Accomplished:
- **Multimodal Integration**: Successfully combined protein sequence and environmental embeddings
- **Concatenation Fusion**: Implemented simple but effective feature concatenation strategy
- **Dimension Mapping**: `[1280 sequence] + [16 environmental] = [1296 fused]` representation
- **Batch Processing**: Demonstrated fusion works with both single samples and batches
- **Verification System**: Built-in checks ensure proper concatenation and data integrity

### Architecture Flow:
1. **Protein Sequence**: ESM2 → Attention Pooling → `[batch_size, 1280]`
2. **Environmental**: Raw features → MLP + BatchNorm → `[batch_size, 16]`  
3. **Fusion**: Concatenation → `[batch_size, 1296]`

### Key Benefits:
- ✅ **Information Preservation**: Both modalities retain full representational power
- ✅ **Learnable Integration**: Downstream layers can learn optimal feature combinations
- ✅ **Batch Compatibility**: Seamless processing for training and inference
- ✅ **Gradient Flow**: End-to-end differentiable for joint optimization
- ✅ **Extensible Design**: Easy to add more input modalities

**Output**: Complete 1296-dimensional multimodal embedding ready for aggregation prediction

### Next Steps:
- Add classification/regression heads for aggregation prediction
- Implement training loops with aggregation datasets  
- Add regularization and optimization strategies

# Prediction

In [14]:
class AggregationPredictor(nn.Module):
    """
    Multi-layer perceptron for protein aggregation prediction.
    Takes fused embeddings and predicts aggregation outcomes through 4 blocks.
    """
    def __init__(self, input_dim=1296, dropout_rate=0.3):
        super(AggregationPredictor, self).__init__()
        
        # Block 1: 1296 -> 256
        self.block1 = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        
        # Block 2: 256 -> 128
        self.block2 = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        
        # Block 3: 128 -> 64
        self.block3 = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        
        # Block 4: 64 -> 2 (binary classification: aggregates/doesn't aggregate)
        self.block4 = nn.Sequential(
            nn.Linear(64, 2),
            # Note: No activation here - will apply softmax in training/inference
        )
        
    def forward(self, fused_embedding):
        # fused_embedding: [batch_size, 1296]
        
        x = self.block1(fused_embedding)  # [batch_size, 256]
        x = self.block2(x)               # [batch_size, 128]
        x = self.block3(x)               # [batch_size, 64]
        logits = self.block4(x)          # [batch_size, 2]
        
        return logits

# Initialize the prediction model
predictor = AggregationPredictor(input_dim=1296, dropout_rate=0.3)

# Display model architecture
print("=== Aggregation Predictor Architecture ===")
print(predictor)

# Count parameters
total_params = sum(p.numel() for p in predictor.parameters())
trainable_params = sum(p.numel() for p in predictor.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Test with our fused embedding
print(f"\n=== Prediction Test ===")
print(f"Input shape: {protein_env_fused.shape}")

# Set model to evaluation mode for inference
predictor.eval()
with torch.no_grad():
    prediction_logits = predictor(protein_env_fused)
    
print(f"Output logits shape: {prediction_logits.shape}")
print(f"Output logits: {prediction_logits}")

# Apply softmax to get probabilities
prediction_probs = F.softmax(prediction_logits, dim=1)
print(f"Prediction probabilities: {prediction_probs}")
print(f"Predicted class: {'Aggregates' if prediction_probs[0][1] > 0.5 else 'No Aggregation'}")
print(f"Confidence: {prediction_probs[0].max().item():.4f}")

# Show layer-by-layer output shapes
print(f"\n=== Layer-by-layer Processing ===")
predictor.train()  # Set to training mode to see intermediate outputs
x = protein_env_fused
print(f"Input: {x.shape}")

x = predictor.block1(x)
print(f"After Block 1: {x.shape}")

x = predictor.block2(x)
print(f"After Block 2: {x.shape}")

x = predictor.block3(x)
print(f"After Block 3: {x.shape}")

x = predictor.block4(x)
print(f"After Block 4: {x.shape}")

=== Aggregation Predictor Architecture ===
AggregationPredictor(
  (block1): Sequential(
    (0): Linear(in_features=1296, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
  )
  (block2): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
  )
  (block3): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
  )
  (block4): Sequential(
    (0): Linear(in_features=64, out_features=2, bias=True)
  )
)

Total parameters: 373,314
Trainable parameters: 373,314

=== Prediction Test ===
Input shape: torch.Size([1, 1296])
Output logits shape: torch.Size([1, 2])
Output logits: tensor([[-0.0103,  0.0952]])
Prediction probabilities: tensor([[0.4737, 0.5263]])
Predicted class: Aggregates
Confidence: 0.5263

=== Layer-by-layer Processing ===
Input: torch.Size([1, 1296])
After Block 1: torch.Size([1, 256])
After Bloc

In [15]:
# Batch processing demonstration
print("=== Batch Processing Demonstration ===")

# Use the batch fused embeddings from earlier
print(f"Batch fused embedding shape: {batch_fused_embedding.shape}")

# Apply predictor to batch
predictor.eval()
with torch.no_grad():
    batch_logits = predictor(batch_fused_embedding)
    batch_probs = F.softmax(batch_logits, dim=1)

print(f"Batch predictions shape: {batch_logits.shape}")
print(f"Batch probabilities shape: {batch_probs.shape}")

# Show predictions for each sample in batch
print(f"\n=== Individual Sample Predictions ===")
for i in range(batch_fused_embedding.shape[0]):
    logits = batch_logits[i]
    probs = batch_probs[i]
    predicted_class = "Aggregates" if probs[1] > 0.5 else "No Aggregation"
    confidence = probs.max().item()
    
    print(f"Sample {i+1}:")
    print(f"  Logits: [{logits[0]:.4f}, {logits[1]:.4f}]")
    print(f"  Probabilities: [No Agg: {probs[0]:.4f}, Agg: {probs[1]:.4f}]")
    print(f"  Prediction: {predicted_class} (confidence: {confidence:.4f})")
    print()

# Utility functions for prediction analysis
def analyze_prediction(logits):
    """
    Analyze prediction logits and return interpretable results
    """
    probs = F.softmax(logits, dim=1)
    predictions = torch.argmax(probs, dim=1)
    confidence = torch.max(probs, dim=1)[0]
    
    results = []
    for i in range(logits.shape[0]):
        result = {
            'prediction': 'Aggregates' if predictions[i] == 1 else 'No Aggregation',
            'confidence': confidence[i].item(),
            'prob_no_agg': probs[i][0].item(),
            'prob_agg': probs[i][1].item(),
            'logit_no_agg': logits[i][0].item(),
            'logit_agg': logits[i][1].item()
        }
        results.append(result)
    
    return results

# Test the analysis function
print("=== Prediction Analysis Results ===")
analysis_results = analyze_prediction(batch_logits)
for i, result in enumerate(analysis_results):
    print(f"Sample {i+1}: {result['prediction']} ({result['confidence']:.3f} confidence)")

# Store predictor for later use
print(f"\n=== Model Summary ===")
print(f"Predictor ready for training with {trainable_params:,} trainable parameters")
print(f"Input: Fused embeddings [batch_size, 1296]")
print(f"Output: Aggregation logits [batch_size, 2]")
print(f"Architecture: 1296 → 256 → 128 → 64 → 2")

=== Batch Processing Demonstration ===
Batch fused embedding shape: torch.Size([4, 1296])
Batch predictions shape: torch.Size([4, 2])
Batch probabilities shape: torch.Size([4, 2])

=== Individual Sample Predictions ===
Sample 1:
  Logits: [0.0808, 0.0808]
  Probabilities: [No Agg: 0.5000, Agg: 0.5000]
  Prediction: No Aggregation (confidence: 0.5000)

Sample 2:
  Logits: [0.0005, 0.1293]
  Probabilities: [No Agg: 0.4678, Agg: 0.5322]
  Prediction: Aggregates (confidence: 0.5322)

Sample 3:
  Logits: [0.0067, 0.1007]
  Probabilities: [No Agg: 0.4765, Agg: 0.5235]
  Prediction: Aggregates (confidence: 0.5235)

Sample 4:
  Logits: [0.0105, 0.1183]
  Probabilities: [No Agg: 0.4731, Agg: 0.5269]
  Prediction: Aggregates (confidence: 0.5269)

=== Prediction Analysis Results ===
Sample 1: No Aggregation (0.500 confidence)
Sample 2: Aggregates (0.532 confidence)
Sample 3: Aggregates (0.523 confidence)
Sample 4: Aggregates (0.527 confidence)

=== Model Summary ===
Predictor ready for training w

## 📝 Prediction Summary

### What We Accomplished:
- **4-Block MLP Architecture**: Implemented deep neural network with progressive dimensionality reduction
- **Regularization**: Added dropout (30%) to each block to prevent overfitting
- **Binary Classification**: Final layer outputs logits for aggregation vs. no-aggregation prediction
- **Batch Processing**: Demonstrated inference on multiple samples simultaneously
- **Analysis Tools**: Created utilities for interpreting prediction results and confidence scores

### Architecture Details:
```
Input:    [batch_size, 1296] (fused embeddings)
Block 1:  Linear(1296→256) → ReLU → Dropout(0.3)
Block 2:  Linear(256→128)  → ReLU → Dropout(0.3)
Block 3:  Linear(128→64)   → ReLU → Dropout(0.3)  
Block 4:  Linear(64→2)     (logits output)
Output:   [batch_size, 2]   (aggregation predictions)
```

### Key Benefits:
- ✅ **Deep Feature Learning**: 4 layers enable complex pattern recognition
- ✅ **Regularization**: Dropout prevents overfitting during training
- ✅ **Efficient Architecture**: Progressive reduction balances capacity vs. efficiency
- ✅ **Probabilistic Output**: Softmax enables confidence estimation
- ✅ **Training Ready**: 373,314 trainable parameters optimized for aggregation prediction

**Output**: Binary aggregation predictions with confidence scores

### Model Pipeline Complete! 🎯
```
Protein Sequence → ESM2 → Attention Pooling → [1280]
                                                  ↓
Environmental Data → MLP + BatchNorm → [16]      ↓
                                                  ↓
                    Fusion (Concat) → [1296] → MLP → [2] → Aggregation Prediction
```

**Ready for training on aggregation datasets!**