In [1]:
import jax
import jax.numpy as jnp
import optax # Commonly used for loss functions like cross-entropy

In [13]:
def calculate_loss(logits, targets, padding_value=0):
  """
  Calculates the cross-entropy loss for sequence prediction.

  Args:
    logits: Predicted logits from the model. Shape: (batch_size, seq_len, num_classes)
    targets: Ground truth target labels (integers). Shape: (batch_size, seq_len)
    padding_value: The integer value used for padding in the targets.

  Returns:
    The mean loss calculated over the batch.
  """

  # Calculate cross-entropy loss for each position
  # optax.softmax_cross_entropy_with_integer_labels expects logits of shape (..., num_classes)
  # and labels of shape (...)
  # It returns loss of shape (...)
  mask = targets == -1
  
  token_loss = optax.softmax_cross_entropy_with_integer_labels(logits[~mask], targets[~mask])
  # Shape: (batch_size, seq_len)

  # Calculate the mean loss across all tokens in the batch
  mean_loss = jnp.mean(token_loss)

  return mean_loss

In [14]:
key = jax.random.PRNGKey(0)
batch_size = 4
seq_len = 50
num_classes = 8 # e.g., 8 secondary structure types

# Dummy model output (logits)
dummy_logits = jax.random.normal(key, (batch_size, seq_len, num_classes))

# Dummy targets with padding (using 0 as padding token)
# Let's assume sequence lengths are [40, 30, 50, 20]
dummy_targets = jax.random.randint(key, (batch_size, seq_len), 1, num_classes) # Values from 1 to num_classes
dummy_targets = dummy_targets.at[0, 40:].set(-1) # Pad first sequence
dummy_targets = dummy_targets.at[1, 30:].set(-1) # Pad second sequence
# Third sequence is full length (50)
dummy_targets = dummy_targets.at[3, 20:].set(-1) # Pad fourth sequence

# Calculate loss
loss = calculate_loss(dummy_logits, dummy_targets, padding_value=0)

print(f"Calculated Loss: {loss}")

Calculated Loss: 2.4950876235961914


In [15]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
from typing import Sequence, Optional

In [16]:
def sinusoidal_init(max_len=2048, embed_dim=512):
    position = np.arange(max_len)[:, np.newaxis]
    div_term = np.exp(np.arange(0, embed_dim, 2) * -(np.log(10000.0) / embed_dim))
    pe = np.zeros((max_len, embed_dim))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    return jnp.array(pe)[np.newaxis, :, :] # Add batch dim: (1, max_len, embed_dim)

In [17]:
sinusoidal_init(10, 4)

Array([[[ 0.        ,  1.        ,  0.        ,  1.        ],
        [ 0.84147096,  0.5403023 ,  0.00999983,  0.99995   ],
        [ 0.9092974 , -0.41614684,  0.01999867,  0.9998    ],
        [ 0.14112   , -0.9899925 ,  0.0299955 ,  0.99955004],
        [-0.7568025 , -0.6536436 ,  0.03998933,  0.9992001 ],
        [-0.9589243 ,  0.2836622 ,  0.04997917,  0.99875027],
        [-0.2794155 ,  0.96017027,  0.059964  ,  0.99820054],
        [ 0.6569866 ,  0.75390226,  0.06994285,  0.997551  ],
        [ 0.98935825, -0.14550003,  0.0799147 ,  0.99680173],
        [ 0.4121185 , -0.91113025,  0.08987855,  0.9959527 ]]],      dtype=float32)

In [18]:
class PositionalEncoding(nn.Module):
    embed_dim: int
    max_len: int = 2048 # Max sequence length anticipate

    @nn.compact
    def __call__(self, seq_len: int):
        # Returns PE for the actual seq_len: (1, seq_len, embed_dim)
        full_pe = self.param('pe', lambda: sinusoidal_init(self.max_len, self.embed_dim))
        # Ensure requires_grad is False if using self.param for fixed PE
        # Or just compute directly if truly fixed:
        # full_pe = sinusoidal_init(self.max_len, self.embed_dim)
        return jax.lax.dynamic_slice_in_dim(full_pe, 0, seq_len, axis=1)

In [19]:
class MLPBlock(nn.Module):
  hidden_dim: int
  output_dim: int
  dropout_rate: float = 0.1

  @nn.compact
  def __call__(self, x, *, deterministic: bool):
    x = nn.Dense(features=self.hidden_dim)(x)
    x = nn.relu(x)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
    x = nn.Dense(features=self.output_dim)(x)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
    return x

