In [2]:
from typing import Callable, Optional, Tuple # 用于类型注解，提高代码可读性和静态检查
import flax
import flax.linen as nn # flax.linen 是其模块化 API，类似于 PyTorch 的 nn.Module
import jax
import jax.numpy as jnp # jax.numpy 是其 NumPy 接口，用于张量操作
import numpy as np # 标准的数值计算库
# Flax 中的 FrozenDict 是不可变的嵌套字典，用于存储模型参数和状态；freeze/unfreeze 用于参数的冻结和解冻
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
# Flax 中实现的标准点积注意力机制，用于计算注意力权重
from flax.linen.attention import dot_product_attention_weights
# 用于将嵌套结构的参数字典展开/重建，方便操作模型参数。
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax # JAX 中的底层操作接口，支持控制流、跨设备通信等
# HuggingFace 提供的 Flax 模型输出类，用于组织模型的返回结果，适配不同任务（如分类、问答、填空等）。
from transformers.modeling_flax_outputs import (
    FlaxBaseModelOutput,
    FlaxBaseModelOutputWithPooling,
    FlaxMaskedLMOutput,
    FlaxMultipleChoiceModelOutput,
    FlaxQuestionAnsweringModelOutput,
    FlaxSequenceClassifierOutput,
    FlaxTokenClassifierOutput,
)
from transformers.modeling_flax_utils import (
    ACT2FN, # 激活函数名称到实际函数的映射，如 "gelu" -> nn.gelu
    FlaxPreTrainedModel,  # 所有Flax模型的基类，实现了权重加载/保存、配置管理等功能
    append_call_sample_docstring,
    append_replace_return_docstrings,
    overwrite_call_docstring, # 用于自动生成文档字符串（docstring）的方法装饰器
)
 # 所有输出类的基类 用于添加文档开头的说明字符串 Transformers 的日志工具
from transformers.utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
# HuggingFace 中 ALBERT 模型的配置类，定义超参数（如层数、隐藏维度等）
from transformers.models.albert.configuration_albert import AlbertConfig

In [3]:
logger = logging.get_logger(__name__) # 日志对象
_CHECKPOINT_FOR_DOC = "albert/albert-base-v2"
_CONFIG_FOR_DOC = "AlbertConfig"

In [4]:
# 使用 flax.struct.dataclass 定义结构化输出，用于保存模型前向传播的多项结果
@flax.struct.dataclass
class FlaxAlbertForPreTrainingOutput(ModelOutput):
    """
    用于 ALBERT 预训练任务（MLM + SOP）的模型输出结构。

    设计意图：
    - 统一模型输出格式，便于兼容 HuggingFace 的下游任务框架。
    - 使用 `flax.struct.dataclass` 生成不可变对象，适合 JAX/Flax 的纯函数式设计。
    - 继承 ModelOutput，允许按属性名或字典方式访问。

    包含字段：
    - prediction_logits：MLM 输出，表示词表上每个 token 的预测 logits。
    - sop_logits：SOP 输出，判断两段文本是否连续。
    - hidden_states：可选，返回所有层的中间表示（含 embedding）。
    - attentions：可选，返回每层的注意力权重。
    """
    # (batch_size, seq_len, vocab_size)，MLM 模型头的输出 logits，未经过 softmax
    prediction_logits: jnp.ndarray = None
    # (batch_size, 2)，SOP 任务的输出 logits，用于判断两个句子是否连续
    sop_logits: jnp.ndarray = None
    # 所有层的中间隐藏状态，包括 embedding 层，可选返回，用于调试或分析
    hidden_states: Optional[Tuple[jnp.ndarray]] = None
    # 所有层的注意力权重（softmax 后），可选返回，用于可视化或解释模型行为
    attentions: Optional[Tuple[jnp.ndarray]] = None

In [5]:
ALBERT_START_DOCSTRING = r"""

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

    This model is also a
    [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
    a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
    behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    Parameters:
        config ([`AlbertConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).

            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
            specified all the computation will be performed with the given `dtype`.

            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.**

            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].
"""

ALBERT_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`numpy.ndarray` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.

