In [1]:
from transformers import AutoTokenizer
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "facebook/xglm-1.7B"
tokenizer = AutoTokenizer.from_pretrained(model_name)


In [3]:
model_input = tokenizer("hello world", return_tensors="pt")
model_input


{'input_ids': tensor([[     2, 113677,   3038]]), 'attention_mask': tensor([[1, 1, 1]])}

In [4]:
len(tokenizer)

256008

In [5]:
len(tokenizer) // 64 * 65

260000

In [6]:
from typing import List, Optional, Tuple

import torch
from torch import nn

import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

from einops import rearrange

from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input


# для того чтобы это заработало нужно открыть исходники и закоментировать все упоминания
# TORCH_CHECK из сурсов, а затем скомпилировать это
# искать в файле csrc/flash_attn/fmha_api.cpp
def flash_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    **other_keys
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    """Input shape: Batch x Time x Channel

    attention_mask: [bsz, q_len]
    """
    bsz, q_len, _ = hidden_states.size()

    query_states = (
        self.q_proj(hidden_states)
        .view(bsz, q_len, self.num_heads, self.head_dim)
        .transpose(1, 2)
    )
    key_states = (
        self.k_proj(hidden_states)
        .view(bsz, q_len, self.num_heads, self.head_dim)
        .transpose(1, 2)
    )
    value_states = (
        self.v_proj(hidden_states)
        .view(bsz, q_len, self.num_heads, self.head_dim)
        .transpose(1, 2)
    )
    # assert past_key_value is None, "past_key_value is not supported"
    # assert not output_attentions, "output_attentions is not supported"
    # assert not use_cache, "use_cache is not supported"

    # Flash attention codes from
    # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py

    # transform the data into the format required by flash attention
    qkv = torch.stack(
        [query_states, key_states, value_states], dim=2
    )  # [bsz, nh, 3, q_len, hd]
    qkv = qkv.transpose(1, 3)  # [bsz, q_len, 3, nh, hd]
    # We have disabled _prepare_decoder_attention_mask in LlamaModel
    # the attention_mask should be the same as the key_padding_mask
    key_padding_mask = attention_mask

    if key_padding_mask is None:
        qkv = rearrange(qkv, "b s ... -> (b s) ...")
        max_s = q_len
        cu_q_lens = torch.arange(
            0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
        )
        output = flash_attn_unpadded_qkvpacked_func(
            qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
        )
        output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
    else:
        nheads = qkv.shape[-2]
        x = rearrange(qkv, "b s three h d -> b s (three h d)")
        x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
        x_unpad = rearrange(
            x_unpad,
            "nnz (three h d) -> nnz three h d",
            three=3,
            h=nheads,
        )
        output_unpad = flash_attn_unpadded_qkvpacked_func(
            x_unpad,
            cu_q_lens,
            max_s,
            0.0,
            softmax_scale=None,
            causal=True,
        )
        output = rearrange(
            pad_input(
                rearrange(output_unpad, "nnz h d -> nnz (h d)"),
                indices,
                bsz,
                q_len,
            ),
            "b s (h d) -> b s h d",
            h=nheads,
        )
    return self.out_proj(rearrange(output, "b s h d -> b s (h d)")), None, None


# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
    self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
    # [bsz, seq_len]
    return attention_mask


In [7]:
# transformers.models.xglm.modeling_xglm.XGLMAttention.forward = flash_forward
transformers.models.xglm.modeling_xglm.XGLMModel._prepare_decoder_attention_mask = (
    _prepare_decoder_attention_mask
)


In [8]:
class NanoGPTXGLMAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )
        self.scaling = self.head_dim**-0.5
        self.is_decoder = is_decoder
        self.resid_dropout = nn.Dropout(self.dropout)

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return (
            tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
            .transpose(1, 2)
            .contiguous()
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        **other_parameters,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

        query_states = self.q_proj(hidden_states) * self.scaling
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        y = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=None,
            dropout_p=self.dropout if self.training else 0,
            is_causal=True,
        )
        y = self.resid_dropout(self.out_proj(y))
        return y, None, None