In [20]:
class EncoderBlock(nn.Module):
  embed_dim: int
  num_heads: int
  mlp_dim: int
  dropout_rate: float = 0.1
  max_len: int = 2048 # Needed for PE inside

  @nn.compact
  def __call__(self, x, mask: Optional[jnp.ndarray] = None, *, deterministic: bool):
    seq_len = x.shape[1]
    pos_encoding_layer = PositionalEncoding(embed_dim=self.embed_dim, max_len=self.max_len)
    # Apply positional encoding *inside* the block
    x = x + pos_encoding_layer(seq_len) # Add PE to input 'x'

    attn_output = nn.SelfAttention(
        num_heads=self.num_heads,
        qkv_features=self.embed_dim,
        dropout_rate=self.dropout_rate,
        deterministic=deterministic
    )(inputs_q=x, mask=mask)

    x_attn = x + attn_output
    x_attn_norm = nn.LayerNorm()(x_attn)

    mlp_output = MLPBlock(
        hidden_dim=self.mlp_dim,
        output_dim=self.embed_dim,
        dropout_rate=self.dropout_rate
    )(x_attn_norm, deterministic=deterministic)

    x_mlp = x_attn_norm + mlp_output
    output = nn.LayerNorm()(x_mlp)

    return output

In [21]:
class TransformerEncoder(nn.Module):
    num_layers: int
    vocab_size: int # e.g., 21 for amino acids + padding
    embed_dim: int
    num_heads: int
    mlp_dim: int
    max_len: int = 2048
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x_tokens, padding_mask: Optional[jnp.ndarray] = None, *, train: bool):
        # 1. Input Embedding
        embed_layer = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim)
        x = embed_layer(x_tokens) # Shape: (batch, seq_len, embed_dim)

        # Apply embedding dropout
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)

        # 2. Encoder Stack
        for _ in range(self.num_layers):
             # Note: This EncoderBlock internally adds PE
             block = EncoderBlock(embed_dim=self.embed_dim,
                                  num_heads=self.num_heads,
                                  mlp_dim=self.mlp_dim,
                                  dropout_rate=self.dropout_rate,
                                  max_len=self.max_len)
             x = block(x, mask=padding_mask, deterministic=not train) # `mask` needs correct format for SelfAttention

        # 3. Final LayerNorm (Optional, sometimes applied)
        # x = nn.LayerNorm()(x)

        return x

In [22]:
key = jax.random.PRNGKey(2)
batch_size = 2
seq_len = 50
vocab_size = 21
embed_dim = 64
num_heads = 2
mlp_dim = 128
num_layers = 3

model = TransformerEncoder(num_layers=num_layers, vocab_size=vocab_size, embed_dim=embed_dim,
                           num_heads=num_heads, mlp_dim=mlp_dim)

# Dummy input tokens (integers)
dummy_tokens = jax.random.randint(key, (batch_size, seq_len), 0, vocab_size)

# Initialize and apply (training mode)
variables = model.init(key, dummy_tokens, train=False) # Use train=False for init determinism
output_embeddings = model.apply(variables, dummy_tokens, train=True) # Use train=True for apply

print(f"Input tokens shape: {dummy_tokens.shape}")
print(f"Output embeddings shape: {output_embeddings.shape}")

# The permutation invariance issue would be tested by actually shuffling
# dummy_tokens along the seq_len dimension and observing the output.

TypeError: PositionalEncoding.__call__.<locals>.<lambda>() takes 0 positional arguments but 1 was given

In [23]:
import jax
import jax.numpy as jnp
import flax.linen as nn
# Using jax.ops for segment_sum is common, but let's use the newer jax.ops.segment_sum
# which is equivalent to jax.ops.index_add with x=data at indices=indices
# For clarity, we'll use the explicit jax.ops.segment_sum style if needed,
# or illustrate with index operations directly. Let's use segment_sum.
# Note: As of newer JAX versions, prefer scatter operations.
# Let's use the scatter approach for modernity.
# from jax.ops import segment_sum # Older style
from jax import ops # For scatter add

