In [2]:
from transformers.configuration_utils import PretrainedConfig

In [None]:
# 1. attribute_map 的作用
# key：外部使用的属性名（“别名”）。
# value：Hugging Face 内部配置里使用的“规范化名称”（唯一属性）。
# 这样做的好处是：
# 兼容不同模型代码里习惯用的不同名字。
# 在统一的 API（例如 AutoConfig、from_pretrained）里，始终能映射到 Hugging Face 自己的标准字段。
# 因为 Hugging Face 想统一：
# 内部保存和导出 checkpoint 时，用 标准字段（如 hidden_size）。
# 外部用户或者不同论文/框架的命名习惯，也能无缝兼容。
# 这样 Hugging Face 的 AutoModel / AutoConfig 就能加载不同来源的模型，而不需要用户手动修改配置。
# 1. base_config_key 在 Hugging Face 里的作用
# 场景：很多模型是“组合模型”（composite models），比如：
# EncoderDecoderModel → 由一个 encoder 配置 + 一个 decoder 配置组合而成。
# VisionTextDualEncoder → 一个 vision 模型配置 + 一个 text 模型配置。
# 问题：这种组合模型保存的时候，config.json 里会有多个子配置。那怎么区分哪个是“主”配置呢？
# 这时就用到 base_config_key，它标识 哪个子配置是基础配置，也就是整个模型的核心。
# 3. 为什么要有 base_config_key
# 因为 Hugging Face 里很多工具（比如 AutoTokenizer, AutoConfig）需要知道：
# 当我加载这个组合模型时，哪个子配置才是我应该默认用的？
# 比如 encoder-decoder 模型，它可能需要用 encoder 的 vocab 来初始化 tokenizer。
# 所以 base_config_key 就是一个“指路标”。
# ✅ 总结
# base_config_key = 在组合模型配置里，标识“主配置”是谁。
# 它和推理阶段的“小模型过滤候选 token → 大模型精排”完全不是一回事。
# multi_label_classification 多标签分类
# 每个样本可以同时属于 多个类别。
# 比如新闻分类：一条新闻可能既是 体育，又涉及 娱乐。
# 普通赋值
# torch_dtype = kwargs.pop("torch_dtype", None)
# if torch_dtype is not None:
#     ...
# 赋值操作和条件判断是 两步。
# 不能写成
# if (torch_dtype = kwargs.pop("torch_dtype", None)) is not None:  # ❌
# 因为 = 是语句，不是表达式，Python 不允许在 if 条件里直接使用。
# 海象赋值 :=
# if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
# torch_dtype := ... 是一个 表达式，会先把值赋给 torch_dtype，
# 同时整个表达式的值就是 kwargs.pop(...) 的返回值，
# 因此可以直接在 if 条件里判断。
# 正确做法是：先初始化模型，再调用方法启用梯度检查点：
# model = MyModel(config)
# model.gradient_checkpointing_enable()
# 如果用 Trainer API 训练模型，可以在 TrainingArguments 里直接设置，不需要在 config 里传：
# training_args = TrainingArguments(
#     gradient_checkpointing=True,
#     ...
# )
# 1. 不能直接实例化 PretrainedConfig
# PretrainedConfig 是基类，不能直接创建实例。
# 示例用 BertConfig（继承自 PretrainedConfig）来演示。
# 2. 从预训练配置加载模型配置
# config = BertConfig.from_pretrained("google-bert/bert-base-uncased")
# 从 Hugging Face Hub 下载 BERT base uncased 的配置，并缓存到本地。
# config = BertConfig.from_pretrained("./test/saved_model/")
# 从本地保存的目录加载配置（以前用 save_pretrained 保存的配置）。
# config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
# 也可以直接从具体的 JSON 文件加载配置。
# . 使用额外参数覆盖配置
# config = BertConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
# # assert config.output_attentions == True
# 获取未使用的参数
# config, unused_kwargs = BertConfig.from_pretrained(
#     "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
# )
# assert config.output_attentions == True
# assert unused_kwargs == {"foo": False}
# return_unused_kwargs=True 时，返回一个元组 (config, unused_kwargs)。
# 任何 BertConfig 不认识的参数（这里是 foo=False）会收集到 unused_kwargs。
# BertConfig() 默认就是 base 模型的默认配置（hidden_size=768，num_attention_heads=12 等）。
# 如果你用 BertConfig.from_pretrained("bert-large-uncased")：
# 得到的配置是 large 模型的参数（hidden_size=1024，num_attention_heads=16 等）。
# 与 BertConfig() 的 base 默认值相比，自然有差异。
# 但是 use_diff=True 的机制只会考虑：
# 用户在实例上修改过的字段
# 它不会自动把 from_pretrained 加载的不同预训练模型参数也视作“差异”，除非这些字段被显式修改过。
# 换句话说：
# 用户修改的字段	✅ 会保存
# 从 from_pretrained("large") 加载的字段	❌ 不算差异，除非再手动修改
# 总结：use_diff=True 关注的是用户显式修改过的字段，而不是不同预训练模型本身的配置差异。
# 用户修改配置通常发生在以下几种情况：
# 1. 改变模型结构

# 用户希望在原始预训练模型基础上调整层数、隐藏维度、注意力头数等：
# config = BertConfig.from_pretrained("bert-base-uncased")
# config.hidden_size = 1024         # 改大隐藏层维度
# config.num_attention_heads = 16   # 改变注意力头数
# 2. 改变训练行为相关参数

# 控制输出或行为的参数，如是否输出中间隐藏状态、注意力权重：
# config.output_hidden_states = True
# config.output_attentions = True
# 3. 改变任务相关设置

# 对于下游任务，修改标签数量或问题类型：
# config.num_labels = 5
# config.problem_type = "multi_label_classification"
# 4. 自定义 tokenizer 或特殊 token

# 指定起始、结束、填充或分隔符 token id：
# config.pad_token_id = 0
# config.bos_token_id = 101
# config.eos_token_id = 102
# 5. 实验性调整或微调

# 例如修改 tie_word_embeddings、chunk_size_feed_forward 等优化训练速度或内存的参数。

# ✅ 总结：

# 用户修改配置 = 显式改变了原本从 from_pretrained 或默认配置加载的字段

# 这些修改才会在 use_diff=True 时被保存到 JSON 文件

# 否则，加载时依然会使用模型默认值或从预训练模型加载的值
# use_diff	保存内容
# False（默认）	保存全部字段：包括默认值和用户修改的字段。JSON 是完整配置，加载时不用依赖默认值。
# True	只保存与默认值不同的字段：只写入用户修改的字段或与类默认值不同的字段，加载时会用默认值补全未保存字段。
# 默认保存全部字段 → 兼容性最强

# use_diff=True → 文件更小，只保存“差异”，方便追踪用户修改
# 在 from_pretrained 方法里，实现“差异保存 JSON 也能完整加载”的机制主要体现在这几步：
# 获取配置字典
# config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get_config_dict 会读取 JSON 文件（无论是完整配置还是差异 JSON）

# 如果是差异 JSON，只返回修改过的字段
# 实例化配置
# return cls.from_dict(config_dict, **kwargs)
# from_dict 会创建一个 默认配置实例

# 然后用 config_dict 中的字段 覆盖默认值

# 这一步就完成了“差异 JSON 补全未保存字段”的逻辑
# 关键点：__init__ 里有 所有字段的默认值：
# def __init__(self, hidden_size=768, num_attention_heads=12, ...):
#     self.hidden_size = hidden_size
#     ...
# 因此 cls(**{"hidden_size":1024}) 会：
# 用 JSON 中提供的值覆盖默认值（hidden_size=1024）。
# 其他未修改的字段使用 __init__ 中的默认值（比如 num_attention_heads=12）。
# cls(**config_dict) 这一行代码执行时，cls 是 PretrainedConfig 的 具体子类，比如 BertConfig、GPT2Config 等。

# config_dict 是从模型的配置文件加载的字典，这个字典会包含加载时的所有配置（可能是默认配置，也可能是修改过的配置）。

# 默认参数的设置：

# BertConfig（或其他具体模型类）会在 __init__ 方法中定义模型的默认参数（比如 hidden_size=768、num_attention_heads=12）。

# 如果 config_dict 里没有这些参数，BertConfig 的构造函数会使用 默认值（即 __init__ 中定义的默认值）。

# use_diff=True 保存的差异：

# 如果在保存配置时设置了 use_diff=True，那么只会保存与 默认配置 的差异部分，比如 {"hidden_size": 1024}。

# 加载配置时（即 cls(**config_dict)），config_dict 中没有差异的部分会使用 子类 __init__ 方法中定义的默认值。例如，如果 config_dict 中没有 num_attention_heads，它会使用 BertConfig 中的默认值 num_attention_heads=12。

