In [146]:
from typing import Any, Dict, OrderedDict, Optional, List, Tuple, Union

from dataclasses import dataclass
import json
import math

import regex as re
import torch
from transformers import AutoConfig, MistralForCausalLM, LlamaTokenizer
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file

In [147]:
pretrained_model_name_or_path = "echarlaix/tiny-random-mistral"

config = AutoConfig.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
model = MistralForCausalLM.from_pretrained(
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    torch_dtype=torch.bfloat16,
)
tokenizer = LlamaTokenizer.from_pretrained(
    pretrained_model_name_or_path=pretrained_model_name_or_path,
)

In [148]:
config

MistralConfig {
  "_name_or_path": "echarlaix/tiny-random-mistral",
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 32,
  "initializer_range": 0.02,
  "intermediate_size": 37,
  "is_decoder": true,
  "max_position_embeddings": 512,
  "model_type": "mistral",
  "num_attention_heads": 4,
  "num_hidden_layers": 2,
  "num_key_value_heads": 2,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "rope_theta": 10000.0,
  "sliding_window": 4096,
  "tie_word_embeddings": false,
  "torch_dtype": "float32",
  "transformers_version": "4.39.1",
  "type_vocab_size": 16,
  "use_cache": true,
  "vocab_size": 32000
}

In [149]:
class MistralConfig:
    def __init__(
        self,
        _name_or_path: str = "echarlaix/tiny-random-mistral",
        _attn_implementation: str = "sdpa",
        architectures: List[str] = [
            "MistralForCausalLM"
        ],
        attention_dropout: float = 0.0,
        attention_probs_dropout_prob: float = 0.1,
        bos_token_id: int = 1,
        eos_token_id: int = 2,
        hidden_act: str = "gelu",
        hidden_dropout_prob: float = 0.1,
        hidden_size: int = 32,
        initializer_range: float = 0.02,
        intermediate_size: int = 37,
        is_decoder: bool = True,
        max_position_embeddings: int = 512,
        model_type: str = "mistral",
        num_attention_heads: int = 4,
        num_hidden_layers: int = 2,
        num_key_value_heads: int = 2,
        pad_token_id: int = 0,
        rms_norm_eps: float = 1e-06,
        rope_theta: float = 10000.0,
        sliding_window: int = 4096,
        tie_word_embeddings: bool = False,
        torch_dtype: str = "float32",
        transformers_version: str = "4.39.1",
        type_vocab_size: int = 16,
        use_cache: bool = True,
        vocab_size: int = 32000,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
    ) -> None:
        self._name_or_path = _name_or_path
        self._attn_implementation = _attn_implementation
        self.architectures = architectures
        self.attention_dropout = attention_dropout
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.hidden_size = hidden_size
        self.initializer_range = initializer_range
        self.intermediate_size = intermediate_size
        self.is_decoder = is_decoder
        self.max_position_embeddings = max_position_embeddings
        self.model_type = model_type
        self.num_attention_heads = num_attention_heads
        self.num_hidden_layers = num_hidden_layers
        self.num_key_value_heads = num_key_value_heads
        self.pad_token_id = pad_token_id
        self.rms_norm_eps = rms_norm_eps
        self.rope_theta = rope_theta
        self.sliding_window = sliding_window
        self.tie_word_embeddings = tie_word_embeddings
        self.torch_dtype = torch_dtype
        self.transformers_version = transformers_version
        self.type_vocab_size = type_vocab_size
        self.use_cache = use_cache
        self.vocab_size = vocab_size
        self.output_attentions = output_attentions
        self.output_hidden_states = output_hidden_states
        
    @staticmethod
    def from_pretrained_model_or_path(pretrained_model_name_or_path: str) -> "MistralConfig":
        resolved_archive_file = cached_file(
            path_or_repo_id=pretrained_model_name_or_path,
            filename=CONFIG_NAME,
            _raise_exceptions_for_missing_entries=False,
        )
        
        config_content = json.load(open(resolved_archive_file))
        return MistralConfig(**config_content)        

In [150]:
config = MistralConfig.from_pretrained_model_or_path(pretrained_model_name_or_path=pretrained_model_name_or_path)

