In [1]:
from transformers import Qwen2ForCausalLM, Qwen2Model, AutoTokenizer
import torch

torch.set_grad_enabled(False)

model = Qwen2ForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct",
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)
model = model.eval()

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [16]:
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")

prompt = "Give me a short introduction to large language model..1"
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)

generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512)
generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response

'A large language model is a type of artificial intelligence that can generate human-like text, including written and spoken language. These models use massive amounts of data and complex algorithms to understand the context and meaning behind natural language input. Large language models have been trained on vast datasets such as Wikipedia, books, articles, and other texts, allowing them to process and generate coherent and meaningful responses. Some examples of popular large language models include GPT, BERT, and Qwen.'

In [17]:
model_output_1 = model(
    **model_inputs,
    output_hidden_states=True,
)
model_output_1.hidden_states[-1].shape

torch.Size([1, 30, 896])

In [18]:
model_output_2 = model(
    inputs_embeds=torch.cat(
        [
            model_output_1.hidden_states[-1],
            model_output_1.hidden_states[-1],
        ],
        dim=0,
    ),
    attention_mask=model_inputs["attention_mask"],
    output_hidden_states=True,
)
model_output_2.hidden_states[-1].shape

torch.Size([2, 30, 896])

In [19]:
model_output_2 = model(
    inputs_embeds=torch.cat(
        [
            model_output_1.hidden_states[-1],
            model_output_1.hidden_states[-1],
        ],
        dim=0,
    ),
    attention_mask=model_inputs["attention_mask"],
    output_hidden_states=True,
)

In [21]:
model_output_1.hidden_states[-1].reshape(1, 10, 3, -1).shape

torch.Size([1, 10, 3, 896])

In [23]:
# model(
#     inputs_embeds=model_output_1.hidden_states[-1].reshape(1, 10, 3, -1),
#     attention_mask=model_inputs["attention_mask"],
#     output_hidden_states=True,
# )

In [None]:
torch.tensor(
    [
        [1, 2],
        [
            1,
        ],
    ]
)

ValueError: expected sequence of length 2 at dim 1 (got 1)

In [35]:
class Qwen2ModelEmbedPooler(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config)
        self.post_init()

    def forward(self, input_embeds, attention_mask, window_size=3):
        # разбиваем входящие эмбединги на бакеты и усредняем их
        sum_mask = attention_mask.reshape(
            attention_mask.shape[0],
            window_size,
            attention_mask.shape[1] // window_size,
            -1,
        ).sum(1)
        embeds_sum = input_embeds.reshape(
            attention_mask.shape[0],
            window_size,
            attention_mask.shape[1] // window_size,
            -1,
        ).sum(1)
        return embeds_sum / sum_mask


embed_pooler = Qwen2ModelEmbedPooler.from_pretrained("Qwen/Qwen2.5-0.5B")
result = embed_pooler(
    # model_output_1.hidden_states[-1],
    torch.cat(
        [
            model_output_1.hidden_states[-1],
            model_output_1.hidden_states[-1],
        ],
        dim=0,
    ),
    torch.cat(
        [
            model_inputs["attention_mask"],
            model_inputs["attention_mask"],
        ],
        dim=0,
    ),
)
result.shape

torch.Size([2, 10, 896])

In [9]:
torch.tensor([1, 2, 3, 4, 5, 6]).reshape(3, -1)

tensor([[1, 2],
        [3, 4],
        [5, 6]])

In [None]:
from typing import Callable, List, Optional, Tuple, Union
from transformers.cache_utils import (
    Cache,
    DynamicCache,
    SlidingWindowCache,
    StaticCache,
)
from transformers.processing_utils import Unpack
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import LossKwargs
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
)


class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


class Qwen2ForCausalEmbedModeling(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = torch.nn.Linear(
            config.hidden_size,
            config.vocab_size,
            bias=False,
        )
        self.embed_pooler = Qwen2ModelEmbedPooler.from_pretrained("Qwen/Qwen2.5-1.5B")

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        # pooled_mask: Optional[torch.Tensor] = None,
        **kwargs: Unpack[KwargsForCausalLM],
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        inputs_embeds_tokens = self.model.embed_tokens(input_ids)
        # if pixel_values is not None:
        #     pixel_values = pixel_values.type(self.visual.get_dtype())
        #     image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
        #     n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
        #     n_image_features = image_embeds.shape[0]
        #     if n_image_tokens != n_image_features:
        #         raise ValueError(
        #             f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
        #         )
        #     image_mask = (
        #         (input_ids == self.config.image_token_id)
        #         .unsqueeze(-1)
        #         .expand_as(inputs_embeds)
        #         .to(inputs_embeds.device)
        #     )
        #     image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
        #     inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
        window_size = torch.tensor(
            [
                [4],
                [3],
            ]
        ).cuda()
        tokens_amount = torch.tensor(
            [
                [2],
                [4],
            ]
        ).cuda()
        lengths = window_size * tokens_amount
        # max_len = lengths.max()
        max_len = inputs_embeds_tokens.shape[1]
        batch_size = window_size.shape[0]
        pooled_mask = (
            torch.arange(max_len, device=device)
            .unsqueeze(0)
            .expand(batch_size, max_len)
        )
        pooled_mask = (lengths >= pooled_mask).to(torch.long)
        pooled_embeds = inputs_embeds * pooled_mask.to(inputs_embeds.dtype)
        pooled_embeds = self.embed_pooler(pooled_embeds, pooled_mask)
        embed_mask = (
            (input_ids == self.config.image_token_id)
            .unsqueeze(-1)
            .expand_as(inputs_embeds)
            .to(inputs_embeds.device)
        )
        inputs_embeds = inputs_embeds.masked_scatter(embed_mask, pooled_embeds)

        # Из-за смешанной структуры, мы будем всегда подавать только эмбединги
        # Идея позаимствована из qwen2vl
        outputs = self.model(
            input_ids=None,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs[0]
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = (
            slice(-logits_to_keep, None)
            if isinstance(logits_to_keep, int)
            else logits_to_keep
        )
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(
                logits=logits,
                labels=labels,
                vocab_size=self.config.vocab_size,
                **kwargs,
            )

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [None]:
tokenizer.encode("<|im_end|>")

[151645]

In [23]:
tokenizer.pad_token

'<|endoftext|>'

In [None]:
window_size = torch.tensor(
    [
        [4],
        [6],
    ]
).cuda()
tokens_amount = torch.tensor(
    [
        [2],
        [2],
    ]
).cuda()
lengths = window_size * tokens_amount
max_len = lengths.max()
batch_size = window_size.shape[0]
pooled_mask = (
    torch.arange(max_len, device=device).unsqueeze(0).expand(batch_size, max_len)
)
pooled_mask = (lengths >= pooled_mask).to(torch.long)
pooled_mask

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')

In [46]:
lengths

tensor([[ 8],
        [12]], device='cuda:0')