In [None]:
class PretrainedConfig(PushToHubMixin):
    model_type: str = "" # 模型类型（例如 "bert", "gpt2"），子类通常会覆盖
    base_config_key: str = ""  # 用于标识基础配置的 key（一般在组合模型中使用）
    sub_configs: dict[str, type["PretrainedConfig"]] = {} # 子配置的注册表，key 为名字，value 为 PretrainedConfig 类型
    has_no_defaults_at_init: bool = False # 初始化时是否没有默认值（控制初始化逻辑）
    attribute_map: dict[str, str] = {}  # 属性映射表，用于将外部访问的属性名映射到内部实际属性名
    base_model_tp_plan: Optional[dict[str, Any]] = None # tensor parallel 的规划（张量并行），通常在大模型分布式时使用
    base_model_pp_plan: Optional[dict[str, tuple[list[str]]]] = None # pipeline parallel 的规划（流水线并行）
    base_model_ep_plan: Optional[dict[str, tuple[list[str]]]] = None  # expert parallel 的规划（专家并行，MoE场景）
    _auto_class: Optional[str] = None # 自动类名（用于 AutoConfig 自动加载）
    def __setattr__(self, key, value):
        # 当设置属性时，检查 key 是否在 attribute_map 中
        if key in super().__getattribute__("attribute_map"):
            key = super().__getattribute__("attribute_map")[key] # 如果存在映射关系，则替换为内部真正的 key
        super().__setattr__(key, value) # 调用父类方法实际设置属性
    def __getattribute__(self, key): 
        # 获取属性时，避免递归调用 attribute_map 本身
        if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
            key = super().__getattribute__("attribute_map")[key] # 如果 key 在映射表中，取对应的内部属性名
        return super().__getattribute__(key)  # 返回实际属性值
    def __init__(
        self,
        *,
        # All models common arguments
        output_hidden_states: bool = False,   # 是否返回所有层的 hidden states
        output_attentions: bool = False,   # 是否返回 attention 权重
        return_dict: bool = True,     # 模型输出是否使用字典格式
        torchscript: bool = False,  # 是否启用 TorchScript
        dtype: Optional[Union[str, "torch.dtype"]] = None # 模型参数的数据类型 (float32, float16等)
        # Common arguments
        pruned_heads: Optional[dict[int, list[int]]] = None,   # 剪枝的 attention head 记录
        tie_word_embeddings: bool = True,  # 是否共享输入和输出 embedding
        chunk_size_feed_forward: int = 0,  # FFN 层的分块大小（大模型可节省内存）
        is_encoder_decoder: bool = False, # 是否为 encoder-decoder 模型
        is_decoder: bool = False,  # 是否是 decoder（在 encoder-decoder 模型中区分）
        cross_attention_hidden_size: Optional[int] = None,  # cross-attn 的 hidden size
        add_cross_attention: bool = False,  # 是否在 decoder 里加 cross-attn
        tie_encoder_decoder: bool = False,   # encoder/decoder 是否共享权重
        # Fine-tuning task arguments
        architectures: Optional[list[str]] = None,  # 模型架构名（如 ["BertForSequenceClassification"]）
        finetuning_task: Optional[str] = None, # 微调任务名称（如 "mrpc", "sst2"）
        id2label: Optional[dict[int, str]] = None,  # 分类任务：id → label 名称
        label2id: Optional[dict[str, int]] = None,  # 分类任务：label → id
        num_labels: Optional[int] = None,  # 分类任务的类别数
        task_specific_params: Optional[dict[str, Any]] = None, # 任务特定的配置
        problem_type: Optional[str] = None,   # 任务类型: regression / single_label_classification / multi_label_classification
        # Tokenizer kwargs
        tokenizer_class: Optional[str] = None,  # 指定 tokenizer 类名
        prefix: Optional[str] = None,   # 输入前缀（seq2seq 模型用）
        bos_token_id: Optional[int] = None,   # 句子开始 token id
        pad_token_id: Optional[int] = None,  # pad token id
        eos_token_id: Optional[int] = None,  # 句子结束 token id
        sep_token_id: Optional[int] = None,  # 分隔符 token id
        decoder_start_token_id: Optional[int] = None, # decoder 起始 token id
        **kwargs, # 允许扩展参数，保证向后兼容
    ):
        # 参数合法性检查
        if label2id is not None and not isinstance(label2id, dict):
            raise ValueError("Argument label2id should be a dictionary.")
        if id2label is not None and not isinstance(id2label, dict):
            raise ValueError("Argument id2label should be a dictionary.")
        if num_labels is not None and id2label is not None and len(id2label) != num_labels:
            logger.warning(
                f"You passed `num_labels={num_labels}` which is incompatible to "
                f"the `id2label` map of length `{len(id2label)}`."
            )
        if problem_type is not None and problem_type not in (
            "regression",
            "single_label_classification",
            "multi_label_classification",
        ):
            raise ValueError(
                f"The config parameter `problem_type` was not understood: received {problem_type} "
                "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
            )
        # 为了兼容旧版 `torch_dtype` 参数，而不是使用更简单的 `dtype`
        # 不发出警告，否则会一直触发，因为大多数 hub 上的配置都有 `torch_dtype`
        # 向后兼容 torch_dtype 参数（hub 上很多 config.json 用的 torch_dtype）
        # 先执行 kwargs.pop("torch_dtype", None)，得到一个值（可能是 None 或实际 dtype）。
        # 将这个值赋给 torch_dtype。
        # 将这个值用于 is not None 判断。
        # 如果条件成立，就进入 if 语句块。
        if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
            dtype = dtype if dtype is not None else torch_dtype # 如果 dtype 和 torch_dtype 同时存在，优先 dtype
        if dtype is not None and isinstance(dtype, str) and is_torch_available():
            # we will start using self.dtype in v5, but to be consistent with
            # from_pretrained's dtype arg convert it to an actual torch.dtype object
            # 如果 dtype 是字符串（如 "float32"），转成 torch.dtype 对象
            import torch
            dtype = getattr(torch, dtype)
        # Attributes common for all models
        self.return_dict = return_dict
        self.output_hidden_states = output_hidden_states
        self.torchscript = torchscript
        self.dtype = dtype
        self._output_attentions = output_attentions  # has public property
        # Less common kwargs, only used by some models
        self.pruned_heads = pruned_heads if pruned_heads is not None else {}
        self.tie_word_embeddings = tie_word_embeddings
        self.chunk_size_feed_forward = chunk_size_feed_forward
        # Encoder-decoder models attributes
        self.is_encoder_decoder = is_encoder_decoder
        self.is_decoder = is_decoder  # used in encoder-decoder models to differentiate encoder from decoder
        self.cross_attention_hidden_size = cross_attention_hidden_size
        self.add_cross_attention = add_cross_attention
        self.tie_encoder_decoder = tie_encoder_decoder
        # Fine-tuning task attributes
        self.architectures = architectures
        self.finetuning_task = finetuning_task
        self.id2label = id2label
        self.label2id = label2id
        self.task_specific_params = task_specific_params
        self.problem_type = problem_type
        if self.id2label is None:
            self._create_id_label_maps(num_labels if num_labels is not None else 2)
        else:
            # Keys are always strings in JSON so convert ids to int here.
            self.id2label = {int(key): value for key, value in self.id2label.items()}
        # Tokenizer attributes
        self.tokenizer_class = tokenizer_class
        self.prefix = prefix
        self.bos_token_id = bos_token_id
        self.pad_token_id = pad_token_id
        self.eos_token_id = eos_token_id
        self.sep_token_id = sep_token_id
        self.decoder_start_token_id = decoder_start_token_id
        # 序列生成参数（兼容旧版本） 这些参数未来会弃用，但现在依然可以加载
        for parameter_name, default_value in self._get_global_generation_defaults().items():
            setattr(self, parameter_name, kwargs.pop(parameter_name, default_value))
        # Name or path to the pretrained checkpoint
        self._name_or_path = str(kwargs.pop("name_or_path", ""))  # checkpoint 名或路径
        self._commit_hash = kwargs.pop("_commit_hash", None)  # Git commit hash
        # Attention 实现方式
        self._attn_implementation = kwargs.pop("attn_implementation", None)
        # Transformers 版本号
        self.transformers_version = kwargs.pop("transformers_version", None)
        # 梯度检查点参数（将被弃用）
        if kwargs.get("gradient_checkpointing", False):
            warnings.warn(
                "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
                "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the "
                "`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`."
            )
        # Additional attributes without default values
        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
            except AttributeError as err:
                logger.error(f"Can't set {key} with value {value} for {self}")
                raise err
        # TODO: remove later, deprecated arguments for TF models
        self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False)
        self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
    def _create_id_label_maps(self, num_labels: int):
        self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
        self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
    @property
    def name_or_path(self) -> Optional[str]:
        return getattr(self, "_name_or_path", None)
    @name_or_path.setter
    def name_or_path(self, value):
        self._name_or_path = str(value)  # Make sure that name_or_path is a string (for JSON encoding)
    @property
    def output_attentions(self):
        """
        `bool`: Whether or not the model should returns all attentions.
        """
        return self._output_attentions
    @output_attentions.setter
    def output_attentions(self, value: bool):
        # If we set `output_attentions` explicitly before the attn implementation, dispatch eager
        if value and self._attn_implementation is None:
            self._attn_implementation = "eager"
        if value and self._attn_implementation != "eager":
            raise ValueError(
                "The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
                f"{self._attn_implementation}. Please set it to 'eager' instead."
            )
        self._output_attentions = value
    @property
    def use_return_dict(self) -> bool:
        """
        `bool`: Whether or not return [`~utils.ModelOutput`] instead of tuples.
        """
        # If torchscript is set, force `return_dict=False` to avoid jit errors
        return self.return_dict and not self.torchscript
    @property
    def num_labels(self) -> int:
        """
        `int`: The number of labels for classification models.
        """
        return len(self.id2label)
    @num_labels.setter
    def num_labels(self, num_labels: int):
        # we do not store `num_labels` attribute in config, but instead
        # compute it based on the length of the `id2label` map
        if self.id2label is None or self.num_labels != num_labels:
            self._create_id_label_maps(num_labels)

    @property
    def _attn_implementation(self):
        return self._attn_implementation_internal

    @_attn_implementation.setter
    def _attn_implementation(self, value: Optional[Union[str, dict]]):
        """We set it recursively on the sub-configs as well"""
        # Set if for current config
        current_attn = getattr(self, "_attn_implementation", None)
        attn_implementation = value if not isinstance(value, dict) else value.get("", current_attn)
        self._attn_implementation_internal = attn_implementation

        # Set it recursively on the subconfigs
        for subconfig_key in self.sub_configs:
            subconfig = getattr(self, subconfig_key, None)
            if subconfig is not None:
                current_subconfig_attn = getattr(subconfig, "_attn_implementation", None)
                sub_implementation = (
                    value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_attn)
                )
                subconfig._attn_implementation = sub_implementation
    @property
    def torch_dtype(self):
        logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
        return self.dtype

    @torch_dtype.setter
    def torch_dtype(self, value):
        logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
        self.dtype = value

    def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
        # 设置 Hugging Face hub token，如果 kwargs 中有 token，会处理
        self._set_token_in_kwargs(kwargs)
         # 确保传入的路径是目录而不是文件
        if os.path.isfile(save_directory):
            raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
        # 获取非默认的生成参数（generation 参数），比如修改过的 decoding 设置
        non_default_generation_parameters = self._get_non_default_generation_parameters()
        if len(non_default_generation_parameters) > 0:
            # TODO: 将来如果用户修改了加载的 config，这里应该抛异常（现在只给警告）
            warnings.warn(
                "Some non-default generation parameters are set in the model config. These should go into either a) "
                "`model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file "
                "(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model)."
                "This warning will become an exception in the future."
                f"\nNon-default generation parameters: {str(non_default_generation_parameters)}",
                UserWarning,
            )
        # 创建保存目录，如果目录不存在就创建
        os.makedirs(save_directory, exist_ok=True)
        # 如果要 push 到 Hugging Face hub
        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None) # 获取提交信息
            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])  # 获取 repo 名称
            repo_id = self._create_repo(repo_id, **kwargs) # 创建远程 repo
            files_timestamps = self._get_files_timestamps(save_directory) # 记录当前文件时间戳，用于增量上传
        # transformers_weights 属性不需要序列化，保存前删除
        if "transformers_weights" in self:
            delattr(self, "transformers_weights")
        # 如果是自定义配置类，保存自定义对象文件，以便从 Hub 加载
        if self._auto_class is not None:
            custom_object_save(self, save_directory, config=self)
        # 定义输出配置文件路径（使用预定义名字 CONFIG_NAME，例如 config.json）
        output_config_file = os.path.join(save_directory, CONFIG_NAME)
        # 将配置保存为 JSON 文件，可选只保存有差异的字段（use_diff=True）
        # 这里的“差异”指的是 当前配置实例 (self) 与该配置类的默认配置 (PretrainedConfig 或派生类的默认值) 的差异。
        self.to_json_file(output_config_file, use_diff=True)
        logger.info(f"Configuration saved in {output_config_file}")
        # 如果要 push 到 Hub，上传修改过的文件
        if push_to_hub:
            self._upload_modified_files(
                save_directory,
                repo_id,
                files_timestamps,
                commit_message=commit_message,
                token=kwargs.get("token"),
            )

    @staticmethod
    def _set_token_in_kwargs(kwargs, token=None):
       
        # Some model config classes like CLIP define their own `from_pretrained` without the new argument `token` yet.
        if token is None:
            token = kwargs.pop("token", None)
        use_auth_token = kwargs.pop("use_auth_token", None)

        if use_auth_token is not None:
            warnings.warn(
                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
                FutureWarning,
            )
            if token is not None:
                raise ValueError(
                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
                )
            token = use_auth_token

        if token is not None:
            kwargs["token"] = token

    @classmethod
    def from_pretrained(
        cls: type[SpecificPretrainedConfigType],
        pretrained_model_name_or_path: Union[str, os.PathLike],
        cache_dir: Optional[Union[str, os.PathLike]] = None,
        force_download: bool = False,
        local_files_only: bool = False,
        token: Optional[Union[str, bool]] = None,
        revision: str = "main",
        **kwargs,
    ) -> SpecificPretrainedConfigType:
        # 文档示例：
        # - 不能直接实例化 PretrainedConfig，所以用 BertConfig 演示
        # - 支持从 Hugging Face Hub 或本地目录 / JSON 文件加载配置
        # - 可以用额外参数覆盖默认字段
        # - 可返回未使用的 kwargs
        kwargs["cache_dir"] = cache_dir # 将常用参数放入 kwargs，以便后续统一处理
        kwargs["force_download"] = force_download
        kwargs["local_files_only"] = local_files_only
        kwargs["revision"] = revision
        # 处理 Hugging Face token（如果需要访问私有模型）
        cls._set_token_in_kwargs(kwargs, token)
        # 从预训练模型路径或 Hub 获取配置字典，同时返回未使用的 kwargs
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
        # 如果存在 base_config_key（组合模型场景），提取基础配置
        if cls.base_config_key and cls.base_config_key in config_dict:
            config_dict = config_dict[cls.base_config_key]
        # 检查配置的 model_type 是否和当前类匹配
        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
            # 有些组合模型可能没有 base_config_key，例如 LlamaConfig
            # 遍历 config_dict 的子字典，看看有没有匹配的 model_type
            for v in config_dict.values():
                if isinstance(v, dict) and v.get("model_type") == cls.model_type:
                    config_dict = v
            # 如果仍然不匹配，发出警告
            # raise warning only if we still can't see a match in `model_type`
            if config_dict["model_type"] != cls.model_type:
                logger.warning(
                    f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                    f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
                )

        return cls.from_dict(config_dict, **kwargs) # 最后通过 from_dict 将字典转换为配置实例
    @classmethod
    def get_config_dict(
        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
    ) -> tuple[dict[str, Any], dict[str, Any]]:
        cls._set_token_in_kwargs(kwargs)
        original_kwargs = copy.deepcopy(kwargs)
        # Get config dict associated with the base config file
        config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
        if config_dict is None:
            return {}, kwargs
        if "_commit_hash" in config_dict:
            original_kwargs["_commit_hash"] = config_dict["_commit_hash"]

        # That config file may point us toward another config file to use.
        if "configuration_files" in config_dict:
            configuration_file = get_configuration_file(config_dict["configuration_files"])
            config_dict, kwargs = cls._get_config_dict(
                pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
            )

        return config_dict, kwargs

    @classmethod
    def _get_config_dict(
        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
    ) -> tuple[dict[str, Any], dict[str, Any]]:
        # 指定缓存目录，存放下载的模型/配置文件。默认使用 transformers 的全局缓存路径。
        cache_dir = kwargs.pop("cache_dir", None)
        # 是否强制重新下载，即使缓存中已有文件也覆盖下载。默认 False。
        force_download = kwargs.pop("force_download", False)
        # 断点续传下载的参数，允许恢复未完成的下载。
        resume_download = kwargs.pop("resume_download", None)
        # 下载时使用的代理设置，例如 {"http": "http://127.0.0.1:8080"}。
        proxies = kwargs.pop("proxies", None)
        # Hugging Face Hub 的访问令牌，用于私有模型下载或推送。
        token = kwargs.pop("token", None)
        # 是否只使用本地文件，不从 Hugging Face Hub 下载。默认 False。
        local_files_only = kwargs.pop("local_files_only", False)
        # 模型的版本或分支，例如 "main" 或 commit hash。用于指定下载特定版本。
        revision = kwargs.pop("revision", None)
        # 是否信任远程模型的自定义代码（仅在 Auto 类使用时有效）。用于安全检查。
        trust_remote_code = kwargs.pop("trust_remote_code", None)
        subfolder = kwargs.pop("subfolder", "") # 指定子文件夹路径，如果配置文件在模型目录下的子文件夹中。
        # 内部参数，用于标记是否通过 pipeline 调用获取配置。用于用户代理记录
        from_pipeline = kwargs.pop("_from_pipeline", None)
        from_auto_class = kwargs.pop("_from_auto", False) # 内部参数，用于标记是否由 Auto 类调用。用于用户代理记录。
        commit_hash = kwargs.pop("_commit_hash", None) # 指定模型或配置的 Hub commit hash，用于版本控制和 reproducibility。
        # 特殊参数，指定 GGUF 格式 checkpoint 文件（一些大模型使用 GGUF 存储配置）。
        gguf_file = kwargs.get("gguf_file")
        # trust_remote_code 参数仅对 Auto 类有效，这里不生效
        if trust_remote_code is True:
            logger.warning(
                "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
                " ignored."
            )
        # 用于记录用户代理信息，传给缓存或下载函数
        user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
        if from_pipeline is not None:
            user_agent["using_pipeline"] = from_pipeline

        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
         # 判断路径是否为本地目录
        is_local = os.path.isdir(pretrained_model_name_or_path)
        if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
            # 特殊情况：直接提供了本地文件
            resolved_config_file = pretrained_model_name_or_path
            is_local = True
        elif is_remote_url(pretrained_model_name_or_path): # 远程 URL 下载
            configuration_file = pretrained_model_name_or_path if gguf_file is None else gguf_file
            resolved_config_file = download_url(pretrained_model_name_or_path)
        else: # 本地文件夹或缓存路径
            configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file
             # 从本地文件夹 / 缓存 / Hub 下载配置
            try:
                # Load from local folder or from cache or download from model Hub and cache
                resolved_config_file = cached_file(
                    pretrained_model_name_or_path,
                    configuration_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                    token=token,
                    user_agent=user_agent,
                    revision=revision,
                    subfolder=subfolder,
                    _commit_hash=commit_hash,
                )
                if resolved_config_file is None:
                    return None, kwargs
                commit_hash = extract_commit_hash(resolved_config_file, commit_hash)  # 提取 commit hash，用于记录 Hub 版本
            except OSError: # 缓存文件读取错误，抛出环境异常
                # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
                # the original exception.
                raise
            except Exception:   # 其他异常，给出通用错误提示
                # For any other exception, we throw a generic error.
                raise OSError(
                    f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
                    " from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
                    f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
                    f" containing a {configuration_file} file"
                )

        try:
            if gguf_file:  # 支持 GGUF 格式 checkpoint（某些 LLM 模型）
                config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"]
            else:
                # 从 JSON 文件读取配置字典
                config_dict = cls._dict_from_json_file(resolved_config_file)
            config_dict["_commit_hash"] = commit_hash # 保存 commit hash
        except (json.JSONDecodeError, UnicodeDecodeError):
            raise OSError(f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file.")

        if is_local:
            logger.info(f"loading configuration file {resolved_config_file}")
        else:
            logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")

        # timm 模型没有保存 model_type，手动补上
        if "model_type" not in config_dict and is_timm_config_dict(config_dict):
            config_dict["model_type"] = "timm_wrapper"

        return config_dict, kwargs # 返回配置字典和剩余未使用的 kwargs

    @classmethod
    def from_dict(
        cls: type[SpecificPretrainedConfigType], config_dict: dict[str, Any], **kwargs
    ) -> SpecificPretrainedConfigType:
        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
        # Those arguments may be passed along for our internal telemetry.
        # We remove them so they don't appear in `return_unused_kwargs`.
        kwargs.pop("_from_auto", None)
        kwargs.pop("_from_pipeline", None)
        # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
        if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
            kwargs["_commit_hash"] = config_dict["_commit_hash"]

        # For BC on the old `torch_dtype`
        if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
            logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
            # If both are present, use `dtype`
            kwargs["dtype"] = kwargs.get("dtype", torch_dtype)

        # We remove it from kwargs so that it does not appear in `return_unused_kwargs`.
        config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)
        # 具体的子类会设置 hidden_size=1024 这些参数 所以可以用use_diff=True 只保存差异
        config = cls(**config_dict)

        if hasattr(config, "pruned_heads"):
            config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}

        # Update config with kwargs if needed
        if "num_labels" in kwargs and "id2label" in kwargs:
            num_labels = kwargs["num_labels"]
            id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
            if len(id2label) != num_labels:
                raise ValueError(
                    f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
                    f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
                    "one of them."
                )
        to_remove = []
        # kwargs 的更新
        # 用额外的 kwargs 覆盖实例字段
        # 这也是一种用户修改字段的机制，但不是默认值补全
        for key, value in kwargs.items(): 
            if hasattr(config, key):
                current_attr = getattr(config, key)
                # To authorize passing a custom subconfig as kwarg in models that have nested configs.
                # We need to update only custom kwarg values instead and keep other attributes in subconfig.
                if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
                    current_attr_updated = current_attr.to_dict()
                    current_attr_updated.update(value)
                    value = current_attr.__class__(**current_attr_updated)
                setattr(config, key, value)
                if key != "dtype":
                    to_remove.append(key)
        for key in to_remove:
            kwargs.pop(key, None)

        logger.info(f"Model config {config}")
        if return_unused_kwargs:
            return config, kwargs
        else:
            return config

    @classmethod
    def from_json_file(
        cls: type[SpecificPretrainedConfigType], json_file: Union[str, os.PathLike]
    ) -> SpecificPretrainedConfigType:
        
        config_dict = cls._dict_from_json_file(json_file)
        return cls(**config_dict)

    @classmethod
    def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
        with open(json_file, encoding="utf-8") as reader:
            text = reader.read()
        return json.loads(text)

    def __eq__(self, other):
        return isinstance(other, PretrainedConfig) and (self.__dict__ == other.__dict__)

    def __repr__(self):
        return f"{self.__class__.__name__} {self.to_json_string()}"

    def __iter__(self):
        yield from self.__dict__

    def to_diff_dict(self) -> dict[str, Any]:
        config_dict = self.to_dict()
        # Get the default config dict (from a fresh PreTrainedConfig instance)
        default_config_dict = PretrainedConfig().to_dict()
        # get class specific config dict
        class_config_dict = self.__class__().to_dict() if not self.has_no_defaults_at_init else {}

        serializable_config_dict = {}

        # Only serialize values that differ from the default config,
        # except always keep the 'config' attribute.
        for key, value in config_dict.items():
            if (
                isinstance(getattr(self, key, None), PretrainedConfig)
                and key in class_config_dict
                and isinstance(class_config_dict[key], dict)
                or key in self.sub_configs
            ):
                # For nested configs we need to clean the diff recursively
                diff = recursive_diff_dict(value, default_config_dict, config_obj=getattr(self, key, None))
                if "model_type" in value:
                    # Needs to be set even if it's not in the diff
                    diff["model_type"] = value["model_type"]

                serializable_config_dict[key] = diff
            elif (
                key not in default_config_dict
                or key == "transformers_version"
                or key == "vocab_file"
                or value != default_config_dict[key]
                or (key in default_config_dict and value != class_config_dict.get(key, value))
            ):
                serializable_config_dict[key] = value

        self._remove_keys_not_serialized(serializable_config_dict)

        # Key removed only in diff dict
        if "_name_or_path" in serializable_config_dict:
            del serializable_config_dict["_name_or_path"]

        if hasattr(self, "quantization_config"):
            serializable_config_dict["quantization_config"] = (
                self.quantization_config.to_dict()
                if not isinstance(self.quantization_config, dict)
                else self.quantization_config
            )
        self.dict_dtype_to_str(serializable_config_dict)

        return serializable_config_dict

    def to_dict(self) -> dict[str, Any]:
        
        output = copy.deepcopy(self.__dict__)
        if hasattr(self.__class__, "model_type"):
            output["model_type"] = self.__class__.model_type

        # Transformers version when serializing the model
        output["transformers_version"] = __version__

        for key, value in output.items():
            # Deal with nested configs like CLIP
            if isinstance(value, PretrainedConfig):
                value = value.to_dict()
                del value["transformers_version"]

            output[key] = value

        self._remove_keys_not_serialized(output)

        if hasattr(self, "quantization_config"):
            output["quantization_config"] = (
                self.quantization_config.to_dict()
                if not isinstance(self.quantization_config, dict)
                else self.quantization_config
            )
        self.dict_dtype_to_str(output)

        return output

    def to_json_string(self, use_diff: bool = True) -> str:
        
        if use_diff is True:
            config_dict = self.to_diff_dict()
        else:
            config_dict = self.to_dict()
        return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

    def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
       
        with open(json_file_path, "w", encoding="utf-8") as writer:
            writer.write(self.to_json_string(use_diff=use_diff))

    def update(self, config_dict: dict[str, Any]):
        
        for key, value in config_dict.items():
            setattr(self, key, value)

    def update_from_string(self, update_str: str):
        d = dict(x.split("=") for x in update_str.split(","))
        for k, v in d.items():
            if not hasattr(self, k):
                raise ValueError(f"key {k} isn't in the original config dict")

            old_v = getattr(self, k)
            if isinstance(old_v, bool):
                if v.lower() in ["true", "1", "y", "yes"]:
                    v = True
                elif v.lower() in ["false", "0", "n", "no"]:
                    v = False
                else:
                    raise ValueError(f"can't derive true or false from {v} (key {k})")
            elif isinstance(old_v, int):
                v = int(v)
            elif isinstance(old_v, float):
                v = float(v)
            elif not isinstance(old_v, str):
                raise TypeError(
                    f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
                )

            setattr(self, k, v)

    def dict_dtype_to_str(self, d: dict[str, Any]) -> None:
       
        if d.get("dtype") is not None:
            if isinstance(d["dtype"], dict):
                d["dtype"] = {k: str(v).split(".")[-1] for k, v in d["dtype"].items()}
            # models like Emu3 can have "dtype" as token in config's vocabulary map,
            # so we also exclude int type here to avoid error in this special case.
            elif not isinstance(d["dtype"], (str, int)):
                d["dtype"] = str(d["dtype"]).split(".")[1]
        for value in d.values():
            if isinstance(value, dict):
                self.dict_dtype_to_str(value)

    def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
       
        if hasattr(self, "quantization_config"):
            # Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
            _ = d.pop("_pre_quantization_dtype", None)

        if "_auto_class" in d:
            del d["_auto_class"]
        if "_output_attentions" in d:
            d["output_attentions"] = d.pop("_output_attentions")
        if "_commit_hash" in d:
            del d["_commit_hash"]
        if "_attn_implementation_internal" in d:
            del d["_attn_implementation_internal"]
        # Do not serialize `base_model_tp_plan` for now
        if "base_model_tp_plan" in d:
            del d["base_model_tp_plan"]
        # Do not serialize `base_model_pp_plan` for now
        if "base_model_pp_plan" in d:
            del d["base_model_pp_plan"]
        for value in d.values():
            if isinstance(value, dict):
                self._remove_keys_not_serialized(value)

    @classmethod
    def register_for_auto_class(cls, auto_class="AutoConfig"):
        
        if not isinstance(auto_class, str):
            auto_class = auto_class.__name__

        import transformers.models.auto as auto_module

        if not hasattr(auto_module, auto_class):
            raise ValueError(f"{auto_class} is not a valid auto class.")

        cls._auto_class = auto_class

    @staticmethod
    def _get_global_generation_defaults() -> dict[str, Any]:
        return {
            "max_length": 20,
            "min_length": 0,
            "do_sample": False,
            "early_stopping": False,
            "num_beams": 1,
            "num_beam_groups": 1,
            "diversity_penalty": 0.0,
            "temperature": 1.0,
            "top_k": 50,
            "top_p": 1.0,
            "typical_p": 1.0,
            "repetition_penalty": 1.0,
            "length_penalty": 1.0,
            "no_repeat_ngram_size": 0,
            "encoder_no_repeat_ngram_size": 0,
            "bad_words_ids": None,
            "num_return_sequences": 1,
            "output_scores": False,
            "return_dict_in_generate": False,
            "forced_bos_token_id": None,
            "forced_eos_token_id": None,
            "remove_invalid_values": False,
            "exponential_decay_length_penalty": None,
            "suppress_tokens": None,
            "begin_suppress_tokens": None,
        }

    def _get_non_default_generation_parameters(self) -> dict[str, Any]:
       
        non_default_generation_parameters = {}
        decoder_attribute_name = None

        # Composite models don't have a default config, use their decoder config as a fallback for default values
        # If no known pattern is matched, then `default_config = None` -> check against the global generation defaults
        try:
            default_config = self.__class__()
        except ValueError:
            decoder_config = self.get_text_config(decoder=True)
            if decoder_config is not self:
                default_config = decoder_config.__class__()
            else:
                default_config = None

        # If it is a composite model, we want to check the subconfig that will be used for generation
        self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name)

        for parameter_name, default_global_value in self._get_global_generation_defaults().items():
            if hasattr(self_decoder_config, parameter_name):
                is_default_in_config = is_default_generation_value = None
                parameter_value = getattr(self_decoder_config, parameter_name)
                # Three cases in which is okay for the model config to hold generation config parameters:
                # 1. The parameter is set to `None`, effectively delegating its value to the generation config
                if parameter_value is None:
                    continue
                # 2. If we have a default config, then the instance should hold the same generation defaults
                if default_config is not None:
                    is_default_in_config = parameter_value == getattr(default_config, parameter_name)
                # 3. if we don't have a default config, then the instance should hold the global generation defaults
                else:
                    is_default_generation_value = parameter_value == default_global_value

                is_non_default = (is_default_in_config is False) or (
                    is_default_in_config is None and is_default_generation_value is False
                )
                if is_non_default:
                    non_default_generation_parameters[parameter_name] = getattr(self_decoder_config, parameter_name)

        return non_default_generation_parameters

    def get_text_config(self, decoder=None, encoder=None) -> "PretrainedConfig":
       
        return_both = decoder == encoder  # both unset or both set -> search all possible names

        decoder_possible_text_config_names = ("decoder", "generator", "text_config")
        encoder_possible_text_config_names = ("text_encoder",)
        if return_both:
            possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
        elif decoder:
            possible_text_config_names = decoder_possible_text_config_names
        else:
            possible_text_config_names = encoder_possible_text_config_names

        valid_text_config_names = []
        for text_config_name in possible_text_config_names:
            if hasattr(self, text_config_name):
                text_config = getattr(self, text_config_name, None)
                if text_config is not None:
                    valid_text_config_names += [text_config_name]

        if len(valid_text_config_names) > 1:
            raise ValueError(
                f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
                "case, using `get_text_config()` would be ambiguous. Please specify the desired text config directly, "
                "e.g. `text_config = config.sub_config_name`"
            )
        elif len(valid_text_config_names) == 1:
            config_to_return = getattr(self, valid_text_config_names[0])
        else:
            config_to_return = self

        # handle legacy models with flat config structure, when we only want one of the configs
        if not return_both and len(valid_text_config_names) == 0 and config_to_return.is_encoder_decoder:
            config_to_return = copy.deepcopy(config_to_return)
            prefix_to_discard = "encoder" if decoder else "decoder"
            for key in config_to_return.to_dict():
                if key.startswith(prefix_to_discard):
                    delattr(config_to_return, key)
            # old encoder/decoder models may use "encoder_layers"/"decoder_layers" instead of "num_hidden_layers"
            if decoder and hasattr(config_to_return, "decoder_layers"):
                config_to_return.num_hidden_layers = config_to_return.decoder_layers
            elif encoder and hasattr(config_to_return, "encoder_layers"):
                config_to_return.num_hidden_layers = config_to_return.encoder_layers

        return config_to_return

    @classmethod
    def from_text_vision_configs(cls, text_config, vision_config, **kwargs):
       
        warnings.warn(
            "The `from_text_vision_configs` method is deprecated and will be removed in v4.60 of Transformers. Please instantiate "
            "the config class directly with `MyConfig(text_config=text_config, vision_config=vision_config, **kwargs)` instead.",
            FutureWarning,
        )

        return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)

    @classmethod
    def from_text_audio_configs(cls, text_config, audio_config, **kwargs):
       
        warnings.warn(
            "The `from_text_audio_configs` method is deprecated and will be removed in v4.60 of Transformers. Please instantiate "
            "the config class directly with `MyConfig(text_config=text_config, audio_config=audio_config, **kwargs)` instead.",
            FutureWarning,
        )

        return cls(text_config=text_config.to_dict(), audio_config=audio_config.to_dict(), **kwargs)