"""

In [7]:
# 构建 ALBERT 的输入嵌入，包括词嵌入、位置嵌入、segment嵌入（token_type）
class FlaxAlbertEmbeddings(nn.Module):
    
    config: AlbertConfig # 模型配置类，定义嵌入维度、词表大小等
    dtype: jnp.dtype = jnp.float32  # 设置计算精度

    def setup(self):
         # 词嵌入层：将 token id 映射到 embedding 向量
        self.word_embeddings = nn.Embed(
            self.config.vocab_size,  # 词表大小
            self.config.embedding_size,  # 嵌入维度（注意：ALBERT 中此处小于 hidden_size）
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
        )
         # 位置嵌入层：为每个 token 提供位置信息
        self.position_embeddings = nn.Embed(
            self.config.max_position_embeddings,  # 最大支持的位置数
            self.config.embedding_size,  # 同样使用 embedding_size
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
        )
        # token_type 嵌入层：通常用于区分句子 A/B（segment embedding）
        self.token_type_embeddings = nn.Embed(
            self.config.type_vocab_size,  # 通常为 2，表示句子对
            self.config.embedding_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
        )
        # 使用 LayerNorm 正则化嵌入结果（提升收敛与稳定性）
        # nn.LayerNorm 是一种无状态的归一化方法，它不像 BatchNorm 那样依赖均值与方差的滑动统计（moving averages）。
        # LayerNorm 在训练和评估时行为完全一致，始终使用当前输入样本本身的均值和方差，不依赖也不保存任何全局统计量。
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # Dropout 作为正则化，防止过拟合
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

    def __call__(self, input_ids, token_type_ids, position_ids, deterministic: bool = True):
        # 获取词嵌入（int32 类型以匹配 nn.Embed 的要求）
        inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
        # 获取位置嵌入
        position_embeds = self.position_embeddings(position_ids.astype("i4"))
        # 获取 token_type（segment）嵌入
        token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))

        # 将三者逐元素相加，作为最终输入表示（符合 BERT/ALBERT 的输入设计）
        hidden_states = inputs_embeds + token_type_embeddings + position_embeds

        # 应用 LayerNorm 与 Dropout
        hidden_states = self.LayerNorm(hidden_states)
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        return hidden_states

In [None]:
# "i4" 是 JAX/NumPy 中的一种 数据类型标识符，表示：
# int32，即 4 字节（32 位）有符号整数。
# 为什么用 "i4"？
# 在 JAX 中，nn.Embed 要求输入必须是整数索引类型，一般是 int32 或 int64。使用 "i4" 可以确保：
# input_ids.astype("i4")
# 转换后的张量是 int32 类型，符合 nn.Embed 要求，避免类型报错。
# 这段 FlaxAlbertEmbeddings 代码中，嵌入层、LayerNorm、Dropout 都是单设备、单块运行，未使用任何分片策略。
# 如果需要实现分布式训练或分片，可以在模型设计或训练代码中添加相关分片约束和多设备映射

In [81]:
_CHECKPOINT_FOR_DOC='albert-xxlarge-v2'

In [82]:
config = AlbertConfig.from_pretrained(_CHECKPOINT_FOR_DOC)

In [83]:
config

AlbertConfig {
  "architectures": [
    "AlbertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0,
  "bos_token_id": 2,
  "classifier_dropout_prob": 0.1,
  "down_scale_factor": 1,
  "embedding_size": 128,
  "eos_token_id": 3,
  "gap_size": 0,
  "hidden_act": "gelu_new",
  "hidden_dropout_prob": 0,
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "inner_group_num": 1,
  "intermediate_size": 16384,
  "layer_norm_eps": 1e-12,
  "layers_to_keep": [],
  "max_position_embeddings": 512,
  "model_type": "albert",
  "net_structure_type": 0,
  "num_attention_heads": 64,
  "num_hidden_groups": 1,
  "num_hidden_layers": 12,
  "num_memory_blocks": 0,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.51.3",
  "type_vocab_size": 2,
  "vocab_size": 30000
}

In [84]:
embeddings=FlaxAlbertEmbeddings(config)

In [85]:
from transformers import AlbertTokenizer
tokenizer = AlbertTokenizer.from_pretrained(_CHECKPOINT_FOR_DOC)

In [86]:
# 模拟输入
text = "Deep learning changes the world"
# 编码文本，返回 NumPy 格式（Flax 使用）
x = tokenizer(text, return_tensors="np")
# 构造 position_ids
seq_length = x["input_ids"].shape[1]
x["position_ids"] = np.arange(seq_length)[None, :]  # shape (1, seq_length)

In [87]:
x.pop('attention_mask')

array([[1, 1, 1, 1, 1, 1, 1]])

In [88]:
x

{'input_ids': array([[   2,  855, 2477, 1693,   14,  126,    3]]), 'token_type_ids': array([[0, 0, 0, 0, 0, 0, 0]]), 'position_ids': array([[0, 1, 2, 3, 4, 5, 6]])}

In [89]:
variables = embeddings.init(jax.random.key(0),**x)

In [90]:
jax.tree_util.tree_map(jnp.shape, variables)

{'params': {'LayerNorm': {'bias': (128,), 'scale': (128,)},
  'position_embeddings': {'embedding': (512, 128)},
  'token_type_embeddings': {'embedding': (2, 128)},
  'word_embeddings': {'embedding': (30000, 128)}}}

In [33]:
params = variables['params']

In [38]:
config.hidden_dropout_prob

0

In [35]:
hidden_states = embeddings.apply(
  {'params': params}, # 参数
  **x, # 输入数据
  deterministic=False, 
  rngs={'dropout': jax.random.key(seed=88)}
)

In [37]:
hidden_states.shape

(1, 7, 128)

In [41]:
config.layer_norm_eps

1e-12

In [42]:
jnp.finfo(jnp.float32).min

-3.4028235e+38

In [44]:
class FlaxAlbertSelfAttention(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32  # 计算精度

    def setup(self):
        # 确保 hidden_size 能被 num_attention_heads 整除，便于多头注意力切分
        if self.config.hidden_size % self.config.num_attention_heads != 0:
            raise ValueError(
                "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
                "                   : {self.config.num_attention_heads}"
            )
        self.query = nn.Dense(  # query 映射层
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        self.key = nn.Dense(  # key 映射层
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        self.value = nn.Dense(  # value 映射层
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        self.dense = nn.Dense( # 最后的输出映射层
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 残差连接后的 LayerNorm，稳定训练
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        # Dropout 层，防止过拟合
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

    def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
        # 计算每个注意力头的维度
        head_dim = self.config.hidden_size // self.config.num_attention_heads
        # 线性变换获得 Query，Key，Value 向量，并 reshape 为多头格式
        query_states = self.query(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        value_states = self.value(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        key_states = self.key(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )

        # attention_mask 转换为 bias，用于在softmax前遮蔽（mask）无效位置
        if attention_mask is not None:
            # attention mask in the form of attention bias
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) # 扩展维度匹配 attention 权重 shape
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype), # 有效位置为0
                jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), # 无效位置为负无穷
            )
        else:
            attention_bias = None
        # attention_bias
        # 先计算 query 和 key 的点积，得到注意力分数（logits）。
        # 将 attention mask 转换成偏置（大负数）加到这些分数上，让无效位置分数变得极小。
        # 对加了偏置的分数做 softmax，softmax 输出中无效位置的概率几乎为 0。
        # 最后用这个概率加权 value。
        # dropout 的随机数生成器，训练时激活，推理时关闭
        dropout_rng = None
        if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
            dropout_rng = self.make_rng("dropout")
        # 计算注意力权重（query 与 key 的点积注意力），并应用 dropout 和掩码
        attn_weights = dot_product_attention_weights(
            query_states, # [batch, seq_len, num_heads, head_dim] 经过线性变换的查询向量Q
            key_states, # 经过线性变换的键向量K
            bias=attention_bias, # 注意力偏置，将mask转换的大负数添加到对应位置，用于屏蔽无效token
            dropout_rng=dropout_rng,  # dropout随机种子，仅在训练且dropout概率>0时使用
            dropout_rate=self.config.attention_probs_dropout_prob,  # 注意力概率的dropout概率
            broadcast_dropout=True, # 是否对dropout进行广播（在heads维度广播）
            # 表示 Dropout 掩码在注意力 head 维度上进行广播，也就是：
            # ✅ 意图说明：
            # 所有注意力头（num_heads）共享同一个 dropout 掩码，
            deterministic=deterministic, # 是否为推理模式，推理时关闭dropout
            dtype=self.dtype,  # 计算数据类型（float32或float16等）
            precision=None,  # JAX的计算精度参数，None表示使用默认精度 默认值为 bfloat16（在支持的硬件上）或 float32。
        )
        # dtype 是数据的类型（变量的存储格式）
        # precision 是执行乘法运算时的数值精度控制参数
        # 计算加权的 value 向量作为多头注意力的输出，使用 einsum 实现多头维度对应乘积求和
        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
        # 将多头维度合并回 hidden_size 维度
        attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
        # 通过输出映射层并应用 dropout
        projected_attn_output = self.dense(attn_output)
        projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic)
        # 残差连接 + LayerNorm，保证训练稳定性和梯度流动
        layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states)
        # 根据是否需要返回 attention 权重，输出不同格式
        outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,)
        return outputs

In [46]:
config.hidden_size

4096

In [47]:
# 2. 构造输入
batch_size = 2
seq_length = 8
hidden_states = jnp.ones((batch_size, seq_length, config.hidden_size), dtype=jnp.float32)

In [48]:
hidden_states.shape

(2, 8, 4096)

In [49]:
attention_mask = jnp.ones((batch_size, seq_length), dtype=jnp.int32)  # 无屏蔽，全部为1

In [50]:
# 3. 构建模块实例
attn_module = FlaxAlbertSelfAttention(config=config, dtype=jnp.float32)

In [52]:
# 4. 初始化权重
key = jax.random.PRNGKey(0) # rng key

In [54]:
variables = attn_module.init(key, hidden_states, attention_mask, deterministic=True)

In [55]:
jax.tree.map(jnp.shape, variables)

{'params': {'LayerNorm': {'bias': (4096,), 'scale': (4096,)},
  'dense': {'bias': (4096,), 'kernel': (4096, 4096)},
  'key': {'bias': (4096,), 'kernel': (4096, 4096)},
  'query': {'bias': (4096,), 'kernel': (4096, 4096)},
  'value': {'bias': (4096,), 'kernel': (4096, 4096)}}}

In [56]:
# 5. 前向调用
outputs = attn_module.apply(variables, hidden_states, attention_mask, deterministic=True)

In [59]:
print(type(outputs),len(outputs))

<class 'tuple'> 1


In [60]:
attention_output = outputs[0]  # shape: (batch_size, seq_length, hidden_size)
# 打印结果形状
print("attention_output.shape:", attention_output.shape)

attention_output.shape: (2, 8, 4096)


In [61]:
class FlaxAlbertLayer(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32  # 模块中所有计算的数据类型（如 float32、float16 等）

    def setup(self):
        # 构建自注意力子层
        self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype)
         # 前馈网络第一层（扩大维度）
        self.ffn = nn.Dense(
            self.config.intermediate_size, # 通常为 hidden_size 的 4 倍
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        self.activation = ACT2FN[self.config.hidden_act] # 激活函数（如 gelu/relu/silu，取决于 config.hidden_act）
        self.ffn_output = nn.Dense( # 前馈网络第二层（降回 hidden_size）
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        # 层归一化（用于前馈网络残差连接后）
        self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) # Dropout用于正则化

    def __call__(
        self,
        hidden_states, # 输入张量 shape: (batch_size, seq_len, hidden_size)
        attention_mask,  # 注意力掩码，用于屏蔽 padding 等无效位置
        deterministic: bool = True,  # 控制 Dropout 是否启用
        output_attentions: bool = False, # 是否返回注意力权重
    ):
        # === 1. 自注意力子层 ===
        attention_outputs = self.attention(
            hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
        )
        attention_output = attention_outputs[0]  # shape: (batch_size, seq_len, hidden_size)
        # === 2. 前馈网络子层 ===
        ffn_output = self.ffn(attention_output)  # shape: (batch_size, seq_len, intermediate_size)
        ffn_output = self.activation(ffn_output) # 应用非线性激活函数
        ffn_output = self.ffn_output(ffn_output)  # shape: (batch_size, seq_len, hidden_size)
        ffn_output = self.dropout(ffn_output, deterministic=deterministic) # Dropout
        # === 3. 残差连接 + 层归一化 ===
        hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)

        outputs = (hidden_states,) # 返回更新后的隐藏状态
         # 可选：返回注意力权重
        if output_attentions:
            outputs += (attention_outputs[1],)
        return outputs

In [None]:
# 设计结构总结（标准 Transformer Block）：
# Self-Attention 子层（带残差连接与隐含 LayerNorm）。
# 前馈网络子层（两层线性 + 激活 + Dropout + 残差 + LayerNorm）。
# 保持输出维度与输入一致（以便堆叠）。

In [62]:
attn_layer = FlaxAlbertLayer(config=config, dtype=jnp.float32)

In [68]:
variables = attn_layer.init(key, hidden_states, attention_mask, deterministic=True)

In [69]:
jax.tree.map(jnp.shape, variables)

{'params': {'attention': {'LayerNorm': {'bias': (4096,), 'scale': (4096,)},
   'dense': {'bias': (4096,), 'kernel': (4096, 4096)},
   'key': {'bias': (4096,), 'kernel': (4096, 4096)},
   'query': {'bias': (4096,), 'kernel': (4096, 4096)},
   'value': {'bias': (4096,), 'kernel': (4096, 4096)}},
  'ffn': {'bias': (16384,), 'kernel': (4096, 16384)},
  'ffn_output': {'bias': (4096,), 'kernel': (16384, 4096)},
  'full_layer_layer_norm': {'bias': (4096,), 'scale': (4096,)}}}

In [70]:
outputs = attn_layer.apply(variables,
                            hidden_states,
                            attention_mask, 
                            deterministic=False,
                            rngs={'dropout': jax.random.key(seed=88)}
                            )

In [71]:
outputs[0].shape

(2, 8, 4096)

In [72]:
config.inner_group_num

1

In [73]:
class FlaxAlbertLayerCollection(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32   # 模块计算所用数据类型

    def setup(self):
        # 构建 inner_group_num 个共享层（FlaxAlbertLayer 实例） 注意：ALBERT 特点是参数共享，所以多个层可以复用同一个子层对象
        self.layers = [
            FlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num)
        ]

    def __call__(
        self,
        hidden_states,  # 输入表示（batch_size, seq_len, hidden_size）
        attention_mask,  # 注意力掩码，用于屏蔽 padding 或未来信息
        deterministic: bool = True, # 是否为推理模式（控制 dropout）
        output_attentions: bool = False,   # 是否输出注意力权重
        output_hidden_states: bool = False, # 是否输出每层的 hidden states
    ):
        layer_hidden_states = () # 存储中间层的 hidden_states
        layer_attentions = () # 存储中间层的注意力权重

        for layer_index, albert_layer in enumerate(self.layers):
             # 依次执行每一层 ALBERT Layer
            layer_output = albert_layer(
                hidden_states,
                attention_mask,
                deterministic=deterministic,
                output_attentions=output_attentions,
            )
            hidden_states = layer_output[0] # 更新 hidden_states，作为下一层输入
            # 可选：记录每层的注意力权重
            if output_attentions:
                layer_attentions = layer_attentions + (layer_output[1],)
            # 可选：记录每层的 hidden state
            if output_hidden_states:
                layer_hidden_states = layer_hidden_states + (hidden_states,)
        # === 构建最终输出 ===
        outputs = (hidden_states,)  # 最后一层的输出
        if output_hidden_states:
            outputs = outputs + (layer_hidden_states,)  # 所有中间层 hidden states
        if output_attentions:
            outputs = outputs + (layer_attentions,)  # 所有中间层 attention
        return outputs  # last-layer hidden state, (layer hidden states), (layer attentions)

In [96]:
# 对前面FlaxAlbertLayerCollection的封装调用
class FlaxAlbertLayerCollections(nn.Module):
    config: AlbertConfig  # 模型配置，包含层数、隐藏维度等超参数
    dtype: jnp.dtype = jnp.float32    # 指定计算时使用的数据类型（如 float32 / float16）
    layer_index: Optional[str] = None # 层编号，可用于命名或调试（目前未使用）

    def setup(self):
         # 初始化 ALBERT 层集合，一个逻辑 block，内部包含 inner_group_num 个共享的 FlaxAlbertLayer
        self.albert_layers = FlaxAlbertLayerCollection(self.config, dtype=self.dtype)

    def __call__(
        self,
        hidden_states,  # 输入隐状态（batch_size, seq_length, hidden_size）
        attention_mask,  # 注意力掩码（用于屏蔽无效位置）
        deterministic: bool = True, # 控制 dropout 是否启用（True 表示推理模式）
        output_attentions: bool = False,  # 是否返回注意力权重
        output_hidden_states: bool = False,   # 是否返回每层的中间隐状态
    ):
        outputs = self.albert_layers( # 调用 ALBERT 层集合执行前向传播
            hidden_states,
            attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        return outputs  # 返回结构为：最后一层隐状态 + 可选的中间隐状态 + 可选的注意力权重

In [91]:
config.num_hidden_groups

1

In [97]:
class FlaxAlbertLayerGroups(nn.Module):
    config: AlbertConfig  # 模型配置，包含总层数、组数等
    dtype: jnp.dtype = jnp.float32  # 计算时使用的数据类型（如 float32、float16）

    def setup(self):
        # 构建多个 layer group，每个 group 是 FlaxAlbertLayerCollections 实例
        # ALBERT 采用参数共享机制，多个 transformer 层共享 group 内的子层权重
        self.layers = [
            FlaxAlbertLayerCollections(self.config, 
                                       name=str(i),   # 用于模块命名
                                       layer_index=str(i), # 可选的调试信息
                                       dtype=self.dtype)  # 控制计算精度
            for i in range(self.config.num_hidden_groups)
        ]

    def __call__(
        self,
        hidden_states,   # 输入张量，形状 [batch, seq_len, hidden_dim]
        attention_mask,   # 注意力 mask，用于屏蔽无效位置
        deterministic: bool = True,  # 是否关闭 dropout，推理时应为 True
        output_attentions: bool = False,  # 是否输出每层的注意力权重
        output_hidden_states: bool = False, # 是否输出每层的隐藏状态
        return_dict: bool = True,    # 是否使用字典形式返回
    ):
        all_attentions = () if output_attentions else None  # 保存所有层的注意力输出（可选）
        all_hidden_states = (hidden_states,) if output_hidden_states else None  # 保存每层隐藏状态（可选）
         # 遍历所有 hidden layer，注意每 N 层复用一个 group
        for i in range(self.config.num_hidden_layers):
            # 计算当前层使用的共享 group 索引
            group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
            # 调用对应 group 执行 forward（会运行 inner_group_num 次共享层）
            layer_group_output = self.layers[group_idx](
                hidden_states,
                attention_mask,
                deterministic=deterministic,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
            # 更新当前隐藏状态
            hidden_states = layer_group_output[0]

            if output_attentions:  # 收集注意力输出
                all_attentions = all_attentions + layer_group_output[-1]

            if output_hidden_states: # 收集隐藏状态输出
                all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict: # 返回结果
            # 返回元组形式 (last_hidden_state, all_hidden_states, all_attentions)
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
        # 返回标准字典形式的结构体
        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states,  # 最后一层的输出
            hidden_states=all_hidden_states,# 所有隐藏层输出（可选）
            attentions=all_attentions   # 所有注意力输出（可选）
        )

In [98]:
layerGroups=FlaxAlbertLayerGroups(config)

In [99]:
variables = layerGroups.init(key, hidden_states, attention_mask, deterministic=True)

In [100]:
jax.tree.map(jnp.shape, variables)

{'params': {'0': {'albert_layers': {'0': {'attention': {'LayerNorm': {'bias': (4096,),
       'scale': (4096,)},
      'dense': {'bias': (4096,), 'kernel': (4096, 4096)},
      'key': {'bias': (4096,), 'kernel': (4096, 4096)},
      'query': {'bias': (4096,), 'kernel': (4096, 4096)},
      'value': {'bias': (4096,), 'kernel': (4096, 4096)}},
     'ffn': {'bias': (16384,), 'kernel': (4096, 16384)},
     'ffn_output': {'bias': (4096,), 'kernel': (16384, 4096)},
     'full_layer_layer_norm': {'bias': (4096,), 'scale': (4096,)}}}}}}

In [101]:
outputs = layerGroups.apply(variables,
                            hidden_states,
                            attention_mask, 
                            deterministic=True
                            )

In [103]:
outputs.last_hidden_state.shape

(2, 8, 4096)

In [104]:
class FlaxAlbertEncoder(nn.Module):
    config: AlbertConfig # 模型配置，包含hidden_size、层数、注意力头数等超参数
    dtype: jnp.dtype = jnp.float32  # 指定计算所用的数据类型（如 float32 或 float16）

    def setup(self):
        # 将嵌入层输出的维度（embedding_size）映射到 hidden_size
        # ALBERT 的 embedding_size 通常小于 hidden_size，通过这个全连接层做线性映射
        self.embedding_hidden_mapping_in = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range), # 初始化方式：正态分布
            dtype=self.dtype,
        )
        # 构建 ALBERT 的多个层组（LayerGroups），每组内部可能共享参数
        # 每组由若干重复层（FlaxAlbertLayerCollections）组成
        self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype)

    def __call__(
        self,
        hidden_states,  # 输入 token 的嵌入表示（embedding 输出）
        attention_mask,  # 注意力掩码，用于屏蔽 padding 等无效位置
        deterministic: bool = True, # 是否处于推理模式（影响 dropout 等行为）
        output_attentions: bool = False,  # 是否输出每层的注意力权重
        output_hidden_states: bool = False,  # 是否输出每层的中间表示
        return_dict: bool = True,  # 是否返回字典结构输出
    ):
        # 第一步：将 embedding 投影到 hidden_size，以供后续层使用
        hidden_states = self.embedding_hidden_mapping_in(hidden_states)
        # 第二步：将 hidden_states 和 attention_mask 传入层组中处理，得到最终输出
        return self.albert_layer_groups(
            hidden_states,
            attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

In [119]:
text = "Deep learning changes the world"
# 编码文本，返回 NumPy 格式（Flax 使用）
xx = tokenizer(text, return_tensors="np")
# 构造 position_ids
seq_length = xx["input_ids"].shape[1]
xx["position_ids"] = np.arange(seq_length)[None, :] 

In [120]:
xx

{'input_ids': array([[   2,  855, 2477, 1693,   14,  126,    3]]), 'token_type_ids': array([[0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1]]), 'position_ids': array([[0, 1, 2, 3, 4, 5, 6]])}

In [121]:
attention_mask=xx.pop('attention_mask')

In [123]:
embeddings=FlaxAlbertEmbeddings(config)
variables = embeddings.init(jax.random.key(0),**xx)

In [124]:
hidden_states = embeddings.apply(
  variables, # 参数
  **xx, # 输入数据
  deterministic=False, 
  rngs={'dropout': jax.random.key(seed=88)}
)

In [125]:
encoder=FlaxAlbertEncoder(config)

In [126]:
variables = encoder.init(key, hidden_states, attention_mask, deterministic=True)

In [127]:
jax.tree.map(jnp.shape, variables)

{'params': {'albert_layer_groups': {'0': {'albert_layers': {'0': {'attention': {'LayerNorm': {'bias': (4096,),
        'scale': (4096,)},
       'dense': {'bias': (4096,), 'kernel': (4096, 4096)},
       'key': {'bias': (4096,), 'kernel': (4096, 4096)},
       'query': {'bias': (4096,), 'kernel': (4096, 4096)},
       'value': {'bias': (4096,), 'kernel': (4096, 4096)}},
      'ffn': {'bias': (16384,), 'kernel': (4096, 16384)},
      'ffn_output': {'bias': (4096,), 'kernel': (16384, 4096)},
      'full_layer_layer_norm': {'bias': (4096,), 'scale': (4096,)}}}}},
  'embedding_hidden_mapping_in': {'bias': (4096,), 'kernel': (128, 4096)}}}

In [129]:
outputs = encoder.apply(variables,
                            hidden_states,
                            attention_mask, 
                            deterministic=True
                            )

In [130]:
print(type(outputs))

<class 'transformers.modeling_flax_outputs.FlaxBaseModelOutput'>


In [133]:
outputs.last_hidden_state.shape

(1, 7, 4096)

In [178]:
class FlaxAlbertOnlyMLMHead(nn.Module):
    config: AlbertConfig # 配置对象
    dtype: jnp.dtype = jnp.float32  # 计算所使用的数据类型
    bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros # 输出层 bias 的初始化方法

    def setup(self):
        # 第1层：将 hidden_size 投影到 embedding_size
        self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype)
        # 激活函数（如 gelu、relu 等），从配置中指定
        self.activation = ACT2FN[self.config.hidden_act]
          # 层归一化，保持数值稳定性
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
         # 解码层：将 embedding_size 映射到 vocab_size，用于生成预测分布（注意：不带 bias）
        self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
         # 显式定义 bias 参数，作为 decoder 的偏置项加到输出上
        self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))

    def __call__(self, hidden_states, shared_embedding=None):
        # 输入为 Transformer 的输出 hidden_states（shape: [batch, seq_len, hidden_size]）
        # Step 1: 投影至 embedding_size（与原始 embedding 层一致，便于参数共享）
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation(hidden_states) # Step 2: 应用非线性激活函数
        hidden_states = self.LayerNorm(hidden_states) # Step 3: 应用层归一化，提升训练稳定性
        
        # Step 4: 使用词向量矩阵转置作为权重共享（若提供），否则使用默认 decoder 权重
        if shared_embedding is not None:
            # 共享 embedding 权重：decoder 的权重 = embedding 的转置
            hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
        else: # 使用 decoder 自身的权重
            hidden_states = self.decoder(hidden_states)
        # Step 5: 加上 bias，得到最终对每个 token 的 logits
        hidden_states += self.bias
        return hidden_states

In [None]:
# 此模块对应 ALBERT 的 MLM（Masked Language Modeling）输出头。
# 包含一个非线性前馈层 + LayerNorm + decoder 层（线性层），用于从 Transformer 的输出生成词表维度的 logits。
# 如果提供 shared_embedding（来自嵌入层），则共享权重以节省参数（与 ALBERT/BERT 论文一致）。
# 手动加上 bias 而非交由 nn.Dense 控制，是为了与 huggingface 兼容并支持权重共享的灵活性

In [179]:
class FlaxAlbertSOPHead(nn.Module):
    config: AlbertConfig # 配置对象
    dtype: jnp.dtype = jnp.float32  # 指定计算精度（如 float32、float16）

    def setup(self):
        # Dropout 层，用于在训练时进行正则化，防止过拟合
        self.dropout = nn.Dropout(self.config.classifier_dropout_prob)
        # 分类器：将 pooled_output 映射为两个类别（用于句子顺序预测）
        self.classifier = nn.Dense(2, dtype=self.dtype)

    def __call__(self, pooled_output, deterministic=True):
        # pooled_output 是从 Transformer 的 cls token 或平均池化得到的句子级表示
        # 应用 dropout，训练时启用，推理时关闭（由 deterministic 控制）
        pooled_output = self.dropout(pooled_output, deterministic=deterministic)
        logits = self.classifier(pooled_output)  # 分类器生成 logits，用于二分类（句子是否是顺序）
        return logits  # shape: [batch_size, 2]

In [None]:
# 模块设计意图说明：
# 本模块用于 SOP (Sentence Order Prediction) 任务，这是 ALBERT 替代 NSP 的预训练目标。
# 输入为整个序列的聚合表示（通常是 [CLS] token 的输出），输出为 [batch_size, 2] 的 logits。
# 包括一个 Dropout 层用于训练时正则化，以及一个 Dense 层进行二分类预测。
# 模块结构简洁，符合 ALBERT 论文中的 SOP head 架构设计

In [139]:
input_ids=jnp.zeros((2,8), dtype="i4")

In [140]:
jnp.atleast_2d(input_ids)

Array([[0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)

In [141]:
jnp.arange(jnp.atleast_2d(input_ids).shape[-1])

Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)

In [142]:
jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),(2,8))

Array([[0, 1, 2, 3, 4, 5, 6, 7],
       [0, 1, 2, 3, 4, 5, 6, 7]], dtype=int32)

In [143]:
# 抽象基类：用于初始化权重、下载和加载预训练模型，提供标准接口。
# 所有具体 ALBERT 模型都应继承自此类。
class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel):
    
    config_class = AlbertConfig # 指定配置类
    base_model_prefix = "albert" # 用于自动加载权重时的 key 前缀
    module_class: nn.Module = None # 子类需指定：ALBERT 具体模型模块类

    def __init__(
        self,
        config: AlbertConfig,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 初始化模块实例（如 FlaxAlbertModule），由 module_class 提供
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        # 继承父类初始化流程，构建 self.module、self.params 等
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
        
    # 初始化模型参数，可指定已有参数（用于缺失参数补齐）。
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 构造输入张量用于触发模块初始化
        input_ids = jnp.zeros(input_shape, dtype="i4")
        token_type_ids = jnp.zeros_like(input_ids)
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        attention_mask = jnp.ones_like(input_ids)
        # 分割随机数生成器，用于初始化参数和 dropout
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}
        # 调用 Flax 模块的 init 方法生成参数
        random_params = self.module.init(
            rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False
        )["params"]
        # 若提供了部分已有参数，则合并已有参数与随机初始化结果（补齐缺失）
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params)) # 随机初始化参数
            params = flatten_dict(unfreeze(params)) 
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
    # 模型推理或训练调用接口，统一处理各种输入参数，并调用内部 Flax module。
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 根据 config 或显式传参，决定是否输出中间层与注意力
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 若 token_type、position_ids、attention_mask 未提供，则自动补全默认值
        if token_type_ids is None:
            token_type_ids = jnp.zeros_like(input_ids)

        if position_ids is None:
            position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)

        # 如果提供了 dropout 的 RNG，则传入 apply 函数中
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng
        # 调用 Flax module 的 apply 方法执行推理（或训练），传入权重与输入
        return self.module.apply(
            {"params": params or self.params},
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(token_type_ids, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            not train,  # deterministic = not train
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
        )

In [144]:
class FlaxAlbertModule(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32  # 模型计算精度
    add_pooling_layer: bool = True # 是否添加池化层用于句向量表示

    def setup(self):
        # 构建嵌入层：词嵌入 + 位置嵌入 + token_type 嵌入
        self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype)
        # 构建Transformer编码器（基于参数共享的ALBERT Encoder）
        self.encoder = FlaxAlbertEncoder(self.config, dtype=self.dtype)
        if self.add_pooling_layer:
             # 构建池化层，用于提取句子级别表示（对应CLS token）
            self.pooler = nn.Dense(
                self.config.hidden_size,  # 输出维度仍为hidden_size
                kernel_init=jax.nn.initializers.normal(self.config.initializer_range),  # 初始化方式
                dtype=self.dtype,
                name="pooler",
            )
            self.pooler_activation = nn.tanh # 使用tanh激活，遵循BERT样式
        else:
            self.pooler = None
            self.pooler_activation = None

    def __call__(
        self,
        input_ids,  # 输入token id序列 [batch, seq]
        attention_mask, # 注意力mask，1表示有效token，0表示padding
        token_type_ids: Optional[np.ndarray] = None,  # 句子类型id（如句子对任务中用于区分A/B句）
        position_ids: Optional[np.ndarray] = None,  # 位置id，若未传入则自动构造
        deterministic: bool = True,  # 是否处于inference模式，控制dropout等行为
        output_attentions: bool = False,  # 是否输出所有注意力权重
        output_hidden_states: bool = False,
        return_dict: bool = True,   # 是否以字典方式返回结果
    ):
         # 如果未提供 token_type_ids，默认为0（即全为第一句话）
        if token_type_ids is None:
            token_type_ids = jnp.zeros_like(input_ids)

        # 如果未提供 position_ids，自动按序构建 [0, 1, ..., seq_len-1]
        if position_ids is None:
            position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
        # 经过 embedding 层：得到初始 hidden_states
        hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic)
         # 输入 encoder 进行多层Transformer处理，输出结构中hidden_states位于outputs[0]
        outputs = self.encoder(
            hidden_states,
            attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs[0] # 获取最后一层的输出
        if self.add_pooling_layer:
            # 池化层：提取CLS位置的向量作为整体句向量
            pooled = self.pooler(hidden_states[:, 0])  # [batch, hidden_size]
            pooled = self.pooler_activation(pooled)  # 应用 tanh 激活
        else:
            pooled = None

        if not return_dict:
             # 若不要求返回字典，则返回tuple格式
            if pooled is None:
                return (hidden_states,) + outputs[1:]
            return (hidden_states, pooled) + outputs[1:]
         # 返回标准结构体格式，包含所有输出项
        return FlaxBaseModelOutputWithPooling(
            last_hidden_state=hidden_states,
            pooler_output=pooled,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [146]:
FlaxAlbertPreTrainedModel.module_class = FlaxAlbertModule  # 设置具体的 Flax 模块类
model = FlaxAlbertPreTrainedModel(config=config, input_shape=(1, 16))

In [149]:
# # Step 2: 初始化参数
rng = jax.random.PRNGKey(0)
params = model.init_weights(rng, input_shape=(1, 16))

In [151]:
jax.tree.map(jnp.shape,params)

{'embeddings': {'LayerNorm': {'bias': (128,), 'scale': (128,)},
  'position_embeddings': {'embedding': (512, 128)},
  'token_type_embeddings': {'embedding': (2, 128)},
  'word_embeddings': {'embedding': (30000, 128)}},
 'encoder': {'albert_layer_groups': {'0': {'albert_layers': {'0': {'attention': {'LayerNorm': {'bias': (4096,),
        'scale': (4096,)},
       'dense': {'bias': (4096,), 'kernel': (4096, 4096)},
       'key': {'bias': (4096,), 'kernel': (4096, 4096)},
       'query': {'bias': (4096,), 'kernel': (4096, 4096)},
       'value': {'bias': (4096,), 'kernel': (4096, 4096)}},
      'ffn': {'bias': (16384,), 'kernel': (4096, 16384)},
      'ffn_output': {'bias': (4096,), 'kernel': (16384, 4096)},
      'full_layer_layer_norm': {'bias': (4096,), 'scale': (4096,)}}}}},
  'embedding_hidden_mapping_in': {'bias': (4096,), 'kernel': (128, 4096)}},
 'pooler': {'bias': (4096,), 'kernel': (4096, 4096)}}

In [153]:
# # Step 3: 构造输入（示例：batch_size=1, seq_len=16）
input_ids = jnp.ones((1, 16), dtype=jnp.int32)
# # Step 4: 执行前向推理
outputs = model(
    input_ids=input_ids,
    params=params,
    train=False,  # 推理模式
)

In [156]:
# 如果 return_dict=True，则 outputs 是一个字典
print(outputs.last_hidden_state.shape,outputs.pooler_output.shape)

(1, 16, 4096) (1, 4096)


In [158]:
# @add_start_docstrings 它的作用是：将两个 docstring 片段拼接起来，作为 FlaxAlbertModel 类的最终文档字符串（__doc__）。
@add_start_docstrings(
    "The bare Albert Model transformer outputting raw hidden-states without any specific head on top.",
    ALBERT_START_DOCSTRING,
)
class FlaxAlbertModel(FlaxAlbertPreTrainedModel):
    module_class = FlaxAlbertModule # 子类需指定：ALBERT 具体模型模块类

In [None]:
help(FlaxAlbertModel)

In [None]:
# 文档字符串（__doc__）在以下几种场景中非常有用：
# 1. 使用 help() 或交互式查看 API 时
# 在 Python shell 或 Jupyter Notebook 中输入：
# help(FlaxAlbertModel)
# 或
# FlaxAlbertModel?
# 你会看到 @add_start_docstrings 拼接后的完整文档，便于快速了解模型用途、输入输出、参数结构等。
# ✅ 2. 自动文档生成工具
# 如 Sphinx + autodoc 扩展会自动抓取类的 __doc__：
# .. autoclass:: transformers.FlaxAlbertModel
#    :members:
# 文档网站（如 HuggingFace 文档）就是用这种方式生成的，add_start_docstrings 提高了文档一致性和可维护性。
# ✅ 3. IDE 提示与补全
# 在 PyCharm、VS Code 中悬浮鼠标或查看 docstring，可直接看到这段描述，快速理解该类或方法：
# model = FlaxAlbertModel(config)
# 悬停 FlaxAlbertModel 时会看到完整 docstring
# ✅ 4. 开源时为他人/未来自己阅读源码服务
# 例如你阅读 Transformers 源码，看到每个模型头（如 FlaxAlbertModel, FlaxAlbertForSequenceClassification）
# 都有详细 docstring，可以快速了解用途，不必一行行查代码。
# ✅ 5. 写 wrapper 或工具自动分析时
# 有些高级工具会分析类的 __doc__ 生成参数列表、接口规范等。比如 HuggingFace CLI 工具、AutoModel 分支选择时也可能用到。

In [163]:
# HuggingFace Transformers 框架中用于 向类的 docstring 中追加代码调用示例 的工具函数，
# 作用是为 FlaxAlbertModel 类动态添加“使用示例代码段”。
append_call_sample_docstring(
    FlaxAlbertModel, # 被修改的类 
    _CHECKPOINT_FOR_DOC,   # 使用的预训练模型名称（作为示例）
    FlaxBaseModelOutputWithPooling,   # 输出的数据结构类型（用于 docstring 说明）
    _CONFIG_FOR_DOC)  # 配置类（也用于 docstring 示例）

In [None]:
# 为 API 自动文档生成提供一致的调用示例
# 用户在 help()、IDE 或网页文档中能看到该模型如何初始化、如何调用
# 配合 add_start_docstrings 形成完整 docstring 结构：
# 前面是模型说明（add_start_docstrings）
# 后面是使用代码（append_call_sample_docstring）

In [None]:
help(FlaxAlbertModel)

In [165]:
from transformers import AutoTokenizer, FlaxAlbertModel

In [166]:
tokenizer = AutoTokenizer.from_pretrained("albert-xxlarge-v2")

In [167]:
model = FlaxAlbertModel.from_pretrained("albert-xxlarge-v2")

model.safetensors:   0%|          | 0.00/893M [00:00<?, ?B/s]

Some weights of the model checkpoint at albert-xxlarge-v2 were not used when initializing FlaxAlbertModel: {('predictions', 'LayerNorm', 'bias'), ('predictions', 'bias'), ('predictions', 'decoder', 'bias'), ('predictions', 'dense', 'kernel'), ('predictions', 'dense', 'bias'), ('predictions', 'LayerNorm', 'kernel')}
- This IS expected if you are initializing FlaxAlbertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxAlbertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [168]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") # 返回张量

In [170]:
inputs

{'input_ids': Array([[    2, 10975,    15,    51,  1952,    25, 10901,     3]], dtype=int32), 'token_type_ids': Array([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32), 'attention_mask': Array([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)}

In [171]:
outputs = model(**inputs)

In [172]:
last_hidden_states = outputs.last_hidden_state

In [174]:
last_hidden_states.shape

(1, 8, 4096)

In [181]:
# 层层封装 这个封装之前的
class FlaxAlbertForPreTrainingModule(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32  # 模型内部计算使用的数值精度类型
    def setup(self):
        # 主体部分使用 ALBERT 编码器模块
        self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
        # 用于 Masked Language Modeling（MLM）任务的预测头
        self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
        # 用于 Sentence Order Prediction（SOP）任务的分类器
        self.sop_classifier = FlaxAlbertSOPHead(config=self.config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 1. 输入传入基础 ALBERT 模块，得到编码器输出（包括序列特征与池化输出）
        outputs = self.albert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 2. 如果使用共享词嵌入，则从 embedding 层中提取词向量权重，用于 decoder 权重共享
        if self.config.tie_word_embeddings:
            shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
        else:
            shared_embedding = None

        hidden_states = outputs[0]  # 序列中每个 token 的隐状态（用于 MLM）
        pooled_output = outputs[1]  # 第一个 token 的池化表示（用于 SOP）
        # 3. 预测 masked token 的词（MLM 输出）
        prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
        # 4. 预测句子顺序是否合理（SOP 输出）
        sop_scores = self.sop_classifier(pooled_output, deterministic=deterministic)
        # 5. 返回结果，支持 tuple 或 dict 形式
        if not return_dict:
            return (prediction_scores, sop_scores) + outputs[2:]

        return FlaxAlbertForPreTrainingOutput(
            prediction_logits=prediction_scores, # MLM logits
            sop_logits=sop_scores,   # SOP logits
            hidden_states=outputs.hidden_states, # 可选中间层表示
            attentions=outputs.attentions,   # 可选注意力权重
        )

In [182]:
@add_start_docstrings(
    """
    Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
    `sentence order prediction (classification)` head.
    """,
    ALBERT_START_DOCSTRING,
)
class FlaxAlbertForPreTraining(FlaxAlbertPreTrainedModel):
    module_class = FlaxAlbertForPreTrainingModule # 内部子模块

In [184]:
model = FlaxAlbertForPreTraining.from_pretrained("albert-xxlarge-v2")

Some weights of the model checkpoint at albert-xxlarge-v2 were not used when initializing FlaxAlbertForPreTraining: {('predictions', 'decoder', 'bias')}
- This IS expected if you are initializing FlaxAlbertForPreTraining from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxAlbertForPreTraining from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxAlbertForPreTraining were not initialized from the model checkpoint at albert-xxlarge-v2 and are newly initialized: {('sop_classifier', 'classifier', 'bias'), ('sop_classifier', 'classifier', 'kernel')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [185]:
outputs = model(**inputs)

In [187]:
jax.tree.map(jnp.shape,outputs)

FlaxAlbertForPreTrainingOutput(prediction_logits=(1, 8, 30000), sop_logits=(1, 2), hidden_states=None, attentions=None)

In [183]:
FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING = """
    Returns:

    Example:

    ```python
    >>> from transformers import AutoTokenizer, FlaxAlbertForPreTraining

    >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
    >>> model = FlaxAlbertForPreTraining.from_pretrained("albert/albert-base-v2")

    >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
    >>> outputs = model(**inputs)

    >>> prediction_logits = outputs.prediction_logits
    >>> seq_relationship_logits = outputs.sop_logits
    ```
