In [1]:
from typing import Any, Dict, OrderedDict, Optional, List, Tuple, Union

from dataclasses import dataclass
import json
import math

import regex as re
import torch
import torch.nn.functional as F
from transformers import MistralConfig, MistralForCausalLM, LlamaTokenizer
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pretrained_model_name_or_path = "echarlaix/tiny-random-mistral"

config = MistralConfig.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
model = MistralForCausalLM.from_pretrained(
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

  return self.fget.__get__(instance, owner)()
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [3]:
model

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 32, padding_idx=0)
    (layers): ModuleList(
      (0-1): 2 x MistralDecoderLayer(
        (self_attn): MistralFlashAttention2(
          (q_proj): Linear(in_features=32, out_features=32, bias=False)
          (k_proj): Linear(in_features=32, out_features=16, bias=False)
          (v_proj): Linear(in_features=32, out_features=16, bias=False)
          (o_proj): Linear(in_features=32, out_features=32, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=32, out_features=37, bias=False)
          (up_proj): Linear(in_features=32, out_features=37, bias=False)
          (down_proj): Linear(in_features=37, out_features=32, bias=False)
          (act_fn): GELUActivation()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm()
  )
  

In [4]:
config

MistralConfig {
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 32,
  "initializer_range": 0.02,
  "intermediate_size": 37,
  "is_decoder": true,
  "max_position_embeddings": 512,
  "model_type": "mistral",
  "num_attention_heads": 4,
  "num_hidden_layers": 2,
  "num_key_value_heads": 2,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "rope_theta": 10000.0,
  "sliding_window": 4096,
  "tie_word_embeddings": false,
  "torch_dtype": "float32",
  "transformers_version": "4.38.2",
  "type_vocab_size": 16,
  "use_cache": true,
  "vocab_size": 32000
}

In [9]:
MISTRAL_ATTENTION_CLASSES = {
    "eager": MistralAttention,
    "flash_attention_2": MistralFlashAttention2,
    "sdpa": MistralSdpaAttention
}

NameError: name 'MistralAttention' is not defined

In [16]:
def repeat_kv(
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    num_kv_groups: int,
) -> torch.Tensor:
    batch_size, num_heads, seq_len, head_size = key_states.size()
    key_states = key_states[:, :, None, :, :].expand(
        batch_size,
        num_heads,
        num_kv_groups,
        seq_len,
        head_size,
    )

    value_states = value_states[:, :, None, :, :].expand(
        batch_size,
        num_heads,
        num_kv_groups,
        seq_len,
        head_size,
    )

    return (
        key_states.reshape(batch_size, num_heads * num_kv_groups, seq_len, head_size),
        value_states.reshape(batch_size, num_heads * num_kv_groups, seq_len, head_size),
    )

## Cache

In [18]:
class Cache:
    """Base, abstract class for all caches."""
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return the updated key and value states."""
        raise NotImplementedError("Make sure to implement `update` method in a subclass.")
    
    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        raise NotImplementedError("Make sure to implement `get_seq_length` in subclass.")
    
    def get_max_length(self) -> int:
        raise NotImplementedError("Make sure to implement `get_max_length` in subclass.")
    
    def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
        max_length = self.get_max_length()
        previous_seq_length = self.get_seq_length(layer_idx=layer_idx)

        if max_length is not None and previous_seq_length + new_seq_length > max_length:
            return max_length - new_seq_length
        
        return previous_seq_length
    

class DynamicCache(Cache):
    def __init__(self) -> None:
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        self.seen_tokens = 0  # Used in `generate`: how many tokens the cache has seen

    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
        if layer_idx < len(self):
            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
        
        raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
    
    def __iter__(self):
        for layer_idx in range(len(self)):
            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])

    def __len__(self):
        return len(self.key_cache)
    
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if layer_idx == 0:
            self.seen_tokens += key_states.shape[-2]

        # Update the cache
        if len(self.key_cache) <= layer_idx:
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        else:
            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]
    
    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        if len(self.key_cache) <= layer_idx:
            return 0
        
        return self.key_cache[layer_idx].shape[-2]
    
    def get_max_length(self) -> Optional[int]:
        return None
    
    def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
        for layer_idx in range(len(self.key_cache)):
            device = self.key_cache[layer_idx].device
            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))

            device = self.value_cache[layer_idx].device
            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        legacy_cache = ()
        for layer_idx in range(len(self)):
            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)

        return legacy_cache
    
    @classmethod
    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
        cache = cls()
        if past_key_values is not None:
            for layer_idx in range(len(past_key_values)):
                key_states, value_states = past_key_values[layer_idx]
                cache.update(
                    key_states=key_states,
                    value_states=value_states,
                    layer_idx=layer_idx,
                )
        return cache

## MistralRMSNorm
$\sigma = \sqrt{\frac{1}{n}\sum_{i=1}^{n}x_{i}^{2}} \newline$
$\widehat{x} = g\frac{x_i}{\sigma}+b$

But mistral does not use `bias` in RMSNorm

In [23]:
class MistralRMSNorm(torch.nn.Module):
    def __init__(self, hidden_size: int, eps=1e-6) -> None:
        """
        The RMSNorm is implemented according `modeling_mistral.py`.
        It is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        original_input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
        return self.weight * hidden_states.to(original_input_dtype)

## MistralRotaryEmbedding

In [27]:
class MistralRotaryEmbedding(torch.nn.Module):
    def __init__(
        self,
        head_size: int,
        max_position_embeddings: int = 2048,
        base: int = 10000,
        device: Optional[Union[torch.device, str]] = None,
    ) -> None:
        super().__init__()
        self.theta = 1 / (base ** (torch.arange(0, head_size, 2).float() / head_size))
        self.theta = torch.cat([self.theta, self.theta], dim=-1).to(device)
        self.position_ids = torch.arange(0, max_position_embeddings).to(device)

    def forward(self, hidden_states: torch.Tensor, position_ids: torch.LongTensor) -> torch.Tensor:
        position_maxtrix = torch.outer(self.position_ids, self.theta)
        cos = torch.cos(position_maxtrix)
        sin = torch.sin(position_maxtrix)

        x1 = hidden_states[..., :hidden_states.shape[-1] // 2]
        x2 = hidden_states[..., hidden_states.shape[-1] // 2 :]
        _x = torch.cat([-x2, x1], dim=-1)
        out = hidden_states * cos[position_ids] + _x * sin[position_ids]

        return out

In [2]:
batch_size = 4
num_heads = 2
seq_len = 100
head_size = 32

query = torch.rand(batch_size, num_heads, seq_len, head_size)
key = torch.rand(batch_size, num_heads, seq_len, head_size)
value = torch.rand(batch_size, num_heads, seq_len, head_size)

In [3]:
sdpa_output = torch.nn.functional.scaled_dot_product_attention(
    query=query,
    key=key,
    value=value,
)
sdpa_output

tensor([[[[0.4986, 0.4978, 0.5310,  ..., 0.4961, 0.5031, 0.5094],
          [0.4945, 0.5042, 0.5389,  ..., 0.4876, 0.5009, 0.5086],
          [0.4953, 0.5013, 0.5344,  ..., 0.4923, 0.4959, 0.5067],
          ...,
          [0.4956, 0.5026, 0.5377,  ..., 0.4932, 0.4978, 0.5067],
          [0.4969, 0.4958, 0.5298,  ..., 0.4919, 0.5036, 0.5153],
          [0.4948, 0.5036, 0.5320,  ..., 0.4947, 0.4992, 0.5061]],

         [[0.5094, 0.4495, 0.5039,  ..., 0.4583, 0.5013, 0.5393],
          [0.5086, 0.4448, 0.5077,  ..., 0.4569, 0.5004, 0.5357],
          [0.5084, 0.4494, 0.5007,  ..., 0.4601, 0.5049, 0.5335],
          ...,
          [0.5111, 0.4432, 0.5004,  ..., 0.4572, 0.5039, 0.5322],
          [0.5077, 0.4446, 0.4999,  ..., 0.4566, 0.5044, 0.5335],
          [0.5064, 0.4423, 0.5060,  ..., 0.4557, 0.5018, 0.5348]]],


        [[[0.5381, 0.5128, 0.5104,  ..., 0.5144, 0.5449, 0.5217],
          [0.5419, 0.5128, 0.5113,  ..., 0.5127, 0.5414, 0.5202],
          [0.5393, 0.5156, 0.5109,  ...,

In [4]:
def my_spda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    # Attention weights (Q * K^T)
    attention_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(head_size)

    # Upcast attention to fp32
    attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attention_weights = torch.nn.functional.dropout(attention_weights, p=0.1, training=False)

    # Attention output (A = Q * K^T, A * V)
    attention_output = torch.matmul(attention_weights, value)
    
    return attention_output


my_spda_output = my_spda(query=query, key=key, value=value)
my_spda_output

tensor([[[[0.4986, 0.4978, 0.5310,  ..., 0.4961, 0.5031, 0.5094],
          [0.4945, 0.5042, 0.5389,  ..., 0.4876, 0.5009, 0.5086],
          [0.4953, 0.5013, 0.5344,  ..., 0.4923, 0.4959, 0.5067],
          ...,
          [0.4956, 0.5026, 0.5377,  ..., 0.4932, 0.4978, 0.5067],
          [0.4969, 0.4958, 0.5298,  ..., 0.4919, 0.5036, 0.5153],
          [0.4948, 0.5036, 0.5320,  ..., 0.4947, 0.4992, 0.5061]],

         [[0.5094, 0.4495, 0.5039,  ..., 0.4583, 0.5013, 0.5393],
          [0.5086, 0.4448, 0.5077,  ..., 0.4569, 0.5004, 0.5357],
          [0.5084, 0.4494, 0.5007,  ..., 0.4601, 0.5049, 0.5335],
          ...,
          [0.5111, 0.4432, 0.5004,  ..., 0.4572, 0.5039, 0.5322],
          [0.5077, 0.4446, 0.4999,  ..., 0.4566, 0.5044, 0.5335],
          [0.5064, 0.4423, 0.5060,  ..., 0.4557, 0.5018, 0.5348]]],


        [[[0.5381, 0.5128, 0.5104,  ..., 0.5144, 0.5449, 0.5217],
          [0.5419, 0.5128, 0.5113,  ..., 0.5127, 0.5414, 0.5202],
          [0.5393, 0.5156, 0.5109,  ...,

In [12]:
import torch
import torch.utils.benchmark as benchmark
import math


def my_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, head_size: float) -> torch.Tensor:
    # Attention weights (Q * K^T)
    attention_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(head_size)

    # Upcast attention to fp32
    attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attention_weights = torch.nn.functional.dropout(attention_weights, p=0.1, training=False)

    # Attention output (A = Q * K^T, A * V)
    attention_output = torch.matmul(attention_weights, value)
    
    return attention_output


def benchmark_my_sdpa():
    return my_sdpa(query, key, value, head_size)

# Testing
t = benchmark.Timer(
    stmt="benchmark_my_sdpa()",
    setup="from __main__ import benchmark_my_sdpa",
    globals=globals(),
    num_threads=torch.get_num_threads(),
)

# Result
print("My Implementation:", t.timeit(100000))

My Implementation: <torch.utils.benchmark.utils.common.Measurement object at 0x7fd5b23f88b0>
benchmark_my_sdpa()
setup: from __main__ import benchmark_my_sdpa
  131.17 us
  1 measurement, 100000 runs , 8 threads


In [13]:
def torch_sdpa(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.scaled_dot_product_attention(
        query=query,
        key=key,
        value=value,
    )


def benchmark_torch_sdpa():
    return torch_sdpa(query, key, value)


# Testing
t = benchmark.Timer(
    stmt="benchmark_torch_sdpa()",
    setup="from __main__ import benchmark_torch_sdpa",
    globals=globals(),
    num_threads=torch.get_num_threads(),
)

# Result
print("Torch Implementation:", t.timeit(100000))

Torch Implementation: <torch.utils.benchmark.utils.common.Measurement object at 0x7fd5b17d5a20>
benchmark_torch_sdpa()
setup: from __main__ import benchmark_torch_sdpa
  90.94 us
  1 measurement, 100000 runs , 8 threads


In [68]:
torch.allclose(sdpa_output, my_spda_output, atol=1e-8)

True

## MistralAttention

In [40]:
class MistralAttention(torch.nn.Module):
    def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None) -> None:
        super().__init__()

        # Init
        self.num_q_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        self.num_kv_groups = self.num_q_heads // self.num_kv_heads
        self.head_size = config.hidden_size // self.num_q_heads
        self.hidden_size = config.hidden_size
        self.attention_dropout = config.attention_dropout

        self.layer_idx = layer_idx

        # QKVO Layer
        self.q_proj = torch.nn.Linear(
            in_features=self.hidden_size,
            out_features=self.hidden_size,
            bias=False,
        )
        self.k_proj = torch.nn.Linear(
            in_features=self.hidden_size,
            out_features=self.num_kv_heads * self.head_size,
            bias=False,
        )
        self.v_proj = torch.nn.Linear(
            in_features=self.hidden_size,
            out_features=self.num_kv_heads * self.head_size,
            bias=False,
        )
        self.o_proj = torch.nn.Linear(
            in_features=self.hidden_size,
            out_features=self.hidden_size,
            bias=False,
        )

        # RoPE
        self.rotary_emb = MistralRotaryEmbedding(
            head_size=self.head_size,
            max_position_embeddings=config.max_position_embeddings,
            base=config.rope_theta,
        )
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # Init
        batch_size, seq_len, hidden_size = hidden_states.size()

        # QKV
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Reshape
        query_states = query_states.view(batch_size, seq_len, self.num_q_heads, self.head_size).transpose(1, 2)
        key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_size).transpose(1, 2)
        value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_size).transpose(1, 2)

        # KV Cache
        kv_seq_len = key_states.size(2)
        if past_key_value is not None and self.layer_idx is not None:
            kv_seq_len += past_key_value.get_usable_length(
                new_seq_length=kv_seq_len,
                layer_idx=self.layer_idx,
            )

        query_states = self.rotary_emb(
            hidden_states=query_states,
            position_ids=position_ids,
        )
        key_states = self.rotary_emb(
            hidden_states=key_states,
            position_ids=position_ids,
        )

        if past_key_value is not None:
            key_states, value_states = past_key_value.update(
                key_states=key_states,
                value_states=value_states,
                layer_idx=self.layer_idx,
            )
        
        # Repeat kv heads
        key_states, value_states = repeat_kv(
            key_states=key_states,
            value_states=value_states,
            num_kv_groups=self.num_kv_groups,
        )

        # Attention weights (Q * K^T)
        attention_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_size)

        # Attention mask
        if attention_mask is not None:
            attention_weights = attention_weights + attention_mask

        # Upcast attention to fp32
        attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attention_weights = torch.nn.functional.dropout(attention_weights, p=self.attention_dropout, training=self.training)

        # Attention output (A = Q * K^T, A * V)
        attention_output = torch.matmul(attention_weights, value_states).reshape(batch_size, seq_len, self.hidden_size)
        attention_output = self.o_proj(attention_output)

        if not output_attentions:
            attention_weights = None

        return attention_output, attention_weights, past_key_value

## MistralSdpaAttention

In [None]:
class MistralSdpaAttention(torch.nn.Module):
    def __init__(self, config: MistralConfig) -> None:
        """
        The mutliple query heads will using the same kv head.
        So `num_attention_heads` > `num_key_value_heads`.
        """
        super().__init__()

        # Init
        self.head_size = config.hidden_size // config.num_attention_heads

        # QKVO Layer
        self.q_proj = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size,
            bias=False,
        )
        self.k_proj = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.num_key_value_heads * self.head_size,
            bias=False,
        )
        self.v_proj = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.num_key_value_heads * self.head_size,
            bias=False,
        )
        self.o_proj = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.hidden_size,
            bias=False,
        )

        # RoPE
        self.rotary_emb = MistralRotaryEmbedding(
            head_size=self.head_size,
            max_position_embeddings=config.max_position_embeddings,
            base=config.rope_theta,
        )

## MistralDecoderLayer

In [None]:
class MistralDecoderLayer(torch.nn.Module):
    def __init__(self, config: MistralConfig) -> None:
        super().__init__()
        self.self_attn = MistralSdpaAttention(config=config)
        self.mlp = MistralMLP(config=config)
        self.input_layernorm = MistralRMSNorm(config.hidden_size)
        self.post_attention_layernorm = MistralRMSNorm(config.hidden_size)

## MistralModel

In [None]:
class MistralModel(torch.nn.Module):
    def __init__(self, config: MistralConfig):
        super().__init__()
        self.embed_tokens = torch.nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.hidden_size,
            padding_idx=config.pad_token_id,
        )
        self.layers = torch.nn.ModuleList([MistralDecoderLayer(config=config)] for _ in range(config.num_hidden_layers))
        self.norm = MistralRMSNorm(config.hidden_size)

## MistralModelForCausalLM

In [None]:
class MyMistralModelForCausalLM(torch.nn.Module):
    def __init__(self, config: MistralConfig) -> None:
        super().__init__()
        self.model = MistralModel(config=config)
        self.lm_head = torch.nn.Linear(
            input_features=config.hidden_size,
            out_features=config.vocab_size,
        )

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor: