In [1]:
from transformers import Qwen2ForCausalLM, Qwen2Model, AutoTokenizer
import torch

torch.set_grad_enabled(False)
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
model = Qwen2ForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)
model = model.eval()

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [2]:
model.get_input_embeddings().weight.dtype, model.dtype

(torch.bfloat16, torch.bfloat16)

In [2]:
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "Give me a short introduction to large language model..1"
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)

generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512)
generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


'A large language model is a type of artificial intelligence that can generate human-like text, such as written or spoken words, based on input data. These models are typically trained using large amounts of data and have been shown to perform well in a variety of tasks, including generating new content, answering questions, and performing natural language processing tasks.'

In [4]:
model_output_1 = model(
    **model_inputs,
    output_hidden_states=True,
)
model_output_1.hidden_states[-1].shape, model_output_1.hidden_states[-1].dtype

(torch.Size([1, 30, 896]), torch.bfloat16)

In [5]:
model_output_2 = model(
    inputs_embeds=torch.cat(
        [
            model_output_1.hidden_states[-1],
            model_output_1.hidden_states[-1],
        ],
        dim=0,
    ),
    attention_mask=model_inputs["attention_mask"],
    output_hidden_states=True,
)
model_output_2.hidden_states[-1].shape

torch.Size([2, 30, 896])

In [8]:
model_output_1.hidden_states[-1].shape

torch.Size([1, 30, 896])

In [9]:
model_output_1.hidden_states[-1].reshape(10, 3, -1).shape

torch.Size([10, 3, 896])

In [6]:
model_output_2 = model(
    inputs_embeds=torch.cat(
        [
            model_output_1.hidden_states[-1],
            model_output_1.hidden_states[-1],
        ],
        dim=0,
    ),
    attention_mask=model_inputs["attention_mask"],
    output_hidden_states=True,
)
model_output_2.hidden_states[-1].shape, model_output_2.hidden_states[-1].dtype

(torch.Size([2, 30, 896]), torch.bfloat16)

In [7]:
model_output_3 = model(
    inputs_embeds=model_output_1.hidden_states[-1].reshape(10, 3, -1),
    # attention_mask=model_inputs["attention_mask"],
    output_hidden_states=True,
)
model_output_3.hidden_states[-1].shape

torch.Size([10, 3, 896])

In [13]:
model_output_1.hidden_states[-1].reshape(1, 10, 3, -1).shape

torch.Size([1, 10, 3, 896])

In [12]:
# model(
#     inputs_embeds=model_output_1.hidden_states[-1].reshape(1, 10, 3, -1),
#     attention_mask=model_inputs["attention_mask"],
#     output_hidden_states=True,
# )

In [8]:
class Qwen2ModelEmbedPooler(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config).cuda()
        self.post_init()

    def forward(self, input_embeds, attention_mask, window_size=3):
        # разбиваем входящие эмбединги на бакеты и усредняем их
        sum_mask = attention_mask.reshape(
            attention_mask.shape[0],
            window_size,
            attention_mask.shape[1] // window_size,
            -1,
        ).sum(1)
        embeds_sum = input_embeds.reshape(
            attention_mask.shape[0],
            window_size,
            attention_mask.shape[1] // window_size,
            -1,
        ).sum(1)
        input_embeds = embeds_sum / sum_mask
        input_embeds = self.model(
            inputs_embeds=input_embeds,
            output_hidden_states=True,
        )
        return input_embeds


embed_pooler = Qwen2ModelEmbedPooler.from_pretrained("Qwen/Qwen2.5-0.5B")
result = embed_pooler(
    # model_output_1.hidden_states[-1],
    torch.cat(
        [
            model_output_1.hidden_states[-1],
            model_output_1.hidden_states[-1],
        ],
        dim=0,
    ),
    torch.cat(
        [
            model_inputs["attention_mask"],
            model_inputs["attention_mask"],
        ],
        dim=0,
    ),
)
result[0].shape

torch.Size([2, 10, 896])

In [15]:
result.last_hidden_state.shape

torch.Size([2, 10, 896])

In [9]:
torch.tensor([1, 2, 3, 4, 5, 6]).reshape(3, -1)

tensor([[1, 2],
        [3, 4],
        [5, 6]])

In [None]:
from typing import Callable, List, Optional, Tuple, Union
from transformers.cache_utils import (
    Cache,
    DynamicCache,
    SlidingWindowCache,
    StaticCache,
)
from transformers.processing_utils import Unpack
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import LossKwargs
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
)


class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