"""

In [None]:
print(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING)

In [195]:
# 这行代码的作用是覆盖 FlaxAlbertForPreTraining.__call__ 方法的文档字符串（docstring），即当用户使
# 用 help(FlaxAlbertForPreTraining) 或 IDE 查看方法时，会看到更清晰的调用说明
overwrite_call_docstring(
    FlaxAlbertForPreTraining,
    ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING,
)

In [None]:
# overwrite_call_docstring 和 @add_start_docstrings 是作用于不同目标的文档字符串：
# @add_start_docstrings 是给 类本身 添加或扩展 docstring。
# overwrite_call_docstring 是给 类的 __call__ 方法 覆盖（替换）docstring。
# 两者文档不会简单“前后”叠加
# @add_start_docstrings 影响的是 FlaxAlbertForPreTraining 类的 docstring，通常你用 help(
#     FlaxAlbertForPreTraining) 看到的是它的内容。
# overwrite_call_docstring 影响的是 FlaxAlbertForPreTraining.__call__ 函数的 
# docstring，help(FlaxAlbertForPreTraining.__call__) 才会看到它。

In [196]:
help(FlaxAlbertForPreTraining.__call__)

Help on function __call__ in module __main__:

__call__(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, params: dict = None, dropout_rng: <function PRNGKey at 0x7e9af89e4ea0> = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None)
    The [`FlaxAlbertPreTrainedModel`] forward method, overrides the `__call__` special method.
    
    <Tip>
    
    Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
    instance afterwards instead of this since the former takes care of running the pre and post processing steps while
    the latter silently ignores them.
    
    </Tip>
    
    Args:
        input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.
    
            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTo

In [197]:
from transformers import AutoTokenizer, FlaxAlbertForPreTraining

In [198]:
tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/760k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

In [199]:
model = FlaxAlbertForPreTraining.from_pretrained("albert/albert-base-v2")

flax_model.msgpack:   0%|          | 0.00/44.9M [00:00<?, ?B/s]

Some weights of FlaxAlbertForPreTraining were not initialized from the model checkpoint at albert/albert-base-v2 and are newly initialized: {('sop_classifier', 'classifier', 'bias'), ('albert', 'pooler', 'kernel'), ('albert', 'pooler', 'bias'), ('sop_classifier', 'classifier', 'kernel')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [200]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="np")

In [201]:
outputs = model(**inputs)

In [204]:
prediction_logits = outputs.prediction_logits
seq_relationship_logits = outputs.sop_logits # 序列关系

In [205]:
print(prediction_logits.shape,seq_relationship_logits.shape)

(1, 8, 30000) (1, 2)


In [206]:
_CONFIG_FOR_DOC

'AlbertConfig'

In [None]:
append_replace_return_docstrings(
    FlaxAlbertForPreTraining, 
    output_type=FlaxAlbertForPreTrainingOutput,
    config_class=_CONFIG_FOR_DOC
)

In [None]:
# append_replace_return_docstrings 主要是修改类的 __call__ 方法的 docstring 中的返回值部分，
# 具体位置是在 overwrite_call_docstring 已经设置的 docstring 中，针对返回值描述段进行追加或替换。
# 具体位置关系
# overwrite_call_docstring 先对 __call__ 的整体 docstring 进行覆盖（包括输入、输出等说明）
# append_replace_return_docstrings 进一步追加或替换返回值部分的描述，细化输出信息
# 所以它是在 overwrite_call_docstring 设置的 docstring 基础上，针对“返回值”段落做补充或覆盖
# 什么时候能看到？
# 执行 help(FlaxAlbertForPreTraining.__call__)（查看该类调用方法的帮助文档）
# 或者在支持的 IDE 中查看 __call__ 方法的文档时

In [209]:
class FlaxAlbertForMaskedLMModule(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 初始化子模块
        # 不使用池化层，符合MaskedLM任务特点（无需序列级别表示）
        self.albert = FlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
        # 预测头，用于输出词汇表上的预测分数
        self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 先通过 Albert 基础模型获取隐藏状态
        outputs = self.albert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]   # 最后一层隐藏状态
         # 是否共享词向量权重，减少参数，提升性能
        if self.config.tie_word_embeddings:
            shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
        else:
            shared_embedding = None

        # 通过预测头计算每个token对应词汇表上的预测分数(logits)
        logits = self.predictions(hidden_states, shared_embedding=shared_embedding)

        if not return_dict: # 返回tuple结构，兼容不同调用方式
            return (logits,) + outputs[1:]
        # 返回结构化对象，包含logits、隐藏状态和注意力权重，方便后续处理和分析
        return FlaxMaskedLMOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [None]:
# 设计意图：
# 该模块基于Albert基础模型，专注于Masked Language Model任务，故禁用池化层（不需要句向量）。
# 支持共享词嵌入权重，减少参数冗余。
# 通过返回结构化输出，方便后续调用者获取更多信息（如中间隐藏层、注意力）

In [210]:
# 掩码模型,继承自PreTrainedModel 这样可以用from_pretrained方法 内部包装了一个子模块
@add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING)
class FlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel): 
    module_class = FlaxAlbertForMaskedLMModule

In [212]:
append_call_sample_docstring(
    FlaxAlbertForMaskedLM,    # 目标模型类
    _CHECKPOINT_FOR_DOC,  # 用于示例的预训练模型权重标识（checkpoint）
    FlaxMaskedLMOutput,   # 模型调用返回的输出类型，用于生成文档说明
    _CONFIG_FOR_DOC,  # 模型配置类，用于补充配置参数说明
    revision="refs/pr/11"   # 代码版本信息，便于追踪文档对应的代码分支或PR
)

In [None]:
help(FlaxAlbertForMaskedLM.__call__)

In [None]:
# ```python
# >>> from transformers import AutoTokenizer, FlaxAlbertForMaskedLM

# >>> tokenizer = AutoTokenizer.from_pretrained("albert-xxlarge-v2", revision="refs/pr/11")
# >>> model = FlaxAlbertForMaskedLM.from_pretrained("albert-xxlarge-v2", revision="refs/pr/11")

# >>> inputs = tokenizer("The capital of France is [MASK].", return_tensors="jax")

# >>> outputs = model(**inputs)
# >>> logits = outputs.logits

In [214]:
class FlaxAlbertForSequenceClassificationModule(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
         # 初始化基础Albert模型
        self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
        # 分类器的dropout率，优先使用专门的classifier_dropout_prob，否则用隐藏层dropout
        classifier_dropout = (
            self.config.classifier_dropout_prob
            if self.config.classifier_dropout_prob is not None
            else self.config.hidden_dropout_prob
        )
        # Dropout层，用于分类器输入，防止过拟合
        self.dropout = nn.Dropout(rate=classifier_dropout)
        self.classifier = nn.Dense( # 线性分类层，输出类别数目个logits
            self.config.num_labels,
            dtype=self.dtype,
        )

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 通过Albert基础模型获得输出，含隐藏状态和池化后的句向量
        outputs = self.albert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = outputs[1] # 获取池化输出（句子级别表示）
        pooled_output = self.dropout(pooled_output, deterministic=deterministic)  # 对池化输出应用dropout，训练时生效，推理时关闭
        logits = self.classifier(pooled_output) # 线性层计算分类logits

        if not return_dict:  # 返回元组，兼容旧接口
            return (logits,) + outputs[2:]
        # 返回结构化输出，包含logits、隐藏层状态和注意力权重
        return FlaxSequenceClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [None]:
# 设计意图：
# 基于Albert模型做序列分类，使用池化层输出句子表示。
# 分类头包含Dropout和全连接层，防止过拟合并实现类别预测。
# 支持返回dict或tuple格式的结果，兼容不同调用习惯。

In [215]:
# 继承自FlaxAlbertPreTrainedModel 就具有下载和加载模型的方法
# 里面包装了一个子模块,实际上是调用子模块的call方法
@add_start_docstrings(
    """
    Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    """,
    ALBERT_START_DOCSTRING,
)
class FlaxAlbertForSequenceClassification(FlaxAlbertPreTrainedModel):
    module_class = FlaxAlbertForSequenceClassificationModule

In [216]:
append_call_sample_docstring( # 附加调用示例文档字符串
    FlaxAlbertForSequenceClassification, # 为哪个类添加
    _CHECKPOINT_FOR_DOC, # 用于示例的预训练模型权重标识（checkpoint）
    FlaxSequenceClassifierOutput, # 模型调用返回的输出类型，用于生成文档说明
    _CONFIG_FOR_DOC, # 模型配置类，用于补充配置参数说明
)

In [None]:
# help(FlaxAlbertForSequenceClassification.__call__)
 # Example:
    
 #    ```python
 #    >>> from transformers import AutoTokenizer, FlaxAlbertForSequenceClassification
    
 #    >>> tokenizer = AutoTokenizer.from_pretrained("albert-xxlarge-v2")
 #    >>> model = FlaxAlbertForSequenceClassification.from_pretrained("albert-xxlarge-v2")
    
 #    >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
    
 #    >>> outputs = model(**inputs)
 #    >>> logits = outputs.logits

In [219]:
class FlaxAlbertForMultipleChoiceModule(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 初始化ALBERT主干模型
        self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
         # 多选任务中用于分类前的dropout，缓解过拟合
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
         # 最终分类器，输出每个选项一个logit（二分类），后续reshape得到每个样本的多个选项
        self.classifier = nn.Dense(1, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        deterministic: bool = True,  # 控制Dropout是否启用
        output_attentions: bool = False, # 是否返回注意力信息
        output_hidden_states: bool = False,  # 是否返回中间隐藏层
        return_dict: bool = True,  # 是否返回结构化结果
    ):
        # 多选任务中input维度为(batch_size, num_choices, seq_len)，这里reshape成(batch_size*num_choices, seq_len)
        num_choices = input_ids.shape[1]
        input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
        attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
        token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
        position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None

        # 通过ALBERT主干获得每个选项的输出
        outputs = self.albert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 取句子级表示（通常是[CLS]位输出）
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output, deterministic=deterministic)
        logits = self.classifier(pooled_output)  # 每个选项对应一个logit (batch_size*num_choices, 1)
        # 恢复为(batch_size, num_choices)
        reshaped_logits = logits.reshape(-1, num_choices)

        if not return_dict:
            return (reshaped_logits,) + outputs[2:]
        # 返回结构化输出：logits、隐藏层和注意力
        return FlaxMultipleChoiceModelOutput(
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [None]:
# 设计意图说明：
# 本模块用于多选任务（例如 SWAG、RACE），每个输入样本包含多个选项，模型需输出每个选项的分数。
# 将多选维度合并为批次维度，通过共享的 ALBERT 模型分别处理每个选项，再通过线性分类层输出 logits。
# 最后将 logits reshape 回 (batch_size, num_choices)，用于交叉熵计算或评估。
# 支持结构化输出，便于调试或进一步处理中间层结果

In [220]:
@add_start_docstrings(
    """
    Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    """,
    ALBERT_START_DOCSTRING,
)
class FlaxAlbertForMultipleChoice(FlaxAlbertPreTrainedModel):
    module_class = FlaxAlbertForMultipleChoiceModule

In [221]:
overwrite_call_docstring(
    FlaxAlbertForMultipleChoice, ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
append_call_sample_docstring(
    FlaxAlbertForMultipleChoice,
    _CHECKPOINT_FOR_DOC,
    FlaxMultipleChoiceModelOutput,
    _CONFIG_FOR_DOC,
)

In [222]:
_CONFIG_FOR_DOC

'AlbertConfig'

In [None]:
# help(FlaxAlbertForMultipleChoice.__call__)

In [224]:
from transformers import AutoTokenizer, FlaxAlbertForMultipleChoice

In [225]:
tokenizer = AutoTokenizer.from_pretrained("albert-xlarge-v2")
model = FlaxAlbertForMultipleChoice.from_pretrained("albert-xlarge-v2")

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/760k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/236M [00:00<?, ?B/s]

Some weights of the model checkpoint at albert-xlarge-v2 were not used when initializing FlaxAlbertForMultipleChoice: {('predictions', 'LayerNorm', 'bias'), ('predictions', 'bias'), ('predictions', 'decoder', 'bias'), ('predictions', 'dense', 'kernel'), ('predictions', 'dense', 'bias'), ('predictions', 'LayerNorm', 'kernel')}
- This IS expected if you are initializing FlaxAlbertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxAlbertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxAlbertForMultipleChoice were not initialized from the model checkpoint at albert-xlarge-v2 and are newly initialized: {('classifier', 'kernel'), ('classifier', 'bi

In [227]:
# 在意大利，在餐厅等正式场合供应的披萨都是未切片的
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
choice0 = "It is eaten with a fork and a knife." # 用刀叉吃
choice1 = "It is eaten while held in the hand." # 拿在手里就可以吃。

In [228]:
encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="jax", padding=True)

In [230]:
encoding['input_ids'].shape

(2, 35)

In [233]:
outputs = model(**{k: v[None, :] for k, v in encoding.items()})

In [235]:
outputs.logits

Array([[0.08329672, 0.0852083 ]], dtype=float32)

In [236]:
class FlaxAlbertForTokenClassificationModule(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 初始化 ALBERT 主干模型，token classification 不需要句向量，故关闭 pooling
        self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
        # 获取分类用的 dropout 概率，优先使用 classifier_dropout_prob
        classifier_dropout = (
            self.config.classifier_dropout_prob
            if self.config.classifier_dropout_prob is not None
            else self.config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(rate=classifier_dropout) # 定义 dropout 层
        # 定义分类器，用于对每个 token 输出 num_labels 个类别得分
        self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        deterministic: bool = True,  # 控制 dropout 是否启用（训练时 False，推理时 True）
        output_attentions: bool = False, # 是否返回注意力权重
        output_hidden_states: bool = False,  # 是否返回所有隐藏层
        return_dict: bool = True, # 是否返回结构化结果
    ):
        # 调用 ALBERT 主干模型，获得所有 token 的隐藏状态（hidden_states）等信息
        outputs = self.albert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0] # 获取最后一层的隐藏状态
        hidden_states = self.dropout(hidden_states, deterministic=deterministic) # 应用 dropout（训练时启用，推理时关闭）
        # 对每个 token 输出类别 logits（shape: [batch_size, seq_len, num_labels]）
        logits = self.classifier(hidden_states)

        if not return_dict: # 返回元组形式
            return (logits,) + outputs[1:]

        return FlaxTokenClassifierOutput( # 返回结构化输出，包含 logits、隐藏状态和注意力
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [237]:
# add_start_docstrings在调用help(FlaxAlbertForTokenClassification)时就可以显示
# 继承自FlaxAlbertPreTrainedModel是因为，父类已经有初始化权重方法和下载，加载权重的方法
# module_class是模型类内部封装的子模块,这样调用时因为FlaxAlbertPreTrainedModel call
# 的是子模块的call方法,所以相当于包装器用法
@add_start_docstrings(
    """
    Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    """,
    ALBERT_START_DOCSTRING,
)
class FlaxAlbertForTokenClassification(FlaxAlbertPreTrainedModel):
    module_class = FlaxAlbertForTokenClassificationModule

In [238]:
_CHECKPOINT_FOR_DOC

'albert-xxlarge-v2'

In [239]:
_CONFIG_FOR_DOC

'AlbertConfig'

In [240]:
append_call_sample_docstring(
    FlaxAlbertForTokenClassification, # 为什么类的call方法添加文档字符串
    _CHECKPOINT_FOR_DOC, 
    FlaxTokenClassifierOutput, # 输出返回值类型
    _CONFIG_FOR_DOC,
)

In [None]:
# help(FlaxAlbertForTokenClassification.__call__)

In [242]:
from transformers import AutoTokenizer, FlaxAlbertForTokenClassification
tokenizer = AutoTokenizer.from_pretrained("albert-xxlarge-v2")
model = FlaxAlbertForTokenClassification.from_pretrained("albert-xxlarge-v2")
inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")

Some weights of the model checkpoint at albert-xxlarge-v2 were not used when initializing FlaxAlbertForTokenClassification: {('albert', 'pooler', 'kernel'), ('albert', 'pooler', 'bias'), ('predictions', 'LayerNorm', 'bias'), ('predictions', 'bias'), ('predictions', 'decoder', 'bias'), ('predictions', 'dense', 'kernel'), ('predictions', 'dense', 'bias'), ('predictions', 'LayerNorm', 'kernel')}
- This IS expected if you are initializing FlaxAlbertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxAlbertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxAlbertForTokenClassification were not initialized from the model checkpoint at albert

In [243]:
inputs

{'input_ids': Array([[    2, 10975,    15,    51,  1952,    25, 10901,     3]], dtype=int32), 'token_type_ids': Array([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32), 'attention_mask': Array([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)}

In [244]:
outputs = model(**inputs)

In [247]:
outputs.logits.shape

(1, 8, 2)

In [None]:
class FlaxAlbertForQuestionAnsweringModule(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # 初始化 ALBERT 主干模型，不需要句向量，故关闭 pooling 层
        self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
        # 定义问答任务的输出层，输出 shape: [batch_size, seq_len, 2]（start 和 end logits）
        self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)  # 通常 num_labels=2

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 调用 ALBERT 主干模型，获取 token 级别表示（hidden_states）
        outputs = self.albert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            deterministic=deterministic, # 控制 dropout 等行为，训练=False，推理=True
            output_attentions=output_attentions,   # 是否返回注意力权重
            output_hidden_states=output_hidden_states,  # 是否返回隐藏层
            return_dict=return_dict,    # 是否返回结构化结果
        )

        hidden_states = outputs[0]  # [batch_size, seq_len, hidden_size]
         # 使用线性层计算每个 token 的起始/结束位置 logits，shape: [batch_size, seq_len, 2]
        logits = self.qa_outputs(hidden_states)
        # 拆分成 start 和 end logits，沿最后一个维度（2）拆分
        start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
        # 去除最后一维（即每个位置只保留一个标量）
        start_logits = start_logits.squeeze(-1)  # shape: [batch_size, seq_len] 
        end_logits = end_logits.squeeze(-1)

        if not return_dict: # 返回元组形式
            return (start_logits, end_logits) + outputs[1:]

        return FlaxQuestionAnsweringModelOutput(  # 返回结构化结果，包含 start/end logits、隐藏状态、注意力
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [None]:
# 设计意图说明：
# 用于抽取式问答（extractive QA）任务，如 SQuAD。
# 每个 token 都输出一个起始分数和一个结束分数，表示答案的区间。
# 不需要句级特征，因此禁用了 pooling 层。
# 整体结构为：ALBERT 编码器 + 线性输出层 + 拆分 start/end。该模式是问答任务的标准设计。

In [250]:
# 继承自FlaxAlbertPreTrainedModel,这样就有了初始化权重,from_pretrained之类的基础方法
# 模型类内部定义一个子模块,这样可以直接让子模块干活,自己只用做添加文档字符串的事情
@add_start_docstrings(
    """
    Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    ALBERT_START_DOCSTRING,
)
class FlaxAlbertForQuestionAnswering(FlaxAlbertPreTrainedModel):
    module_class = FlaxAlbertForQuestionAnsweringModule