## Output Format

In [151]:
@dataclass
class BaseModelOutputWithPast:
    last_hidden_state: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


@dataclass
class CausalLMOutputWithPast:
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None

In [152]:
def repeat_kv(
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    num_kv_groups: int,
) -> torch.Tensor:
    batch_size, num_heads, seq_len, head_size = key_states.size()
    key_states = key_states[:, :, None, :, :].expand(
        batch_size,
        num_heads,
        num_kv_groups,
        seq_len,
        head_size,
    )

    value_states = value_states[:, :, None, :, :].expand(
        batch_size,
        num_heads,
        num_kv_groups,
        seq_len,
        head_size,
    )

    return (
        key_states.reshape(batch_size, num_heads * num_kv_groups, seq_len, head_size),
        value_states.reshape(batch_size, num_heads * num_kv_groups, seq_len, head_size),
    )

## Cache

In [153]:
class Cache:
    """Base, abstract class for all caches."""
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return the updated key and value states."""
        raise NotImplementedError("Make sure to implement `update` method in a subclass.")
    
    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        raise NotImplementedError("Make sure to implement `get_seq_length` in subclass.")
    
    def get_max_length(self) -> int:
        raise NotImplementedError("Make sure to implement `get_max_length` in subclass.")
    
    def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
        max_length = self.get_max_length()
        previous_seq_length = self.get_seq_length(layer_idx=layer_idx)

        if max_length is not None and previous_seq_length + new_seq_length > max_length:
            return max_length - new_seq_length
        
        return previous_seq_length
    

class DynamicCache(Cache):
    def __init__(self) -> None:
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        self.seen_tokens = 0  # Used in `generate`: how many tokens the cache has seen

    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
        if layer_idx < len(self):
            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
        
        raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
    
    def __iter__(self):
        for layer_idx in range(len(self)):
            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])

    def __len__(self):
        return len(self.key_cache)
    
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if layer_idx == 0:
            self.seen_tokens += key_states.shape[-2]

        # Update the cache
        if len(self.key_cache) <= layer_idx:
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        else:
            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]
    
    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        if len(self.key_cache) <= layer_idx:
            return 0
        
        return self.key_cache[layer_idx].shape[-2]
    
    def get_max_length(self) -> Optional[int]:
        return None
    
    def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
        for layer_idx in range(len(self.key_cache)):
            device = self.key_cache[layer_idx].device
            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))

            device = self.value_cache[layer_idx].device
            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        legacy_cache = ()
        for layer_idx in range(len(self)):
            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)

        return legacy_cache
    
    @classmethod
    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
        cache = cls()
        if past_key_values is not None:
            for layer_idx in range(len(past_key_values)):
                key_states, value_states = past_key_values[layer_idx]
                cache.update(
                    key_states=key_states,
                    value_states=value_states,
                    layer_idx=layer_idx,
                )
        return cache

## MistralRMSNorm
$\sigma = \sqrt{\frac{1}{n}\sum_{i=1}^{n}x_{i}^{2}} \newline$
$\widehat{x} = g\frac{x_i}{\sigma}+b$

But mistral does not use `bias` in RMSNorm

In [154]:
class MistralRMSNorm(torch.nn.Module):
    def __init__(self, hidden_size: int, eps=1e-6) -> None:
        """
        The RMSNorm is implemented according `modeling_mistral.py`.
        It is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        original_input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
        return self.weight * hidden_states.to(original_input_dtype)

## MistralRotaryEmbedding

