In [12]:
from transformers import Qwen2ForCausalLM, Qwen2Model, AutoTokenizer
import torch

torch.set_grad_enabled(False)
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
model = Qwen2ForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)
model = model.eval()

In [None]:
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "Give me a short introduction to large language model..1"
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)

generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512)
generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response

In [3]:
model_output_1 = model(
    **model_inputs,
    output_hidden_states=True,
)
model_output_1.hidden_states[-1].shape

torch.Size([1, 30, 896])

In [4]:
model_output_2 = model(
    inputs_embeds=torch.cat(
        [
            model_output_1.hidden_states[-1],
            model_output_1.hidden_states[-1],
        ],
        dim=0,
    ),
    attention_mask=model_inputs["attention_mask"],
    output_hidden_states=True,
)
model_output_2.hidden_states[-1].shape

torch.Size([2, 30, 896])

In [5]:
model_output_1.hidden_states[-1].shape

torch.Size([1, 30, 896])

In [6]:
model_output_1.hidden_states[-1].reshape(10, 3, -1).shape

torch.Size([10, 3, 896])

In [7]:
model_output_2 = model(
    inputs_embeds=torch.cat(
        [
            model_output_1.hidden_states[-1],
            model_output_1.hidden_states[-1],
        ],
        dim=0,
    ),
    attention_mask=model_inputs["attention_mask"],
    output_hidden_states=True,
)
model_output_2.hidden_states[-1].shape

torch.Size([2, 30, 896])

In [8]:
model_output_3 = model(
    inputs_embeds=model_output_1.hidden_states[-1].reshape(10, 3, -1),
    # attention_mask=model_inputs["attention_mask"],
    output_hidden_states=True,
)
model_output_3.hidden_states[-1].shape

torch.Size([10, 3, 896])

In [6]:
model_output_1.hidden_states[-1].reshape(1, 10, 3, -1).shape

torch.Size([1, 10, 3, 896])

In [23]:
# model(
#     inputs_embeds=model_output_1.hidden_states[-1].reshape(1, 10, 3, -1),
#     attention_mask=model_inputs["attention_mask"],
#     output_hidden_states=True,
# )

In [36]:
class Qwen2ModelEmbedPooler(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config).cuda()
        self.post_init()

    def forward(self, input_embeds, attention_mask, window_size=3):
        # разбиваем входящие эмбединги на бакеты и усредняем их
        sum_mask = attention_mask.reshape(
            attention_mask.shape[0],
            window_size,
            attention_mask.shape[1] // window_size,
            -1,
        ).sum(1)
        embeds_sum = input_embeds.reshape(
            attention_mask.shape[0],
            window_size,
            attention_mask.shape[1] // window_size,
            -1,
        ).sum(1)
        input_embeds = embeds_sum / sum_mask
        input_embeds = self.model(
            inputs_embeds=input_embeds,
            output_hidden_states=True,
        )
        return input_embeds


embed_pooler = Qwen2ModelEmbedPooler.from_pretrained("Qwen/Qwen2.5-0.5B")
result = embed_pooler(
    # model_output_1.hidden_states[-1],
    torch.cat(
        [
            model_output_1.hidden_states[-1],
            model_output_1.hidden_states[-1],
        ],
        dim=0,
    ),
    torch.cat(
        [
            model_inputs["attention_mask"],
            model_inputs["attention_mask"],
        ],
        dim=0,
    ),
)
result[0].shape

torch.Size([2, 10, 896])

In [15]:
result.last_hidden_state.shape

torch.Size([2, 10, 896])

In [9]:
torch.tensor([1, 2, 3, 4, 5, 6]).reshape(3, -1)

tensor([[1, 2],
        [3, 4],
        [5, 6]])

In [None]:
from typing import Callable, List, Optional, Tuple, Union
from transformers.cache_utils import (
    Cache,
    DynamicCache,
    SlidingWindowCache,
    StaticCache,
)
from transformers.processing_utils import Unpack
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import LossKwargs
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
)


class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


