In [2]:
# 导入数学和随机模块，分别用于数值运算与随机数生成
import math
import random
# 从functools导入partial，用于函数偏应用（预填参数）
from functools import partial
# 引入类型提示相关工具
from typing import Callable, Optional, Tuple
import flax.linen as nn # 引入Flax的模块定义接口
# 导入JAX核心模块及其NumPy接口
import jax
import jax.numpy as jnp
# 用于处理Flax中的冻结字典（不可变参数容器）
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
# 导入Flax中用于合并attention mask的工具函数
from flax.linen import combine_masks, make_causal_mask
# 导入Flax中attention模块的核心函数：点积注意力权重计算
from flax.linen.attention import dot_product_attention_weights
# 导入用于嵌套字典展开/恢复的工具（常用于权重转换或配置结构化）
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax # JAX中用于跨设备同步的控制流工具
from jax.random import PRNGKey # JAX随机种子生成接口
# 从HuggingFace Transformers库中导入多种模型输出类型（用于兼容模型返回结构）
from transformers.modeling_flax_outputs import (
    FlaxBaseModelOutput,
    FlaxBaseModelOutputWithPastAndCrossAttentions,
    FlaxCausalLMOutputWithCrossAttentions,
    FlaxSeq2SeqLMOutput,
    FlaxSeq2SeqModelOutput,
    FlaxSeq2SeqQuestionAnsweringModelOutput,
    FlaxSeq2SeqSequenceClassifierOutput,
)
# 引入Flax模型相关的辅助工具（主要用于封装模型行为和生成文档字符串）
from transformers.modeling_flax_utils import (
    ACT2FN,  # 激活函数映射表
    FlaxPreTrainedModel,  # 预训练模型基类
    append_call_sample_docstring,  # 添加样例调用文档
    append_replace_return_docstrings,  # 添加或替换模型返回值文档
    overwrite_call_docstring,  # 覆盖模型 __call__ 的文档
)
# 引入用于文档生成与日志的工具函数
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
# 导入BART模型的配置类（用于构建与加载模型结构）
from transformers.models.bart.configuration_bart import BartConfig

In [3]:
logger = logging.get_logger(__name__)

In [4]:
_CHECKPOINT_FOR_DOC = "facebook/bart-base"
_CONFIG_FOR_DOC = "BartConfig"

In [5]:
BART_START_DOCSTRING = r"""
    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a Flax Linen
    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    Parameters:
        config ([`BartConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).

            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
            specified all the computation will be performed with the given `dtype`.

            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.**

            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].
"""

BART_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            For translation and summarization training, `decoder_input_ids` should be provided. If no
            `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
            for denoising pre-training following the paper.
        decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.

            If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
            paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.
        decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
            range `[0, config.max_position_embeddings - 1]`.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


BART_ENCODE_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""

BART_DECODE_INPUTS_DOCSTRING = r"""
    Args:
        decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            For translation and summarization training, `decoder_input_ids` should be provided. If no
            `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
            for denoising pre-training following the paper.
        encoder_outputs (`tuple(tuple(jnp.ndarray)`):
            Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
            hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
        encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.

            If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
            paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
        decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
            range `[0, config.max_position_embeddings - 1]`.
        past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
            Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
            auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""

In [6]:
# 该函数用于训练阶段解码器的输入预处理，将标签序列右移一位，插入起始 token，同时处理 -100 标记为 pad_token_id，
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
    # 创建与 input_ids 相同形状的零张量
    shifted_input_ids = jnp.zeros_like(input_ids)
    # 将原 input_ids 除最后一个 token 外，填入新张量中，从第1列开始
    shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
    # 第0列设置为解码器起始 token
    shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)
    # 将所有位置中原本是 -100 的标记替换为 pad_token_id
    shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
    return shifted_input_ids

In [7]:
config=BartConfig()

In [10]:
 make_causal_mask(
                jnp.ones((1,8), dtype="bool"), dtype="bool"
            )

Array([[[[ True, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True]]]],      dtype=bool)

In [25]:
 *a,b,c,d=(2, 128, 8, 64)

In [27]:
type(a)

list

In [28]:
len(a)

1

In [30]:
tuple([2]) + (1, 1,2)

(2, 1, 1, 2)

In [29]:
jnp.broadcast_to(
                jnp.arange(2) < 1 + 1,
                tuple([2]) + (1, 1,2),
            )

Array([[[[ True,  True]]],


       [[[ True,  True]]]], dtype=bool)

In [31]:
# @nn.compact 并不只能用于 __call__ 方法，也可以用于其他方法，但前提是：
# ✅ 该方法内部需要注册子模块或变量（如 self.param, self.variable, Module() 等）
# Flax 中子模块和参数的定义需要在 setup() 或 @nn.compact 装饰的方法中进行：
# 在 setup() 中定义子模块或变量：适合静态模块结构。
# 在被 @nn.compact 装饰的方法中定义子模块或变量：适合结构依赖输入、需要动态构建的情况。