In [155]:
class MistralRotaryEmbedding(torch.nn.Module):
    def __init__(
        self,
        head_size: int,
        max_position_embeddings: int = 2048,
        base: int = 10000,
        device: Optional[Union[torch.device, str]] = None,
    ) -> None:
        super().__init__()
        self.theta = 1 / (base ** (torch.arange(0, head_size, 2).float() / head_size))
        self.theta = torch.cat([self.theta, self.theta], dim=-1).to(device)
        self.position_ids = torch.arange(0, max_position_embeddings).to(device)

    def forward(self, hidden_states: torch.Tensor, position_ids: torch.LongTensor) -> torch.Tensor:
        position_maxtrix = torch.outer(self.position_ids, self.theta)
        cos = torch.cos(position_maxtrix)
        sin = torch.sin(position_maxtrix)

        x1 = hidden_states[..., :hidden_states.shape[-1] // 2]
        x2 = hidden_states[..., hidden_states.shape[-1] // 2 :]
        _x = torch.cat([-x2, x1], dim=-1)
        out = hidden_states * cos[position_ids] + _x * sin[position_ids]

        return out

## MistralAttention

In [156]:
class MistralAttention(torch.nn.Module):
    def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None) -> None:
        super().__init__()

        # Init
        self.num_q_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        self.num_kv_groups = self.num_q_heads // self.num_kv_heads
        self.head_size = config.hidden_size // self.num_q_heads
        self.hidden_size = config.hidden_size
        self.attention_dropout = config.attention_dropout
        self.is_causal = True

        self.layer_idx = layer_idx

        # QKVO Layer
        self.q_proj = torch.nn.Linear(
            in_features=self.hidden_size,
            out_features=self.hidden_size,
            bias=False,
        )
        self.k_proj = torch.nn.Linear(
            in_features=self.hidden_size,
            out_features=self.num_kv_heads * self.head_size,
            bias=False,
        )
        self.v_proj = torch.nn.Linear(
            in_features=self.hidden_size,
            out_features=self.num_kv_heads * self.head_size,
            bias=False,
        )
        self.o_proj = torch.nn.Linear(
            in_features=self.hidden_size,
            out_features=self.hidden_size,
            bias=False,
        )

        # RoPE
        self.rotary_emb = MistralRotaryEmbedding(
            head_size=self.head_size,
            max_position_embeddings=config.max_position_embeddings,
            base=config.rope_theta,
        )
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # Init
        batch_size, seq_len, hidden_size = hidden_states.size()

        # QKV
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Reshape
        query_states = query_states.view(batch_size, seq_len, self.num_q_heads, self.head_size).transpose(1, 2)
        key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_size).transpose(1, 2)
        value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_size).transpose(1, 2)

        # KV Cache
        kv_seq_len = key_states.size(2)
        if past_key_value is not None and self.layer_idx is not None:
            kv_seq_len += past_key_value.get_usable_length(
                new_seq_length=kv_seq_len,
                layer_idx=self.layer_idx,
            )

        query_states = self.rotary_emb(
            hidden_states=query_states,
            position_ids=position_ids,
        )
        key_states = self.rotary_emb(
            hidden_states=key_states,
            position_ids=position_ids,
        )

        if past_key_value is not None:
            key_states, value_states = past_key_value.update(
                key_states=key_states,
                value_states=value_states,
                layer_idx=self.layer_idx,
            )
        
        # Repeat kv heads
        key_states, value_states = repeat_kv(
            key_states=key_states,
            value_states=value_states,
            num_kv_groups=self.num_kv_groups,
        )

        # Attention weights (Q * K^T)
        attention_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_size)

        # Attention mask
        if attention_mask is not None:
            attention_weights = attention_weights + attention_mask

        # Upcast attention to fp32
        attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attention_weights = torch.nn.functional.dropout(attention_weights, p=self.attention_dropout, training=self.training)

        # Attention output (A = Q * K^T, A * V)
        attention_output = torch.matmul(attention_weights, value_states).reshape(batch_size, seq_len, self.hidden_size)
        attention_output = self.o_proj(attention_output)

        if not output_attentions:
            attention_weights = None

        return attention_output, attention_weights, past_key_value

## MistralSdpaAttention