class Qwen2ForCausalEmbedModeling(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = torch.nn.Linear(
            config.hidden_size,
            config.vocab_size,
            bias=False,
        )
        self.embed_pooler = Qwen2ModelEmbedPooler.from_pretrained("Qwen/Qwen2.5-1.5B")

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        # pooled_mask: Optional[torch.Tensor] = None,
        **kwargs: Unpack[KwargsForCausalLM],
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        inputs_embeds_tokens = self.model.embed_tokens(input_ids)
        # if pixel_values is not None:
        #     pixel_values = pixel_values.type(self.visual.get_dtype())
        #     image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
        #     n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
        #     n_image_features = image_embeds.shape[0]
        #     if n_image_tokens != n_image_features:
        #         raise ValueError(
        #             f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
        #         )
        #     image_mask = (
        #         (input_ids == self.config.image_token_id)
        #         .unsqueeze(-1)
        #         .expand_as(inputs_embeds)
        #         .to(inputs_embeds.device)
        #     )
        #     image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
        #     inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
        window_size = torch.tensor(
            [
                [4],
                [3],
            ]
        ).cuda()
        tokens_amount = torch.tensor(
            [
                [2],
                [4],
            ]
        ).cuda()
        lengths = window_size * tokens_amount
        # max_len = lengths.max()
        max_len = inputs_embeds_tokens.shape[1]
        batch_size = window_size.shape[0]
        pooled_mask = (
            torch.arange(max_len, device=device)
            .unsqueeze(0)
            .expand(batch_size, max_len)
        )
        pooled_mask = (lengths >= pooled_mask).to(torch.long)
        pooled_embeds = inputs_embeds * pooled_mask.to(inputs_embeds.dtype)
        pooled_embeds = self.embed_pooler(pooled_embeds, pooled_mask)
        embed_mask = (
            (input_ids == self.config.image_token_id)
            .unsqueeze(-1)
            .expand_as(inputs_embeds)
            .to(inputs_embeds.device)
        )
        inputs_embeds = inputs_embeds.masked_scatter(embed_mask, pooled_embeds)

        # Из-за смешанной структуры, мы будем всегда подавать только эмбединги
        # Идея позаимствована из qwen2vl
        outputs = self.model(
            input_ids=None,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs[0]
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = (
            slice(-logits_to_keep, None)
            if isinstance(logits_to_keep, int)
            else logits_to_keep
        )
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(
                logits=logits,
                labels=labels,
                vocab_size=self.config.vocab_size,
                **kwargs,
            )

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

## Использование концепции qwen2vl для текстовых токенов и эмбедингов

In [9]:
text_example = "Deep learning, a subfield of machine learning, has revolutionized the landscape of artificial intelligence in recent years. From self-driving cars to personalized medicine, its applications are becoming increasingly pervasive. But what exactly is deep learning? And what makes it so powerful?"

tokenizer.encode(text_example)[:10]

[33464, 6832, 11, 264, 1186, 2566, 315, 5662, 6832, 11]

In [10]:
tokenizer.encode("<|image_pad|>")

[151655]

In [11]:
tokenizer.encode("text_example")

[1318, 39304]

### 1. Случай первый - мы сжимаем только текстовые токены (максимально упрощенный вариант)

In [12]:
window_size = 3
chunk_size = 4
new_tokens = []
original_tokens = tokenizer.encode(text_example)
new_tokens += tokenizer.encode("<|object_ref_start|>") * chunk_size
new_tokens += original_tokens[chunk_size * window_size :]
tokenizer.decode(new_tokens)

'<|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|>ized the landscape of artificial intelligence in recent years. From self-driving cars to personalized medicine, its applications are becoming increasingly pervasive. But what exactly is deep learning? And what makes it so powerful?'

In [13]:
tokenizer.decode(original_tokens)

'Deep learning, a subfield of machine learning, has revolutionized the landscape of artificial intelligence in recent years. From self-driving cars to personalized medicine, its applications are becoming increasingly pervasive. But what exactly is deep learning? And what makes it so powerful?'

In [14]:
original_tokens_torch = torch.tensor(
    [
        original_tokens,
        original_tokens,
    ],
    device="cuda",
)
new_tokens_torch = torch.tensor(
    [
        new_tokens,
        new_tokens,
    ],
    device="cuda",
)
# torch.Size([2, 51, 896])
original_embeds = model.get_input_embeddings()(original_tokens_torch)
# torch.Size([2, 43, 896]) 51 - 3*4 + 4
compressed_embeds_template = model.get_input_embeddings()(new_tokens_torch)
compressed_embeds_template.shape, compressed_embeds_template.dtype, original_embeds.dtype

(torch.Size([2, 43, 896]), torch.bfloat16, torch.bfloat16)

In [15]:
class Qwen2ModelEmbedPoolerV2(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config)
        self.post_init()

    def forward(self, input_embeds):
        # print(input_embeds.dtype)
        input_embeds = self.model(
            inputs_embeds=input_embeds,
            output_hidden_states=True,
        )[0]
        # print(input_embeds.dtype)
        input_embeds = input_embeds.sum(1) / torch.tensor(
            input_embeds.shape[1],
            device=input_embeds.device,
        )
        # print(input_embeds.dtype)
        input_embeds = input_embeds.unsqueeze(1)
        return input_embeds


embed_pooler_v2 = Qwen2ModelEmbedPoolerV2.from_pretrained(
    "Qwen/Qwen2.5-0.5B",
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)

# torch.Size([8, 3, 896])
compressed_embeds = original_embeds[:, : chunk_size * window_size].reshape(
    chunk_size * original_embeds.shape[0],
    window_size,
    -1,
)
compressed_embeds.shape
# torch.Size([8, 1, 896])
pooled_embeds = embed_pooler_v2(compressed_embeds)
# torch.Size([2, 4, 896])
pooled_embeds = pooled_embeds.reshape(
    original_embeds.shape[0],
    chunk_size,
    -1,
)
# torch.Size([2, 43, 896])
compressed_embeds_template[:, :chunk_size] = pooled_embeds
compressed_embeds_template.shape, compressed_embeds_template.dtype, pooled_embeds.dtype

(torch.Size([2, 43, 896]), torch.bfloat16, torch.bfloat16)

In [16]:
embed_pooler_v2.model.dtype

torch.bfloat16

In [17]:
labels = new_tokens_torch.clone()
text_token_id = tokenizer.encode("<|object_ref_start|>")[0]
labels[labels == text_token_id] = -100
labels

tensor([[ -100,  -100,  -100,  -100,  1506,   279, 18414,   315, 20443, 11229,
           304,  3213,  1635,    13,  5542,   656, 59711,  9331,   311, 34549,
         15712,    11,  1181,  8357,   525, 10454, 14756, 70767,    13,  1988,
          1128,  6896,   374,  5538,  6832,    30,  1597,  1128,  3643,   432,
           773,  7988,    30],
        [ -100,  -100,  -100,  -100,  1506,   279, 18414,   315, 20443, 11229,
           304,  3213,  1635,    13,  5542,   656, 59711,  9331,   311, 34549,
         15712,    11,  1181,  8357,   525, 10454, 14756, 70767,    13,  1988,
          1128,  6896,   374,  5538,  6832,    30,  1597,  1128,  3643,   432,
           773,  7988,    30]], device='cuda:0')

In [18]:
model_output_1 = model(
    inputs_embeds=compressed_embeds_template,
    labels=labels,
    output_hidden_states=True,
)
model_output_1.loss

tensor(8.7288, device='cuda:0')

### 2. Усложненная версия токенизации текста. Сжатые токены возникают между обычным тектом

Входная последовательность может быть 5 видов.

- когда у нас на входе только текст, который мы просто моделируем 
- когда на входе текст, но мы хотим сжать его некоторые части
- когда на входе текст и эмбединги с последнего слоя, где мы хотим сжать только части текста
- когда на входе текст и эмбединги с последнего слоя, где мы хотим сжать, часть текста и эмбедингов, совместно (токены переводим в эмбединги и сжимаем вместе с hidden states)
- когда на входе текст и эмбединги с последнего слоя, где мы хотим сжать, только эмбединги

Как мы решаем какие токены хотим сжать?
- никак. Просто говорим сожми с такого по такой.
- это норм только на этапе обучения, так как мы можем перебрать все комбинации.
- не норм на этапе инференса. например мы можем сжимать после каждых сгенеренных 10 токенов или 100, 1000. а какое окно контекста? 3, 10, 100? А что если менять стратегию. Сначала мы сжимали с окном 5 токенов, потом 20?
- Кажется что эти гиперпараметры можно найти простым перебором на валидации. Однако перебор стратегии уже не кажется таким очевидным. Напрашивается RL.


In [19]:
text_example = [
    "Deep learning, a subfield of machine learning, has revolutionized the landscape of artificial intelligence in recent years. From self-driving cars to personalized medicine, its applications are becoming increasingly pervasive. But what exactly is deep learning? And what makes it so powerful? ",
    """Before diving into the "deep" part, let's establish a foundation with the basics of artificial neural networks (ANNs). An ANN consists of interconnected nodes, called neurons, organized in layers. These layers typically include: These layers typically include: These layers typically include: networks (ANNs). An ANN consists of interconnected nodes, called neurons, organized in layers. These layers typically include: These layers typically include: These layers typically include: networks (ANNs). An ANN consists of interconnected nodes, called neurons, organized in layers. These layers typically include: These layers typically include: These layers typically include:""",
]

tokenizer.batch_encode_plus(
    text_example,
    padding=True,
)

{'input_ids': [[33464, 6832, 11, 264, 1186, 2566, 315, 5662, 6832, 11, 702, 13791, 1506, 279, 18414, 315, 20443, 11229, 304, 3213, 1635, 13, 5542, 656, 59711, 9331, 311, 34549, 15712, 11, 1181, 8357, 525, 10454, 14756, 70767, 13, 1988, 1128, 6896, 374, 5538, 6832, 30, 1597, 1128, 3643, 432, 773, 7988, 30, 220, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643], [10227, 42415, 1119, 279, 330, 32880, 1, 949, 11, 1077, 594, 5695, 264, 16266, 448, 279, 31774, 315, 20443, 29728, 14155, 320, 

In [20]:
torch.tensor(
    tokenizer.batch_encode_plus(
        text_example,
        padding=True,
    )["input_ids"]
).shape

torch.Size([2, 122])

In [21]:
tokenizer.decode(151643)

'<|endoftext|>'

In [None]:
tokenizer

In [22]:
from pprint import pprint
from more_itertools import chunked
import numpy as np
import random

text_token_id = tokenizer.encode("<|object_ref_start|>")[0]
eos_token_id = tokenizer.encode("<|endoftext|>")[0]

window_size = 3
chunk_size = 4
new_tokens = []
original_tokens = tokenizer.batch_encode_plus(
    text_example,
    padding=True,
)
compressed_tokens = []
replaced_original_tokens_batch = []
# original_tokens
for tokens in original_tokens["input_ids"]:
    original_lines = np.array(tokens)
    pure_tokens = original_lines[original_lines != eos_token_id].tolist()
    print(pure_tokens)
    full_chunks_amount = len(pure_tokens) // (window_size * chunk_size)
    print(full_chunks_amount)
    max_percent = 0.8
    pure_tokens_chunks = list(chunked(pure_tokens, window_size * chunk_size))
    print(pure_tokens_chunks)
    prob = 0.3
    random_mask = np.random.random(int(full_chunks_amount * max_percent))
    mask = random_mask < prob
    chunks_for_tokenization = np.where(mask)[0].tolist()
    chunks_for_tokenization = set(chunks_for_tokenization)
    if len(chunks_for_tokenization) == 0:
        chunks_for_tokenization = set(
            [
                random.randint(
                    0,
                    int(full_chunks_amount * max_percent),
                )
            ]
        )
    print("chunks_for_tokenization ", chunks_for_tokenization)
    replaced_original_tokens = []
    new_input_tokens = []
    for i, tokens in enumerate(pure_tokens_chunks):
        if i in chunks_for_tokenization:
            replaced_original_tokens.extend([text_token_id] * len(tokens))
            new_input_tokens.extend([text_token_id] * chunk_size)
        else:
            replaced_original_tokens.extend(tokens)
            new_input_tokens.extend(tokens)
    print("=== replaced_original_tokens")
    print(tokenizer.decode(replaced_original_tokens))
    print("=== new_input_tokens")
    print(tokenizer.decode(new_input_tokens))
    print(len(new_input_tokens))
    compressed_tokens.append(new_input_tokens)
    replaced_original_tokens_batch.append(replaced_original_tokens)
    print("==")
    print("==")
    print("==")

compressed_tokens_attention = []
max_compressed_len = max([len(item) for item in compressed_tokens])
max_replaced_len = max([len(item) for item in replaced_original_tokens_batch])

for compressed_seq, replaced_seq in zip(
    compressed_tokens,
    replaced_original_tokens_batch,
):
    compressed_seq_len = len(compressed_seq)
    replaced_seq_len = len(replaced_seq)
    attention_mask = [1] * (compressed_seq_len)

    if compressed_seq_len < max_compressed_len:
        compressed_seq += [eos_token_id] * (max_compressed_len - compressed_seq_len)
        attention_mask += [0] * (max_compressed_len - compressed_seq_len)

    if compressed_seq_len < max_replaced_len:
        replaced_seq += [eos_token_id] * (max_replaced_len - replaced_seq_len)

    compressed_tokens_attention.append(attention_mask)
# len(compressed_tokens[0]), len(compressed_tokens[1])
# compressed_tokens = torch.tensor(compressed_tokens)
# compressed_tokens_attention = torch.tensor(compressed_tokens_attention)
np.array(compressed_tokens).shape, np.array(
    replaced_original_tokens_batch
).shape, np.array(original_tokens["input_ids"]).shape

[33464, 6832, 11, 264, 1186, 2566, 315, 5662, 6832, 11, 702, 13791, 1506, 279, 18414, 315, 20443, 11229, 304, 3213, 1635, 13, 5542, 656, 59711, 9331, 311, 34549, 15712, 11, 1181, 8357, 525, 10454, 14756, 70767, 13, 1988, 1128, 6896, 374, 5538, 6832, 30, 1597, 1128, 3643, 432, 773, 7988, 30, 220]
4
[[33464, 6832, 11, 264, 1186, 2566, 315, 5662, 6832, 11, 702, 13791], [1506, 279, 18414, 315, 20443, 11229, 304, 3213, 1635, 13, 5542, 656], [59711, 9331, 311, 34549, 15712, 11, 1181, 8357, 525, 10454, 14756, 70767], [13, 1988, 1128, 6896, 374, 5538, 6832, 30, 1597, 1128, 3643, 432], [773, 7988, 30, 220]]
chunks_for_tokenization  {2}
=== replaced_original_tokens
Deep learning, a subfield of machine learning, has revolutionized the landscape of artificial intelligence in recent years. From self<|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|><|o

((2, 114), (2, 122), (2, 122))

In [23]:
original_tokens_torch = torch.tensor(
    original_tokens["input_ids"],
    device="cuda",
)
replaced_tokens_torch = torch.tensor(
    replaced_original_tokens_batch,
    device="cuda",
)
compressed_tokens_torch = torch.tensor(
    compressed_tokens,
    device="cuda",
)
# torch.Size([2, 51, 896])
original_embeds = model.get_input_embeddings()(original_tokens_torch)
replaced_embeds = model.get_input_embeddings()(replaced_tokens_torch)
# torch.Size([2, 35, 896]) 51 - 3*4*2 + 4*2
compressed_embeds_template = model.get_input_embeddings()(compressed_tokens_torch)
compressed_embeds_template.shape, original_embeds.shape, replaced_embeds.shape

(torch.Size([2, 114, 896]),
 torch.Size([2, 122, 896]),
 torch.Size([2, 122, 896]))

In [24]:
compressed_embeds_template.dtype, original_embeds.dtype, replaced_embeds.dtype

(torch.bfloat16, torch.bfloat16, torch.bfloat16)

In [33]:
text_token_id

151646

In [25]:
tokens_for_compression_mask = replaced_tokens_torch == text_token_id
compressed_tokens_mask = compressed_tokens_torch == text_token_id
# original_tokens_torch[tokens_for_compression_mask].shape
original_embeds[tokens_for_compression_mask].shape, compressed_tokens_torch[
    compressed_tokens_mask
].shape

(torch.Size([24, 896]), torch.Size([8]))

In [26]:
embeds_for_compression = original_embeds[tokens_for_compression_mask].reshape(
    -1,
    window_size,
    original_embeds.shape[-1],
)
embeds_for_compression.shape

torch.Size([8, 3, 896])

In [27]:
pooled_embeds = embed_pooler_v2(embeds_for_compression)
# pooled_embeds = pooled_embeds.reshape(pooled_embeds.shape[0], -1)
pooled_embeds.shape, pooled_embeds.dtype

(torch.Size([8, 1, 896]), torch.bfloat16)

In [28]:
import torch

self = torch.tensor(
    [
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
    ]
)
mask = torch.tensor(
    [
        [0, 0, 0, 0, 1],
        [1, 1, 0, 1, 1],
    ],
    dtype=torch.bool,
)
source = torch.tensor(
    [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
)
self.masked_scatter_(mask, source)

tensor([[0, 0, 0, 0, 1],
        [2, 3, 0, 4, 5]])

In [52]:
compressed_embeds_template.shape

torch.Size([2, 98, 896])

In [74]:
compressed_tokens_mask.shape

torch.Size([2, 98])

In [75]:
compressed_embeds_template.shape

torch.Size([2, 98, 896])

In [76]:
compressed_tokens_mask.shape
# .expand_as(compressed_embeds_template).shape

torch.Size([2, 98])

In [29]:
compressed_tokens_mask.unsqueeze(-1).expand_as(compressed_embeds_template).shape

torch.Size([2, 114, 896])

In [83]:
compressed_embeds_template.dtype, pooled_embeds.dtype

(torch.bfloat16, torch.float32)

In [30]:
compressed_embeds_template.masked_scatter(
    compressed_tokens_mask.unsqueeze(-1).expand_as(compressed_embeds_template),
    pooled_embeds,
).shape

torch.Size([2, 114, 896])

In [31]:
compressed_embeds_template = compressed_embeds_template.masked_scatter_(
    compressed_tokens_mask.unsqueeze(-1).expand_as(compressed_embeds_template),
    pooled_embeds,
)

In [None]:
compressed_embeds_template[compressed_tokens_mask][0]

In [None]:
compressed_embeds_template[compressed_tokens_mask][0]

In [32]:
labels = compressed_tokens_torch.clone()
text_token_id = tokenizer.encode("<|object_ref_start|>")[0]
labels[labels == text_token_id] = -100
# labels

In [33]:
model_output_1 = model(
    inputs_embeds=compressed_embeds_template,
    labels=labels,
    output_hidden_states=True,
)
model_output_1.loss

tensor(5.0430, device='cuda:0')

### R1 Chat template

In [3]:
from transformers import Qwen2ForCausalLM, Qwen2Model, AutoTokenizer
import torch

torch.set_grad_enabled(False)
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model = Qwen2ForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)
model = model.eval()

In [4]:
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "how many wings has a bird?"
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
print(text)
model_inputs = tokenizer([text], return_tensors="pt").to(device)

generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=1024)
generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


<｜begin▁of▁sentence｜>You are a helpful assistant.<｜User｜>how many wings has a bird?<｜Assistant｜><think>



"Okay, so I need to figure out how many wings a bird has. I'm not entirely sure, but I know that most birds have two wings. I think that's the common case, but maybe there are some exceptions. Let me think... I remember that some birds have more than two wings, like maybe three or four. For example, I think there's a bird called a trident that has three wings. But wait, I'm not sure if that's accurate. Maybe I should break it down.\n\nFirst, I should recall the basic structure of a bird's wings. Wings are part of the bird's airframe, which is the skin and structure that allows the bird to fly. Each wing is made up of several parts, like a root, a chord, and a span. The number of wings can vary depending on the bird's type.\n\nI think most birds have two wings. That makes sense because it's the simplest structure and allows for flight. But I've heard that some birds have more wings, especially those that are more specialized or have certain needs. For example, maybe some birds have thre

In [5]:
print(text)

<｜begin▁of▁sentence｜>You are a helpful assistant.<｜User｜>how many wings has a bird?<｜Assistant｜><think>



In [6]:
print(tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0])

Okay, so I need to figure out how many wings a bird has. I'm not entirely sure, but I know that most birds have two wings. I think that's the common case, but maybe there are some exceptions. Let me think... I remember that some birds have more than two wings, like maybe three or four. For example, I think there's a bird called a trident that has three wings. But wait, I'm not sure if that's accurate. Maybe I should break it down.

First, I should recall the basic structure of a bird's wings. Wings are part of the bird's airframe, which is the skin and structure that allows the bird to fly. Each wing is made up of several parts, like a root, a chord, and a span. The number of wings can vary depending on the bird's type.

I think most birds have two wings. That makes sense because it's the simplest structure and allows for flight. But I've heard that some birds have more wings, especially those that are more specialized or have certain needs. For example, maybe some birds have three win

In [71]:
tokenizer

LlamaTokenizerFast(name_or_path='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', vocab_size=151643, model_max_length=16384, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<｜begin▁of▁sentence｜>', 'eos_token': '<｜end▁of▁sentence｜>', 'pad_token': '<｜end▁of▁sentence｜>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	151643: AddedToken("<｜end▁of▁sentence｜>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	151644: AddedToken("<｜User｜>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),
	151645: AddedToken("<｜Assistant｜>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),
	151646: AddedToken("<｜begin▁of▁sentence｜>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	151647: AddedToken("<|EOT|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),
	151648: AddedToken("<think>", rstrip=False

In [None]:
tokenizer.encode("<|fim_pad|>", add_special_tokens=False)[0]

151662

### Generate Dataset

In [83]:
from huggingface_hub import InferenceClient

client = InferenceClient(f"http://{open('ip').read()}:1338")

user_prompt = "9.11 and 9.9 -- which is bigger? Let's think step by step."
output = client.chat.completions.create(
    model="tgi",
    messages=[
        {
            "role": "user",
            "content": user_prompt,
        },
    ],
    stream=False,
    max_tokens=10000,
    temperature=0.0,
)
result = output.choices[0].message.content
print("total_len", len(tokenizer.encode(result)))
print(result)

total_len 437
First, I need to compare the two numbers: 9.11 and 9.9.

To make the comparison easier, I'll align their decimal places by writing 9.9 as 9.90.

Now, I'll compare each corresponding digit from left to right.

Both numbers have 9 in the units place, so they are equal there.

Next, I'll look at the tenths place. In 9.11, the tenths digit is 1, while in 9.90, it's 9.

Since 9 is greater than 1, 9.90 is larger than 9.11.

Therefore, 9.9 is bigger than 9.11.
</think>

**Solution:**

To determine which number is larger between **9.11** and **9.9**, follow these steps:

1. **Align the Decimal Places:**
   
   To make the comparison easier, write both numbers with the same number of decimal places:
   
   \[
   9.11 \quad \text{and} \quad 9.90
   \]

2. **Compare the Numbers Digit by Digit:**
   
   - **Units Place:**
     
     Both numbers have **9** in the units place.
     
     \[
     9 \quad \text{vs} \quad 9
     \]
     
     They are equal in this place.

   - **Tenths 

In [None]:
import concurrent

from huggingface_hub import InferenceClient

client = InferenceClient(f"http://{open('ip').read()}:1338")


def gen_text(text):
    chat_completion = client.chat.completions.create(
        model="tgi",
        messages=[{"role": "user", "content": text}],
        temperature=0.0,
        max_tokens=5_000,
    )
    return chat_completion.choices[0].message.content


batch_size = 4
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]


def batch_generation(prompts):
    with concurrent.futures.ThreadPoolExecutor(max_workers=len(prompts)) as executor:
        prompts_results = list(executor.map(gen_text, prompts))
    return prompts_results


prompts_results = batch_generation(prompts)

for prompt, prompts_result in zip(prompts, prompts_results):
    print(f"Prompt: {prompt!r}, Generated text: {prompts_result!r}")

Prompt: 'Hello, my name is', Generated text: 'Okay, so I\'m trying to figure out what the user named is. They just said "Hello, my name is" and then stopped. Hmm, that\'s a bit confusing. Maybe they forgot to finish the sentence. I should probably ask them to complete it so I can help better. I don\'t want to assume anything about their identity or name. It\'s better to be safe than sorry. I\'ll let them know I\'m here to help once they provide their full name.\n</think>\n\nIt seems like your name is incomplete. Could you please provide your full name so I can assist you better?'
Prompt: 'The president of the United States is', Generated text: "Okay, so I need to figure out what the president of the United States is. Hmm, I'm not exactly sure, but I think the president is the head of the executive branch of the United States. I remember hearing that they're usually named after a person, like George W. Bush or Barack Obama. Maybe I should look up the current president to get the most ac

In [36]:
from datasets import load_dataset

dataset = load_dataset("Open-Orca/OpenOrca")
dataset = dataset["train"]
dataset = dataset.train_test_split(test_size=10_00, seed=42)
dataset = dataset["test"]
dataset

Dataset({
    features: ['id', 'system_prompt', 'question', 'response'],
    num_rows: 1000
})

In [None]:
from more_itertools import chunked
from tqdm.notebook import tqdm

batch_size = 32 * 2 * 2
questions = list(
    chunked(
        dataset["question"],
        batch_size,
    )
)
# 5 min 26 sec - 1000
correct_qa_pairs = []
for question_chunk in tqdm(questions):
    answers = batch_generation(question_chunk)
    for question, answer in zip(question_chunk, answers):
        if answer.count("</think>") == 1:
            correct_qa_pairs.append(
                [
                    question,
                    answer,
                ]
            )

  0%|          | 0/8 [00:00<?, ?it/s]

In [95]:
len(correct_qa_pairs), len(dataset["question"])

(905, 1000)

In [96]:
new_dataset = []
for question, answer in correct_qa_pairs:
    new_dataset.append(
        {
            "question": question,
            "answer": answer,
        }
    )
new_dataset[0]

{'question': 'NEW: Peterson to media on handcuffs, chains: "I got the bling. Can\'t complain" Drew Peterson arrested in the death of his third wife, Kathleen Savio. Renewed interest in Savio\'s death came after Peterson\'s fourth wife disappeared. Peterson, through his attorney, denies any wrongdoing in either case.\n\nWrite an article based on these highlights.',
 'answer': "Okay, so I need to write an article based on the highlights provided about Drew Peterson and his fourth wife, Kathleen Savio. The user has given me some specific points to include, so I should make sure to cover all of them.\n\nFirst, I should start by introducing Drew Peterson and his fourth wife, Kathleen Savio. I need to mention that he's been arrested in the deaths of his first and fourth wives. That's a significant event because it's the second time he's been involved in such a death, which could be a big deal.\n\nI should also note that the renewed interest in Savio's death came after his fourth wife disappe

In [37]:
from datasets import Dataset

new_dataset = Dataset.from_list(new_dataset)

NameError: name 'new_dataset' is not defined

In [None]:
new_dataset.push_to_hub("dim/open_orca_905_DeepSeek-R1-Distill-Qwen-1.5B")

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/dim/open_orca_905_DeepSeek-R1-Distill-Qwen-1.5B/commit/82cf73531b7b8daaee327e4813aa774abbc0fcc4', commit_message='Upload dataset', commit_description='', oid='82cf73531b7b8daaee327e4813aa774abbc0fcc4', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/dim/open_orca_905_DeepSeek-R1-Distill-Qwen-1.5B', endpoint='https://huggingface.co', repo_type='dataset', repo_id='dim/open_orca_905_DeepSeek-R1-Distill-Qwen-1.5B'), pr_revision=None, pr_num=None)

### Compress thinking

In [7]:
from datasets import load_dataset

dataset = load_dataset("dim/open_orca_905_DeepSeek-R1-Distill-Qwen-1.5B")
dataset = dataset["train"]
dataset

Dataset({
    features: ['question', 'answer'],
    num_rows: 905
})

In [39]:
think_end_id = tokenizer.encode("</think>")[0]
think_end_id

151646

In [8]:
messages = [
    # {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": dataset[0]["question"]},
    # {"role": "assistant", "content": result},
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
# print(text)
text

'<｜begin▁of▁sentence｜><｜User｜>NEW: Peterson to media on handcuffs, chains: "I got the bling. Can\'t complain" Drew Peterson arrested in the death of his third wife, Kathleen Savio. Renewed interest in Savio\'s death came after Peterson\'s fourth wife disappeared. Peterson, through his attorney, denies any wrongdoing in either case.\n\nWrite an article based on these highlights.<｜Assistant｜><think>\n'

In [9]:
tokenizer.decode(tokenizer.encode(text, add_special_tokens=False))

'<｜begin▁of▁sentence｜><｜User｜>NEW: Peterson to media on handcuffs, chains: "I got the bling. Can\'t complain" Drew Peterson arrested in the death of his third wife, Kathleen Savio. Renewed interest in Savio\'s death came after Peterson\'s fourth wife disappeared. Peterson, through his attorney, denies any wrongdoing in either case.\n\nWrite an article based on these highlights.<｜Assistant｜><think>\n'

In [10]:
tokenizer.encode("<｜begin▁of▁sentence｜><｜User｜>", add_special_tokens=False)

[151646, 151644]

In [11]:
# https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L150
messages = [
    {"role": "user", "content": dataset[0]["question"]},
    # {"role": "assistant", "content": dataset[0]["answer"]},
]
part_1 = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    # return_dict=True,
)
part_2 = tokenizer.encode(
    dataset[0]["answer"],
    add_special_tokens=False,
)
part_3 = tokenizer.encode(
    "<｜end▁of▁sentence｜>",
    add_special_tokens=False,
)
print(tokenizer.decode(part_1 + part_2 + part_3))
labels = len(part_1) * [-100] + part_2 + [-100]

<｜begin▁of▁sentence｜><｜User｜>NEW: Peterson to media on handcuffs, chains: "I got the bling. Can't complain" Drew Peterson arrested in the death of his third wife, Kathleen Savio. Renewed interest in Savio's death came after Peterson's fourth wife disappeared. Peterson, through his attorney, denies any wrongdoing in either case.

Write an article based on these highlights.<｜Assistant｜><think>
Okay, so I need to write an article based on the highlights provided about Drew Peterson and his fourth wife, Kathleen Savio. The user has given me some specific points to include, so I should make sure to cover all of them.

First, I should start by introducing Drew Peterson and his fourth wife, Kathleen Savio. I need to mention that he's been arrested in the deaths of his first and fourth wives. That's a significant event because it's the second time he's been involved in such a death, which could be a big deal.

I should also note that the renewed interest in Savio's death came after his fourth 

In [12]:
messages = [
    {"role": "user", "content": dataset[0]["question"]},
    {"role": "assistant", "content": dataset[0]["answer"]},
]

part_3 = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    # add_generation_prompt=False,
    # add_generation_prompt=True,
    # continue_final_message=True,
    # return_dict=True,
    # return_assistant_tokens_mask=True
)
part_3
print(tokenizer.decode(part_3))

<｜begin▁of▁sentence｜><｜User｜>NEW: Peterson to media on handcuffs, chains: "I got the bling. Can't complain" Drew Peterson arrested in the death of his third wife, Kathleen Savio. Renewed interest in Savio's death came after Peterson's fourth wife disappeared. Peterson, through his attorney, denies any wrongdoing in either case.

Write an article based on these highlights.<｜Assistant｜>

**Drew Peterson's Fourth Wife: A Legal Journey and Renewed Interest**

Drew Peterson, a controversial figure known for his controversial past, has recently come to light as the subject of legal proceedings involving his fourth wife, Kathleen Savio. Peterson, who has been arrested in the deaths of his first and fourth wives, has faced charges related to these tragic incidents. His attorney has denied any wrongdoing in either case, presenting a positive narrative for the legal community.

**Introduction to Drew Peterson and His Fourth Wife**

Drew Peterson, a controversial figure, has been the subject of l

### manual chat templating (single turn)

In [13]:
tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "question"},
        {"role": "assistant", "content": "answer"},
    ],
    tokenize=False,
)

