In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import Tuple

# -----------------------------
# Utilities
# -----------------------------
def glorot_init(key, shape, dtype=jnp.float32):
    fan_in, fan_out = shape[0], shape[1]
    limit = jnp.sqrt(6.0 / (fan_in + fan_out))
    return jax.random.uniform(key, shape, dtype, -limit, limit)

# -----------------------------
# Simulated data
# -----------------------------
def simulate_data(
    num_individuals: int,
    num_markers: int,
    key: jax.random.PRNGKey
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    key, subkey1, subkey2 = jax.random.split(key, 3)
    genotypes = jax.random.randint(
        subkey1,
        shape=(num_individuals, num_markers * 2),
        minval=0,
        maxval=2
    ).astype(jnp.float32)
    phenotypes = jax.random.normal(subkey2, shape=(num_individuals,)).astype(jnp.float32)
    return genotypes, phenotypes

# -----------------------------
# Encoder (creates params outside scan)
# -----------------------------
class Encoder(nn.Module):
    d_model: int
    num_heads: int
    num_layers: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, train: bool):
        # initial projection to model dim
        x = nn.Dense(self.d_model)(x)
        for _ in range(self.num_layers):
            x_norm = nn.LayerNorm()(x)
            # Use SelfAttention which creates params here (but this call is outside scan)
            x_attn = nn.SelfAttention(num_heads=self.num_heads, qkv_features=self.d_model)(x_norm)
            x = x + nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x_attn)

            x_norm = nn.LayerNorm()(x)
            x_ffn = nn.Dense(self.d_model * 4)(x_norm)
            x_ffn = nn.relu(x_ffn)
            x_ffn = nn.Dense(self.d_model)(x_ffn)
            x = x + nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x_ffn)

        return nn.LayerNorm()(x)

# -----------------------------
# Pointer Network (no param creation in scan)
# -----------------------------
class PointerNetwork(nn.Module):
    d_model: int
    num_heads: int
    num_encoder_layers: int

    def setup(self):
        # Encoder is a module; calling encoder will create params during init/apply outside scan.
        self.encoder = Encoder(d_model=self.d_model, num_heads=self.num_heads, num_layers=self.num_encoder_layers)

        # Create projection matrices as explicit params in setup so they are available before scan.
        # Shapes: Wk: [d_model, d_model], Wq: [d_model, d_model]
        # We'll initialize them here with zeros placeholders; actual initial values are provided via self.param()
        # Use lambda initializers that accept rng.
        self.Wk = self.param("Wk", lambda k, s=(self.d_model, self.d_model): glorot_init(k, s))
        self.Wq = self.param("Wq", lambda k, s=(self.d_model, self.d_model): glorot_init(k, s))

        # Optionally a bias for query or keys (not necessary but kept for parity)
        self.bk = self.param("bk", lambda k, s=(self.d_model,): jnp.zeros(s))
        self.bq = self.param("bq", lambda k, s=(self.d_model,): jnp.zeros(s))

    def __call__(self, genotypes: jnp.ndarray, phenotypes: jnp.ndarray, train: bool):
        # Combine genotype + phenotype
        phenotypes_reshaped = jnp.expand_dims(phenotypes, axis=-1)
        inputs = jnp.concatenate([genotypes, phenotypes_reshaped], axis=-1)  # [N, feat]

        # Run encoder (this creates/uses encoder params). This happens outside scan.
        encoder_outputs = self.encoder(inputs, train)  # shape [N, d_model]

        # Precompute normalized keys & key projections outside scan
        # We'll use a simple LayerNorm over encoder outputs (the layernorm in encoder already applied; but we can reuse)
        # Use the stored Wk to compute key projections once:
        # encoder_outputs: [N, d_model], Wk: [d_model, d_model]
        key_proj = jnp.dot(encoder_outputs, self.Wk) + self.bk  # [N, d_model]

        # We'll use Wq for projecting the last embedding inside scan; Wq is a DeviceArray (static parameter).
        # No Flax param creation happens inside scan now.

        num_individuals = encoder_outputs.shape[0]

        # initial selection: pick index 0 (deterministic demo)
        initial_idx = jnp.array(0, dtype=jnp.int32)
        # selected array: -1 for empty slots; we will fill positions 0..N-1 with selected indices
        selected_array = -jnp.ones((num_individuals,), dtype=jnp.int32).at[0].set(initial_idx)

        # boolean mask indicating which indices are already selected
        mask = jnp.zeros((num_individuals,), dtype=jnp.bool_).at[0].set(True)

        # carry: (last_idx, selected_array, mask, step)
        initial_carry = (initial_idx, selected_array, mask, jnp.array(1, dtype=jnp.int32))

        def step_fn(carry, _):
            last_idx, selected_array, mask, step = carry

            # get last embedding (safe traced indexing via jnp.take)
            last_emb = jnp.take(encoder_outputs, last_idx, axis=0)  # [d_model]

            # Project query using pre-created Wq (no param creation here)
            q_proj = jnp.dot(last_emb, self.Wq) + self.bq  # [d_model]

            # compute scores: dot(q_proj, key_proj[i]) for each i
            # key_proj: [N, d_model], q_proj: [d_model] -> scores: [N]
            scores = jnp.einsum("d,nd->n", q_proj, key_proj) / jnp.sqrt(self.d_model)

            # mask already selected
            scores = jnp.where(mask, -1e9, scores)

            # greedy pick
            next_idx = jnp.argmax(scores).astype(jnp.int32)

            # write into selected array at position `step`
            selected_array = selected_array.at[step].set(next_idx)

            # update mask
            mask = mask.at[next_idx].set(True)

            next_step = step + 1
            new_carry = (next_idx, selected_array, mask, next_step)
            return new_carry, None

        # run scan for remaining selections (we already set initial index)
        num_steps = num_individuals - 1
        final_carry, _ = jax.lax.scan(step_fn, initial_carry, None, length=num_steps)

        return final_carry[1]  # selected_array