In [157]:
class MistralSdpaAttention(MistralAttention):
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )
        
        batch_size, seq_len, hidden_size = hidden_states.size()

        # QKV
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Reshape
        query_states = query_states.view(batch_size, seq_len, self.num_q_heads, self.head_size).transpose(1, 2)
        key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_size).transpose(1, 2)
        value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_size).transpose(1, 2)

        # KV Cache
        kv_seq_len = key_states.size(2)
        if past_key_value is not None and self.layer_idx is not None:
            kv_seq_len += past_key_value.get_usable_length(
                new_seq_length=kv_seq_len,
                layer_idx=self.layer_idx,
            )

        query_states = self.rotary_emb(
            hidden_states=query_states,
            position_ids=position_ids,
        )
        key_states = self.rotary_emb(
            hidden_states=key_states,
            position_ids=position_ids,
        )

        if past_key_value is not None:
            key_states, value_states = past_key_value.update(
                key_states=key_states,
                value_states=value_states,
                layer_idx=self.layer_idx,
            )
        
        # Repeat kv heads
        key_states, value_states = repeat_kv(
            key_states=key_states,
            value_states=value_states,
            num_kv_groups=self.num_kv_groups,
        )

        # Contiguous
        query_states = query_states.contiguous()
        key_states = key_states.contiguous()
        value_states = value_states.contiguous()

        # SDPA
        attention_output = torch.nn.functional.scaled_dot_product_attention(
            query=query_states,
            key=key_states,
            value=value_states,
            attn_mask=attention_mask,
            dropout_p=self.attention_dropout,
            is_causal=self.is_causal and attention_mask is None and seq_len > 1,
        )

        attention_output = attention_output.transpose(1, 2).contiguous()
        attention_output = attention_output.view(batch_size, seq_len, hidden_size)
        attention_output = self.o_proj(attention_output)

        return attention_output, None, past_key_value

## MistralMLP

In [158]:
class MistralMLP(torch.nn.Module):
    def __init__(self, config: MistralConfig) -> None:
        super().__init__()
        self.gate_proj = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.intermediate_size,
            bias=False,
        )
        self.up_proj = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.intermediate_size,
            bias=False,
        )
        self.down_proj = torch.nn.Linear(
            in_features=config.intermediate_size,
            out_features=config.hidden_size,
            bias=False,
        )
        self.act_fn = torch.nn.functional.gelu

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        up_output = self.up_proj(x)
        gate_output = self.gate_proj(x)
        intermediate_output = self.act_fn(gate_output) + up_output
        down_output = self.down_proj(intermediate_output)
        return down_output

## MistralDecoderLayer

In [159]:
class MistralDecoderLayer(torch.nn.Module):
    def __init__(self, config: MistralConfig, layer_idx: int) -> None:
        super().__init__()
        self.self_attn = MistralSdpaAttention(config=config, layer_idx=layer_idx)
        self.mlp = MistralMLP(config=config)
        self.input_layernorm = MistralRMSNorm(config.hidden_size)
        self.post_attention_layernorm = MistralRMSNorm(config.hidden_size)

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attention_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )

        # Redisual connection
        hidden_states = hidden_states + residual

        # Fully connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attention_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs


In [160]:
def _my_prepare_4d_causal_attention_mask_for_sdpa(
    attention_mask: Optional[torch.Tensor],
    input_shape: Union[torch.Size, Tuple, List],
    inputs_embeds: torch.Tensor,
    past_key_values_length: int,
):
    """
    Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
    """
    batch_size, query_length = input_shape
    key_value_length = query_length + past_key_values_length

    if attention_mask is None:
        # Creating a causal mask for all positions
        mask = torch.full((query_length, key_value_length), torch.finfo(inputs_embeds.dtype).min, device=inputs_embeds.device)
        mask_cond = torch.arange(key_value_length, device=inputs_embeds.device)
        mask[:, :query_length] = (mask_cond[None, :] < (mask_cond[:query_length] + 1)[:, None]).float()
        expanded_4d_mask = mask[None, None, :, :].expand(batch_size, 1, query_length, key_value_length)
    elif len(attention_mask.shape) == 4:
        # If a 4D attention mask is already provided, use it directly
        expanded_4d_mask = attention_mask
    else:
        # Expanding a 2D mask to 4D
        expanded_mask = attention_mask[:, None, None, :]
        expanded_4d_mask = expanded_mask.expand(batch_size, 1, query_length, key_value_length).to(dtype=inputs_embeds.dtype)
        # Apply causal mask to the expanded mask
        causal_mask = torch.triu(torch.ones((query_length, key_value_length), device=inputs_embeds.device, dtype=torch.bool), diagonal=1)
        padding_mask = attention_mask == 0
        padding_mask = padding_mask.view(batch_size, 1, 1, key_value_length)
        expanded_4d_mask = expanded_4d_mask.masked_fill(~padding_mask, 0.)
        expanded_4d_mask = expanded_4d_mask.masked_fill(padding_mask, torch.finfo(inputs_embeds.dtype).min)
        expanded_4d_mask = expanded_4d_mask.masked_fill(causal_mask, torch.finfo(inputs_embeds.dtype).min)

    return expanded_4d_mask