'<｜begin▁of▁sentence｜><｜User｜>question<｜Assistant｜>answer<｜end▁of▁sentence｜>'

In [14]:
tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "question"},
    ],
    tokenize=False,
    add_generation_prompt=True,
)

'<｜begin▁of▁sentence｜><｜User｜>question<｜Assistant｜><think>\n'

In [15]:
content_compression_mask = []
part_1 = """<｜begin▁of▁sentence｜><｜User｜>"""
content_compression_mask += len(
    tokenizer.encode(
        part_1,
        add_special_tokens=False,
    )
) * [0]
print(content_compression_mask)
part_2 = 'NEW: Peterson to media on handcuffs, chains: "I got the bling. Can'
content_compression_mask += len(
    tokenizer.encode(
        part_2,
        add_special_tokens=False,
    )
) * [1]
print(content_compression_mask)
part_3 = "<｜Assistant｜><think>\n"
content_compression_mask += len(
    tokenizer.encode(
        part_3,
        add_special_tokens=False,
    )
) * [0]
print(content_compression_mask)
part_4 = "Peterson attorney that</think>Peterson is facing"
content_compression_mask += len(
    tokenizer.encode(
        part_4,
        add_special_tokens=False,
    )
) * [2]
print(content_compression_mask)
content_compression_mask += [0]
print(content_compression_mask)