In [24]:
class SimpleGNNLayer(nn.Module):
    out_dim: int
    activation: callable = nn.relu

    @nn.compact
    def __call__(self, node_features: jnp.ndarray, senders: jnp.ndarray, receivers: jnp.ndarray):
        """
        Applies a simple GNN layer.

        Args:
          node_features: Node features. Shape: (num_nodes, in_dim)
          senders: Sender indices for each edge. Shape: (num_edges,)
          receivers: Receiver indices for each edge. Shape: (num_edges,)

        Returns:
          Updated node features. Shape: (num_nodes, out_dim)
        """
        num_nodes = node_features.shape[0]
        in_dim = node_features.shape[-1]

        # 1. Transform node features (linear layer for messages)
        message_features = nn.Dense(features=self.out_dim)(node_features)

        # 2. Gather features from sending nodes for each edge
        sender_features = message_features[senders] # Shape: (num_edges, out_dim)

        # 3. Aggregate messages for each receiving node using sum
        # We sum all incoming messages for each node 'i'.
        # Use jax.ops.scatter_add: adds values into updates at indices.
        # Initialize aggregated messages to zeros
        aggregated_messages = jnp.zeros((num_nodes, self.out_dim), dtype=sender_features.dtype)
        # Add features from senders to their corresponding receivers
        # For each edge (sender -> receiver), add sender_feature to aggregated_messages[receiver]
        aggregated_messages = aggregated_messages.at[receivers].add(sender_features)
        # Alternative using older segment_sum:
        # aggregated_messages = ops.segment_sum(sender_features, receivers, num_segments=num_nodes)


        # 4. Update node representation using aggregated messages
        # (Note: A common update also includes the node's previous features, e.g., via sum or concat,
        # but this simplified version uses only aggregated messages for the update basis)
        new_node_features = self.activation(aggregated_messages)

        return new_node_features

In [25]:
# --- Example Graph Data & Usage ---
key = jax.random.PRNGKey(4)
num_nodes = 10
num_edges = 30
in_dim = 16
out_dim = 32

# Dummy node features
node_features = jax.random.normal(key, (num_nodes, in_dim))

# Dummy graph structure (edge list)
# Ensure some nodes have high degree, others low degree
senders = jax.random.randint(key, (num_edges,), 0, num_nodes)
receivers = jax.random.randint(key, (num_edges,), 0, num_nodes)

# Make node 0 a high-degree node (receives many messages)
high_degree_node_idx = 0
num_extra_edges = 20
extra_senders = jax.random.randint(key, (num_extra_edges,), 1, num_nodes) # Avoid self-loops for simplicity
extra_receivers = jnp.full((num_extra_edges,), high_degree_node_idx)

senders = jnp.concatenate([senders, extra_senders])
receivers = jnp.concatenate([receivers, extra_receivers])
num_edges = senders.shape[0]

print(f"Total nodes: {num_nodes}")
print(f"Total edges: {num_edges}")
# Calculate in-degrees to show variance
in_degrees = jnp.zeros((num_nodes,)).at[receivers].add(1)
print(f"Node in-degrees: {in_degrees}")
print(f"Max in-degree: {jnp.max(in_degrees)}")

# Initialize and apply the GNN layer
gnn_layer = SimpleGNNLayer(out_dim=out_dim)
variables = gnn_layer.init(key, node_features, senders, receivers)

# Forward pass
# Wrap in a simple function to simulate potential gradient calculation context
@jax.jit
def apply_layer(params, features, senders, receivers):
    return gnn_layer.apply({'params': params}, features, senders, receivers)

output_features = apply_layer(variables['params'], node_features, senders, receivers)

print(f"Input features shape: {node_features.shape}")
print(f"Output features shape: {output_features.shape}")
print(f"Sample output features (node 0): {output_features[0, :5]}...")
print(f"Output feature norm (node 0): {jnp.linalg.norm(output_features[0])}")
print(f"Output feature norm (node 1): {jnp.linalg.norm(output_features[1])}") # Likely lower degree

# In a real training loop, large norms here could lead to NaNs after loss/gradients.

Total nodes: 10
Total edges: 50
Node in-degrees: [26.  0.  3.  4.  3.  3.  2.  2.  2.  5.]
Max in-degree: 26.0
Input features shape: (10, 16)
Output features shape: (10, 32)
Sample output features (node 0): [ 2.2878036 19.458845   3.3662817  0.        18.386833 ]...
Output feature norm (node 0): 45.07743453979492
Output feature norm (node 1): 0.0


In [4]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import Sequence

class SimpleRNNClassifier(nn.Module):
    """A simple RNN for sequence classification."""
    num_classes: int
    rnn_hidden_size: int
    embedding_dim: int
    vocab_size: int

    @nn.compact
    def __call__(self, inputs: jnp.ndarray, train: bool = True):
        """
        Args:
            inputs: Batch of input token sequences, shape (batch_size, seq_length).
            train: Boolean indicating if the model is in training mode.

        Returns:
            Output logits for classification, shape (batch_size, num_classes).
        """
        # 1. Embed the input tokens
        embed = nn.Embed(num_embeddings=self.vocab_size, features=self.embedding_dim)
        embedded_inputs = embed(inputs)
        # embedded_inputs shape: (batch_size, seq_length, embedding_dim)

        # 2. Process sequence with RNN
        # Initialize RNN cell
        rnn_cell = nn.LSTMCell(features=self.rnn_hidden_size, name="lstm_cell")

        # Initialize hidden state
        batch_size = inputs.shape[0]
        initial_carry = rnn_cell.initialize_carry(jax.random.PRNGKey(0), (batch_size,))

        # Scan RNN cell over sequence
        # Use nn.scan for efficient processing over the sequence length axis
        scan_rnn = nn.scan(
            rnn_cell,
            variable_broadcast="params",
            split_rngs={"params": False},
            in_axes=1,  # Scan over seq_length
            out_axes=1
        )
        final_carry, outputs = scan_rnn(initial_carry, embedded_inputs)
        # outputs shape: (batch_size, seq_length, rnn_hidden_size)
        # final_carry is a tuple (hidden_state, cell_state)
        # final_hidden_state shape: (batch_size, rnn_hidden_size)

        final_hidden_state = final_carry[0] # Use the final hidden state for classification

        # 3. Classification Head
        # Dense layer to map final hidden state to number of classes
        output_logits = nn.Dense(features=self.num_classes, name="output_dense")(final_hidden_state)
        # output_logits shape: (batch_size, num_classes)

        # Apply activation function
        output_probs = nn.relu(output_logits)

        return output_probs # Should represent class scores/probabilities

# --- Example Usage ---
# Assume some dummy data
batch_size = 4
seq_length = 10
vocab_size = 100
embedding_dim = 32
rnn_hidden_size = 64
num_classes = 5

key = jax.random.PRNGKey(1)
dummy_inputs = jax.random.randint(key, (batch_size, seq_length), 0, vocab_size)

model = SimpleRNNClassifier(
    num_classes=num_classes,
    rnn_hidden_size=rnn_hidden_size,
    embedding_dim=embedding_dim,
    vocab_size=vocab_size
)

# Initialize model parameters
variables = model.init(key, dummy_inputs, train=False)

# Forward pass
output = model.apply(variables, dummy_inputs, train=False)

print("Input shape:", dummy_inputs.shape)
print("Output shape:", output.shape)
# Expected output shape: (batch_size, num_classes)
# Output values should be suitable for cross-entropy loss

# --- Training Loss (Conceptual) ---
# Assume 'dummy_labels' are integer class indices, shape (batch_size,)
# dummy_labels = jnp.array([1, 0, 4, 2])

# def cross_entropy_loss(outputs, labels):
#     one_hot_labels = jax.nn.one_hot(labels, num_classes=num_classes)
#     # Standard cross-entropy often expects logits or log-probabilities
#     # For instance, optax.softmax_cross_entropy expects logits
#     # loss = -jnp.sum(one_hot_labels * jnp.log(outputs + 1e-7), axis=-1) # Manual CE assuming probabilities
#     # Or using a library function (e.g., from optax)
#     # import optax
#     # loss = optax.softmax_cross_entropy(logits=???, labels=one_hot_labels)
#     return jnp.mean(loss)

# # Problem: Why might training using standard cross-entropy loss on 'output' fail?

TransformTargetError: Linen transformations must be applied to Modules classes or functions taking a Module instance as the first argument. The provided target is not a Module class or callable: LSTMCell(
    # attributes
    features = 64
    gate_fn = sigmoid
    activation_fn = tanh
    kernel_init = init
    recurrent_kernel_init = init
    bias_init = zeros
    dtype = None
    param_dtype = float32
    carry_init = zeros
) (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.TransformTargetError)