## MistralModel

In [161]:
class MistralModel(torch.nn.Module):
    def __init__(self, config: MistralConfig):
        super().__init__()
        self.config = config
        self._attn_implementation = config._attn_implementation

        self.embed_tokens = torch.nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.hidden_size,
            padding_idx=config.pad_token_id,
        )
        self.layers = torch.nn.ModuleList([MistralDecoderLayer(config=config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)])
        self.norm = MistralRMSNorm(config.hidden_size)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        # Config
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        # Init
        batch_size, seq_length = input_ids.shape
        past_key_values_length = 0

        # If use cache
        if use_cache:
            use_legacy_cache = not isinstance(past_key_values, Cache)
            if use_legacy_cache:
                past_key_values = DynamicCache.from_legacy_cache(past_key_values=past_key_values)

            past_key_values_length = past_key_values.get_usable_length(seq_length)

        # Position ids
        if position_ids is None:
            position_ids = torch.arange(
                start=past_key_values_length,
                end=past_key_values_length + seq_length,
                dtype=torch.long,
                device=input_ids.device,
            )
        else:
            position_ids = position_ids.view(-1, seq_length).long()

        # Input embedding
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # Attention mask (output_attentions is not supported when using SDPA)
        if self._attn_implementation == "sdpa" and not output_attentions:
            attention_mask = _my_prepare_4d_causal_attention_mask_for_sdpa(
                attention_mask=attention_mask,
                input_shape=(batch_size, seq_length),
                inputs_embeds=inputs_embeds,
                past_key_values_length=past_key_values_length,
            )
        else:
            raise NotImplementedError(
                "_my_prepare_4d_causal_attention_mask() if not implemented for now.",
            )
        
        # Feed-Forward
        hidden_states = inputs_embeds
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        next_decoder_cache = None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            layer_outputs = decoder_layer(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )

            # No matter how many data returned, the first one is the `hidden_states`
            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attentions += (layer_outputs[1],)
        
        hidden_states = self.norm(hidden_states)

        # Add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = None
        if use_cache:
            next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )

## MistralModelForCausalLM

In [162]:
class MyMistralForCausalLM(torch.nn.Module):
    def __init__(self, config: MistralConfig) -> None:
        super().__init__()
        self.config = config
        self.model = MistralModel(config=config)
        self.lm_head = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.vocab_size,
            bias=False,
        )

        self.tie_weights()

    def tie_weights(self) -> None:
        self.lm_head.weight = self.model.embed_tokens.weight

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        # Settings
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)

        # decoder outputs
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,   
        )

        hidden_states = outputs.last_hidden_state
        logits = self.lm_head(hidden_states).float()

        # Loss
        loss = None
        if labels is not None:
            criterion = torch.nn.CrossEntropyLoss()

            # Shift
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            # Flatten
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)

            # Make sure they are on the same device
            shift_labels = shift_labels.to(shift_logits.device)
            loss = criterion(shift_logits, shift_labels)

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=past_key_values,
            hidden_states=hidden_states,
            attentions=attention_mask,
        )
    
    @staticmethod
    def from_pretrained(pretrained_model_name_or_path: str) -> "MyMistralForCausalLM":
        """Load pretrained weights from HuggingFace into model.
        
        Args:
            pretrained_model_name_or_path: One of
                * "echarlaix/tiny-random-mistral"
                ...

        Returns:
            model: MyMistralModelForCausalLM model with weights loaded
        """

        def load_state_dict_hf(path_or_repo_id: str) -> OrderedDict:
            resolved_archive_file = cached_file(
                path_or_repo_id=path_or_repo_id,
                filename=WEIGHTS_NAME,
            )
            return torch.load(resolved_archive_file, weights_only=True)

        # Load config
        config = MistralConfig.from_pretrained_model_or_path(pretrained_model_name_or_path=pretrained_model_name_or_path)

        # Load weights
        state_dict = load_state_dict_hf(pretrained_model_name_or_path)

        # Load model
        model = MyMistralForCausalLM(config=config)
        model.load_state_dict(state_dict=state_dict, strict=True)

        return model