In [9]:
import math
from torch.nn import functional as F

# !XGLM ATTENTION
# def __init__(
# self,
# embed_dim: int,
# num_heads: int,
# dropout: float = 0.0,
# is_decoder: bool = False,
# bias: bool = True,
# ):
# super().__init__()
# self.embed_dim = embed_dim
# self.num_heads = num_heads
# self.dropout = dropout
# self.head_dim = embed_dim // num_heads

# if (self.head_dim * num_heads) != self.embed_dim:
#     raise ValueError(
#         f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
#         f" and `num_heads`: {num_heads})."
#     )
# self.scaling = self.head_dim**-0.5
# self.is_decoder = is_decoder

# self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
# self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
# self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
# self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)


def nanoGPT_forward(
    self,
    hidden_states: torch.Tensor,
):
    x = hidden_states
    B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

    # calculate query, key, values for all heads in batch and move head forward to be the batch dim
    q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
    k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
    q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
    v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

    # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
    # efficient attention using Flash Attention CUDA kernels
    y = torch.nn.functional.scaled_dot_product_attention(
        q,
        k,
        v,
        attn_mask=None,
        dropout_p=self.dropout if self.training else 0,
        is_causal=True,
    )
    y = (
        y.transpose(1, 2).contiguous().view(B, T, C)
    )  # re-assemble all head outputs side by side

    # output projection
    y = self.resid_dropout(self.c_proj(y))
    return y


In [10]:
transformers.models.xglm.modeling_xglm.XGLMAttention = NanoGPTXGLMAttention


In [11]:
from transformers import XGLMForCausalLM

model = XGLMForCausalLM.from_pretrained(model_name, device_map="auto")
model.resize_token_embeddings(len(tokenizer) // 64 * 65)
model.half()


XGLMForCausalLM(
  (model): XGLMModel(
    (embed_tokens): Embedding(260000, 2048)
    (embed_positions): XGLMSinusoidalPositionalEmbedding()
    (layers): ModuleList(
      (0-23): 24 x XGLMDecoderLayer(
        (self_attn): NanoGPTXGLMAttention(
          (resid_dropout): Dropout(p=0.1, inplace=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
        )
        (activation_fn): GELUActivation()
        (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=2048, out_features=8192, bias=True)
        (fc2): Linear(in_features=8192, out_features=2048, bias=True)
        (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      )
    )
    (layer_norm):

In [12]:
model(**model_input)


CausalLMOutputWithCrossAttentions(loss={'logits': tensor([[[-3.4326e-01, -4.6753e-01,  1.0950e+02,  ..., -1.0352e+00,
          -3.9038e-01,  6.0742e-01],
         [ 6.3818e-01,  8.8623e-01,  1.3825e+02,  ..., -1.2451e+00,
          -4.8169e-01,  7.6660e-01],
         [ 5.0859e+00,  4.7422e+00,  2.1538e+02,  ..., -1.4395e+00,
          -7.0435e-02, -6.3538e-02]]], dtype=torch.float16,
       grad_fn=<ToCopyBackward0>), 'past_key_values': (None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)}, logits=tensor([[[-3.4326e-01, -4.6753e-01,  1.0950e+02,  ..., -1.0352e+00,
          -3.9038e-01,  6.0742e-01],
         [ 6.3818e-01,  8.8623e-01,  1.3825e+02,  ..., -1.2451e+00,
          -4.8169e-01,  7.6660e-01],
         [ 5.0859e+00,  4.7422e+00,  2.1538e+02,  ..., -1.4395e+00,
          -7.0435e-02, -6.3538e-02]]], dtype=torch.float16,
       grad_fn=<ToCopyBackward0>), past_key_values=(None, None, No