class FlaxBartAttention(nn.Module):
    config: BartConfig # 配置
    embed_dim: int # 嵌入维度
    num_heads: int # 头数
    dropout: float = 0.0 # dropout比率
    causal: bool = False # 是否是因果掩码,如果是,就是自回归
    bias: bool = True
    dtype: jnp.dtype = jnp.float32  # 矩阵计算用的数据类型
    # flax linen的模块一般设置于此
    def setup(self) -> None:
        self.head_dim = self.embed_dim // self.num_heads # 每个头的表示大小 hd
        if self.head_dim * self.num_heads != self.embed_dim: # 头数必须被嵌入维度整除
            raise ValueError( # 如果不能整除,抛出错误
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {self.num_heads})."
            )
        # 线性层
        dense = partial(
            nn.Dense,
            self.embed_dim,
            use_bias=self.bias, # 是否使用截距
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        # q,k,v投影层
        self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
        self.out_proj = dense() # 最后的线性层
        # dropout层
        self.dropout_layer = nn.Dropout(rate=self.dropout)

        if self.causal: # 如果是自回归
            # 构造因果掩码
            self.causal_mask = make_causal_mask(
                jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
            )
    # 拆分嵌入维度 (b,s,d)-->(b,s,h,hd)
    def _split_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
    # 合并头 (b,s,h,hd)-->(b,s,d)
    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
    # 此函数用于将当前时间步的 key 和 value 拼接到缓存中（cached_key/value），
    # 实现自回归推理中的缓存更新，加快生成效率
    @nn.compact
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        # 判断是否已初始化缓存（首次调用时无缓存)
        # 判断当前模块实例中，变量集合 "cache" 下是否已经存在名为 "cached_key" 的变量。
        is_initialized = self.has_variable("cache", "cached_key")
        # 初始化缓存变量（key, value, index），默认填充为零，仅在首次调用时生效
        # self.variable(...) 是 惰性注册机制：只有变量不存在时才会执行初始化表达式。
        # 若变量已存在（is_initialized == True），这几句不会覆盖变量内容，只返回句柄。
        # self.variable(...) 执行逻辑
        # 如果 "cache" 变量集合中不存在 "cached_key"，则：
        # 注册该变量，初始化为 jnp.zeros(key.shape, key.dtype)
        # 如果已经存在 "cached_key"，则：
        # 不再执行初始化表达式，只返回这个变量的引用（一个 flax.core.scope.Variable 对象）
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
        # Flax 的变量机制：
        # self.has_variable(...) 是判断变量是否已经存在
        # self.variable(...) 是注册变量（如果不存在）并返回句柄
        # 所以这三句执行完后，下次再调用这个函数时，has_variable(...) 才会返回 True
        if is_initialized:
            # 获取缓存张量的维度信息：batch 维 + max_length+ num_heads + head_dim
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # 获取当前缓存的写入位置
            cur_index = cache_index.value
            # 生成写入位置索引（用于 dynamic_update_slice） (0,cur_index, 0, 0)
            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
            # 将当前 key 和 value 插入缓存中对应位置
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
            value = lax.dynamic_update_slice(cached_value.value, value, indices)
            # 更新缓存变量的值
            cached_key.value = key
            cached_value.value = value
            num_updated_cache_vectors = query.shape[1]  # 记录更新了多少个 token（通常为 1，但支持并行）
            # 更新插入位置,每次如果只是一个token,这次之后就更新为原位置+num_updated_cache_vectors
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # jnp.arange(max_length) < cur_index + num_updated_cache_vectors
            # 生成 [0, 1, ..., max_length-1]，判断每个位置是否 在当前生成范围内。
            # 返回形如 [True, True, ..., False, False] 的一维 bool 数组，前面是可看的 key，后面是不可看的。
            # 将上述 1D bool 数组广播成 full attention mask：shape = (*batch_dims, 1, tgt_len, max_length)
            # tgt_len = num_updated_cache_vectors，表示当前解码了多少个新 token；
            # max_length 是 key 的长度；最终形成 [batch, 1, tgt_len, max_length] 的因果 mask。
            # 这个 mask 控制的是：
            # 当前 query 只能 attend 到历史生成的 key（缓存过的部分），不能看到未来的填充值（如全 0 的未写入部分）。
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            # 将因果 mask 与原 attention_mask 结合
            attention_mask = combine_masks(pad_mask, attention_mask)
        # 第一次 is_initialized为False,这时直接返回key,value,attention_mask
        # 只有之后第二次,第三次...才会执行if is_initialized:之内的代码
        return key, value, attention_mask

    def __call__(
        self,
        hidden_states: jnp.ndarray,
        key_value_states: Optional[jnp.ndarray] = None,
        attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:
        # 判断是否是 cross-attention（用于 decoder 的 encoder-decoder attention）
        is_cross_attention = key_value_states is not None
        batch_size = hidden_states.shape[0]

        # 计算 query 向量（投影后 shape 为 [b, t, h * d]）
        query_states = self.q_proj(hidden_states)
        # 根据 attention 类型决定 key/value 来源
        # key_value_states应该是编码器输出
        if is_cross_attention:
            # cross_attentions
            key_states = self.k_proj(key_value_states)
            value_states = self.v_proj(key_value_states)
        else:
            # self_attention
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)
        # 拆分多头（shape: [b, t, h, d_head]）
        query_states = self._split_heads(query_states)
        key_states = self._split_heads(key_states)
        value_states = self._split_heads(value_states)

        # ========== 处理因果 Mask ========== #
        if self.causal:
            query_length, key_length = query_states.shape[1], key_states.shape[1]
            if self.has_variable("cache", "cached_key"):
                # 缓存中已有 key，说明是增量推理模式（step-by-step）
                mask_shift = self.variables["cache"]["cache_index"] # 当前缓存中插值位置
                max_decoder_length = self.variables["cache"]["cached_key"].shape[1] # 缓存中的key长度
                # 从预定义的因果 mask 中动态截取有效部分
                causal_mask = lax.dynamic_slice(
                    self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
                )
            else: # 正常情况下,切出(当前q长度,当前k长度)
                causal_mask = self.causal_mask[:, :, :query_length, :key_length]
            # 广播
            causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])

        # ========== 合并 Padding 和 Causal Mask ========== #
        if attention_mask is not None and self.causal:
            # broadcast 到 [b, 1, tgt_len, src_len] 形状，并与因果 mask 做与运算
            attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
            attention_mask = combine_masks(attention_mask, causal_mask)
        elif self.causal: # 没有填充掩码,这时就只用因果掩码
            attention_mask = causal_mask
        elif attention_mask is not None: # 非因果注意力，只扩维用于 broadcast
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        # 在快速自回归解码过程中，我们一次只输入一个位置的内容，并逐步缓存 key 和 value。
        # ========== 缓存处理（推理阶段单步拼接）========== #
        if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
            # 拼接 key/value 到缓存，并更新 mask,这时的attention_mask一般已经是缓存的因果mask
            key_states, value_states, attention_mask = self._concatenate_to_cache(
                key_states, value_states, query_states, attention_mask
            )

        # # 将布尔注意力掩码转换为注意力偏差
        # ========== 将 mask 转为 attention bias ========== #
        if attention_mask is not None:
            # True → 0.0；False → -inf（禁止 attend）
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
            )
        else: 
            attention_bias = None
        # ========== Dropout 随机数生成器 ========== #
        dropout_rng = None
        if not deterministic and self.dropout > 0.0:
            dropout_rng = self.make_rng("dropout") # 构造随机rng key
        # ========== 计算注意力权重（含 dropout 和 bias）========== #
        # broadcast_dropout=True每个 head 在相同位置的 attention 都被丢弃或保留；
        # 相比 broadcast_dropout=False（每个 head 各自丢弃不同位置），它更节省资源，也更稳定。
        attn_weights = dot_product_attention_weights(
            query_states,  # 查询向量，形状 (..., q_len, num_heads, head_dim)
            key_states,  # 键向量，形状 (..., kv_len, num_heads, head_dim)
            bias=attention_bias,  # 注意力偏置（通常由 mask 转换而来，用于屏蔽非法位置）
            dropout_rng=dropout_rng,   # 用于注意力 dropout 的随机数生成器
            dropout_rate=self.dropout,   # dropout 比率，仅在训练时启用
            broadcast_dropout=True,  # 是否对所有 heads 应用相同的 dropout mask（节省计算）
            deterministic=deterministic,  # 是否为确定性模式（True 表示推理，禁用 dropout）
            dtype=self.dtype,  # 数据类型，如 float32、bfloat16 等
            precision=None,   # 点积计算的精度选项，通常为默认精度即可
        )
        # ========== 使用注意力权重加权 value，得到输出 ========== #
        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
        attn_output = self._merge_heads(attn_output) # 多头合并为原始维度
        attn_output = self.out_proj(attn_output)  # 最终线性投影输出
        # 返回注意力输出和权重矩阵
        return attn_output, attn_weights