class Qwen2ForCausalEmbedModeling(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = torch.nn.Linear(
            config.hidden_size,
            config.vocab_size,
            bias=False,
        )
        self.embed_pooler = Qwen2ModelEmbedPooler.from_pretrained("Qwen/Qwen2.5-1.5B")

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        # pooled_mask: Optional[torch.Tensor] = None,
        **kwargs: Unpack[KwargsForCausalLM],
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        inputs_embeds_tokens = self.model.embed_tokens(input_ids)
        # if pixel_values is not None:
        #     pixel_values = pixel_values.type(self.visual.get_dtype())
        #     image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
        #     n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
        #     n_image_features = image_embeds.shape[0]
        #     if n_image_tokens != n_image_features:
        #         raise ValueError(
        #             f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
        #         )
        #     image_mask = (
        #         (input_ids == self.config.image_token_id)
        #         .unsqueeze(-1)
        #         .expand_as(inputs_embeds)
        #         .to(inputs_embeds.device)
        #     )
        #     image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
        #     inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
        window_size = torch.tensor(
            [
                [4],
                [3],
            ]
        ).cuda()
        tokens_amount = torch.tensor(
            [
                [2],
                [4],
            ]
        ).cuda()
        lengths = window_size * tokens_amount
        # max_len = lengths.max()
        max_len = inputs_embeds_tokens.shape[1]
        batch_size = window_size.shape[0]
        pooled_mask = (
            torch.arange(max_len, device=device)
            .unsqueeze(0)
            .expand(batch_size, max_len)
        )
        pooled_mask = (lengths >= pooled_mask).to(torch.long)
        pooled_embeds = inputs_embeds * pooled_mask.to(inputs_embeds.dtype)
        pooled_embeds = self.embed_pooler(pooled_embeds, pooled_mask)
        embed_mask = (
            (input_ids == self.config.image_token_id)
            .unsqueeze(-1)
            .expand_as(inputs_embeds)
            .to(inputs_embeds.device)
        )
        inputs_embeds = inputs_embeds.masked_scatter(embed_mask, pooled_embeds)

        # Из-за смешанной структуры, мы будем всегда подавать только эмбединги
        # Идея позаимствована из qwen2vl
        outputs = self.model(
            input_ids=None,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs[0]
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = (
            slice(-logits_to_keep, None)
            if isinstance(logits_to_keep, int)
            else logits_to_keep
        )
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(
                logits=logits,
                labels=labels,
                vocab_size=self.config.vocab_size,
                **kwargs,
            )

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

## Использование концепции qwen2vl для текстовых токенов и эмбедингов

In [None]:
text_example = "Deep learning, a subfield of machine learning, has revolutionized the landscape of artificial intelligence in recent years. From self-driving cars to personalized medicine, its applications are becoming increasingly pervasive. But what exactly is deep learning? And what makes it so powerful?"

tokenizer.encode(text_example)[:10]

[33464, 6832, 11, 264, 1186, 2566, 315, 5662, 6832, 11]

In [None]:
tokenizer.encode("<|image_pad|>")

[151655]

In [None]:
tokenizer.encode("text_example")

### 1. Случай первый - мы сжимаем только текстовые токены

In [29]:
window_size = 3
tokens_amount = 4
new_tokens = []
original_tokens = tokenizer.encode(text_example)
new_tokens += tokenizer.encode("<|image_pad|>") * tokens_amount
new_tokens += original_tokens[tokens_amount * window_size :]
tokenizer.decode(new_tokens)

'<|image_pad|><|image_pad|><|image_pad|><|image_pad|>ized the landscape of artificial intelligence in recent years. From self-driving cars to personalized medicine, its applications are becoming increasingly pervasive. But what exactly is deep learning? And what makes it so powerful?'

In [None]:
original_tokens_torch = torch.tensor(
    [
        original_tokens,
        original_tokens,
    ],
    device="cuda",
)
new_tokens_torch = torch.tensor(
    [
        new_tokens,
        new_tokens,
    ],
    device="cuda",
)
# torch.Size([2, 51, 896])
original_embeds = model.get_input_embeddings()(original_tokens_torch)
# torch.Size([2, 43, 896]) 51 - 3*4 + 4
compressed_embeds_template = model.get_input_embeddings()(new_tokens_torch)
compressed_embeds_template.shape

torch.Size([2, 43, 896])

In [None]:
class Qwen2ModelEmbedPoolerV2(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config).cuda()
        self.post_init()

    def forward(self, input_embeds):
        input_embeds = self.model(
            inputs_embeds=input_embeds,
            output_hidden_states=True,
        )[0]
        input_embeds = input_embeds.sum(1) / torch.tensor(
            input_embeds.shape[1],
            device=input_embeds.device,
        )
        input_embeds = input_embeds.unsqueeze(1)
        return input_embeds


embed_pooler_v2 = Qwen2ModelEmbedPoolerV2.from_pretrained("Qwen/Qwen2.5-0.5B")

# torch.Size([4, 3, 896])
compressed_embeds = original_embeds[:, : tokens_amount * window_size].reshape(
    tokens_amount * original_embeds.shape[0],
    window_size,
    -1,
)
compressed_embeds.shape
# torch.Size([8, 1, 896])
pooled_embeds = embed_pooler_v2(compressed_embeds)
# torch.Size([2, 4, 896])
pooled_embeds = pooled_embeds.reshape(
    original_embeds.shape[0],
    tokens_amount,
    -1,
)
# torch.Size([2, 43, 896])
compressed_embeds_template[:, :tokens_amount] = pooled_embeds
compressed_embeds_template.shape

torch.Size([2, 43, 896])

In [67]:
labels = new_tokens_torch.clone()
image_token_id = tokenizer.encode("<|image_pad|>")[0]
labels[labels == image_token_id] = -100
labels

tensor([[ -100,  -100,  -100,  -100,  1506,   279, 18414,   315, 20443, 11229,
           304,  3213,  1635,    13,  5542,   656, 59711,  9331,   311, 34549,
         15712,    11,  1181,  8357,   525, 10454, 14756, 70767,    13,  1988,
          1128,  6896,   374,  5538,  6832,    30,  1597,  1128,  3643,   432,
           773,  7988,    30],
        [ -100,  -100,  -100,  -100,  1506,   279, 18414,   315, 20443, 11229,
           304,  3213,  1635,    13,  5542,   656, 59711,  9331,   311, 34549,
         15712,    11,  1181,  8357,   525, 10454, 14756, 70767,    13,  1988,
          1128,  6896,   374,  5538,  6832,    30,  1597,  1128,  3643,   432,
           773,  7988,    30]], device='cuda:0')

In [None]:
model(inputs_embeds=compressed_embeds_template, labels=labels)

CausalLMOutputWithPast(loss=tensor(8.7477, device='cuda:0'), logits=tensor([[[ 8.4375,  9.6875,  4.1875,  ..., -5.3438, -5.3438, -5.3438],
         [14.1250, 14.8750, 11.6875,  ..., -6.8438, -6.8438, -6.8438],
         [13.6250, 13.9375,  9.5000,  ..., -4.8125, -4.8125, -4.8125],
         ...,
         [11.3750,  3.8125,  2.5625,  ..., -2.4219, -2.4219, -2.4219],
         [11.0625,  2.3281,  1.5625,  ..., -2.6875, -2.6875, -2.6875],
         [ 5.7812, -0.7461, -0.9531,  ..., -3.8125, -3.8125, -3.8125]],

        [[ 8.4375,  9.6875,  4.1875,  ..., -5.3438, -5.3438, -5.3438],
         [14.1250, 14.8750, 11.6875,  ..., -6.8438, -6.8438, -6.8438],
         [13.6250, 13.9375,  9.5000,  ..., -4.8125, -4.8125, -4.8125],
         ...,
         [11.3750,  3.8125,  2.5625,  ..., -2.4219, -2.4219, -2.4219],
         [11.0625,  2.3281,  1.5625,  ..., -2.6875, -2.6875, -2.6875],
         [ 5.7812, -0.7461, -0.9531,  ..., -3.8125, -3.8125, -3.8125]]],
       device='cuda:0', dtype=torch.bfloat16), p