In [3]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from functools import partial

In [4]:
class SimpleProteinClassifier(nn.Module):
    vocab_size: int  # Number of unique amino acids + padding token
    embed_dim: int   # Dimension for embeddings
    num_heads: int   # Number of attention heads
    ff_dim: int      # Dimension of the feed-forward layer

    @nn.compact
    def __call__(self, x, training: bool):
        # 1. Embed amino acid sequence
        # Input shape: (batch_size, seq_len)
        # Output shape: (batch_size, seq_len, embed_dim)
        x = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim)(x)

        # BUG 1: Missing component for sequence order awareness

        # 2. Simplified Transformer Block (Self-Attention + MLP)
        # 2a. Multi-Head Self-Attention
        attn_output = nn.SelfAttention(
            num_heads=self.num_heads,
            qkv_features=self.embed_dim,
            deterministic=not training, # Use dropout during training
        )(x) # Input shape: (batch_size, seq_len, embed_dim)
             # Output shape: (batch_size, seq_len, embed_dim)

        x = x + attn_output # Residual connection
        x = nn.LayerNorm()(x)

        # 2b. Feed-Forward Network
        ff_output = nn.Sequential([
            nn.Dense(features=self.ff_dim),
            nn.relu,
            nn.Dense(features=self.embed_dim) # Project back to embed_dim
        ])(x)

        x = x + ff_output # Residual connection
        x = nn.LayerNorm()(x)

        # 3. Pooling & Classification Head
        # Average pool across the sequence dimension
        x = jnp.mean(x, axis=1) # Shape: (batch_size, embed_dim)

        # Final dense layer for binary classification output (logits)
        x = nn.Dense(features=1)(x) # Output shape: (batch_size, 1)
        # Note: No sigmoid activation here, expecting loss function to handle logits

        return x # Return logits

In [11]:
# BUG 2: Inappropriate loss function for the task/output
def calculate_loss(logits, labels):
    # Expects one-hot encoded labels and multi-class logits
    return optax.sigmoid_binary_cross_entropy(logits=logits.squeeze(), labels=labels)
    # For binary tasks with single logit output, sigmoid_binary_cross_entropy is needed
    # Also, the label format might be mismatched (e.g., expects float, gets int)

In [16]:
def train_step(params, opt_state, batch, key, model_apply_fn, optimizer):
    sequences, labels = batch

    def loss_fn(params):
        logits = model_apply_fn({'params': params}, sequences, training=True, rngs={'dropout': key})
        loss = jnp.mean(calculate_loss(logits, labels.astype(jnp.float32))) # Ensure labels are float
        return loss

    loss_val, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)

    return params, opt_state, loss_val

In [13]:
def generate_dummy_batch(key, batch_size, seq_len, vocab_size):
    key, subkey1, subkey2 = jax.random.split(key, 3)
    # Generate random sequences
    sequences = jax.random.randint(subkey1, (batch_size, seq_len), 0, vocab_size)

    # BUG 3: Labels are completely random, no correlation with sequences
    labels = jax.random.randint(subkey2, (batch_size, 1), 0, 2) # Random 0 or 1

    return (sequences, labels), key

In [4]:
VOCAB_SIZE = 22 # 20 amino acids + padding + mask token (example)
EMBED_DIM = 64
NUM_HEADS = 4
FF_DIM = 128
SEQ_LEN = 50
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
NUM_STEPS = 10

In [9]:
key = jax.random.PRNGKey(0)
model_key, params_key, data_key, dropout_key, loop_key = jax.random.split(key, 5)

In [17]:
model = SimpleProteinClassifier(
    vocab_size=VOCAB_SIZE,
    embed_dim=EMBED_DIM,
    num_heads=NUM_HEADS,
    ff_dim=FF_DIM
)

In [18]:
dummy_input = jnp.ones((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32)
params = model.init(params_key, dummy_input, training=False)['params']

In [24]:
optimizer = optax.adam(learning_rate=LEARNING_RATE)
opt_state = optimizer.init(params)

In [None]:
print("Starting training...")
# Training loop
for step in range(NUM_STEPS):
    loop_key, data_key, dropout_key = jax.random.split(loop_key, 3)
    batch, data_key = generate_dummy_batch(data_key, BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)

    params, opt_state, loss = train_step(
        params,
        opt_state,
        batch,
        dropout_key, # Pass dropout key
        model.apply, # Pass model's apply function
        optimizer
    )

    if step % 100 == 0:
        print(f"Step: {step}, Loss: {loss:.4f}")

print("Training finished.")

Starting training...
Step: 0, Loss: 0.7017


In [None]:
def sinusoidal_positional_encoding(max_len, embed_dim):
    """Generates sinusoidal positional encoding matrix."""
    position = jnp.arange(max_len)[:, jnp.newaxis] # (max_len, 1)
    div_term = jnp.exp(jnp.arange(0, embed_dim, 2) * -(jnp.log(10000.0) / embed_dim)) # (embed_dim/2,)
    pe = jnp.zeros((max_len, embed_dim))

    pe = pe.at[:, 0::2].set(jnp.sin(position * div_term))
    pe = pe.at[:, 1::2].set(jnp.cos(position * div_term))

    # Add batch dimension: (1, max_len, embed_dim)
    pe = pe[jnp.newaxis, :, :]
    return pe # shape (1, max_len, embed_dim)

In [None]:
sinusoidal_positional_encoding(SEQ_LEN, EMBED_DIM)