In [32]:
# 编码器层
class FlaxBartEncoderLayer(nn.Module):
    config: BartConfig  # 配置
    dtype: jnp.dtype = jnp.float32 # 计算的数据类型

    def setup(self) -> None:
        self.embed_dim = self.config.d_model # 嵌入维度
        self.self_attn = FlaxBartAttention( # 自注意力
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.encoder_attention_heads,
            dropout=self.config.attention_dropout,
            dtype=self.dtype,
        )
        self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) # 标准化
        self.dropout_layer = nn.Dropout(rate=self.config.dropout) # dropout
        self.activation_fn = ACT2FN[self.config.activation_function] # 激活函数
        self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) 
        self.fc1 = nn.Dense( # 前馈第一个线性层
            self.config.encoder_ffn_dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        self.fc2 = nn.Dense( # 前馈降维线性层
            self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )
        self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) # 最后的标准化层

    def __call__(
        self,
        hidden_states: jnp.ndarray,
        attention_mask: jnp.ndarray,
        output_attentions: bool = True,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:
        residual = hidden_states # 残差
        # 获取注意力输出
        hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
        # dropout
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        hidden_states = residual + hidden_states # 注意力前后残差
        hidden_states = self.self_attn_layer_norm(hidden_states) # 标准化
        residual = hidden_states # 残差
        hidden_states = self.activation_fn(self.fc1(hidden_states)) # 前馈升维
        # 升维后的dropout 
        hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
        hidden_states = self.fc2(hidden_states) # 前馈降维层
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) # dropout
        hidden_states = residual + hidden_states # 前馈前后残差
        hidden_states = self.final_layer_norm(hidden_states) # 标准化
        outputs = (hidden_states,) # 获取编码器层的输出
        if output_attentions:
            outputs += (attn_weights,)
        return outputs

In [33]:
config.encoder_layers

12

In [34]:
config.encoder_layerdrop

0.0

In [35]:
random.uniform(0, 1)

0.294972321158809

In [36]:
class FlaxBartEncoderLayerCollection(nn.Module): # 编码器层集合
    config: BartConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        # 编码器
        self.layers = [
            FlaxBartEncoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers)
        ]
        self.layerdrop = self.config.encoder_layerdrop # 这个会丢弃整个编码器层

    def __call__(
        self,
        hidden_states,
        attention_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        all_attentions = () if output_attentions else None # 存储每层的注意力权重矩阵
        all_hidden_states = () if output_hidden_states else None # 存储每个编码器层的输出
        for encoder_layer in self.layers: # 遍历每个编码器层
            if output_hidden_states: # 存储每层的输入
                all_hidden_states = all_hidden_states + (hidden_states,)
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1) # 一个随机数
            # 如果是训练模式,并且获取的随机数<配置中设置的layerdrop 跳过当前编码器层
            if not deterministic and (dropout_probability < self.layerdrop):  # skip the layer
                layer_outputs = (None, None)
            else: # 否则,正常做注意力
                layer_outputs = encoder_layer( # 注意力
                    hidden_states,
                    attention_mask,
                    output_attentions,
                    deterministic,
                )
            hidden_states = layer_outputs[0] # 编码器层的输出
            if output_attentions: # 用来存储每层的注意力权重矩阵
                all_attentions = all_attentions + (layer_outputs[1],)
        # 用来存储最后一层的隐藏状态
        if output_hidden_states:
            all_hidden_states += (hidden_states,)
        # 输出:(经过所有层之后的编码器输出,...)
        outputs = (hidden_states, all_hidden_states, all_attentions)
        # 返回元组
        if not return_dict:
            return tuple(v for v in outputs if v is not None)
        # 返回结构化输出
        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )

In [None]:
class FlaxBartDecoderLayer(nn.Module):
    config: BartConfig 
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        self.embed_dim = self.config.d_model # 嵌入维度
        self.self_attn = FlaxBartAttention( # 解码器自注意力
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.decoder_attention_heads,
            dropout=self.config.attention_dropout,
            causal=True, # 使用因果掩码
            dtype=self.dtype,
        )
        self.dropout_layer = nn.Dropout(rate=self.config.dropout) # dropout
        self.activation_fn = ACT2FN[self.config.activation_function] # 激活函数
        self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) # 前馈中间层的dropout
        # 自注意力之后的norm
        self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
        self.encoder_attn = FlaxBartAttention( # 跨注意力
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.decoder_attention_heads,
            dropout=self.config.attention_dropout,
            dtype=self.dtype,
        )
        self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
        self.fc1 = nn.Dense( # 前馈升维
            self.config.decoder_ffn_dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        self.fc2 = nn.Dense( # 前馈降维
            self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )
        self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

    def __call__(
        self,
        hidden_states: jnp.ndarray,
        attention_mask: jnp.ndarray,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        output_attentions: bool = True,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:
        residual = hidden_states # 残差

        # 解码器自注意力
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
        )
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) # dropout
        hidden_states = residual + hidden_states # 自注意力前后残差
        hidden_states = self.self_attn_layer_norm(hidden_states) # norm

        # Cross-Attention Block
        cross_attn_weights = None # 交叉注意力权重
        if encoder_hidden_states is not None: # 如果编码器输出存在的话
            residual = hidden_states # 残差
            hidden_states, cross_attn_weights = self.encoder_attn(
                hidden_states=hidden_states, # 上一步自注意力的输出
                key_value_states=encoder_hidden_states, # 编码器输出
                attention_mask=encoder_attention_mask, # 编码器填充掩码
            )
            hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) # dropout
            hidden_states = residual + hidden_states # 跨注意力前后残差+norm
            hidden_states = self.encoder_attn_layer_norm(hidden_states) 

        # 前馈全连接
        residual = hidden_states # 残差
        hidden_states = self.activation_fn(self.fc1(hidden_states)) # 升维+激活
        hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) # dropout
        hidden_states = self.fc2(hidden_states) # 降维
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) # dropout
        hidden_states = residual + hidden_states # 前馈前后残差+norm
        hidden_states = self.final_layer_norm(hidden_states)
        outputs = (hidden_states,) # 解码器层的输出
        if output_attentions:
            outputs += (self_attn_weights, cross_attn_weights)
        return outputs

In [37]:
config.decoder_layers

12

In [38]:
class FlaxBartDecoderLayerCollection(nn.Module): # 解码器
    config: BartConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        self.layers = [ # 解码器层堆叠
            FlaxBartDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers)
        ]
        self.layerdrop = self.config.decoder_layerdrop # 这个可以整层丢弃

    def __call__(
        self,
        hidden_states,
        attention_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # decoder layers
        all_hidden_states = () if output_hidden_states else None # 每个解码器层的输出的集合
        all_self_attns = () if output_attentions else None # 所有的自注意力权重矩阵
        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None # 交叉注意力权重矩阵
        # 遍历每个decoder层
        for decoder_layer in self.layers:
            if output_hidden_states: # 添加每个decoder layer的输入
                all_hidden_states += (hidden_states,)
                # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1) 
            # 如果是训练模式,并且dropout_probability < self.layerdrop满足条件
            if not deterministic and (dropout_probability < self.layerdrop): # 整层丢弃
                layer_outputs = (None, None, None)
            else: # 正常情况
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    init_cache=init_cache,
                    output_attentions=output_attentions,
                    deterministic=deterministic,
                )
            
            hidden_states = layer_outputs[0] # 当前解码器层的输出
            if output_attentions:
                all_self_attns += (layer_outputs[1],)
                if encoder_hidden_states is not None:
                    all_cross_attentions += (layer_outputs[2],)

        # 添加最后一个解码器层的输出
        if output_hidden_states:
            all_hidden_states += (hidden_states,)
        
        outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
        # 返回元组
        if not return_dict:
            return tuple(v for v in outputs if v is not None)
        # 返回结构化输出
        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states, # 最后一层解码器的输出
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            cross_attentions=all_cross_attentions,
        )

In [39]:
class FlaxBartClassificationHead(nn.Module): # 分类头
    """句子级分类任务的头部."""
    config: BartConfig
    inner_dim: int
    num_classes: int
    pooler_dropout: float
    dtype: jnp.dtype = jnp.float32

    def setup(self): 
        self.dense = nn.Dense( # 线性层
            self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )
        self.dropout = nn.Dropout(rate=self.pooler_dropout)
        self.out_proj = nn.Dense( # 最后的输出线性层
            self.num_classes,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )

    def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
        hidden_states = self.dropout(hidden_states, deterministic=deterministic) # dropout
        hidden_states = self.dense(hidden_states) # 线性投影
        hidden_states = jnp.tanh(hidden_states) # 用tanh激活函数增加非线性
        hidden_states = self.dropout(hidden_states, deterministic=deterministic) # dropout
        hidden_states = self.out_proj(hidden_states) # 最后输出几分类
        return hidden_states

In [40]:
config.max_position_embeddings 

1024

In [41]:
class FlaxBartEncoder(nn.Module):
    config: BartConfig
    embed_tokens: nn.Embed
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        self.dropout_layer = nn.Dropout(rate=self.config.dropout) # dropout层
        embed_dim = self.config.d_model # 嵌入维度
        self.padding_idx = self.config.pad_token_id # 填充id
        self.max_source_positions = self.config.max_position_embeddings  # 最大位置
        self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 # 缩放嵌入

        # Bart 的设置是，如果指定了 padding_idx，则将嵌入 ID 偏移 2
        # 并适当调整num_embeddings。其他型号没有这个技巧
        self.offset = 2
        self.embed_positions = nn.Embed(
            self.config.max_position_embeddings + self.offset,
            embed_dim,
            embedding_init=jax.nn.initializers.normal(self.config.init_std), # 正态分布初始化
            dtype=self.dtype,
        )
        self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype) # 子模块是编码器
        self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) # norm

    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        input_shape = input_ids.shape # 输入的形状
        input_ids = input_ids.reshape(-1, input_shape[-1]) # 变形成(b,s)
        # 词嵌入
        inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
        # 位置嵌入
        embed_pos = self.embed_positions(position_ids + self.offset)

        hidden_states = inputs_embeds + embed_pos # 带位置的嵌入表示
        hidden_states = self.layernorm_embedding(hidden_states) # norm
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) # dropout
        # 经过编码器后的输出
        outputs = self.layers(
            hidden_states,
            attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if not return_dict:
            return outputs

        return FlaxBaseModelOutput(
            last_hidden_state=outputs.last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [43]:
class FlaxBartDecoder(nn.Module): # 解码器
    config: BartConfig
    embed_tokens: nn.Embed
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        self.dropout_layer = nn.Dropout(rate=self.config.dropout) # dropout 层
        embed_dim = self.config.d_model # 嵌入维度
        self.padding_idx = self.config.pad_token_id # 填充id
        self.max_target_positions = self.config.max_position_embeddings
        self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
        
        self.offset = 2
        self.embed_positions = nn.Embed(
            self.config.max_position_embeddings + self.offset,
            embed_dim,
            embedding_init=jax.nn.initializers.normal(self.config.init_std),
            dtype=self.dtype,
        )

        self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype) # 解码器集合子模块
        self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        input_shape = input_ids.shape # 输入形状
        input_ids = input_ids.reshape(-1, input_shape[-1]) # (b,s)
        # 目标序列词嵌入
        inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
        # embed positions
        positions = self.embed_positions(position_ids + self.offset)
        hidden_states = inputs_embeds + positions # 合并后的嵌入表示
        hidden_states = self.layernorm_embedding(hidden_states) # norm
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) # dropout
        # 解码器输出
        outputs = self.layers(
            hidden_states,
            attention_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if not return_dict:
            return outputs

        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=outputs.last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )

In [44]:
class FlaxBartModule(nn.Module):
    config: BartConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        self.shared = nn.Embed(  # 共享词嵌入模块
            self.config.vocab_size, # 词表大小
            self.config.d_model, # 嵌入维度
            embedding_init=jax.nn.initializers.normal(self.config.init_std), # 嵌入初始化
            dtype=self.dtype,
        )
        # encoder和decoder 使用相同的词嵌入
        self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
        self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared)

    def _get_encoder_module(self): # 获取encoder
        return self.encoder

    def _get_decoder_module(self): # 获取decoder
        return self.decoder

    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        encoder_outputs = self.encoder( # 编码器输出
            input_ids=input_ids, # 传入编码器输入
            attention_mask=attention_mask, # 编码器填充mask
            position_ids=position_ids, # 位置ids
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic, # 训练或推理模式
        )

        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids, # 目标序列输入
            attention_mask=decoder_attention_mask, # 解码器因果掩码
            position_ids=decoder_position_ids, # 目标序列位置ids
            encoder_hidden_states=encoder_outputs[0],# 编码器输出用来做跨交叉的k,v
            encoder_attention_mask=attention_mask, # 编码器填充mask
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )
        # 返回最后的输出
        if not return_dict:
            return decoder_outputs + encoder_outputs

        return FlaxSeq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state, # 经过encoder-decoder后的输出
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,# 编码器输出
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )

In [None]:
# 1. 模型保存/加载时的参数嵌套结构
# 当你用 .save_pretrained() 或 .from_pretrained() 保存/加载模型时，base_model_prefix 会决定保存文件中的结构。例如：
# from transformers import FlaxBartForConditionalGeneration
# model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-base")
# print(model.params.keys())
# 你会看到参数是嵌套在 'model' 这个 key 下的：
# dict_keys(['model'])  # 因为 base_model_prefix = "model"
# 这意味着：权重文件里是：
# {
#   "model": { ... 所有 Bart 模型参数 ... },
#   "lm_head": ...
# }

In [48]:
class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
    config_class = BartConfig # 配置
    base_model_prefix: str = "model" # 用于自动匹配权重时的子模块前缀
    module_class: nn.Module = None # 要实例化的核心模型模块（由子类指定）

    def __init__(
        self,
        config: BartConfig,
        input_shape: Tuple[int] = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 实例化子模块（如编码器-解码器模型），传入配置与其他参数
        module = self.module_class(config=config, dtype=dtype, **kwargs) 
        # 调用基类构造方法，完成权重初始化等工作
        # 当前类的init_weights可能被间接调用（由 FlaxPreTrainedModel.__init__ 内部触发）
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 构造初始化输入张量（伪数据），用于执行一次 forward，触发参数初始化
        input_ids = jnp.zeros(input_shape, dtype="i4")
        # make sure initialization pass will work for FlaxBartForSequenceClassificationModule
        input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)  # 设置结尾 token 保证合法性
        attention_mask = jnp.ones_like(input_ids) # 编码器填充mask
        decoder_input_ids = input_ids # 解码器输入
        decoder_attention_mask = jnp.ones_like(input_ids) # 解码器因果mask
        batch_size, sequence_length = input_ids.shape
        # 构造位置编码张量
        position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
        decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
        # 拆分随机数 rng key
        # 拆分随机数生成器，分别用于参数和 dropout 初始化
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng} # 设置rngs 参数流和dropout流
        # 调用模块的 init 方法，返回包含参数的 PyTree
        random_params = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            decoder_input_ids,
            decoder_attention_mask,
            position_ids,
            decoder_position_ids,
        )["params"]
        # 如果提供了已有参数（例如加载某些层的预训练参数），则用已有参数覆盖初始化参数
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]  # 填充缺失项
            self._missing_keys = set()
            return freeze(unflatten_dict(params)) # 重新 freeze 为不可变字典
        else: # 如果没提供传入参数,返回随机初始化的参数
            return random_params
    # 初始化用于解码阶段的缓存（如 KV 缓存），用于加速自回归生成。
    def init_cache(self, batch_size, max_length, encoder_outputs):
        r"""
        Args:
            batch_size (`int`):
                batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
            max_length (`int`):
                maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
                cache.
            encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
                `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
                `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
                is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
                cross-attention of the decoder.
        """
        # 构造伪输入，仅用于触发 decoder 中的缓存初始化逻辑
        decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
        decoder_attention_mask = jnp.ones_like(decoder_input_ids)
        decoder_position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
        )
         # 定义只调用 decoder 的 forward 方法
        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
            decoder_module = module._get_decoder_module()
            return decoder_module(
                decoder_input_ids,
                decoder_attention_mask,
                decoder_position_ids,
                **kwargs,
            )
        # 调用 init 仅执行 decoder 的初始化，设置 init_cache=True 以初始化 KV 缓存
        init_variables = self.module.init(
            jax.random.PRNGKey(0),
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            encoder_hidden_states=encoder_outputs[0],
            init_cache=True,
            method=_decoder_forward,   # 只初始化 decoder 子模块
        )
        return unfreeze(init_variables["cache"])  # 返回初始化后的缓存（KV缓存等），用于生成任务

    @add_start_docstrings(BART_ENCODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=BartConfig)
    def encode(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

        >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

        >>> text = "My friends are cool but they eat too many carbs."
        >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax")
        >>> encoder_outputs = model.encode(**inputs)
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        if attention_mask is None: # 这种默认的填充mask
            attention_mask = jnp.ones_like(input_ids)
        if position_ids is None: #设置默认的位置ids
            batch_size, sequence_length = input_ids.shape
            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None: # 设置rng dropout流
            rngs["dropout"] = dropout_rng

        def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
            encode_module = module._get_encoder_module() # encoder的前向模式
            return encode_module(input_ids, attention_mask, position_ids, **kwargs)

        return self.module.apply(
            {"params": params or self.params}, # 参数
            input_ids=jnp.array(input_ids, dtype="i4"), # input_ids
            attention_mask=jnp.array(attention_mask, dtype="i4"), # 
            position_ids=jnp.array(position_ids, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train, # 确定性, False表示使用dropout
            rngs=rngs,
            method=_encoder_forward, # 调用的方法 调用的是编码器
        )

    @add_start_docstrings(BART_DECODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=BartConfig)
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Example:

        ```python
        >>> import jax.numpy as jnp
        >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

        >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

        >>> text = "My friends are cool but they eat too many carbs."
        >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax")
        >>> encoder_outputs = model.encode(**inputs)

        >>> decoder_start_token_id = model.config.decoder_start_token_id
        >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

        >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
        >>> last_decoder_hidden_states = outputs.last_hidden_state
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        encoder_hidden_states = encoder_outputs[0] # 编码器输出
        if encoder_attention_mask is None: # 如果不存在编码器填充mask
            batch_size, sequence_length = encoder_hidden_states.shape[:2]
            encoder_attention_mask = jnp.ones((batch_size, sequence_length)) # 默认的编码器填充掩码

        batch_size, sequence_length = decoder_input_ids.shape
        if decoder_attention_mask is None: # 设置默认的解码器因果掩码
            decoder_attention_mask = jnp.ones((batch_size, sequence_length))
        # 设置默认的解码器位置ids
        if decoder_position_ids is None: 
            if past_key_values is not None: 
                raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")

            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
            )

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None: # 设置默认的rng dropout流
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params} # 参数

        # 如果传入了 past_key_values，那么缓存已经被初始化，此时需要传递一个私有标志 init_cache，以确保缓存被使用。
        # 同时还必须确保缓存被标记为 mutable（可变），这样 FlaxBartAttention 模块才能修改缓存内容。
        if past_key_values:
            inputs["cache"] = past_key_values 
            mutable = ["cache"]
        else:
            mutable = False
        # 解码器前向call
        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
            decoder_module = module._get_decoder_module()
            return decoder_module(
                decoder_input_ids,
                decoder_attention_mask,
                decoder_position_ids,
                **kwargs,
            )

        outputs = self.module.apply(
            inputs, # inputs是一个字典结构 有参数,以及缓存等
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),# 目标序列输入
            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
            encoder_hidden_states=encoder_hidden_states, # 编码器输入
            encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs, # 随机数生成器
            mutable=mutable, # 可以修改的变量 "cache"
            method=_decoder_forward,
        )

        # 将更新的缓存添加到模型输出
        if past_key_values is not None and return_dict:
            outputs, past = outputs # 解码器输出+缓存
            outputs["past_key_values"] = unfreeze(past["cache"]) # 设置缓存
            return outputs 
        elif past_key_values is not None and not return_dict:
            outputs, past = outputs
            outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] # 返回元组

        return outputs

    @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
    def __call__(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        decoder_input_ids: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 准备编码器输入
        if attention_mask is None: # 设置默认的编码器填充mask
            attention_mask = jnp.ones_like(input_ids)
        if position_ids is None: # 设置默认的原序列位置ids
            batch_size, sequence_length = input_ids.shape
            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # 这时decoder_input_ids没有传入,只给了 input_ids（目标序列），我就自动 shift 它，生成 decoder_input_ids。
        # 这段逻辑默认 input_ids 同时作为目标序列用在 decoder 上，是为了支持预训练或某些推理场景的兼容行为，不是通用的翻译推理
        # 逻辑。在翻译时，应显式传入 decoder_input_ids。
        if decoder_input_ids is None:
            decoder_input_ids = shift_tokens_right( # 目标序列输入input_ids
                input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
            )
        if decoder_attention_mask is None: # 设置默认的解码器自注意力因果掩码
            decoder_attention_mask = jnp.ones_like(decoder_input_ids)
        if decoder_position_ids is None: # 设置默认的解码器目标序列位置ids
            batch_size, sequence_length = decoder_input_ids.shape
            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
            )

        # Handle any PRNG if needed 处理随机数 rng流
        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
        # self.__call__() 是模型封装类的前向调用；
        # self.module.__call__() 才是网络结构的具体实现。
        # 这并不是递归调用 self.__call__，而是调用 子模块.__call__()，也就是你定义的 transformer 编码器-解码器网络结构。
        return self.module.apply(
            {"params": params or self.params}, # 参数
            input_ids=jnp.array(input_ids, dtype="i4"),
            attention_mask=jnp.array(attention_mask, dtype="i4"),
            position_ids=jnp.array(position_ids, dtype="i4"),
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
        )

In [None]:
help(FlaxBartPreTrainedModel.__call__)

In [None]:
# >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

# >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
# >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

# >>> text = "My friends are cool but they eat too many carbs."
# >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax")
# >>> encoder_outputs = model.encode(**inputs)

# >>> decoder_start_token_id = model.config.decoder_start_token_id
# >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

# >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
# >>> last_decoder_hidden_states = outputs.last_hidden_state

In [52]:
@add_start_docstrings(
    "The bare Bart Model transformer outputting raw hidden-states without any specific head on top.",
    BART_START_DOCSTRING,
)
class FlaxBartModel(FlaxBartPreTrainedModel): 
    config: BartConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    module_class = FlaxBartModule # 设定的子模块 返回解码器最后一层的输出

append_call_sample_docstring(FlaxBartModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)

In [None]:
help(FlaxBartModel.__call__)

In [54]:
from transformers import AutoTokenizer, FlaxBartModel
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
model = FlaxBartModel.from_pretrained("facebook/bart-base")

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

flax_model.msgpack:   0%|          | 0.00/558M [00:00<?, ?B/s]

In [55]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")

In [56]:
outputs = model(**inputs)

In [57]:
last_hidden_states = outputs.last_hidden_state

In [59]:
last_hidden_states.shape

(1, 8, 768)

In [60]:
class FlaxBartForConditionalGenerationModule(nn.Module): # 条件生成模块
    config: BartConfig
    dtype: jnp.dtype = jnp.float32
    bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros # 偏距初始化方法

    def setup(self):
        self.model = FlaxBartModule(config=self.config, dtype=self.dtype)  # 设定的子模块 返回解码器最后一层的输出
        self.lm_head = nn.Dense(  # 掩码语言头
            self.model.shared.num_embeddings,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        # bias参数
        # 这是一个占位的偏置项，方便微调场景中手动设置（如 adapter tuning、LoRA 之类），但默认主干训练时保持不变，类似一个可注册但冻结的 bias
        self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))

    def _get_encoder_module(self):
        return self.model.encoder

    def _get_decoder_module(self):
        return self.model.decoder

    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        outputs = self.model( # 获取encoder-decoder的输出
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            position_ids=position_ids,
            decoder_position_ids=decoder_position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        hidden_states = outputs[0] # 最后一层的隐藏状态 (b,s,d)
        # 如果设定要共享词嵌入
        if self.config.tie_word_embeddings:
            shared_embedding = self.model.variables["params"]["shared"]["embedding"] # 获取嵌入
            # 调用lm_head获取输出lm_logits hidden_states传入的参数
            lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
        else: # 不共享时,直接前向
            lm_logits = self.lm_head(hidden_states) 
        # 加上截距
        # 训练时禁用梯度（stop_gradient）的做法，主要是：
        # 保留这个 bias，不让它被损失函数优化过程干扰。
        lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
        # 返回输出 元组或结构字典
        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return output

        return FlaxSeq2SeqLMOutput(
            logits=lm_logits,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )

In [None]:
# 用于微调阶段、加载 checkpoint 时的结构对齐
# self.param("final_logits_bias", ...) 注册了这个参数结构，哪怕不参与训练；
# 这允许你在微调或加载权重时，自动从 .ckpt 或 .safetensors 中加载已有的偏置值；
# 不需要写自定义加载逻辑，也不需要自定义 forward 接口。
# 换句话说：
# 即使你不更新这个 bias，但你依然可能加载到了非零值，并在推理中使用。

In [62]:
@add_start_docstrings(
    "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
)
class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel): # 条件生成模型
    module_class = FlaxBartForConditionalGenerationModule # 子模块 返回(b,s,v)
    dtype: jnp.dtype = jnp.float32

    @add_start_docstrings(BART_DECODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=BartConfig)
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Example:

        ```python
        >>> import jax.numpy as jnp
        >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

        >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

        >>> text = "My friends are cool but they eat too many carbs."
        >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax")
        >>> encoder_outputs = model.encode(**inputs)

        >>> decoder_start_token_id = model.config.decoder_start_token_id
        >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

        >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
        >>> logits = outputs.logits
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict
        
        encoder_hidden_states = encoder_outputs[0] # 获取编码器输出
        if encoder_attention_mask is None: # 编码器填充mask
            batch_size, sequence_length = encoder_hidden_states.shape[:2]
            encoder_attention_mask = jnp.ones((batch_size, sequence_length))

        batch_size, sequence_length = decoder_input_ids.shape
        if decoder_attention_mask is None: # 设置默认的解码器因果mask
            decoder_attention_mask = jnp.ones((batch_size, sequence_length))

        if decoder_position_ids is None: # 设置默认的解码器位置ids
            # 如果有缓存,就必须传入decoder_position_ids
            if past_key_values is not None:
                raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")

            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
            )

        # 设置rngs 这里设置dropout随机流
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params} # 设置输入 参数 

        # 如果传递了 past_key_values，则缓存已初始化，必须有一个私有标志 init_cache
        # 传递以确保使用缓存。必须确保缓存被标记为可变的，以便
        # 可以通过 FlaxBartAttention 模块进行更改
        if past_key_values:
            inputs["cache"] = past_key_values # 这个作为参数的一部分 名称:cache
            mutable = ["cache"] # 设置缓存参数可变
        else: 
            mutable = False
        
        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
            decoder_module = module._get_decoder_module()
            outputs = decoder_module(
                decoder_input_ids,
                decoder_attention_mask,
                decoder_position_ids,
                **kwargs,
            )
            hidden_states = outputs[0] # 解码器输出
            # 设置lm_logits,共享嵌入和不共享嵌入的两种情况
            if self.config.tie_word_embeddings: 
                shared_embedding = module.model.variables["params"]["shared"]["embedding"]
                lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
            else:
                lm_logits = module.lm_head(hidden_states)
            # 返回掩码语言头输出
            lm_logits += module.final_logits_bias.astype(self.dtype)
            return lm_logits, outputs
        # 只使用解码器的情况
        outputs = self.module.apply(
            inputs,
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            mutable=mutable,
            method=_decoder_forward,
        )

        if past_key_values is None: # 如果没有缓存
            lm_logits, decoder_outputs = outputs
        else: # 有缓存的情况 解包情况
            (lm_logits, decoder_outputs), past = outputs
        # 返回结果
        if return_dict:
            outputs = FlaxCausalLMOutputWithCrossAttentions(
                logits=lm_logits,
                hidden_states=decoder_outputs.hidden_states,
                attentions=decoder_outputs.attentions,
                cross_attentions=decoder_outputs.cross_attentions,
            )
        else:
            outputs = (lm_logits,) + decoder_outputs[1:]

        # 如果有缓存,并且设定返回字典
        if past_key_values is not None and return_dict:
            outputs["past_key_values"] = unfreeze(past["cache"]) # 把缓存加入输出的字典结构
            return outputs
        elif past_key_values is not None and not return_dict: # 返回元组的情况
            outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]

        return outputs

    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        max_length,
        attention_mask: Optional[jax.Array] = None,
        decoder_attention_mask: Optional[jax.Array] = None,
        encoder_outputs=None,
        **kwargs,
    ):
        # 获取当前解码序列的形状
        batch_size, seq_length = decoder_input_ids.shape
        # 初始化缓存结构（用于加速解码），包含自注意力和交叉注意力的 key/value 状态
        past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
        # 注意，通常情况下，对于位置 x > input_ids.shape[-1] 和 x < cache_length，我们需要在 attention_mask 中填充 0。
        # 但由于 decoder 使用的是因果掩码（causal mask），这些位置反正也会被屏蔽。
        # 因此我们可以在这里构造一个静态的 attention_mask，这样在编译时更高效 扩展注意力掩码
        # 构造扩展的 decoder attention mask（形状为 [batch_size, max_length]）
        # 由于 decoder 使用 causal mask，理论上不需要手动屏蔽未来的位置，因此这里用静态的 mask 即可编译时优化
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if decoder_attention_mask is not None:
            # 生成 position_ids：即每个 token 的位置，基于已有 mask 的 cumsum
            position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
            # 将实际的 decoder_attention_mask 更新到静态模板中，以保留有效位置
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
        else: # 如果未提供 decoder_attention_mask，则直接生成标准的连续位置编码
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
        # 返回 decoder 所需的全部输入，供生成过程中的单步调用使用
        return {
            "past_key_values": past_key_values, # 用于缓存历史的 attention 信息
            "encoder_outputs": encoder_outputs,  # 编码器的输出，供 cross-attention 使用
            "encoder_attention_mask": attention_mask,  # 编码器侧的 attention mask
            "decoder_attention_mask": extended_attention_mask,  # 解码器侧的 attention mask（已扩展）
            "decoder_position_ids": position_ids,  # 解码器的位置编码
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
         # 更新缓存值为当前输出中的 past_key_values（即更新后的解码历史）
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        # 更新 position_ids，令其自增一（只保留最新位置，用于下一步生成）
        model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
        return model_kwargs # 返回更新后的 kwargs，供下一步生成使用

In [None]:
help(FlaxBartForConditionalGeneration.decode)

In [65]:
from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
    
model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]

flax_model.msgpack:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [66]:
text = "My friends are cool but they eat too many carbs."
inputs = tokenizer(text, max_length=1024, return_tensors="jax")

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [67]:
encoder_outputs = model.encode(**inputs)

In [68]:
decoder_start_token_id = model.config.decoder_start_token_id
decoder_start_token_id

2

In [72]:
jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4")

Array([[1]], dtype=int32)

In [69]:
decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id

In [73]:
decoder_input_ids

Array([[2]], dtype=int32)

In [74]:
outputs = model.decode(decoder_input_ids, encoder_outputs)

In [75]:
logits = outputs.logits
logits.shape

(1, 1, 50264)

In [76]:
FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING = """
    Returns:

    Summarization example:

    ```python
    >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

    >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
    >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

    >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
    >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="np")

    >>> # Generate Summary
    >>> summary_ids = model.generate(inputs["input_ids"]).sequences
    >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
    ```

    Mask filling example:

    ```python
    >>> import jax
    >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

    >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large")
    >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

    >>> TXT = "My friends are <mask> but they eat too many carbs."
    >>> input_ids = tokenizer([TXT], return_tensors="jax")["input_ids"]

    >>> logits = model(input_ids).logits
    >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item()
    >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
    >>> values, predictions = jax.lax.top_k(probs, k=1)

    >>> tokenizer.decode(predictions).split()
    ```
"""
overwrite_call_docstring(
    FlaxBartForConditionalGeneration, 
    BART_INPUTS_DOCSTRING + FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING
)
append_replace_return_docstrings(
    FlaxBartForConditionalGeneration,
    output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)

