In [19]:
from typing import Any, Dict, OrderedDict, Optional, List, Tuple, Union

from dataclasses import dataclass
import json
import math

import regex as re
import torch
from transformers import MistralConfig, MistralForCausalLM, LlamaTokenizer
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file

In [None]:
from transformers import LlavaForConditionalGeneration

In [20]:
pretrained_model_name_or_path = "echarlaix/tiny-random-mistral"

config = MistralConfig.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
model = MistralForCausalLM.from_pretrained(
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    torch_dtype=torch.bfloat16,
)

  return self.fget.__get__(instance, owner)()


In [22]:
config

MistralConfig {
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 32,
  "initializer_range": 0.02,
  "intermediate_size": 37,
  "is_decoder": true,
  "max_position_embeddings": 512,
  "model_type": "mistral",
  "num_attention_heads": 4,
  "num_hidden_layers": 2,
  "num_key_value_heads": 2,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "rope_theta": 10000.0,
  "sliding_window": 4096,
  "tie_word_embeddings": false,
  "torch_dtype": "float32",
  "transformers_version": "4.38.2",
  "type_vocab_size": 16,
  "use_cache": true,
  "vocab_size": 32000
}

In [9]:
MISTRAL_ATTENTION_CLASSES = {
    "eager": MistralAttention,
    "flash_attention_2": MistralFlashAttention2,
    "sdpa": MistralSdpaAttention
}

NameError: name 'MistralAttention' is not defined

In [23]:
def repeat_kv(
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    num_kv_groups: int,
) -> torch.Tensor:
    batch_size, num_heads, seq_len, head_size = key_states.size()
    key_states = key_states[:, :, None, :, :].expand(
        batch_size,
        num_heads,
        num_kv_groups,
        seq_len,
        head_size,
    )

    value_states = value_states[:, :, None, :, :].expand(
        batch_size,
        num_heads,
        num_kv_groups,
        seq_len,
        head_size,
    )

    return (
        key_states.reshape(batch_size, num_heads * num_kv_groups, seq_len, head_size),
        value_states.reshape(batch_size, num_heads * num_kv_groups, seq_len, head_size),
    )

## Cache

In [27]:
class Cache:
    """Base, abstract class for all caches."""
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return the updated key and value states."""
        raise NotImplementedError("Make sure to implement `update` method in a subclass.")
    
    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        raise NotImplementedError("Make sure to implement `get_seq_length` in subclass.")
    
    def get_max_length(self) -> int:
        raise NotImplementedError("Make sure to implement `get_max_length` in subclass.")
    
    def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
        max_length = self.get_max_length()
        previous_seq_length = self.get_seq_length(layer_idx=layer_idx)

        if max_length is not None and previous_seq_length + new_seq_length > max_length:
            return max_length - new_seq_length
        
        return previous_seq_length
    

class DynamicCache(Cache):
    def __init__(self) -> None:
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        self.seen_tokens = 0  # Used in `generate`: how many tokens the cache has seen

    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
        if layer_idx < len(self):
            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
        
        raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
    
    def __iter__(self):
        for layer_idx in range(len(self)):
            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])

    def __len__(self):
        return len(self.key_cache)
    
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if layer_idx == 0:
            self.seen_tokens += key_states.shape[-2]

        # Update the cache
        if len(self.key_cache) <= layer_idx:
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        else:
            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]
    
    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        if len(self.key_cache) <= layer_idx:
            return 0
        
        return self.key_cache[layer_idx].shape[-2]
    
    def get_max_length(self) -> Optional[int]:
        return None
    
    def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
        for layer_idx in range(len(self.key_cache)):
            device = self.key_cache[layer_idx].device
            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))

            device = self.value_cache[layer_idx].device
            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        legacy_cache = ()
        for layer_idx in range(len(self)):
            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)

        return legacy_cache
    
    @classmethod
    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
        cache = cls()
        if past_key_values is not None:
            for layer_idx in range(len(past_key_values)):
                key_states, value_states = past_key_values[layer_idx]
                cache.update(
                    key_states=key_states,
                    value_states=value_states,
                    layer_idx=layer_idx,
                )
        return cache

## MistralRMSNorm
$\sigma = \sqrt{\frac{1}{n}\sum_{i=1}^{n}x_{i}^{2}} \newline$
$\widehat{x} = g\frac{x_i}{\sigma}+b$

But mistral does not use `bias` in RMSNorm

In [24]:
class MistralRMSNorm(torch.nn.Module):
    def __init__(self, hidden_size: int, eps=1e-6) -> None:
        """
        The RMSNorm is implemented according `modeling_mistral.py`.
        It is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        original_input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
        return self.weight * hidden_states.to(original_input_dtype)