In [251]:
append_call_sample_docstring(
    FlaxAlbertForQuestionAnswering, # 问答任务
    _CHECKPOINT_FOR_DOC,
    FlaxQuestionAnsweringModelOutput, # 模型的输出类型
    _CONFIG_FOR_DOC,
)

In [None]:
help(FlaxAlbertForQuestionAnswering.__call__)

In [253]:
from transformers import AutoTokenizer, FlaxAlbertForQuestionAnswering

In [254]:
tokenizer = AutoTokenizer.from_pretrained("albert-xxlarge-v2")
model = FlaxAlbertForQuestionAnswering.from_pretrained("albert-xxlarge-v2")

Some weights of the model checkpoint at albert-xxlarge-v2 were not used when initializing FlaxAlbertForQuestionAnswering: {('albert', 'pooler', 'kernel'), ('albert', 'pooler', 'bias'), ('predictions', 'LayerNorm', 'bias'), ('predictions', 'bias'), ('predictions', 'decoder', 'bias'), ('predictions', 'dense', 'kernel'), ('predictions', 'dense', 'bias'), ('predictions', 'LayerNorm', 'kernel')}
- This IS expected if you are initializing FlaxAlbertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxAlbertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxAlbertForQuestionAnswering were not initialized from the model checkpoint at albert-xxlarge

In [255]:
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
inputs = tokenizer(question, text, return_tensors="jax")

In [256]:
inputs

{'input_ids': Array([[    2,    72,    23,  2170, 27674,    60,     3,  2170, 27674,
           23,    21,  2210, 10956,     3]], dtype=int32), 'token_type_ids': Array([[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]], dtype=int32), 'attention_mask': Array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)}

In [257]:
outputs = model(**inputs)

In [258]:
outputs

FlaxQuestionAnsweringModelOutput(start_logits=Array([[ 0.14820898,  0.16998863,  0.14983207, -0.24656601,  0.06165942,
         0.06819007,  0.14820896, -0.23087387,  0.10601612,  0.2952439 ,
         0.18761368, -0.05515826, -0.1330166 ,  0.14820883]],      dtype=float32), end_logits=Array([[-0.04195   ,  0.06987897,  0.00533415,  0.12311459, -0.23050275,
        -0.04046968, -0.04195002,  0.0745064 , -0.0927614 , -0.06515643,
        -0.19795206, -0.07101451, -0.08701305, -0.04194993]],      dtype=float32), hidden_states=None, attentions=None)

In [259]:
start_scores = outputs.start_logits
end_scores = outputs.end_logits

In [260]:
# 标记每个token是开始和结束的概率
print(start_scores.shape,end_scores.shape)

(1, 14) (1, 14)


In [261]:
__all__ = [
    "FlaxAlbertPreTrainedModel",
    "FlaxAlbertModel",
    "FlaxAlbertForPreTraining",
    "FlaxAlbertForMaskedLM",
    "FlaxAlbertForSequenceClassification",
    "FlaxAlbertForMultipleChoice",
    "FlaxAlbertForTokenClassification",
    "FlaxAlbertForQuestionAnswering",
]