In [79]:
jnp.where(jnp.array([[1,2,3]]) == 2, 1, 0)

Array([[0, 1, 0]], dtype=int32, weak_type=True)

In [80]:
class FlaxBartForSequenceClassificationModule(nn.Module): # 序列分类模块
    config: BartConfig
    dtype: jnp.dtype = jnp.float32
    num_labels: Optional[int] = None

    def setup(self):
        self.model = FlaxBartModule(config=self.config, dtype=self.dtype) # 返回解码器输出(b,s,d)
        self.classification_head = FlaxBartClassificationHead( # 分类头
            config=self.config,
            inner_dim=self.config.d_model,
            num_classes=self.num_labels if self.num_labels is not None else self.config.num_labels,
            pooler_dropout=self.config.classifier_dropout,
        )

    def _get_encoder_module(self):
        return self.model.encoder

    def _get_decoder_module(self):
        return self.model.decoder

    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        outputs = self.model( 
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            position_ids=position_ids,
            decoder_position_ids=decoder_position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        hidden_states = outputs[0]  # last hidden state
        # eos_token_id的位置是1,其他位置是0
        eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)

        # 检查是否在 JAX JIT 编译过程中，如果是 tracer 则跳过后续值检查（避免 concretization 错误）
        if not isinstance(eos_mask, jax.interpreters.partial_eval.DynamicJaxprTracer):
            # 确保每个样本都有相同数量的 <eos>，否则 raise 异常
            if len(jnp.unique(eos_mask.sum(1))) > 1:
                raise ValueError("All examples must have the same number of <eos> tokens.")
             # 检查是否存在缺失 <eos> 的情况
            if any(eos_mask.sum(1) == 0):
                raise ValueError("There are missing <eos> tokens in input_ids")

            # 处理多个 <eos> 的情况：通过加小噪声找出最后一个 <eos> 的位置
            # 例如两个位置为 <eos>，会保留最后一个为1，其余位置为0
            eos_mask_noised = eos_mask + jnp.arange(eos_mask.shape[1]) * 1e-6
            eos_mask = jnp.where(eos_mask_noised == eos_mask_noised.max(1).reshape(-1, 1), 1, 0)
        # 使用 einsum 提取最后一个 <eos> 位置对应的 hidden_state 作为句子表示
        # 实际计算等价于：hidden_states * eos_mask[:, :, None] 再在 seq_len 维度上求和
        sentence_representation = jnp.einsum("ijk, ij -> ijk", hidden_states, eos_mask).sum(1)
        # 将句子表示输入到分类头，得到最终 logits（如情感分类等任务）
        logits = self.classification_head(sentence_representation, deterministic=deterministic)
        # 返回元组或字典结构
        if not return_dict:
            output = (logits,) + outputs[1:]
            return output
        # 
        return FlaxSeq2SeqSequenceClassifierOutput(
            logits=logits, # 分类分数
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )

In [81]:
@add_start_docstrings(
    """
    Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
    tasks.
    """,
    BART_START_DOCSTRING,
)
class FlaxBartForSequenceClassification(FlaxBartPreTrainedModel): # 序列分类模型
    module_class = FlaxBartForSequenceClassificationModule # 使用具体的模块做事
    dtype = jnp.float32


append_call_sample_docstring(
    FlaxBartForSequenceClassification,
    _CHECKPOINT_FOR_DOC,
    FlaxSeq2SeqSequenceClassifierOutput,
    _CONFIG_FOR_DOC,
)

In [None]:
help(FlaxBartForSequenceClassification.__call__)

In [83]:
from transformers import AutoTokenizer, FlaxBartForSequenceClassification
    
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
model = FlaxBartForSequenceClassification.from_pretrained("facebook/bart-base")

Some weights of FlaxBartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: {('classification_head', 'dense', 'bias'), ('classification_head', 'out_proj', 'bias'), ('classification_head', 'dense', 'kernel'), ('classification_head', 'out_proj', 'kernel')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [84]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")

In [85]:
outputs = model(**inputs)
logits = outputs.logits
logits.shape

(1, 3)

In [86]:
key = jax.random.PRNGKey(0)
b, s = 4, 5  # 举例
x = jax.random.normal(key, shape=(b, s, 2))

In [87]:
x.shape

(4, 5, 2)

In [88]:
a,b = jnp.split(x,2, axis=-1)

In [89]:
print(a.shape,b.shape)

(4, 5, 1) (4, 5, 1)


In [90]:
class FlaxBartForQuestionAnsweringModule(nn.Module):
    config: BartConfig
    dtype: jnp.dtype = jnp.float32
    num_labels = 2

    def setup(self):
        self.model = FlaxBartModule(config=self.config, dtype=self.dtype) # 返回(b,s,d)
        self.qa_outputs = nn.Dense( # 返回(b,s,2)
            self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )

    def _get_encoder_module(self):
        return self.model.encoder

    def _get_decoder_module(self):
        return self.model.decoder

    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            position_ids=position_ids,
            decoder_position_ids=decoder_position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        sequence_output = outputs[0] # 返回最后一个decoder的输出 (b,s,d)

        logits = self.qa_outputs(sequence_output) # (b,s,2)
        start_logits, end_logits = jnp.split(logits, logits.shape[-1], axis=-1) # 拆分
        start_logits = start_logits.squeeze(-1) # 模型对答案起始的预测
        end_logits = end_logits.squeeze(-1) # 模型对答案结束的预测

        if not return_dict:
            output = (start_logits, end_logits) + outputs[1:]
            return output

        return FlaxSeq2SeqQuestionAnsweringModelOutput(
            start_logits=start_logits,
            end_logits=end_logits,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )

In [91]:
@add_start_docstrings(
    """
    BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    BART_START_DOCSTRING,
)
class FlaxBartForQuestionAnswering(FlaxBartPreTrainedModel):
    module_class = FlaxBartForQuestionAnsweringModule
    dtype = jnp.float32


append_call_sample_docstring(
    FlaxBartForQuestionAnswering,
    _CHECKPOINT_FOR_DOC,
    FlaxSeq2SeqQuestionAnsweringModelOutput,
    _CONFIG_FOR_DOC,
)

In [None]:
help(FlaxBartForQuestionAnswering.__call__)

In [93]:
from transformers import AutoTokenizer, FlaxBartForQuestionAnswering
    
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
model = FlaxBartForQuestionAnswering.from_pretrained("facebook/bart-base")

Some weights of FlaxBartForQuestionAnswering were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: {('qa_outputs', 'bias'), ('qa_outputs', 'kernel')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [94]:
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
inputs = tokenizer(question, text, return_tensors="jax")

In [95]:
outputs = model(**inputs)

In [96]:
start_scores = outputs.start_logits
end_scores = outputs.end_logits

In [98]:
print(start_scores.shape,end_scores.shape)

(1, 17) (1, 17)


In [99]:
class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel): # 解码器预训练模型
    config_class = BartConfig
    base_model_prefix: str = "model"
    module_class: nn.Module = None  # 子模块

    def __init__(
        self,
        config: BartConfig,
        input_shape: Tuple[int] = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        config.is_decoder = True # 单解码器架构
        config.is_encoder_decoder = False # 非编码器-解码器双架构
        module = self.module_class(config=config, dtype=dtype, **kwargs) # 子模块
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
    # 初始化权重
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)

        batch_size, sequence_length = input_ids.shape
        position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
        # 拆分出参数流和dropout流
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng} # rngs
        encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,)) # (b,s,d)
        encoder_attention_mask = attention_mask 
        module_init_outputs = self.module.init( 
            rngs, # rngs 
            input_ids,
            attention_mask,
            position_ids,
            encoder_hidden_states,
            encoder_attention_mask,
            return_dict=False, # 返回元组
        )
        return module_init_outputs["params"] # 返回初始化后的参数

    def init_cache(self, batch_size, max_length):
        r"""
        Args:
            batch_size (`int`):
                batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
            max_length (`int`):
                maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
                cache.
        """
        # init input variables to retrieve cache
        input_ids = jnp.ones((batch_size, max_length), dtype="i4")
        attention_mask = jnp.ones_like(input_ids, dtype="i4")
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
        
        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        return unfreeze(init_variables["cache"])

    @add_start_docstrings_to_model_forward(BART_DECODE_INPUTS_DOCSTRING)
    def __call__(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict
        # 如果有编码器输出 没有编码器填充掩码 设置默认
        if encoder_hidden_states is not None and encoder_attention_mask is None:
            batch_size, sequence_length = encoder_hidden_states.shape[:2]
            encoder_attention_mask = jnp.ones((batch_size, sequence_length))

        # prepare decoder inputs
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)
        if position_ids is None:
            batch_size, sequence_length = input_ids.shape
            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

        inputs = {"params": params or self.params}

        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
        # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
        # changed by FlaxBartAttention module
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        outputs = self.module.apply(
            inputs,
            input_ids=jnp.array(input_ids, dtype="i4"),
            attention_mask=jnp.array(attention_mask, dtype="i4"),
            position_ids=jnp.array(position_ids, dtype="i4"),
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            mutable=mutable,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]

        return outputs

In [100]:
class FlaxBartDecoderWrapper(nn.Module):
    """
    This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
    used in combination with the [`EncoderDecoderModel`] framework.
    """

    config: BartConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        embed_dim = self.config.d_model
        embed_tokens = nn.Embed(
            self.config.vocab_size,
            embed_dim,
            embedding_init=jax.nn.initializers.normal(self.config.init_std),
            dtype=self.dtype,
        )
        self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)

    def __call__(self, *args, **kwargs):
        return self.decoder(*args, **kwargs)

In [101]:
class FlaxBartForCausalLMModule(nn.Module):
    config: BartConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype)
        self.lm_head = nn.Dense(
            self.config.vocab_size,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )

    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        outputs = self.model(
            input_ids,
            attention_mask,
            position_ids,
            encoder_hidden_states,
            encoder_attention_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]

        if self.config.tie_word_embeddings:
            shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
            lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
        else:
            lm_logits = self.lm_head(hidden_states)

        if not return_dict:
            return (lm_logits,) + outputs[1:]

        return FlaxCausalLMOutputWithCrossAttentions(
            logits=lm_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )

In [102]:
@add_start_docstrings(
    """
    Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings)
    e.g for autoregressive tasks.
    """,
    BART_START_DOCSTRING,
)
class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
    module_class = FlaxBartForCausalLMModule

    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
        # initializing the cache
        batch_size, seq_length = input_ids.shape

        past_key_values = self.init_cache(batch_size, max_length)
        # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
        # But since the decoder uses a causal mask, those positions are masked anyway.
        # Thus, we can create a single static attention_mask here, which is more efficient for compilation
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
        else:
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs


append_call_sample_docstring(
    FlaxBartForCausalLM,
    _CHECKPOINT_FOR_DOC,
    FlaxCausalLMOutputWithCrossAttentions,
    _CONFIG_FOR_DOC,
)

In [None]:
help(FlaxBartForCausalLM.__call__)

In [None]:
from transformers import AutoTokenizer, FlaxBartForCausalLM
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
model = FlaxBartForCausalLM.from_pretrained("facebook/bart-base")

In [105]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
outputs = model(**inputs)

In [109]:
outputs.logits.shape

(1, 8, 50265)

In [106]:
next_token_logits = outputs.logits[:, -1]

In [108]:
next_token_logits.shape

(1, 50265)

In [None]:
__all__ = [
    "FlaxBartDecoderPreTrainedModel",
    "FlaxBartForCausalLM",
    "FlaxBartForConditionalGeneration",
    "FlaxBartForQuestionAnswering",
    "FlaxBartForSequenceClassification",
    "FlaxBartModel",
    "FlaxBartPreTrainedModel",
]