## MistralRotaryEmbedding

In [25]:
class MistralRotaryEmbedding(torch.nn.Module):
    def __init__(
        self,
        head_size: int,
        max_position_embeddings: int = 2048,
        base: int = 10000,
        device: Optional[Union[torch.device, str]] = None,
    ) -> None:
        super().__init__()
        self.theta = 1 / (base ** (torch.arange(0, head_size, 2).float() / head_size))
        self.theta = torch.cat([self.theta, self.theta], dim=-1).to(device)
        self.position_ids = torch.arange(0, max_position_embeddings).to(device)

    def forward(self, hidden_states: torch.Tensor, position_ids: torch.LongTensor) -> torch.Tensor:
        position_maxtrix = torch.outer(self.position_ids, self.theta)
        cos = torch.cos(position_maxtrix)
        sin = torch.sin(position_maxtrix)

        x1 = hidden_states[..., :hidden_states.shape[-1] // 2]
        x2 = hidden_states[..., hidden_states.shape[-1] // 2 :]
        _x = torch.cat([-x2, x1], dim=-1)
        out = hidden_states * cos[position_ids] + _x * sin[position_ids]

        return out

## MistralAttention

In [28]:
class MistralAttention(torch.nn.Module):
    def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None) -> None:
        super().__init__()

        # Init
        self.num_q_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        self.num_kv_groups = self.num_q_heads // self.num_kv_heads
        self.head_size = config.hidden_size // self.num_q_heads
        self.hidden_size = config.hidden_size
        self.attention_dropout = config.attention_dropout
        self.is_causal = True

        self.layer_idx = layer_idx

        # QKVO Layer
        self.q_proj = torch.nn.Linear(
            in_features=self.hidden_size,
            out_features=self.hidden_size,
            bias=False,
        )
        self.k_proj = torch.nn.Linear(
            in_features=self.hidden_size,
            out_features=self.num_kv_heads * self.head_size,
            bias=False,
        )
        self.v_proj = torch.nn.Linear(
            in_features=self.hidden_size,
            out_features=self.num_kv_heads * self.head_size,
            bias=False,
        )
        self.o_proj = torch.nn.Linear(
            in_features=self.hidden_size,
            out_features=self.hidden_size,
            bias=False,
        )

        # RoPE
        self.rotary_emb = MistralRotaryEmbedding(
            head_size=self.head_size,
            max_position_embeddings=config.max_position_embeddings,
            base=config.rope_theta,
        )
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # Init
        batch_size, seq_len, hidden_size = hidden_states.size()

        # QKV
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Reshape
        query_states = query_states.view(batch_size, seq_len, self.num_q_heads, self.head_size).transpose(1, 2)
        key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_size).transpose(1, 2)
        value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_size).transpose(1, 2)

        # KV Cache
        kv_seq_len = key_states.size(2)
        if past_key_value is not None and self.layer_idx is not None:
            kv_seq_len += past_key_value.get_usable_length(
                new_seq_length=kv_seq_len,
                layer_idx=self.layer_idx,
            )

        query_states = self.rotary_emb(
            hidden_states=query_states,
            position_ids=position_ids,
        )
        key_states = self.rotary_emb(
            hidden_states=key_states,
            position_ids=position_ids,
        )

        if past_key_value is not None:
            key_states, value_states = past_key_value.update(
                key_states=key_states,
                value_states=value_states,
                layer_idx=self.layer_idx,
            )
        
        # Repeat kv heads
        key_states, value_states = repeat_kv(
            key_states=key_states,
            value_states=value_states,
            num_kv_groups=self.num_kv_groups,
        )

        # Attention weights (Q * K^T)
        attention_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_size)

        # Attention mask
        if attention_mask is not None:
            attention_weights = attention_weights + attention_mask

        # Upcast attention to fp32
        attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attention_weights = torch.nn.functional.dropout(attention_weights, p=self.attention_dropout, training=self.training)

        # Attention output (A = Q * K^T, A * V)
        attention_output = torch.matmul(attention_weights, value_states).reshape(batch_size, seq_len, self.hidden_size)
        attention_output = self.o_proj(attention_output)

        if not output_attentions:
            attention_weights = None

        return attention_output, attention_weights, past_key_value