# -----------------------------
# Example usage
# -----------------------------
if __name__ == "__main__":
    key = jax.random.PRNGKey(0)
    num_individuals = 12   # small demo
    num_markers = 8
    genotypes, phenotypes = simulate_data(num_individuals, num_markers, key)

    key, model_key, dropout_key = jax.random.split(key, 3)
    d_model = 64
    num_heads = 4
    num_encoder_layers = 2

    model = PointerNetwork(d_model=d_model, num_heads=num_heads, num_encoder_layers=num_encoder_layers)

    print("Initializing model parameters...")
    variables = model.init({"params": model_key}, genotypes, phenotypes, train=False)
    params = variables["params"]
    print("Initialization complete.\n")

    print("Running forward pass...")
    # forward; dropout not used here so no rngs necessary
    selected_sequence = model.apply({"params": params}, genotypes, phenotypes, train=False)
    print("Forward pass complete.\n")

    print("Selected sequence of individuals (indices):")
    print(selected_sequence)

    seq_list = list(map(int, jnp.array(selected_sequence).tolist()))
    print("\nIs the sequence a permutation (no repeated indices)?", len(set(seq_list)) == len(seq_list))


In [4]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
from typing import Sequence

# --- 1. Simulate Input Data ---
def simulate_data(num_individuals, num_markers, batch_size):
    """
    Simulates a batch of 0/1 coded vectors.

    Args:
        num_individuals: The number of individuals in the input sequence.
        num_markers: The feature dimension for each individual.
        batch_size: The number of samples in the batch.

    Returns:
        A jnp.ndarray of shape (batch_size, num_individuals, num_markers).
    """
    return jnp.array(np.random.randint(2, size=(batch_size, num_individuals, num_markers)))

# --- 2. Pointer Network with Transformer Encoder ---

class Encoder(nn.Module):
    """A simple transformer encoder."""
    num_heads: int
    emb_dim: int

    @nn.compact
    def __call__(self, x, mask=None):
        # The nn.MultiHeadDotProductAttention layer is a standard implementation
        # of multi-head attention, a key component of the transformer architecture.
        x = nn.MultiHeadDotProductAttention(num_heads=self.num_heads, qkv_features=self.emb_dim)(x, mask=mask)

        # We apply layer normalization before the feedforward network, which is a common
        # practice in transformer architectures (pre-LN).
        x = nn.LayerNorm()(x)

        # The feedforward network consists of two dense layers with a GELU activation in between.
        # This is a standard component of the transformer architecture.
        y = nn.Dense(features=self.emb_dim * 4)(x)
        y = nn.gelu(y)
        y = nn.Dense(features=self.emb_dim)(y)

        # A residual connection is used to add the output of the feedforward network to the input.
        return x + y

class PointerNetwork(nn.Module):
    """A Pointer Network with a Transformer Encoder."""
    num_heads: int
    emb_dim: int
    num_layers: int

    @nn.compact
    def __call__(self, x, deterministic=True):
        # --- Encoder ---
        # The input is first passed through a dense layer to project it to the embedding dimension.
        x = nn.Dense(features=self.emb_dim, name="input_embedding")(x)

        # The encoded input is then passed through a series of transformer encoder layers.
        for _ in range(self.num_layers):
            x = Encoder(num_heads=self.num_heads, emb_dim=self.emb_dim)(x)

        # --- Decoder ---
        # The decoder uses a similar architecture to the encoder, with a final dense layer
        # to produce the output logits.
        logits = nn.Dense(features=x.shape[-1], name="output_logits")(x)

        # The output logits are then passed through a softmax function to produce a probability
        # distribution over the input individuals.
        return nn.log_softmax(logits, axis=-1)


# --- 3. Sample Run ---

# --- Hyperparameters ---
NUM_INDIVIDUALS = 10
NUM_MARKERS = 50
EMB_DIM = 128
NUM_HEADS = 4
NUM_LAYERS = 3
BATCH_SIZE = 4
KEY = jax.random.PRNGKey(0)

# --- Initialize Model ---
model = PointerNetwork(num_heads=NUM_HEADS, emb_dim=EMB_DIM, num_layers=NUM_LAYERS)
params = model.init(KEY, jnp.ones((BATCH_SIZE, NUM_INDIVIDUALS, NUM_MARKERS)))["params"]

# --- Simulate Data and Run Model ---
input_data = simulate_data(NUM_INDIVIDUALS, NUM_MARKERS, BATCH_SIZE)
output_log_probs = model.apply({"params": params}, input_data)

# --- Get Predicted Sequence ---
predicted_sequence = jnp.argmax(output_log_probs, axis=-1)

# --- Print Results ---
print("--- Input Data Shape ---")
print(input_data.shape)
print("\n--- Output Log Probs Shape ---")
print(output_log_probs.shape)
print("\n--- Sample Input (first 5 markers of first individual in batch) ---")
print(input_data[0, 0, :5])
print("\n--- Predicted Sequence (for first sample in batch) ---")
print(predicted_sequence[0])

--- Input Data Shape ---
(4, 10, 50)

--- Output Log Probs Shape ---
(4, 10, 128)

--- Sample Input (first 5 markers of first individual in batch) ---
[0 0 0 1 0]

--- Predicted Sequence (for first sample in batch) ---
[65 65 65 65 65 65 65 65 65 65]


In [5]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
from typing import Sequence

# --- 1. Simulate Input Data ---
def simulate_data(num_individuals, num_markers, batch_size):
    """
    Simulates a batch of 0/1 coded vectors.

    Args:
        num_individuals: The number of individuals in the input sequence.
        num_markers: The feature dimension for each individual.
        batch_size: The number of samples in the batch.

    Returns:
        A jnp.ndarray of shape (batch_size, num_individuals, num_markers).
    """
    return jnp.array(np.random.randint(2, size=(batch_size, num_individuals, num_markers)))

# --- 2. Pointer Network with Transformer Encoder ---

class Encoder(nn.Module):
    """A simple transformer encoder."""
    num_heads: int
    emb_dim: int

    @nn.compact
    def __call__(self, x, mask=None):
        x = nn.MultiHeadDotProductAttention(num_heads=self.num_heads, qkv_features=self.emb_dim)(x, mask=mask)
        x = nn.LayerNorm()(x)
        y = nn.Dense(features=self.emb_dim * 4)(x)
        y = nn.gelu(y)
        y = nn.Dense(features=self.emb_dim)(y)
        return x + y

class PointerNetwork(nn.Module):
    """A Pointer Network with a Transformer Encoder."""
    num_heads: int
    emb_dim: int
    num_layers: int

    @nn.compact
    def __call__(self, x, deterministic=True):
        # --- Encoder ---
        x = nn.Dense(features=self.emb_dim, name="input_embedding")(x)

        for _ in range(self.num_layers):
            x = Encoder(num_heads=self.num_heads, emb_dim=self.emb_dim)(x)

        # --- Decoder (Corrected Pointer Mechanism) ---
        # The original code had a Dense layer that outputted a vector of size `emb_dim`.
        # The corrected version uses matrix multiplication to create scores for each
        # input individual, effectively "pointing" to them.
        # This results in logits with the correct shape: (batch_size, num_individuals, num_individuals)
        logits = jnp.matmul(x, x.transpose((0, 2, 1)))

        return nn.log_softmax(logits, axis=-1)


# --- 3. Sample Run ---

# --- Hyperparameters ---
NUM_INDIVIDUALS = 10
NUM_MARKERS = 50
EMB_DIM = 128
NUM_HEADS = 4
NUM_LAYERS = 3
BATCH_SIZE = 4
KEY = jax.random.PRNGKey(0)

# --- Initialize Model ---
model = PointerNetwork(num_heads=NUM_HEADS, emb_dim=EMB_DIM, num_layers=NUM_LAYERS)
# Pass the correct shape to the init method
init_data = jnp.ones((BATCH_SIZE, NUM_INDIVIDUALS, NUM_MARKERS))
params = model.init(KEY, init_data)["params"]

# --- Simulate Data and Run Model ---
input_data = simulate_data(NUM_INDIVIDUALS, NUM_MARKERS, BATCH_SIZE)
output_log_probs = model.apply({"params": params}, input_data)

# --- Get Predicted Sequence ---
predicted_sequence = jnp.argmax(output_log_probs, axis=-1)

# --- Print Results ---
print("--- Input Data Shape ---")
print(input_data.shape)
print("\n--- Output Log Probs Shape (Corrected) ---")
print(output_log_probs.shape)
print("\n--- Sample Input (first 5 markers of first individual in batch) ---")
print(input_data[0, 0, :5])
print("\n--- Predicted Sequence (for first sample in batch) ---")
print(predicted_sequence[0])

--- Input Data Shape ---
(4, 10, 50)

--- Output Log Probs Shape (Corrected) ---
(4, 10, 10)

--- Sample Input (first 5 markers of first individual in batch) ---
[1 1 0 0 0]

--- Predicted Sequence (for first sample in batch) ---
[0 0 7 4 5 4 0 4 4 3]