[0, 0]
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2]
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0]


In [16]:
def tokenize_single_turn(question, answer):
    content_compression_mask = []

    part_1 = """<｜begin▁of▁sentence｜><｜User｜>"""
    content_compression_mask += len(
        tokenizer.encode(
            part_1,
            add_special_tokens=False,
        )
    ) * [0]

    # question
    part_2 = question
    content_compression_mask += len(
        tokenizer.encode(
            part_2,
            add_special_tokens=False,
        )
    ) * [1]

    part_3 = "<｜Assistant｜><think>\n"
    content_compression_mask += len(
        tokenizer.encode(
            part_3,
            add_special_tokens=False,
        )
    ) * [0]

    # answer
    part_4 = answer
    content_compression_mask += len(
        tokenizer.encode(
            part_4,
            add_special_tokens=False,
        )
    ) * [2]

    part_5 = "<｜end▁of▁sentence｜>"
    content_compression_mask += len(
        tokenizer.encode(
            part_5,
            add_special_tokens=False,
        )
    ) * [0]

    complete_prompt = ""
    for part in [part_1, part_2, part_3, part_4, part_5]:
        complete_prompt += part
    original_tokens = tokenizer.encode(
        complete_prompt,
        add_special_tokens=False,
    )
    attention_mask = len(original_tokens) * [1]
    return {
        "input_ids": original_tokens,
        "attention_mask": attention_mask,
        "content_compression_mask": content_compression_mask,
    }


example = tokenize_single_turn(
    question=dataset[6]["question"],
    answer=dataset[6]["answer"],
)
print(tokenizer.decode(example["input_ids"]))
print(example["content_compression_mask"])

<｜begin▁of▁sentence｜><｜User｜>In this task you will be given an arithmetic operation in Italian and you have to find its answer. The operations 'addition' and 'subtraction' have been replaced with their italian translations i.e you need to perform addition when you see 'aggiunta' and subtraction in case of 'sottrazione'.
Q: 8680 sottrazione 9504 sottrazione 9115 aggiunta 4098
A: <｜Assistant｜><think>
Okay, so I've got this arithmetic problem in Italian, and I need to figure out the answer. The problem is: 8680 sottrazione 9504 sottrazione 9115 aggiunta 4098. Hmm, let me break this down step by step because I'm not entirely sure how the operations are being translated.

First, I know that in English, "sottrazione" means subtraction and "aggiunta" means addition. So, in this problem, whenever I see "sottrazione," I should subtract, and "aggiunta" means I should add. Got it.

Let me rewrite the problem with the correct operations. So, it should be: 8680 minus 9504 minus 9115 plus 4098. Wait

### Compression tokenization

In [17]:
tokenizer.encode("<|fim_pad|>", add_special_tokens=False)

[151662]

- сжимаем последовательности последовательно
- К примеру сначала мы сжимаем первые последовательные 12 токенов, потом первые 24 и тд
- чтобы на инференсе мы могли просто после каждой генерации 12 токенов, сжать ее и генерировать дальше  

In [85]:
from pprint import pprint
from more_itertools import chunked
import numpy as np
import random

text_token_id = tokenizer.encode("<|fim_pad|>", add_special_tokens=False)[0]
eos_token_id = tokenizer.encode("<｜end▁of▁sentence｜>", add_special_tokens=False)[0]

window_size = 4

batch = 2

dataset_batch = [dataset[i] for i in range(batch)]
dataset_batch = [tokenize_single_turn(**item) for item in dataset_batch]
aligned_batch = []
# определяем какие именно элементы мы хотим сжимать,а какие
# не помещаются не помещаются в чанк размером window_size
for tokens in dataset_batch:
    input_ids = np.array(tokens["input_ids"])
    content_mask = np.array(tokens["content_compression_mask"])
    user_part = input_ids[content_mask == 1]
    total_parts = len(user_part) // window_size
    new_len_part_1 = total_parts * window_size
    mask_end_pos = np.where(content_mask == 1)[0][-1]
    # print(content_mask.tolist())
    content_mask[
        mask_end_pos - (len(user_part) - new_len_part_1) + 1 : mask_end_pos + 1
    ] = 0
    # print(content_mask.tolist())
    # print(user_part.shape, total_parts, new_len_part_1)

    answer_part = input_ids[content_mask == 2]
    total_parts = len(answer_part) // window_size
    new_len_part_2 = total_parts * window_size
    mask_end_pos = np.where(content_mask == 2)[0][-1]
    # print(content_mask.tolist())
    content_mask[
        mask_end_pos - (len(answer_part) - new_len_part_2) + 1 : mask_end_pos + 1
    ] = 0
    # print(content_mask.tolist())
    # print(answer_part.shape, total_parts, new_len_part_2)
    # content_mask[content_mask == 2] = 1
    # print(content_mask.tolist())
    aligned_batch.append(
        {
            "input_ids": tokens["input_ids"],
            "content_compression_mask": content_mask,
            "attention_mask": tokens["attention_mask"],
        }
    )
    # break

train_examples = []
train_examples_amount = 2
for tokens in aligned_batch:
    input_ids = np.array(tokens["input_ids"])
    content_mask = np.array(tokens["content_compression_mask"])
    for chunks_amount in range(train_examples_amount):
        # фикусируемся на сжатии ответа модели, выбираем 2
        start_pos = np.where(content_mask == 2)[0][0]
        input_ids[start_pos : start_pos + (chunks_amount + 1) * window_size] = (
            text_token_id
        )
        compressed_input_ids = input_ids[:start_pos].tolist()
        compressed_input_ids += [text_token_id] * (chunks_amount + 1)
        compressed_input_ids += input_ids[
            start_pos + (chunks_amount + 1) * window_size :
        ].tolist()
        # print(text_token_id)
        # print(" ".join(f"{num:>{8}}" for num in content_mask.tolist()))
        # print(" ".join(f"{num:>{8}}" for num in input_ids.tolist()))
        # print(" ".join(f"{num:>{8}}" for num in compressed_input_ids))
        # print("===")
        train_examples.append(
            {
                "replaced_original_tokens": input_ids.tolist(),
                "compressed_input_ids": compressed_input_ids,
                "original_tokens": tokens["input_ids"],
            }
        )
    # break


# pad to the same length
new_inputs = {}
for item in train_examples:
    for key, value in item.items():
        if not key in new_inputs:
            new_inputs[key] = []
        new_inputs[key].append(value)


for key, value in new_inputs.items():
    new_inputs[key] = tokenizer.pad(
        {
            "input_ids": new_inputs[key],
        },
        padding=True,
        # return_tensors="pt",
    )
new_inputs

{'replaced_original_tokens': {'input_ids': [[151646, 151644, 20594, 25, 39791, 311, 3687, 389, 65045, 32568, 11, 26179, 25, 330, 40, 2684, 279, 1501, 287, 13, 2980, 944, 27911, 1, 40108, 39791, 12517, 304, 279, 4545, 315, 806, 4843, 7403, 11, 64063, 20079, 815, 13, 48986, 291, 2734, 304, 20079, 815, 594, 4545, 3697, 1283, 39791, 594, 11737, 7403, 28396, 13, 39791, 11, 1526, 806, 13747, 11, 46491, 894, 64228, 304, 2987, 1142, 382, 7985, 458, 4549, 3118, 389, 1493, 21314, 13, 151645, 151648, 198, 151662, 151662, 151662, 151662, 1184, 311, 3270, 458, 4549, 3118, 389, 279, 21314, 3897, 911, 40108, 39791, 323, 806, 11737, 7403, 11, 64063, 20079, 815, 13, 576, 1196, 702, 2661, 752, 1045, 3151, 3501, 311, 2924, 11, 773, 358, 1265, 1281, 2704, 311, 3421, 678, 315, 1105, 382, 5338, 11, 358, 1265, 1191, 553, 31918, 40108, 39791, 323, 806, 11737, 7403, 11, 64063, 20079, 815, 13, 358, 1184, 311, 6286, 429, 566, 594, 1012, 12517, 304, 279, 16375, 315, 806, 1156, 323, 11737, 38620, 13, 2938, 594, 26

In [90]:
batch_index = 2
print(text_token_id, eos_token_id)
print(
    " ".join(
        f"{num:>{8}}"
        for num in new_inputs["replaced_original_tokens"]["input_ids"][batch_index]
    )
)
print(
    " ".join(
        f"{num:>{8}}"
        for num in new_inputs["replaced_original_tokens"]["attention_mask"][batch_index]
    )
)

151662 151643
  151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151643   151

In [91]:
new_inputs.keys()

dict_keys(['replaced_original_tokens', 'compressed_input_ids', 'original_tokens'])

In [92]:
original_tokens_torch = torch.tensor(
    new_inputs["original_tokens"]["input_ids"],
    device="cuda",
)
replaced_tokens_torch = torch.tensor(
    new_inputs["replaced_original_tokens"]["input_ids"],
    device="cuda",
)
compressed_tokens_torch = torch.tensor(
    new_inputs["compressed_input_ids"]["input_ids"],
    device="cuda",
)
# torch.Size([2, 51, 896])
original_embeds = model.get_input_embeddings()(original_tokens_torch)
replaced_embeds = model.get_input_embeddings()(replaced_tokens_torch)
# torch.Size([2, 35, 896]) 51 - 3*4*2 + 4*2
compressed_embeds_template = model.get_input_embeddings()(compressed_tokens_torch)
compressed_embeds_template.shape, original_embeds.shape, replaced_embeds.shape

(torch.Size([4, 1111, 1536]),
 torch.Size([4, 1114, 1536]),
 torch.Size([4, 1114, 1536]))

In [93]:
tokens_for_compression_mask = replaced_tokens_torch == text_token_id
compressed_tokens_mask = compressed_tokens_torch == text_token_id
# original_tokens_torch[tokens_for_compression_mask].shape
original_embeds[tokens_for_compression_mask].shape, compressed_tokens_torch[
    compressed_tokens_mask
].shape

(torch.Size([24, 1536]), torch.Size([6]))

In [94]:
embeds_for_compression = original_embeds[tokens_for_compression_mask].reshape(
    -1,
    window_size,
    original_embeds.shape[-1],
)
embeds_for_compression.shape

torch.Size([6, 4, 1536])

In [98]:
class Qwen2ModelEmbedPoolerV2(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config)
        self.post_init()

    def forward(self, input_embeds):
        # print(input_embeds.dtype)
        input_embeds = self.model(
            inputs_embeds=input_embeds,
            output_hidden_states=True,
        )[0]
        # print(input_embeds.dtype)
        input_embeds = input_embeds.sum(1) / torch.tensor(
            input_embeds.shape[1],
            device=input_embeds.device,
        )
        # print(input_embeds.dtype)
        input_embeds = input_embeds.unsqueeze(1)
        return input_embeds


embed_pooler_v3 = Qwen2ModelEmbedPoolerV2.from_pretrained(
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)

In [99]:
pooled_embeds = embed_pooler_v3(embeds_for_compression)
# pooled_embeds = pooled_embeds.reshape(pooled_embeds.shape[0], -1)
pooled_embeds.shape, pooled_embeds.dtype

(torch.Size([6, 1, 1536]), torch.bfloat16)

In [None]:
compressed_embeds_template = compressed_embeds_template.masked_scatter_(
    compressed_tokens_mask.unsqueeze(-1).expand_as(compressed_embeds_template),
    pooled_embeds,
)

In [100]:
labels = compressed_tokens_torch.clone()
labels[labels == text_token_id] = -100
labels[labels == eos_token_id] = -100
# labels

In [101]:
model_output_1 = model(
    inputs_embeds=compressed_embeds_template,
    labels=labels,
    output_hidden_states=True,
)
model_output_1.loss

tensor(1.5505, device='cuda:0')

#### inference stage

#### Default text generation

In [None]:
generated_tokens = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "2+2*10"},
    ],
    tokenize=True,
    add_generation_prompt=True,
)
generated_tokens = torch.tensor(generated_tokens).unsqueeze(0).cuda()
max_steps = 400
for _ in range(max_steps):
    logits = model(input_ids=generated_tokens).logits
    top = logits.argmax(-1)[-1][-1]
    # print(top)
    generated_tokens = torch.cat([generated_tokens, top.reshape(1, 1)], dim=1)
print(tokenizer.decode(generated_tokens[-1]))
# break

<｜begin▁of▁sentence｜><｜User｜>2+2*10<｜Assistant｜><think>
First, I need to evaluate the expression 2 + 2 * 10.

According to the order of operations, I should perform the multiplication before the addition.

Calculating 2 * 10 gives me 20.

Then, I add 2 to 20, resulting in 22.

Therefore, the final answer is 22.
</think>

Sure, let's solve the expression step by step:

\[
2 + 2 \times 10
\]

**Step 1: Perform the multiplication**

According to the order of operations (PEMDAS/BODMAS), multiplication comes before addition.

\[
2 \times 10 = 20
\]

**Step 2: Add the result to 2**

\[
2 + 20 = 22
\]

**Final Answer:**

\[
\boxed{22}
\]<｜end▁of▁sentence｜><｜begin▁of▁sentence｜>

Sure, let's solve the expression step by step:

\[
2 + 2 \times 10
\]

**Step 1: Perform the multiplication**

According to the order of operations (PEMDAS/BODMAS), multiplication comes before addition.

\[
2 \times 10 = 20
\]

**Step 2: Add the result to 2**

\[
2 + 20 = 22
\]

**Final Answer:**

\[
\boxed{22}
\]<｜end

#### generation with compression

In [None]:
generated_tokens = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "2+2*10"},
    ],
    tokenize=True,
    add_generation_prompt=True,
)

generated_tokens = torch.tensor(generated_tokens).unsqueeze(0).cuda()
max_steps = 10
temp_gen_size = 0
compression_started = False
window_size = 4
embeddings_cache = model.get_input_embeddings()(generated_tokens)
for _ in range(max_steps):
    if compression_started:
        if temp_gen_size == window_size:
            print("state 1")
            new_tokens_for_compression = generated_tokens[:, -4:]
            new_embeds_for_compression = model.get_input_embeddings()(
                new_tokens_for_compression
            )
            compressed_part = embed_pooler_v3(new_embeds_for_compression)
            input_embeds = torch.cat(
                [embeddings_cache, compressed_part],
                dim=1,
            )
            logits = model(input_embeds=input_embeds).logits
        break
        pass
    else:
        logits = model(input_embeds=embeddings_cache).logits
        top = logits.argmax(-1)[-1][-1]
        temp_gen_size += 1
        # print(top)
        generated_tokens = torch.cat([generated_tokens, top.reshape(1, 1)], dim=1)
        if temp_gen_size == window_size:
            compression_started = True

print(tokenizer.decode(generated_tokens[-1]))
# break

state 1


ValueError: You must specify exactly one of input_ids or inputs_embeds

In [150]:
embeddings_cache.shape

torch.Size([1, 11, 1536])

In [136]:
generated_tokens = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "2+2*10"},
    ],
    tokenize=True,
    add_generation_prompt=True,
)

generated_tokens = torch.tensor(generated_tokens).unsqueeze(0).cuda()
generated_tokens[:, :-4]

tensor([[151646, 151644,     17,     10,     17,      9,     16]],
       device='cuda:0')

In [151]:
compressed_part = embed_pooler_v3(new_embeds_for_compression)
torch.cat([embeddings_cache, compressed_part], dim=1).shape

torch.Size([1, 12, 1536])