## MistralSdpaAttention

In [29]:
class MistralSdpaAttention(MistralAttention):
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )
        
        batch_size, seq_len, hidden_size = hidden_states.size()

        # QKV
        query_states = self.q_proj(hidden_size)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Reshape
        query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # KV Cache
        kv_seq_len = key_states.size(2)
        if past_key_value is not None and self.layer_idx is not None:
            kv_seq_len += past_key_value.get_usable_length(
                new_seq_length=kv_seq_len,
                layer_idx=self.layer_idx,
            )

        query_states = self.rotary_emb(
            hidden_states=query_states,
            position_ids=position_ids,
        )
        key_states = self.rotary_emb(
            hidden_states=key_states,
            position_ids=position_ids,
        )

        if past_key_value is not None:
            key_states, value_states = past_key_value.update(
                key_states=key_states,
                value_states=value_states,
                layer_idx=self.layer_idx,
            )
        
        # Repeat kv heads
        key_states, value_states = repeat_kv(
            key_states=key_states,
            value_states=value_states,
            num_kv_groups=self.num_kv_groups,
        )

        # Contiguous
        query_states = query_states.contiguous()
        key_states = key_states.contiguous()
        value_states = value_states.contiguous()

        # SDPA
        attention_output = torch.nn.functional.scaled_dot_product_attention(
            query=query_states,
            key=key_states,
            value=value_states,
            attn_mask=attention_mask,
            dropout_p=self.attention_dropout,
            is_causal=self.is_causal and attention_mask is None and seq_len > 1,
        )

        attention_output = attention_output.transpose(1, 2).contiguous()
        attention_output = attention_output.view(batch_size, seq_len, hidden_size)
        attention_output = self.o_proj(attention_output)

        return attention_output, None, past_key_value

## MistralMLP

In [None]:
class MistralMLP(torch.nn.Module):
    def __init__(self, config: MistralConfig) -> None:
        super().__init__()
        self.gate_proj = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.intermediate_size,
            bias=False,
        )
        self.up_proj = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.intermediate_size,
            bias=False,
        )
        self.down_proj = torch.nn.Linear(
            in_features=config.intermediate_size,
            out_features=config.hidden_size,
            bias=False,
        )
        self.act_fn = torch.nn.functional.gelu

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        up_output = self.up_proj(x)
        gate_output = self.gate_proj(x)
        intermediate_output = self.act_fn(gate_output) + up_output
        down_output = self.down_proj(intermediate_output)
        return down_output

## MistralDecoderLayer

In [42]:
class MistralDecoderLayer(torch.nn.Module):
    def __init__(self, config: MistralConfig) -> None:
        super().__init__()
        self.self_attn = MistralSdpaAttention(config=config)
        self.mlp = MistralMLP(config=config)
        self.input_layernorm = MistralRMSNorm(config.hidden_size)
        self.post_attention_layernorm = MistralRMSNorm(config.hidden_size)

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attention_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )

        # Redisual connection
        hidden_states = hidden_states + residual

        # Fully connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attention_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs


## MistralModel

In [None]:
class MistralModel(torch.nn.Module):
    def __init__(self, config: MistralConfig):
        super().__init__()
        self.embed_tokens = torch.nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.hidden_size,
            padding_idx=config.pad_token_id,
        )
        self.layers = torch.nn.ModuleList([MistralDecoderLayer(config=config)] for _ in range(config.num_hidden_layers))
        self.norm = MistralRMSNorm(config.hidden_size)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        input_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        

## MistralModelForCausalLM

In [None]:
class MyMistralModelForCausalLM(torch.nn.Module):
    def __init__(self, config: MistralConfig) -> None:
        super().__init__()
        self.model = MistralModel(config=config)
        self.lm_head = torch.nn.Linear(
            input_features=config.hidden_size,
            out_features=config.vocab_size,
        )

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor: