In [1]:
from transformers import AutoTokenizer
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "facebook/xglm-1.7B"
tokenizer = AutoTokenizer.from_pretrained(model_name)


In [3]:
model_input = tokenizer("hello world", return_tensors="pt")
model_input


{'input_ids': tensor([[     2, 113677,   3038]]), 'attention_mask': tensor([[1, 1, 1]])}

In [4]:
from typing import List, Optional, Tuple

import torch
from torch import nn

import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

from einops import rearrange

from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input

# для того чтобы это заработало нужно открыть исходники и закоментировать все упоминания 
# TORCH_CHECK из сурсов, а затем скомпилировать это
# искать в файле csrc/flash_attn/fmha_api.cpp
def flash_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    **other_keys
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    """Input shape: Batch x Time x Channel

    attention_mask: [bsz, q_len]
    """
    bsz, q_len, _ = hidden_states.size()

    query_states = (
        self.q_proj(hidden_states)
        .view(bsz, q_len, self.num_heads, self.head_dim)
        .transpose(1, 2)
    )
    key_states = (
        self.k_proj(hidden_states)
        .view(bsz, q_len, self.num_heads, self.head_dim)
        .transpose(1, 2)
    )
    value_states = (
        self.v_proj(hidden_states)
        .view(bsz, q_len, self.num_heads, self.head_dim)
        .transpose(1, 2)
    )
    assert past_key_value is None, "past_key_value is not supported"
    assert not output_attentions, "output_attentions is not supported"
    assert not use_cache, "use_cache is not supported"

    # Flash attention codes from
    # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py

    # transform the data into the format required by flash attention
    qkv = torch.stack(
        [query_states, key_states, value_states], dim=2
    )  # [bsz, nh, 3, q_len, hd]
    qkv = qkv.transpose(1, 3)  # [bsz, q_len, 3, nh, hd]
    # We have disabled _prepare_decoder_attention_mask in LlamaModel
    # the attention_mask should be the same as the key_padding_mask
    key_padding_mask = attention_mask

    if key_padding_mask is None:
        qkv = rearrange(qkv, "b s ... -> (b s) ...")
        max_s = q_len
        cu_q_lens = torch.arange(
            0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
        )
        output = flash_attn_unpadded_qkvpacked_func(
            qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
        )
        output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
    else:
        nheads = qkv.shape[-2]
        x = rearrange(qkv, "b s three h d -> b s (three h d)")
        x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
        x_unpad = rearrange(
            x_unpad,
            "nnz (three h d) -> nnz three h d",
            three=3,
            h=nheads,
        )
        output_unpad = flash_attn_unpadded_qkvpacked_func(
            x_unpad,
            cu_q_lens,
            max_s,
            0.0,
            softmax_scale=None,
            causal=True,
        )
        output = rearrange(
            pad_input(
                rearrange(output_unpad, "nnz h d -> nnz (h d)"),
                indices,
                bsz,
                q_len,
            ),
            "b s (h d) -> b s h d",
            h=nheads,
        )
    return self.out_proj(rearrange(output, "b s h d -> b s (h d)")), None, None


# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
    self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
    # [bsz, seq_len]
    return attention_mask


In [5]:
transformers.models.xglm.modeling_xglm.XGLMAttention.forward = flash_forward
transformers.models.xglm.modeling_xglm.XGLMModel._prepare_decoder_attention_mask = (
    _prepare_decoder_attention_mask
)


In [8]:
from transformers import XGLMForCausalLM

model = XGLMForCausalLM.from_pretrained(model_name, device_map="auto")
model.half()

XGLMForCausalLM(
  (model): XGLMModel(
    (embed_tokens): Embedding(256008, 2048, padding_idx=1)
    (embed_positions): XGLMSinusoidalPositionalEmbedding()
    (layers): ModuleList(
      (0-23): 24 x XGLMDecoderLayer(
        (self_attn): XGLMAttention(
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
        )
        (activation_fn): GELUActivation()
        (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=2048, out_features=8192, bias=True)
        (fc2): Linear(in_features=8192, out_features=2048, bias=True)
        (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      )
    )
    (layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine

In [12]:
model(**model_input)


CausalLMOutputWithCrossAttentions(loss={'logits': tensor([[[ -0.3433,  -0.4675, 109.5000,  ...,  -0.7705,  -1.8184,  -1.6680],
         [  3.4414,   3.0742, 234.7500,  ...,   1.6592,  -0.4905,   1.4746],
         [  3.8652,   3.6523, 241.8750,  ...,   2.3770,   0.2998,   2.6523]]],
       dtype=torch.float16, grad_fn=<ToCopyBackward0>), 'past_key_values': (None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)}, logits=tensor([[[ -0.3433,  -0.4675, 109.5000,  ...,  -0.7705,  -1.8184,  -1.6680],
         [  3.4414,   3.0742, 234.7500,  ...,   1.6592,  -0.4905,   1.4746],
         [  3.8652,   3.6523, 241.8750,  ...,   2.3770,   0.2998,   2.6523]]],
       dtype=torch.float16, grad_fn=<ToCopyBackward0>), past_key_values=(None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None), hidden_states=None, attentions=None, 