def get_configuration_file(configuration_files: list[str]) -> str:
    
    configuration_files_map = {}
    for file_name in configuration_files:
        if file_name.startswith("config.") and file_name.endswith(".json") and file_name != "config.json":
            v = file_name.removeprefix("config.").removesuffix(".json")
            configuration_files_map[v] = file_name
    available_versions = sorted(configuration_files_map.keys())

    # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
    configuration_file = CONFIG_NAME
    transformers_version = version.parse(__version__)
    for v in available_versions:
        if version.parse(v) <= transformers_version:
            configuration_file = configuration_files_map[v]
        else:
            # No point going further since the versions are sorted.
            break

    return configuration_file


def recursive_diff_dict(dict_a, dict_b, config_obj=None):
    
    diff = {}
    default = config_obj.__class__().to_dict() if config_obj is not None else {}
    for key, value in dict_a.items():
        obj_value = getattr(config_obj, str(key), None)
        if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
            diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
            diff[key] = diff_value
        elif key not in dict_b or (value != default[key]):
            diff[key] = value
    return diff


PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
if PretrainedConfig.push_to_hub.__doc__ is not None:
    PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
        object="config", object_class="AutoConfig", object_files="configuration file"
    )
ALLOWED_LAYER_TYPES = (
    "full_attention",
    "sliding_attention",
    "chunked_attention",
    "linear_attention",  # used in minimax
)

