# Mixtral-8x7B

- [論文](https://arxiv.org/abs/2401.04088)
- [実装](https://github.com/mistralai/mistral-inference)
- [ウェブサイト](https://mistral.ai/news/mixtral-of-experts)

## 概要

**Mixtral 8x7B** は、スパース混合エキスパート（Sparse Mixture of Experts: **SMoE**）言語モデル

Mistral 7Bと異なり、各層が8つのフィードフォワードネットワーク（**エキスパート**）で構成される

トークン毎に **ルーターネットワーク** が2つのエキスパートを選択し、その出力を組み合わせて処理する

エキスパートを動的に選択する仕組みにより、47B（470億）のうち13B（130億）パラメータしか使用しない

計算コストとレイテンシを抑えて、Llama 2 70BやGPT-3.5に匹敵する性能を実現

SFT（Supervised Fine-Tuning）とDPO（Direct Preference Optimization）で訓練した **Mixtral 8x7B Instruct** も公開

## アーキテクチャ

Mixtralは、Mistral 7Bから以下の変更を加えている:

- Sliding Window Attention（SWA）からFully Dense Attentionに変更
- フィードフォワードブロックを混合エキスパート層（MoE層）に変更 

![](image/table_1.png)


### スパース混合エキスパート（Sparse Mixture of Experts）

混合エキスパートの詳細は別の[論文](https://arxiv.org/pdf/2401.04088)を参照

入力 $x$ に対するMoEの出力は、 **エキスパートネットワークの出力の重み付き和** で決まる

エキスパートネットワークの重みは、 **ゲーティングネットワークの出力** で決まる

スパース混合エキスパートの概要図:

![](image/figure_1.png)

$n$ 個のエキスパートネットワーク $\{E_0, E_1, transformers., E_{n-1}\}$ が与えられたときの出力の重み付き和:

$$
\sum_{i=0}^{n-1}G(x)_{i}\cdot E_{i}(x).
$$

- $G(x)_i$: 入力 $x$ に対する $i$ 番目のエキスパートの重み（ゲーティングネットワーク）
- $E_i(x)$: 入力 $x$ を $i$ 番目のエキスパートが処理した出力（エキスパートネットワーク）

Mixtralのゲーティングネットワークは、上位 $K$ 個（Top-K）のロジットに対してソフトマックスを適用した関数:

$$
G(x) := \text{Softmax}(\text{TopK}(x \cdot W_g))
$$

- $W_g$: ゲーティングネットワークの重み行列
- $x\cdot W_g$: 各エキスパートのスコア（ロジット）
- $\text{TopK}$: 上位$K$個に含まれないロジットをマイナス無限大にする関数
- $\text{Softmax}(\cdot)$: マイナス無限大になったロジットを除いて合計 $1.0$ の確率分布に変換する関数

使用するエキスパート数 $K$ を固定し、エキスパートの総数 $n$ を増やすことで、効率的にパラメータ総数を増加できる:

- トークン毎に使用するパラメータ数を **アクティブパラメータ数** （active parameter count）と呼ぶ
- モデルのパラメータ総数を **スパースパラメータ数** （sparse parameter count）と呼ぶ

MoE層は、単一もしくは複数のGPUで効率的に実行できる:

- 単一GPUでの効率化手法
    - [Megablocks][1]: MoEのFFNの操作を大きなスパース行列乗算として扱い、実行速度を向上させる
- 複数GPUでの効率化手法
    - モデル並列化（Model Parallelism techniques）: モデルを層ごとに分けて複数のGPUに展開する
    - [エキスパート並列化][2]（Expert Parallelism: EP）: エキスパートをグループに分けて複数のGPUに展開する

[1]: https://proceedings.mlsys.org/paper_files/paper/2023/hash/5a54f79333768effe7e8927bcccffe40-Abstract-mlsys2023.html
[2]: https://arxiv.org/abs/1701.06538


Mixtralでは、エキスパートをSwiGLUアーキテクチャで実装し、使用するエキスパート数を $K=2$ とする:

$$
y = \sum_{i=0}^{n-1} \text{Softmax}(\text{Top2}(x\cdot W_g))_i \cdot \text{SwiGLU}_i(x)
$$

- $\text{Softmax}(\text{Top2}(x\cdot W_g))_i$: $i$ 番目のエキスパートに対する重み
- $\text{SwiGLU}_i(x)$: $i$ 番目のエキスパートの出力

## ベンチマーク結果

MixtralとLlamaをベンチマークで評価し比較:

- 常識推論（0-shot）
    - Hellaswag: 文脈から自然に続く結末を選ぶ
    - Winogrande: 代名詞が指している単語を選ぶ
    - PIQA（Physical Interaction Question Answering）: 物理的な理解が必要な選択肢を選ぶ
    - SIQA（Social Interaction QA）:人の感情の理解が必要な選択肢を選ぶ
    - OpenbookQA: 一般的な科学的事実（Open Book）の理解が必要な選択肢を選ぶ
    - ARC-Easy（AI2 Reasoning Challenge）: 小学生レベルの科学の理解が必要な選択肢を選ぶ
    - ARC-Challenge: 中学生レベルの科学の理解が必要な選択肢を選ぶ
    - CommonsenseQA: 社会常識の理解が必要な選択肢を選ぶ
- 世界知識（5-shot）
    - NaturalQuestions: 与えられたWikipediaのページから長い回答と短い回答を抽出する
    - TriviaQA: ウェブページやWikipediaのページが与えられ、それらを統合する必要のある回答を抽出する
- 読解（0-shot）
    - BoolQ: 与えられた文章に対して、はい/いいえで回答する
    - QuAC（Question Answering in Context）: 一人のユーザーが連続して質問しそれに対して回答し続ける
- 数学
    - GSM8K（8-shot）: 小学生レベルの算数問題
    - MATH（4-shot）: 競技数学レベルの難しい数学問題
- コード
    - Humaneval（0-shot）: 人が作成したPython関数をヒントから完成させる
    - MBPP（Mostly Basic Python Programming）（3-shot）: 初心者向けの基本的なPythonプログラミング問題を解く
- 総合
    - MMLU（Massive Multitask Language Understanding）（5-shot）: 57の異なる分野の選択問題
    - BBH（Big-Bench Hard）（3-shot）: 現在のモデルが苦手とする23の挑戦的なタスク
    - AGI Eval（3-5-shot）: 米国の大学入学試験、法科大学院試験、医師国家試験などの選択問題

Mixtralは、アクティブパラメータ数が5倍多いLlama2 70Bを多くのベンチマークで上回った（コードと数学が強い）:

![](image/figure_2.png)

![](image/table_2.png)

Mixtralは、アクティブパラメータ数が少なく性能が高い:

![](image/figure_3.png)

Mixtralは、LlaMA 2 70Bより優れ、GPT-3.5（GPT-3.5-Turbo）に匹敵する性能を示した:

![](image/table_3.png)

Mixtralは、英語の他にフランス語・ドイツ語・スペイン語・イタリア語でLlama2 70 Bを大幅に上回る性能:

![](image/table_4.png)

長い文章から探し出すタスク（[passkey retrieval][1]）では、100%の検索性能を示し、困惑度（perplexity）も長さに応じて減少:

![](image/figure_4.png)

[1]: https://arxiv.org/abs/2305.16300

Llama 2と比較して、BBQ（Bias Benchmark for QA）で低いバイアスを示した:

![](image/figure_5.png)

指示チューニング済みモデルは、MT-Benchではオープンウェイトモデルの中で最も高い:

![](image/figure_6.png)

## ルーティング分析

The Pileの検証データセットを使い、トピック毎に0層目・15層目・31層目のエキスパート選択状態を測定

トピックに基づいたエキスパートの選択に明確なパターンは見られなかった（数学のみわずかに反応）:

![](image/figure_7.png)

![](image/table_5.png)

トークンごとのエキスパートの割当では、`self`・`Question`・インデント・連続したトークンが同じルーティング:

![](image/figure_8.png)

## 実装

In [None]:
%pip install -qU transformers==4.57.1
%pip install -qU sentencepiece protobuf bitsandbytes accelerate

try:
    import google.colab
except ImportError:
    from dotenv import load_dotenv
    import os
    load_dotenv()
    HF_TOKEN = os.getenv("HF_TOKEN")

assert HF_TOKEN

import os
import logging as logging_
import transformers
import bitsandbytes
from transformers import PretrainedConfig
from transformers.utils import logging

from typing import Callable, Optional, Union

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

# トークナイザー

import os
from shutil import copyfile
from typing import TYPE_CHECKING, Any, Optional

import sentencepiece as spm

from transformers.convert_slow_tokenizer import import_protobuf
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
from transformers.utils import logging
from transformers.utils.import_utils import requires

# モデル

from typing import Callable, Optional, Union

import torch
from torch import nn

assert torch.cuda.is_available(), "CUDAを使用できません"

from transformers.utils.generic import check_model_inputs

from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.integrations import use_kernel_forward_from_hub
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import (
    GenericForQuestionAnswering,
    GenericForSequenceClassification,
    GenericForTokenClassification,
    GradientCheckpointingLayer,
)
from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import OutputRecorder
from transformers.models.mixtral.configuration_mixtral import MixtralConfig

# デバイス設定

torch.set_default_device("cuda")

# ログ設定

if os.path.exists('debug.log'):
    os.remove('debug.log')

def custom_format(record):
    match record.levelno:
        case logging_.DEBUG:
            level = '🟦'
        case logging_.INFO:
            level = '🟩'
        case logging_.WARNING:
            level = '🟨'
        case logging_.ERROR:
            level = '🟥'
        case logging_.CRITICAL:
            level = '🛑'
    return f"{level} {record.getMessage()}"

logging.set_verbosity_debug()
logger = logging.get_logger()

for handler in logger.handlers:
    logger.removeHandler(handler)

formatter = logging_.Formatter()
formatter.format = custom_format

file_handler = logging_.FileHandler('debug.log')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

stream_handler = logging_.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)

logger.info(f"Transformers version: {transformers.__version__}")
logger.info(f"Numpy version: {np.__version__}")
logger.info(f"BitsAndBytes version: {bitsandbytes.__version__}")

### LlamaTokenizer

In [None]:
# 学習済み語彙ファイルのパス
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}

# 単語の先頭を示す特殊文字
SPIECE_UNDERLINE = "▁"

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
 that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
correct. If you don't know the answer to a question, please don't share false information."""  # fmt: skip

In [None]:
@requires(backends=("sentencepiece",))
class LlamaTokenizer(PreTrainedTokenizer):
    """
    Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
    no padding token in the original model.

    Args:
        vocab_file (`str`):
            Path to the vocabulary file.
        unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
        eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
            The end of sequence token.
        pad_token (`str` or `tokenizers.AddedToken`, *optional*):
            A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
            attention mechanisms or loss computation.
        sp_model_kwargs (`dict[str, Any]`, `Optional`, *optional*):
            Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
            SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
            to set:

            - `enable_sampling`: Enable subword regularization.
            - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.

              - `nbest_size = {0,1}`: No sampling is performed.
              - `nbest_size > 1`: samples from the nbest_size results.
              - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
                using forward-filtering-and-backward-sampling algorithm.

            - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
              BPE-dropout.

        add_bos_token (`bool`, *optional*, defaults to `True`):
            Whether or not to add an `bos_token` at the start of sequences.
        add_eos_token (`bool`, *optional*, defaults to `False`):
            Whether or not to add an `eos_token` at the end of sequences.
        clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
            Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
            extra spaces.
        use_default_system_prompt (`bool`, *optional*, defaults to `False`):
            Whether or not the default system prompt for Llama should be used.
        spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
            Whether or not to add spaces between special tokens.
        legacy (`bool`, *optional*):
            Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
            and #25224 which includes fixes to properly handle tokens that appear after special tokens.
            Make sure to also set `from_slow` to `True`.
            A simple example:

            - `legacy=True`:
            ```python
            >>> from transformers import LlamaTokenizerFast

            >>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=True, from_slow=True)
            >>> tokenizer.encode("Hello <s>.") # 869 is '▁.'
            [1, 15043, 29871, 1, 869]
            ```
            - `legacy=False`:
            ```python
            >>> from transformers import LlamaTokenizerFast

            >>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=False, from_slow=True)
            >>> tokenizer.encode("Hello <s>.")  # 29889 is '.'
            [1, 15043, 29871, 1, 29889]
            ```
            Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
        add_prefix_space (`bool`, *optional*, defaults to `True`):
            Whether or not to add an initial space to the input. This allows to treat the leading word just as any
            other word. Again, this should be set with `from_slow=True` to make sure it's taken into account.
    """

    vocab_files_names = VOCAB_FILES_NAMES
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        vocab_file,
        unk_token="<unk>",
        bos_token="<s>",
        eos_token="</s>",
        pad_token=None,
        sp_model_kwargs: Optional[dict[str, Any]] = None,
        add_bos_token=True,
        add_eos_token=False,
        clean_up_tokenization_spaces=False,
        use_default_system_prompt=False,
        spaces_between_special_tokens=False,
        legacy=None,
        add_prefix_space=True,
        **kwargs,
    ):
        logger.info(f"LlamaTokenizerの初期化開始 vocab_file={vocab_file} unk_token={unk_token} bos_token={bos_token} eos_token={eos_token} pad_token={pad_token} legacy={legacy} add_prefix_space={add_prefix_space}")

        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
        bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
        eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
        unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
        pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token

        if legacy is None:
            logger.warning_once(
                f"You are using the default legacy behaviour of the {self.__class__}. This is"
                " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
                " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
                " means, and thoroughly read the reason why this was added as explained in"
                " https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file"
                " you can ignore this message"
            )
            legacy = True

        self.legacy = legacy
        self.vocab_file = vocab_file
        self.add_bos_token = add_bos_token
        self.add_eos_token = add_eos_token
        self.use_default_system_prompt = use_default_system_prompt
        self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
        self.add_prefix_space = add_prefix_space

        super().__init__(
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            pad_token=pad_token,
            add_bos_token=add_bos_token,
            add_eos_token=add_eos_token,
            sp_model_kwargs=self.sp_model_kwargs,
            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
            use_default_system_prompt=use_default_system_prompt,
            spaces_between_special_tokens=spaces_between_special_tokens,
            legacy=legacy,
            add_prefix_space=add_prefix_space,
            **kwargs,
        )
        logger.info("LlamaTokenizerの初期化完了")

    @property
    def unk_token_length(self):
        return len(self.sp_model.encode(str(self.unk_token)))

    # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
    def get_spm_processor(self, from_slow=False):
        logger.info(f"SentencePieceProcessorの取得 from_slow={from_slow} legacy={self.legacy}")

        tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        if self.legacy or from_slow:  # no dependency on protobuf
            tokenizer.Load(self.vocab_file)
            return tokenizer

        with open(self.vocab_file, "rb") as f:
            sp_model = f.read()
            model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
            model = model_pb2.ModelProto.FromString(sp_model)
            normalizer_spec = model_pb2.NormalizerSpec()
            normalizer_spec.add_dummy_prefix = False
            model.normalizer_spec.MergeFrom(normalizer_spec)
            sp_model = model.SerializeToString()
            tokenizer.LoadFromSerializedProto(sp_model)

        logger.info("SentencePieceProcessorの取得完了")
        return tokenizer

    def __getstate__(self):
        state = self.__dict__.copy()
        state["sp_model"] = None
        state["sp_model_proto"] = self.sp_model.serialized_model_proto()
        return state

    def __setstate__(self, d):
        self.__dict__.update(d)
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.LoadFromSerializedProto(self.sp_model_proto)

    @property
    def vocab_size(self):
        """Returns vocab size"""
        return self.sp_model.get_piece_size()

    def get_vocab(self):
        """Returns vocab as a dict"""
        logger.info("語彙の取得開始")

        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        vocab.update(self.added_tokens_encoder)

        logger.info(f"語彙の取得完了 vocab_size={len(vocab)}")
        return vocab

    # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
    def tokenize(self, text: "TextInput", **kwargs) -> list[str]:
        """
        Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
        first token is special.
        """
        logger.info(f"トークン化開始 text={text} legacy={self.legacy} add_prefix_space={self.add_prefix_space}")

        # True
        if self.legacy or len(text) == 0:

            # 'Hello' -> '_Hello'
            res = super().tokenize(text, **kwargs)
            logger.info(f"トークン化完了（legacy） tokens={res}")
            return res

        text = text.replace(SPIECE_UNDERLINE, " ")
        if self.add_prefix_space:
            text = SPIECE_UNDERLINE + text

        tokens = super().tokenize(text, **kwargs)

        if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
            tokens = tokens[1:]
        return tokens

    # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
    def _tokenize(self, text, **kwargs):
        """
        Returns a tokenized string.

        We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
        SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
        `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
        `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
        `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
        """
        logger.info(f"_tokenizeの開始 text={text} legacy={self.legacy}")

        # True
        if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
            # 'Hello' -> '_Hello'
            res = self.sp_model.encode(text, out_type=str)
            logger.info(f"_tokenizeの完了（legacy） tokens={res}")
            return res

        # 1. Encode string + prefix ex: "<unk> Hey"
        tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
        # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
        res = tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
        return res

    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        logger.info(f"_convert_token_to_idの開始 token={token}")
        # '_Hello' -> 22557
        res = self.sp_model.piece_to_id(token)
        logger.info(f"_convert_token_to_idの完了 token={token} id={res}")
        return res

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        # logger.info(f"_convert_id_to_tokenの開始 id={index}")
        token = self.sp_model.IdToPiece(index)
        res = token
        # logger.info(f"_convert_id_to_tokenの完了 id={index} token={res}")
        return res

    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        logger.info(f"トークン列の文字列変換開始 tokens={tokens} legacy={self.legacy} add_prefix_space={self.add_prefix_space}")

        # since we manually add the prefix space, we have to remove it when decoding
        # ['_Hello', ',] -> ['Hello', ',']
        if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space:
            tokens[0] = tokens[0][1:]
            logger.debug(f"最初のトークンから接頭辞スペースを削除 {tokens=}")

        current_sub_tokens = []
        out_string = ""
        prev_is_special = False
        for i, token in enumerate(tokens):
            # make sure that special tokens are not decoded using sentencepiece model
            if token in self.all_special_tokens:
                if not prev_is_special and i != 0 and self.legacy:
                    out_string += " "
                out_string += self.sp_model.decode(current_sub_tokens) + token
                logger.debug(f"{current_sub_tokens=} をデコードし集約 {out_string=}")
                prev_is_special = True
                current_sub_tokens = []
            else:
                if prev_is_special and i == 1 and self.add_prefix_space and not token.startswith(SPIECE_UNDERLINE):
                    out_string += " "
                current_sub_tokens.append(token)
                logger.debug(f"特殊トークンではないのでバッファに追加 {current_sub_tokens=}")
                prev_is_special = False

        out_string += self.sp_model.decode(current_sub_tokens)
        logger.debug(f"最後のバッファ {current_sub_tokens=} をデコードし集約 {out_string=}")

        logger.info(f"トークン列の文字列変換完了 {out_string=}")
        return out_string

    def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> tuple[str]:
        """
        Save the vocabulary and special tokens file to a directory.

        Args:
            save_directory (`str`):
                The directory in which to save the vocabulary.

        Returns:
            `Tuple(str)`: Paths to the files saved.
        """
        logger.info(f"語彙の保存開始 save_directory={save_directory} filename_prefix={filename_prefix}")
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
            copyfile(self.vocab_file, out_vocab_file)
        elif not os.path.isfile(self.vocab_file):
            with open(out_vocab_file, "wb") as fi:
                content_spiece_model = self.sp_model.serialized_model_proto()
                fi.write(content_spiece_model)

        logger.info(f"語彙の保存完了 out_vocab_file={out_vocab_file}")
        return (out_vocab_file,)

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        logger.info(f"特殊トークン付き入力の構築開始 token_ids_0={token_ids_0} token_ids_1={token_ids_1}")

        # [1]
        bos_token_id = [self.bos_token_id] if self.add_bos_token else []

        # []
        eos_token_id = [self.eos_token_id] if self.add_eos_token else []

        # [1, 22557]
        output = bos_token_id + token_ids_0 + eos_token_id

        # False
        if token_ids_1 is not None:
            output = output + bos_token_id + token_ids_1 + eos_token_id

        logger.info(f"特殊トークン付き入力の構築完了 output={output}")
        return output

    def get_special_tokens_mask(
        self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
    ) -> list[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`list[int]`):
                List of IDs.
            token_ids_1 (`list[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """
        logger.info(f"特殊トークンマスクの取得開始 token_ids_0={token_ids_0} token_ids_1={token_ids_1} already_has_special_tokens={already_has_special_tokens}")
        if already_has_special_tokens:
            res = super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )
            logger.info(f"特殊トークンマスクの取得完了（すでに特殊トークンを設定済み） mask={res}")
            return res

        bos_token_id = [1] if self.add_bos_token else []
        eos_token_id = [1] if self.add_eos_token else []

        if token_ids_1 is None:
            res = bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
            logger.info(f"特殊トークンマスクの取得完了（token_ids_1が空） mask={res}")
            return res
        res = (
            bos_token_id
            + ([0] * len(token_ids_0))
            + eos_token_id
            + bos_token_id
            + ([0] * len(token_ids_1))
            + eos_token_id
        )
        logger.info(f"特殊トークンマスクの取得完了 mask={res}")
        return res

    def create_token_type_ids_from_sequences(
        self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
    ) -> list[int]:
        """
        Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
        sequence pair mask has the following format:

        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```

        if token_ids_1 is None, only returns the first portion of the mask (0s).

        Args:
            token_ids_0 (`list[int]`):
                List of ids.
            token_ids_1 (`list[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `list[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
        """
        logger.info(f"シーケンスからトークンタイプIDの作成開始 token_ids_0={token_ids_0} token_ids_1={token_ids_1}")

        # [1]
        bos_token_id = [self.bos_token_id] if self.add_bos_token else []

        # []
        eos_token_id = [self.eos_token_id] if self.add_eos_token else []

        # [1, 22557] -> [0, 0]
        output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)

        # False
        if token_ids_1 is not None:
            output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)

        logger.info(f"シーケンスからトークンタイプIDの作成完了 output={output}")
        return output


### MixtralBlockSparseTop2MLP

MixtralBlockSparseTop2MLPは、SMoEのエキスパートクラス

実体は、SwiGLUで実装されたフィードフォワードネットワーク

In [None]:
class MixtralBlockSparseTop2MLP(nn.Module):
    def __init__(self, config: MixtralConfig):
        logger.info(f"MixtralBlockSparseTop2MLPの初期化開始 {config.intermediate_size=}, {config.hidden_size=}, {config.hidden_act=}")

        super().__init__()

        # 14336
        self.ffn_dim = config.intermediate_size

        # 4096
        self.hidden_dim = config.hidden_size

        # 4096 -> 14336
        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)

        # 14336 -> 4096
        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)

        # 14336 -> 4096
        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)

        # SiLU
        self.act_fn = ACT2FN[config.hidden_act]

        logger.info("MixtralBlockSparseTop2MLPの初期化完了")

    def forward(self, hidden_states):
        logger.info(f"MixtralBlockSparseTop2MLPの順伝播開始 {hidden_states.shape=}")

        # (1, 4096) -> (1, 14336)
        gate = self.act_fn(self.w1(hidden_states))
        logger.debug(f"ゲートを計算 {gate.shape=}")

        # (1, 4096) -> (1, 14336)
        up = self.w3(hidden_states)
        logger.debug(f"アッププロジェクションを計算 {up.shape=}")

        # (1, 14336) -> (1, 4096)
        current_hidden_states = self.w2(gate * up)
        logger.debug(f"ダウンプロジェクションを計算 {current_hidden_states.shape=}")

        logger.info(f"MixtralBlockSparseTop2MLPの順伝播完了 {current_hidden_states.shape=}")
        return current_hidden_states

### MixtralSparseMoeBlock

MixtralSparseMoeBlockは、ルーティングネットワークとエキスパートネットワークを統合するクラス

In [None]:
class MixtralSparseMoeBlock(nn.Module):
    """
    This implementation is
    strictly equivalent to standard MoE with full capacity (no
    dropped tokens). It's faster since it formulates MoE operations
    in terms of block-sparse operations to accommodate imbalanced
    assignments of tokens to experts, whereas standard MoE either
    (1) drop tokens at the cost of reduced performance or (2) set
    capacity factor to number of experts and thus waste computation
    and memory on padding.
    """

    def __init__(self, config):
        logger.info(f"MixtralSparseMoeBlockの初期化開始 {config.num_local_experts=}, {config.num_experts_per_tok=}, {config.hidden_size=}, {config.intermediate_size=}, {config.router_jitter_noise=}")

        super().__init__()

        # 4096
        self.hidden_dim = config.hidden_size

        # 14336
        self.ffn_dim = config.intermediate_size

        # 8
        self.num_experts = config.num_local_experts

        # 2
        self.top_k = config.num_experts_per_tok

        # ルーティングネットワークを初期化
        # 4096 -> 8
        self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

        # 8つのエキスパートネットワークを初期化
        self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])

        # Jitter parameters
        # 0.0
        self.jitter_noise = config.router_jitter_noise

        logger.info("MixtralSparseMoeBlockの初期化完了")

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logger.info(f"MixtralSparseMoeBlockの順伝播開始 {hidden_states.shape=}")

        #########
        # 初期化 #
        #########

        # (1, 2, 4096)
        batch_size, sequence_length, hidden_dim = hidden_states.shape

        # False
        if self.training and self.jitter_noise > 0:
            hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)

        # (1 * 2, 4096)
        hidden_states = hidden_states.view(-1, hidden_dim)
        logger.debug(f"hidden_statesを整形 {hidden_states.shape=}")

        ##############
        # ルーティング #
        ##############

        # スコア（ロジット）を計算
        # (2, 4096) -> (2, 8)
        router_logits = self.gate(hidden_states)
        logger.debug(f"ルーターのロジットを計算 {router_logits.shape=}")

        # エキスパートの重みを計算（アップキャストし、ソフトマックスを適用）
        # (2, 8)
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        logger.debug(f"ルーティングの重みを計算 {routing_weights.shape=}")

        # 上位2つのエキスパートを選択
        # (2, 2), (2, 2)
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        logger.debug(f"上位2つのエキスパートを選択 {routing_weights.shape=} {selected_experts.shape=}")

        # 2つのエキスパートの重みを正規化し、ダウンキャスト
        # (2, 2)
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        routing_weights = routing_weights.to(hidden_states.dtype)
        logger.debug(f"ルーティングの重みを正規化 {routing_weights.shape=}")

        ############
        # 処理の実行 #
        ############

        # 結果を初期化
        # (2, 4096)
        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim),
            dtype=hidden_states.dtype,
            device=hidden_states.device
        )
        logger.debug(f"最終的なhidden_statesを初期化 {final_hidden_states.shape=}")

        # エキスパートが担当するトークンのマスクを作成
        # One hot encode the selected experts to create an expert mask
        # this will be used to easily index which expert is going to be sollicitated
        # (8, 2, 2)
        expert_mask = torch.nn.functional.one_hot(
            selected_experts,
            num_classes=self.num_experts
        ).permute(2, 1, 0)
        logger.debug(f"エキスパートマスクを作成 {expert_mask.shape=}")

        # (4, 1)
        expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
        logger.debug(f"エキスパートヒットを計算 {expert_hit.shape=}")

        for expert_idx in expert_hit:
            # エキスパートを取得
            expert_layer = self.experts[expert_idx]

            # (1,), (1,)
            idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
            logger.debug(f"マスクの値を取得 {expert_idx[0]=} {idx=} {top_x=}")

            # Index the correct hidden states and compute the expert hidden state for
            # the current expert. We need to make sure to multiply the output hidden
            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
            # (1, 4096)
            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
            logger.debug(f"エキスパートの現在のhidden_statesを取得 {current_state.shape=}")

            # エキスパートの順伝播を実行
            # (1, 4096) -> (1, 4096)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
            logger.debug(f"エキスパートの順伝播を実行 {current_hidden_states.shape=}")

            # 結果に重みを掛けて加算
            # (2, 4096)
            # However `index_add_` only support torch tensors for indexing so we'll use
            # the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
            logger.debug(f"最終的なhidden_statesに加算 {final_hidden_states.shape=}")

        # 形状を元に戻す
        # (2, 4096) -> (1, 2, 4096)
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        logger.debug(f"hidden_statesの形状を元に戻す {final_hidden_states.shape=}")

        logger.info(f"MixtralSparseMoeBlockの順伝播完了 {final_hidden_states.shape=} {router_logits.shape=}")
        return final_hidden_states, router_logits

### MixtralRMSNorm

MixtralRMSNormは、二乗平均平方根正規化（Root Mean Square Layer Normalization）のクラス

データの分布の拡大や縮小を抑制することで、学習を安定化させる

レイヤー正規化の簡易版で、中心化をスキップした軽量版

$$
y_i = \frac{x_i}{\sqrt{\frac{1}{n} \sum_{j=1}^{n} x_j^2 + \epsilon}} \cdot g_i
$$

In [None]:
@use_kernel_forward_from_hub("RMSNorm")
class MixtralRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        MixtralRMSNorm is equivalent to T5LayerNorm
        """
        logger.info(f"MixtralRMSNormの初期化開始 {hidden_size=}, {eps=}")

        super().__init__()

        # (4096,)
        self.weight = nn.Parameter(torch.ones(hidden_size))

        # 1e-05
        self.variance_epsilon = eps

        logger.info("MixtralRMSNormの初期化完了")

    def forward(self, hidden_states):
        logger.info(f"MixtralRMSNormの順伝播開始 {hidden_states.shape=}")

        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)

        # (1, 2, 1)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        logger.debug(f"分散を計算 {variance.shape=}")

        # (1, 2, 4096)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        logger.debug(f"正規化を適用 {hidden_states.shape=}")

        # (1, 2, 4096)
        res = self.weight * hidden_states.to(input_dtype)
        logger.debug(f"ゲインを適用 {res.shape=}")

        logger.info(f"MixtralRMSNormの順伝播完了 {res.shape=}")
        return res

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

### MixtralRotaryEmbedding

MixtralRotaryEmbeddingは、回転位置埋め込み（Rotary Position Embedding: RoPE）の適用に必要なサインとコサインの値を計算するクラス

RoPEは、トークンの位置に応じてクエリとキーを変調させることで、位置情報を追加する手法

クエリの変調の流れ:

1.  128個の要素を持つクエリ$Q$を、前半64個$Q_1$と後半64個の$Q_2$に分ける
2. $Q_1$と$Q_2$から一つずつ要素を取り、$(q1, q2)$のペアを作る（=2次元平面を構成）
3. ペアに対し角度$\theta$で回転を適用する

2次元回転の公式:

$$
(q_1 \cos{\theta} - q_2 \sin{\theta}, q_2 \cos{\theta} + q_1 \sin{\theta})
$$

回転の角度:

$$
\theta_{m, i} = m\cdot b^{-\frac{2i}{d}}
$$

- $m$: シーケンス内でのトークンの位置（$0, 1, 2,...$）
- $b^{-\frac{2i}{d}}$: 周波数の逆数
    - $b$: 基数（$10000$）
    - $i$: 次元インデックス（$0, 1, 2, ..., d/2-1$）
    - $d$: ベクトルの総次元数

トークンの位置が遠い（回転角度の差が大きい）ほど向きが揃わなくなり、クエリとキーの行列積は小さくなる

2次元回転の実装は、公式を展開して効率的に計算を行う（apply_rotate_pos_emb関数）:

$$
(q_1 \cos{\theta} - q_2 \sin{\theta}, q_2 \cos{\theta} + q_1 \sin{\theta})
= [q_1, q_2] \cdot \cos{\theta} + [-q_2, q_1] \cdot \sin{\theta}
$$

$$
(k_1 \cos{\theta} - k_2 \sin{\theta}, k_2 \cos{\theta} + k_1 \sin{\theta})
= [k_1, k_2] \cdot \cos{\theta} + [-k_2, k_1] \cdot \sin{\theta}
$$

$x$を$[-x_2, x_1]$に変換するrotate_hal関数を使用し、更に実装を簡単にする:


$$
[q_1, q_2] \cdot \cos{\theta} + [-q_2, q_1] \cdot \sin{\theta}
= Q \cdot \cos{\theta} + \text{rotate\_half}(Q) \cdot \sin{\theta}
$$

$$
[k_1, k_2] \cdot \cos{\theta} + [-k_2, k_1] \cdot \sin{\theta}
= K \cdot \cos{\theta} + \text{rotate\_half}(K) \cdot \sin{\theta}
$$GemmaRotaryEmbeddingは、回転位置埋め込み（Rotary Position Embedding: RoPE）の適用に必要なサインとコサインの値を計算するクラス

RoPEは、トークンの位置に応じてクエリとキーを変調させることで、位置情報を追加する手法

クエリの変調の流れ:

1.  128個の要素を持つクエリ$Q$を、前半64個$Q_1$と後半64個の$Q_2$に分ける
2. $Q_1$と$Q_2$から一つずつ要素を取り、$(q1, q2)$のペアを作る（=2次元平面を構成）
3. ペアに対し角度$\theta$で回転を適用する

2次元回転の公式:

$$
(q_1 \cos{\theta} - q_2 \sin{\theta}, q_2 \cos{\theta} + q_1 \sin{\theta})
$$

回転の角度:

$$
\theta_{m, i} = m\cdot b^{-\frac{2i}{d}}
$$

- $m$: シーケンス内でのトークンの位置（$0, 1, 2,...$）
- $b^{-\frac{2i}{d}}$: 周波数の逆数
    - $b$: 基数（$10000$）
    - $i$: 次元インデックス（$0, 1, 2, ..., d/2-1$）
    - $d$: ベクトルの総次元数

トークンの位置が遠い（回転角度の差が大きい）ほど向きが揃わなくなり、クエリとキーの行列積は小さくなる

2次元回転の実装は、公式を展開して効率的に計算を行う（apply_rotate_pos_emb関数）:

$$
(q_1 \cos{\theta} - q_2 \sin{\theta}, q_2 \cos{\theta} + q_1 \sin{\theta})
= [q_1, q_2] \cdot \cos{\theta} + [-q_2, q_1] \cdot \sin{\theta}
$$

$$
(k_1 \cos{\theta} - k_2 \sin{\theta}, k_2 \cos{\theta} + k_1 \sin{\theta})
= [k_1, k_2] \cdot \cos{\theta} + [-k_2, k_1] \cdot \sin{\theta}
$$

$x$を$[-x_2, x_1]$に変換するrotate_hal関数を使用し、更に実装を簡単にする:


$$
[q_1, q_2] \cdot \cos{\theta} + [-q_2, q_1] \cdot \sin{\theta}
= Q \cdot \cos{\theta} + \text{rotate\_half}(Q) \cdot \sin{\theta}
$$

$$
[k_1, k_2] \cdot \cos{\theta} + [-k_2, k_1] \cdot \sin{\theta}
= K \cdot \cos{\theta} + \text{rotate\_half}(K) \cdot \sin{\theta}
$$

In [None]:
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    logger.info(f"rotate_halfの開始 {x.shape=}")

    x1 = x[..., : x.shape[-1] // 2]
    logger.debug(f"x1を取得 {x1.shape=}")

    x2 = x[..., x.shape[-1] // 2 :]
    logger.debug(f"x2を取得 {x2.shape=}")

    res = torch.cat((-x2, x1), dim=-1)
    logger.info(f"rotate_halfの完了 {res.shape=}")
    return res

In [None]:
class MixtralRotaryEmbedding(nn.Module):
    inv_freq: torch.Tensor  # fix linting for `register_buffer`

    def __init__(self, config: MixtralConfig, device=None):
        logger.info(f"MixtralRotaryEmbeddingの初期化開始 {config.max_position_embeddings=}")
        super().__init__()

        if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
        else:
            self.rope_type = "default"

        # 32768
        self.max_seq_len_cached = config.max_position_embeddings

        # 32768
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config

        # default
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
        logger.debug(f"{self.rope_type=}")

        # (64,), 1.0
        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        logger.debug(f"{inv_freq.shape=}, {self.attention_scaling=}")

        self.register_buffer("inv_freq", inv_freq, persistent=False)

        self.original_inv_freq = self.inv_freq

        logger.info("MixtralRotaryEmbeddingの初期化完了")

    @torch.no_grad()
    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
    def forward(self, x, position_ids):
        logger.info(f"MixtralRotaryEmbeddingの順伝播開始 {x.shape=} {position_ids.shape=}")

        # (1, 64, 1)
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        logger.debug(f"{inv_freq_expanded.shape=}")

        # (1, 1, 2)
        position_ids_expanded = position_ids[:, None, :].float()
        logger.debug(f"{position_ids_expanded.shape=}")

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"

        with torch.autocast(device_type=device_type, enabled=False):  # Force float32
            # (1, 64, 1) @ (1, 1, 2) -> (1, 64, 2) -> (1, 2, 64)
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            logger.debug(f"{freqs.shape=}")

            # (1, 2, 128)
            emb = torch.cat((freqs, freqs), dim=-1)
            logger.debug(f"{emb.shape=}")

            # (1, 2, 128)
            cos = emb.cos() * self.attention_scaling

            # (1, 2, 128)
            sin = emb.sin() * self.attention_scaling

        logger.info(f"MixtralRotaryEmbeddingの順伝播完了 {cos.shape=} {sin.shape=}")
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

In [None]:
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    logger.info(f"apply_rotary_pos_embの開始 {q.shape=} {k.shape=} {cos.shape=} {sin.shape=} {unsqueeze_dim=}")

    # (1, 2, 128) -> (1, 1, 2, 128)
    cos = cos.unsqueeze(unsqueeze_dim)

    # (1, 2, 128) -> (1, 1, 2, 128)
    sin = sin.unsqueeze(unsqueeze_dim)

    # (1, 32, 2, 128) * (1, 1, 2, 128) + (1, 32, 2, 128) * (1, 1, 2, 128) -> (1, 32, 2, 128)
    q_embed = (q * cos) + (rotate_half(q) * sin)

    # (1, 8, 2, 128) * (1, 1, 2, 128) + (1, 8, 2, 128) * (1, 1, 2, 128) -> (1, 8, 2, 128)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    logger.info(f"apply_rotary_pos_embの完了 {q_embed.shape=} {k_embed.shape=}")
    return q_embed, k_embed

### MixtralAttention

In [None]:
class MixtralAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: MixtralConfig, layer_idx: int):
        logger.info(f"MixtralAttentionの初期化開始 {config.hidden_size=}, {config.num_attention_heads=}, {config.num_key_value_heads=}, {getattr(config, 'head_dim', None) or config.hidden_size // config.num_attention_heads=}, {config.attention_dropout=}, {layer_idx=}")

        super().__init__()

        self.config = config

        # 0, 1, 2, ..., 31
        self.layer_idx = layer_idx

        # 4096 // 32 = 128
        self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
        logger.debug(f"{self.head_dim=}")

        # 32 // 8 = 4
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        logger.debug(f"{self.num_key_value_groups=}")

        # 0.088388347
        self.scaling = self.head_dim**-0.5
        logger.debug(f"{self.scaling=}")

        # 0.0
        self.attention_dropout = config.attention_dropout
        logger.debug(f"{self.attention_dropout=}")

        self.is_causal = True

        # 4096 -> 32 * 128 = 4096
        self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)

        # 4096 -> 8 * 128 = 1024
        self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)

        # 4096 -> 8 * 128 = 1024
        self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)

        # 32 * 128 = 4096 -> 4096
        self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
        logger.info("MixtralAttentionの初期化完了")

    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        logger.info(f"MixtralAttentionの順伝播開始 {hidden_states.shape=} {attention_mask.shape if attention_mask is not None else None} {past_key_values is not None=} {cache_position.shape if cache_position is not None else None}")

        # (1, 2, 4096) -> (1, 2)
        input_shape = hidden_states.shape[:-1]

        # (1, 2, 4096) -> (1, 2, -1, 128)
        hidden_shape = (*input_shape, -1, self.head_dim)
        logger.debug(f"{hidden_shape=}")

        # (1, 2, 4096) -> (1, 2, 8192) -> (1, 2, 32, 128) -> (1, 32, 2, 128)
        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        logger.debug(f"{query_states.shape=}")

        # (1, 2, 4096) -> (1, 2, 1024) -> (1, 2, 8, 128) -> (1, 8, 2, 128)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        logger.debug(f"{key_states.shape=}")

        # (1, 2, 4096) -> (1, 2, 1024) -> (1, 2, 8, 128) -> (1, 8, 2, 128)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        logger.debug(f"{value_states.shape=}")

        # (1, 2, 128), (1, 2, 128)
        cos, sin = position_embeddings

        # (1, 32, 2, 128), (1, 8, 2, 128)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_values is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            # (1, 8, 2, 128), (1, 8, 2, 128)
            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
            logger.debug(f"過去のキー・バリューを更新 {key_states.shape=} {value_states.shape=}")

        # SDPA
        attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
        logger.debug(f"アテンションの実装を選択 {attention_interface=}")

        logger.debug(f"アテンションを計算開始 {query_states.shape=} {key_states.shape=} {value_states.shape=} {attention_mask.shape if attention_mask is not None else None} {self.scaling=} sliding_window={getattr(self.config, 'sliding_window', None)} dropout={0.0 if not self.training else self.attention_dropout}")

        # (1, 32, 2, 128) -> (1, 32, 2, 128)
        attn_output, attn_weights = attention_interface(
            self,
            query_states, # (1, 32, 2, 128)
            key_states, # (1, 8, 2, 128)
            value_states, # (1, 8, 2, 128)
            attention_mask, # None
            dropout=0.0 if not self.training else self.attention_dropout, # 0.0
            scaling=self.scaling, # 0.088388347
            sliding_window=getattr(self.config, "sliding_window", None), # None
            **kwargs,
        )

        logger.debug(f"アテンションを計算完了 {attn_output.shape=} {attn_weights.shape if attn_weights is not None else None}")

        # (1, 32, 2, 128) -> (1, 2, 4096)
        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        logger.debug(f"アテンション出力を整形 {attn_output.shape=}")

        # (1, 2, 4096) -> (1, 2, 4096)
        attn_output = self.o_proj(attn_output)
        logger.debug(f"出力のプロジェクションを適用 {attn_output.shape=}")

        logger.info(f"MixtralAttentionの順伝播完了 {attn_output.shape=}")
        return attn_output, attn_weights

### MixtralDecoderLayer

MixtralDecoderLayerは、アテンションブロックとSMoEブロックからなるTransformerのデコーダークラス

In [None]:
class MixtralDecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: MixtralConfig, layer_idx: int):
        logger.info(f"MixtralDecoderLayerの初期化開始 {config.hidden_size=}, {config.num_attention_heads=}, {config.intermediate_size=}, {layer_idx=}")
        super().__init__()

        # 4096
        self.hidden_size = config.hidden_size

        # セルフアテンションを初期化
        self.self_attn = MixtralAttention(config, layer_idx)

        # MoEブロックを初期化
        self.block_sparse_moe = MixtralSparseMoeBlock(config)

        # アテンションの前のレイヤー正規化
        self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # アテンションの後のレイヤー正規化
        self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        logger.info("MixtralDecoderLayerの初期化完了")

    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.FloatTensor:
        logger.info(f"MixtralDecoderLayerの順伝播開始 {hidden_states.shape=} {attention_mask.shape if attention_mask is not None else None} {position_ids.shape if position_ids is not None else None} {past_key_values is not None=} {cache_position.shape if cache_position is not None else None}")

        ###################
        # セルフアテンション #
        ###################

        # 残差接続用にコピー
        # (1, 2, 4096)
        residual = hidden_states

        # レイヤー正規化を適用
        # (1, 2, 4096) -> (1, 2, 4096)
        hidden_states = self.input_layernorm(hidden_states)

        # セルフアテンションを適用
        hidden_states, _ = self.self_attn(
            hidden_states=hidden_states,
            position_embeddings=position_embeddings,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            cache_position=cache_position,
            **kwargs,
        )

        # 残差接続を適用
        # (1, 2, 4096) + (1, 2, 4096) -> (1, 2, 4096)
        hidden_states = residual + hidden_states
        logger.debug(f"残差接続を適用 {hidden_states.shape=}")

        ####################
        # SMoEブロックを適用 #
        ####################

        # 残差接続用にコピー
        # (1, 2, 4096)
        residual = hidden_states

        # (1, 2, 4096) -> (1, 2, 4096)
        hidden_states = self.post_attention_layernorm(hidden_states)

        # (1, 2, 4096) -> (1, 2, 4096)
        hidden_states, _ = self.block_sparse_moe(hidden_states)

        # (1, 2, 4096) + (1, 2, 4096) -> (1, 2, 4096)
        hidden_states = residual + hidden_states
        logger.debug(f"残差接続を適用 {hidden_states.shape=}")

        logger.info(f"MixtralDecoderLayerの順伝播完了 {hidden_states.shape=}")
        return hidden_states


### MixtralPreTrainedModel

In [None]:
class MixtralPreTrainedModel(PreTrainedModel):
    config: MixtralConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["MixtralDecoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn = True
    _supports_sdpa = True
    _supports_flex_attn = True
    _can_compile_fullgraph = False  # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
    _supports_attention_backend = True
    _can_record_outputs = {
        "router_logits": OutputRecorder(MixtralSparseMoeBlock, index=1),
        "hidden_states": MixtralDecoderLayer,
        "attentions": MixtralAttention,
    }

### MixtralModel

In [None]:
class MixtralModel(MixtralPreTrainedModel):
    def __init__(self, config: MixtralConfig):
        logger.info(f"MixtralModelの初期化開始 {config.vocab_size=}, {config.hidden_size=}, {config.num_hidden_layers=}, {config.pad_token_id=}")

        super().__init__(config)

        self.padding_idx = config.pad_token_id

        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)

        # 32層のデコーダーレイヤーを初期化
        self.layers = nn.ModuleList(
            [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )

        self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.rotary_emb = MixtralRotaryEmbedding(config=config)

        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

        logger.info("MixtralModelの初期化完了")

    @check_model_inputs
    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> MoeModelOutputWithPast:
        logger.info(f"MixtralModelの順伝播開始 {input_ids.shape=} {attention_mask.shape=} {position_ids.shape=} {past_key_values=} {inputs_embeds.shape if inputs_embeds is not None else None} {use_cache=} {cache_position=}")

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache(config=self.config)
            logger.debug("DynamicCacheを初期化")

        if inputs_embeds is None:
            # (1, 2, 4096)
            inputs_embeds = self.embed_tokens(input_ids)
            logger.debug(f"入力の埋め込みを取得 {inputs_embeds.shape=}")

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # create_causal_mask
        mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
        logger.debug(f"マスク関数を選択 {mask_function=}")

        # None
        causal_mask = mask_function(
            config=self.config,
            input_embeds=inputs_embeds,
            attention_mask=attention_mask,
            cache_position=cache_position,
            past_key_values=past_key_values,
            position_ids=position_ids,
        )
        logger.debug(f"因果マスクを作成 {causal_mask=}")

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        # (1, 2, 128), (1, 2, 128)
        position_embeddings = self.rotary_emb(hidden_states, position_ids)
        logger.debug(f"位置埋め込みを作成 {position_embeddings[0].shape=} {position_embeddings[1].shape=}")

        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            hidden_states = decoder_layer(
                hidden_states,
                position_embeddings=position_embeddings,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                use_cache=use_cache,
                cache_position=cache_position,
                **kwargs,
            )

        # (1, 2, 4096) -> (1, 2, 4096)
        hidden_states = self.norm(hidden_states)

        res = MoeModelOutputWithPast(  # only diff with Mistral is the output type, we need MoE
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
        )
        logger.info(f"MixtralModelの順伝播完了 {hidden_states.shape=}")
        return res

### MixtralForCausalLM

In [None]:
class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]
    _tp_plan = {"lm_head": "colwise_rep"}
    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

    def __init__(self, config):
        logger.info("MixtralForCausalLMの初期化開始")
        super().__init__(config)
        self.model = MixtralModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.router_aux_loss_coef = config.router_aux_loss_coef
        self.num_experts = config.num_local_experts
        self.num_experts_per_tok = config.num_experts_per_tok

        # Initialize weights and apply final processing
        self.post_init()
        logger.info("MixtralForCausalLMの初期化完了")

    @can_return_tuple
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_router_logits: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[TransformersKwargs],
    ) -> MoeCausalLMOutputWithPast:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, MixtralForCausalLM

        >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
        >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""

        logger.info(f"MixtralForCausalLMの順伝播開始 {input_ids.shape if input_ids is not None else None} {attention_mask.shape if attention_mask is not None else None} {position_ids.shape if position_ids is not None else None} {past_key_values is not None=} {inputs_embeds.shape if inputs_embeds is not None else None} {labels.shape if labels is not None else None} {use_cache=} {output_router_logits=} {cache_position.shape if cache_position is not None else None} {logits_to_keep=}")

        output_router_logits = (
            output_router_logits if output_router_logits is not None else self.config.output_router_logits
        )
        logger.debug(f"{output_router_logits=}")

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs: MoeModelOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_router_logits=output_router_logits,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs.last_hidden_state
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep

        # (1, 2, 4096) -> (1, 2, 32000)
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)

        aux_loss = None
        if output_router_logits:
            aux_loss = load_balancing_loss_func(
                outputs.router_logits,
                self.num_experts,
                self.num_experts_per_tok,
                attention_mask,
            )
            if labels is not None:
                loss += self.router_aux_loss_coef * aux_loss.to(loss.device)  # make sure to reside in the same device

        res = MoeCausalLMOutputWithPast(
            loss=loss,
            aux_loss=aux_loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            router_logits=outputs.router_logits,
        )
        logger.info(f"MixtralForCausalLMの順伝播完了 {logits.shape=}")
        return res

### 推論

In [None]:
tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
tokenizer

In [None]:
model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", dtype=torch.bfloat16, load_in_4bit=True, device_map="auto")
model

In [None]:
text= "Hello"
logger.info(f"入力プロンプト {text=}")

In [None]:
model_inputs = tokenizer(text, return_tensors="pt")
model_inputs

In [None]:
generated_ids = model.generate(**model_inputs, max_new_tokens=1)
generated_ids

In [None]:
decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
logger.info(f"生成されたテキスト {decoded}")

In [None]:
logger.setLevel(logging.WARN)
text= "Fukuoka is"
model_inputs = tokenizer(text, return_tensors="pt")
generated_ids = model.generate(**model_inputs)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))