In [163]:
custom_model = MyMistralForCausalLM.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)

## Test

In [164]:
tokenizer.pad_token_id = config.pad_token_id

In [165]:
texts = [
    "Today is a nice day.",
    "I want to go to play, do you want to join us?",
    "???",
]

inputs = tokenizer(texts, padding=True, return_tensors="pt")

In [166]:
custom_model(**inputs)

CausalLMOutputWithPast(loss=None, logits=tensor([[[ 0.0572,  0.3510,  0.1318,  ..., -0.2449, -0.0390, -0.0728],
         [ 0.0973,  0.0157,  0.1455,  ..., -0.0860, -0.1055, -0.1053],
         [ 0.1604,  0.1928,  0.0469,  ..., -0.0620,  0.0588, -0.0535],
         ...,
         [ 0.3447, -0.1340, -0.1853,  ...,  0.0474, -0.0168, -0.2040],
         [ 0.3447, -0.1342, -0.1855,  ...,  0.0477, -0.0167, -0.2039],
         [ 0.3446, -0.1344, -0.1857,  ...,  0.0479, -0.0166, -0.2038]],

        [[ 0.0572,  0.3510,  0.1318,  ..., -0.2449, -0.0390, -0.0728],
         [-0.0568, -0.2399,  0.0718,  ...,  0.1150, -0.1147, -0.0202],
         [-0.0531,  0.0850,  0.0348,  ..., -0.1261,  0.1661,  0.1647],
         ...,
         [ 0.1463, -0.1691, -0.1373,  ...,  0.1508, -0.0556,  0.0063],
         [-0.0097, -0.0065,  0.2922,  ..., -0.0257, -0.1324,  0.0456],
         [-0.0217, -0.0284,  0.1220,  ...,  0.0391, -0.1025, -0.2051]],

        [[ 0.0572,  0.3510,  0.1318,  ..., -0.2449, -0.0390, -0.0728],
    

In [168]:
model(**inputs)

CausalLMOutputWithPast(loss=None, logits=tensor([[[ 0.1035,  0.1514, -0.0029,  ..., -0.1748,  0.0884, -0.0410],
         [-0.0264,  0.1084,  0.2236,  ..., -0.1660, -0.0938, -0.0349],
         [-0.1338, -0.0393,  0.0269,  ...,  0.0515, -0.0090, -0.0048],
         ...,
         [ 0.0952,  0.0056,  0.0962,  ..., -0.0537, -0.0767, -0.0312],
         [ 0.0957,  0.0058,  0.0962,  ..., -0.0537, -0.0762, -0.0320],
         [ 0.0952,  0.0051,  0.0962,  ..., -0.0537, -0.0757, -0.0320]],

        [[ 0.1035,  0.1514, -0.0029,  ..., -0.1748,  0.0884, -0.0410],
         [-0.0010, -0.1348,  0.0942,  ..., -0.0315, -0.0618,  0.0996],
         [-0.1904,  0.0284,  0.0840,  ..., -0.0344, -0.0830,  0.0430],
         ...,
         [-0.1064, -0.1475,  0.0023,  ..., -0.1108,  0.0693, -0.0087],
         [-0.1084, -0.2188, -0.1719,  ...,  0.0064,  0.0845,  0.0054],
         [-0.1079,  0.0247,  0.1157,  ..., -0.0403, -0.0820,  0.0530]],

        [[ 0.1035,  0.1514, -0.0029,  ..., -0.1748,  0.0884, -0.0410],
    