def layer_type_validation(layer_types: list[str]):
    """Check that each entry in `layer_types` are allowed."""
    if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types):
        raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}")

In [None]:
# GemmaConfig 类 是继承自 PretrainedConfig 的，用于初始化一个新的 Gemma 配置。
# model_type = "gemma" 该属性指定了模型的类型（在这个例子中是 gemma），这通常用于在模型加载时标识特定模型。
# keys_to_ignore_at_inference = ["past_key_values"] 这个列表指定了推理阶段需要忽略的字段。在推理过程中，
# 这些字段不会被保存或用于计算。
# base_model_tp_plan 和 base_model_pp_plan：这些是与模型分布式训练（例如模型并行）和推理优化（
# 例如参数并行）相关的配置项。它们为不同的层和操作指定了并行处理方式和参数拆分方式。

In [None]:
# "layers.*.self_attn.q_proj"、"layers.*.self_attn.k_proj"、"layers.*.self_attn.v_proj"：这些项对应于
# 自注意力机制中的查询（Q）、键（K）和值（V）投影矩阵。根据 colwise 设置，这些投影矩阵将在列方向上进行并行化，意味
# 着每个设备只处理矩阵的某一列。
# "layers.*.self_attn.o_proj"：这是自注意力机制的输出投影矩阵，根据 rowwise 设置，它将在行方向上进行并行化，意
# 味着每个设备只处理矩阵的某一行。
# "layers.*.mlp.gate_proj"、"layers.*.mlp.up_proj"、"layers.*.mlp.down_proj"：这些项对应于多层感知机（MLP
# ）中的投影矩阵。colwise 表示这些矩阵将在列方向上进行并行化，而 rowwise 表示某些矩阵将在行方向上进行并行化。

In [None]:
# "embed_tokens"：模型的嵌入层，负责将 input_ids 转换为 inputs_embeds。这里可能是在设备 1 上处理该层，
# input_ids 作为输入，inputs_embeds 作为输出传递给后续层。
# "layers"：这里是模型的多层（transformer 层）。这些层可能分布在不同的设备上。在第一台设备上，hidden_states
# 和 attention_mask 会被传递给后续设备，最终在各自的设备上处理，并且只传递 hidden_states，以减少数据传输的开销。
# "norm"：归一化层，这里处理的是 hidden_states，并且将结果传递给下游设备继续计算。
# 键值对的结构
# 键： 这里的键是模型中的某些层或操作的名称，例如 embed_tokens 表示嵌入层（Embedding），layers 表示模型
# 的 Transformer 层，norm 是归一化层。
# 值： 每个键对应的值是一个元组，元组包含两个元素：
# 第一个列表（输入）： 表示这一层/操作需要的输入数据（可能是其他层的输出）。
# 第二个列表（输出）： 表示这一层/操作的输出数据，通常会被传递给下一个操作或层。

In [None]:
class GemmaConfig(PretrainedConfig):
    r"""
    ```python
    >>> from transformers import GemmaModel, GemmaConfig
    >>> # Initializing a Gemma gemma-7b style configuration
    >>> configuration = GemmaConfig()
    >>> # Initializing a model from the gemma-7b style configuration
    >>> model = GemmaModel(configuration)
    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""
    model_type = "gemma"
    keys_to_ignore_at_inference = ["past_key_values"]
    # Tensor Parallelism (TP) 是一种分布式训练策略，其中模型的各个层被切分并并行化到多个设备上
    base_model_tp_plan = {
        "layers.*.self_attn.q_proj": "colwise",
        "layers.*.self_attn.k_proj": "colwise",
        "layers.*.self_attn.v_proj": "colwise",
        "layers.*.self_attn.o_proj": "rowwise",
        "layers.*.mlp.gate_proj": "colwise",
        "layers.*.mlp.up_proj": "colwise",
        "layers.*.mlp.down_proj": "rowwise",
    }
    base_model_pp_plan = {
        "embed_tokens": (["input_ids"], ["inputs_embeds"]),
        "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
        "norm": (["hidden_states"], ["hidden_states"]),
    }

    def __init__(
        self,
        vocab_size=256000,
        hidden_size=3072,
        intermediate_size=24576,
        num_hidden_layers=28,
        num_attention_heads=16,
        num_key_value_heads=16,
        head_dim=256,
        hidden_act="gelu_pytorch_tanh",
        hidden_activation=None,
        max_position_embeddings=8192,
        initializer_range=0.02,
        rms_norm_eps=1e-6,
        use_cache=True,
        pad_token_id=0,
        eos_token_id=1,
        bos_token_id=2,
        tie_word_embeddings=True,
        rope_theta=10000.0,
        attention_bias=False,
        attention_dropout=0.0,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.head_dim = head_dim
        self.num_key_value_heads = num_key_value_heads
        self.hidden_act = hidden_act
        self.hidden_activation = hidden_activation
        self.initializer_range = initializer_range
        self.rms_norm_eps = rms_norm_eps
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.attention_bias = attention_bias
        self.attention_dropout = attention_dropout

        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )


__all__ = ["GemmaConfig"]
