In [1]:
from transformers import Qwen2ForCausalLM, Qwen2Model, AutoTokenizer
import torch

torch.set_grad_enabled(False)
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
model = Qwen2ForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)
model = model.eval()

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [2]:
model.get_input_embeddings().weight.dtype, model.dtype

(torch.bfloat16, torch.bfloat16)

In [3]:
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "Give me a short introduction to large language model..1"
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)

generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512)
generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


'A large language model is a type of artificial intelligence that can generate human-like text based on input data, such as natural language sentences or questions. These models are designed to mimic the complexity and creativity of human language, allowing them to produce coherent and meaningful responses that can be used in various applications, including speech recognition, chatbots, and virtual assistants. Large language models have been trained on vast amounts of text data, enabling them to understand and respond to complex queries and ideas more accurately than traditional AI systems.'

In [6]:
model_output_1 = model(
    **model_inputs,
    output_hidden_states=True,
)
model_output_1.hidden_states[-1].shape, model_output_1.hidden_states[-1].dtype

(torch.Size([1, 30, 896]), torch.bfloat16)

In [7]:
model_output_2 = model(
    inputs_embeds=torch.cat(
        [
            model_output_1.hidden_states[-1],
            model_output_1.hidden_states[-1],
        ],
        dim=0,
    ),
    attention_mask=model_inputs["attention_mask"],
    output_hidden_states=True,
)
model_output_2.hidden_states[-1].shape

torch.Size([2, 30, 896])

In [8]:
model_output_1.hidden_states[-1].shape

torch.Size([1, 30, 896])

In [9]:
model_output_1.hidden_states[-1].reshape(10, 3, -1).shape

torch.Size([10, 3, 896])

In [11]:
model_output_2 = model(
    inputs_embeds=torch.cat(
        [
            model_output_1.hidden_states[-1],
            model_output_1.hidden_states[-1],
        ],
        dim=0,
    ),
    attention_mask=model_inputs["attention_mask"],
    output_hidden_states=True,
)
model_output_2.hidden_states[-1].shape, model_output_2.hidden_states[-1].dtype

(torch.Size([2, 30, 896]), torch.bfloat16)

In [12]:
model_output_3 = model(
    inputs_embeds=model_output_1.hidden_states[-1].reshape(10, 3, -1),
    # attention_mask=model_inputs["attention_mask"],
    output_hidden_states=True,
)
model_output_3.hidden_states[-1].shape

torch.Size([10, 3, 896])

In [13]:
model_output_1.hidden_states[-1].reshape(1, 10, 3, -1).shape

torch.Size([1, 10, 3, 896])

In [12]:
# model(
#     inputs_embeds=model_output_1.hidden_states[-1].reshape(1, 10, 3, -1),
#     attention_mask=model_inputs["attention_mask"],
#     output_hidden_states=True,
# )

In [36]:
class Qwen2ModelEmbedPooler(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config).cuda()
        self.post_init()

    def forward(self, input_embeds, attention_mask, window_size=3):
        # разбиваем входящие эмбединги на бакеты и усредняем их
        sum_mask = attention_mask.reshape(
            attention_mask.shape[0],
            window_size,
            attention_mask.shape[1] // window_size,
            -1,
        ).sum(1)
        embeds_sum = input_embeds.reshape(
            attention_mask.shape[0],
            window_size,
            attention_mask.shape[1] // window_size,
            -1,
        ).sum(1)
        input_embeds = embeds_sum / sum_mask
        input_embeds = self.model(
            inputs_embeds=input_embeds,
            output_hidden_states=True,
        )
        return input_embeds


embed_pooler = Qwen2ModelEmbedPooler.from_pretrained("Qwen/Qwen2.5-0.5B")
result = embed_pooler(
    # model_output_1.hidden_states[-1],
    torch.cat(
        [
            model_output_1.hidden_states[-1],
            model_output_1.hidden_states[-1],
        ],
        dim=0,
    ),
    torch.cat(
        [
            model_inputs["attention_mask"],
            model_inputs["attention_mask"],
        ],
        dim=0,
    ),
)
result[0].shape

torch.Size([2, 10, 896])

In [15]:
result.last_hidden_state.shape

torch.Size([2, 10, 896])

In [9]:
torch.tensor([1, 2, 3, 4, 5, 6]).reshape(3, -1)

tensor([[1, 2],
        [3, 4],
        [5, 6]])

In [None]:
from typing import Callable, List, Optional, Tuple, Union
from transformers.cache_utils import (
    Cache,
    DynamicCache,
    SlidingWindowCache,
    StaticCache,
)
from transformers.processing_utils import Unpack
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import LossKwargs
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
)


class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


class Qwen2ForCausalEmbedModeling(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = torch.nn.Linear(
            config.hidden_size,
            config.vocab_size,
            bias=False,
        )
        self.embed_pooler = Qwen2ModelEmbedPooler.from_pretrained("Qwen/Qwen2.5-1.5B")

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        # pooled_mask: Optional[torch.Tensor] = None,
        **kwargs: Unpack[KwargsForCausalLM],
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        inputs_embeds_tokens = self.model.embed_tokens(input_ids)
        # if pixel_values is not None:
        #     pixel_values = pixel_values.type(self.visual.get_dtype())
        #     image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
        #     n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
        #     n_image_features = image_embeds.shape[0]
        #     if n_image_tokens != n_image_features:
        #         raise ValueError(
        #             f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
        #         )
        #     image_mask = (
        #         (input_ids == self.config.image_token_id)
        #         .unsqueeze(-1)
        #         .expand_as(inputs_embeds)
        #         .to(inputs_embeds.device)
        #     )
        #     image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
        #     inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
        window_size = torch.tensor(
            [
                [4],
                [3],
            ]
        ).cuda()
        tokens_amount = torch.tensor(
            [
                [2],
                [4],
            ]
        ).cuda()
        lengths = window_size * tokens_amount
        # max_len = lengths.max()
        max_len = inputs_embeds_tokens.shape[1]
        batch_size = window_size.shape[0]
        pooled_mask = (
            torch.arange(max_len, device=device)
            .unsqueeze(0)
            .expand(batch_size, max_len)
        )
        pooled_mask = (lengths >= pooled_mask).to(torch.long)
        pooled_embeds = inputs_embeds * pooled_mask.to(inputs_embeds.dtype)
        pooled_embeds = self.embed_pooler(pooled_embeds, pooled_mask)
        embed_mask = (
            (input_ids == self.config.image_token_id)
            .unsqueeze(-1)
            .expand_as(inputs_embeds)
            .to(inputs_embeds.device)
        )
        inputs_embeds = inputs_embeds.masked_scatter(embed_mask, pooled_embeds)

        # Из-за смешанной структуры, мы будем всегда подавать только эмбединги
        # Идея позаимствована из qwen2vl
        outputs = self.model(
            input_ids=None,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs[0]
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = (
            slice(-logits_to_keep, None)
            if isinstance(logits_to_keep, int)
            else logits_to_keep
        )
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(
                logits=logits,
                labels=labels,
                vocab_size=self.config.vocab_size,
                **kwargs,
            )

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

## Использование концепции qwen2vl для текстовых токенов и эмбедингов

In [14]:
text_example = "Deep learning, a subfield of machine learning, has revolutionized the landscape of artificial intelligence in recent years. From self-driving cars to personalized medicine, its applications are becoming increasingly pervasive. But what exactly is deep learning? And what makes it so powerful?"

tokenizer.encode(text_example)[:10]

[33464, 6832, 11, 264, 1186, 2566, 315, 5662, 6832, 11]

In [15]:
tokenizer.encode("<|image_pad|>")

[151655]

In [16]:
tokenizer.encode("text_example")

[1318, 39304]

### 1. Случай первый - мы сжимаем только текстовые токены (максимально упрощенный вариант)

In [17]:
window_size = 3
chunk_size = 4
new_tokens = []
original_tokens = tokenizer.encode(text_example)
new_tokens += tokenizer.encode("<|object_ref_start|>") * chunk_size
new_tokens += original_tokens[chunk_size * window_size :]
tokenizer.decode(new_tokens)

'<|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|>ized the landscape of artificial intelligence in recent years. From self-driving cars to personalized medicine, its applications are becoming increasingly pervasive. But what exactly is deep learning? And what makes it so powerful?'

In [18]:
tokenizer.decode(original_tokens)

'Deep learning, a subfield of machine learning, has revolutionized the landscape of artificial intelligence in recent years. From self-driving cars to personalized medicine, its applications are becoming increasingly pervasive. But what exactly is deep learning? And what makes it so powerful?'

In [21]:
original_tokens_torch = torch.tensor(
    [
        original_tokens,
        original_tokens,
    ],
    device="cuda",
)
new_tokens_torch = torch.tensor(
    [
        new_tokens,
        new_tokens,
    ],
    device="cuda",
)
# torch.Size([2, 51, 896])
original_embeds = model.get_input_embeddings()(original_tokens_torch)
# torch.Size([2, 43, 896]) 51 - 3*4 + 4
compressed_embeds_template = model.get_input_embeddings()(new_tokens_torch)
compressed_embeds_template.shape, compressed_embeds_template.dtype, original_embeds.dtype

(torch.Size([2, 43, 896]), torch.bfloat16, torch.bfloat16)

In [None]:
class Qwen2ModelEmbedPoolerV2(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config)
        self.post_init()

    def forward(self, input_embeds):
        # print(input_embeds.dtype)
        input_embeds = self.model(
            inputs_embeds=input_embeds,
            output_hidden_states=True,
        )[0]
        # print(input_embeds.dtype)
        input_embeds = input_embeds.sum(1) / torch.tensor(
            input_embeds.shape[1],
            device=input_embeds.device,
        )
        # print(input_embeds.dtype)
        input_embeds = input_embeds.unsqueeze(1)
        return input_embeds


embed_pooler_v2 = Qwen2ModelEmbedPoolerV2.from_pretrained(
    "Qwen/Qwen2.5-0.5B",
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)

# torch.Size([8, 3, 896])
compressed_embeds = original_embeds[:, : chunk_size * window_size].reshape(
    chunk_size * original_embeds.shape[0],
    window_size,
    -1,
)
compressed_embeds.shape
# torch.Size([8, 1, 896])
pooled_embeds = embed_pooler_v2(compressed_embeds)
# torch.Size([2, 4, 896])
pooled_embeds = pooled_embeds.reshape(
    original_embeds.shape[0],
    chunk_size,
    -1,
)
# torch.Size([2, 43, 896])
compressed_embeds_template[:, :chunk_size] = pooled_embeds
compressed_embeds_template.shape, compressed_embeds_template.dtype, pooled_embeds.dtype

(torch.Size([2, 43, 896]), torch.bfloat16, torch.bfloat16)

In [25]:
embed_pooler_v2.model.dtype

torch.bfloat16

In [26]:
labels = new_tokens_torch.clone()
text_token_id = tokenizer.encode("<|object_ref_start|>")[0]
labels[labels == text_token_id] = -100
labels

tensor([[ -100,  -100,  -100,  -100,  1506,   279, 18414,   315, 20443, 11229,
           304,  3213,  1635,    13,  5542,   656, 59711,  9331,   311, 34549,
         15712,    11,  1181,  8357,   525, 10454, 14756, 70767,    13,  1988,
          1128,  6896,   374,  5538,  6832,    30,  1597,  1128,  3643,   432,
           773,  7988,    30],
        [ -100,  -100,  -100,  -100,  1506,   279, 18414,   315, 20443, 11229,
           304,  3213,  1635,    13,  5542,   656, 59711,  9331,   311, 34549,
         15712,    11,  1181,  8357,   525, 10454, 14756, 70767,    13,  1988,
          1128,  6896,   374,  5538,  6832,    30,  1597,  1128,  3643,   432,
           773,  7988,    30]], device='cuda:0')

In [27]:
model_output_1 = model(
    inputs_embeds=compressed_embeds_template,
    labels=labels,
    output_hidden_states=True,
)
model_output_1.loss

tensor(8.7288, device='cuda:0')

### 2. Усложненная версия токенизации текста. Сжатые токены возникают между обычным тектом

Входная последовательность может быть 5 видов.

- когда у нас на входе только текст, который мы просто моделируем 
- когда на входе текст, но мы хотим сжать его некоторые части
- когда на входе текст и эмбединги с последнего слоя, где мы хотим сжать только части текста
- когда на входе текст и эмбединги с последнего слоя, где мы хотим сжать, часть текста и эмбедингов, совместно (токены переводим в эмбединги и сжимаем вместе с hidden states)
- когда на входе текст и эмбединги с последнего слоя, где мы хотим сжать, только эмбединги

Как мы решаем какие токены хотим сжать?
- никак. Просто говорим сожми с такого по такой.
- это норм только на этапе обучения, так как мы можем перебрать все комбинации.
- не норм на этапе инференса. например мы можем сжимать после каждых сгенеренных 10 токенов или 100, 1000. а какое окно контекста? 3, 10, 100? А что если менять стратегию. Сначала мы сжимали с окном 5 токенов, потом 20?
- Кажется что эти гиперпараметры можно найти простым перебором на валидации. Однако перебор стратегии уже не кажется таким очевидным. Напрашивается RL.


In [28]:
text_example = [
    "Deep learning, a subfield of machine learning, has revolutionized the landscape of artificial intelligence in recent years. From self-driving cars to personalized medicine, its applications are becoming increasingly pervasive. But what exactly is deep learning? And what makes it so powerful? ",
    """Before diving into the "deep" part, let's establish a foundation with the basics of artificial neural networks (ANNs). An ANN consists of interconnected nodes, called neurons, organized in layers. These layers typically include: These layers typically include: These layers typically include: networks (ANNs). An ANN consists of interconnected nodes, called neurons, organized in layers. These layers typically include: These layers typically include: These layers typically include: networks (ANNs). An ANN consists of interconnected nodes, called neurons, organized in layers. These layers typically include: These layers typically include: These layers typically include:""",
]

tokenizer.batch_encode_plus(
    text_example,
    padding=True,
)

{'input_ids': [[33464, 6832, 11, 264, 1186, 2566, 315, 5662, 6832, 11, 702, 13791, 1506, 279, 18414, 315, 20443, 11229, 304, 3213, 1635, 13, 5542, 656, 59711, 9331, 311, 34549, 15712, 11, 1181, 8357, 525, 10454, 14756, 70767, 13, 1988, 1128, 6896, 374, 5538, 6832, 30, 1597, 1128, 3643, 432, 773, 7988, 30, 220, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643], [10227, 42415, 1119, 279, 330, 32880, 1, 949, 11, 1077, 594, 5695, 264, 16266, 448, 279, 31774, 315, 20443, 29728, 14155, 320, 

In [24]:
torch.tensor(
    tokenizer.batch_encode_plus(
        text_example,
        padding=True,
    )["input_ids"]
).shape

torch.Size([2, 122])

In [29]:
tokenizer.decode(151643)

'<|endoftext|>'

In [None]:
tokenizer

In [30]:
from pprint import pprint
from more_itertools import chunked
import numpy as np
import random

text_token_id = tokenizer.encode("<|object_ref_start|>")[0]
eos_token_id = tokenizer.encode("<|endoftext|>")[0]

window_size = 3
chunk_size = 4
new_tokens = []
original_tokens = tokenizer.batch_encode_plus(
    text_example,
    padding=True,
)
compressed_tokens = []
replaced_original_tokens_batch = []
# original_tokens
for tokens in original_tokens["input_ids"]:
    original_lines = np.array(tokens)
    pure_tokens = original_lines[original_lines != eos_token_id].tolist()
    print(pure_tokens)
    full_chunks_amount = len(pure_tokens) // (window_size * chunk_size)
    print(full_chunks_amount)
    max_percent = 0.8
    pure_tokens_chunks = list(chunked(pure_tokens, window_size * chunk_size))
    print(pure_tokens_chunks)
    prob = 0.3
    random_mask = np.random.random(int(full_chunks_amount * max_percent))
    mask = random_mask < prob
    chunks_for_tokenization = np.where(mask)[0].tolist()
    chunks_for_tokenization = set(chunks_for_tokenization)
    if len(chunks_for_tokenization) == 0:
        chunks_for_tokenization = set(
            [
                random.randint(
                    0,
                    int(full_chunks_amount * max_percent),
                )
            ]
        )
    print("chunks_for_tokenization ", chunks_for_tokenization)
    replaced_original_tokens = []
    new_input_tokens = []
    for i, tokens in enumerate(pure_tokens_chunks):
        if i in chunks_for_tokenization:
            replaced_original_tokens.extend([text_token_id] * len(tokens))
            new_input_tokens.extend([text_token_id] * chunk_size)
        else:
            replaced_original_tokens.extend(tokens)
            new_input_tokens.extend(tokens)
    print("=== replaced_original_tokens")
    print(tokenizer.decode(replaced_original_tokens))
    print("=== new_input_tokens")
    print(tokenizer.decode(new_input_tokens))
    print(len(new_input_tokens))
    compressed_tokens.append(new_input_tokens)
    replaced_original_tokens_batch.append(replaced_original_tokens)
    print("==")
    print("==")
    print("==")

compressed_tokens_attention = []
max_compressed_len = max([len(item) for item in compressed_tokens])
max_replaced_len = max([len(item) for item in replaced_original_tokens_batch])

for compressed_seq, replaced_seq in zip(
    compressed_tokens,
    replaced_original_tokens_batch,
):
    compressed_seq_len = len(compressed_seq)
    replaced_seq_len = len(replaced_seq)
    attention_mask = [1] * (compressed_seq_len)

    if compressed_seq_len < max_compressed_len:
        compressed_seq += [eos_token_id] * (max_compressed_len - compressed_seq_len)
        attention_mask += [0] * (max_compressed_len - compressed_seq_len)

    if compressed_seq_len < max_replaced_len:
        replaced_seq += [eos_token_id] * (max_replaced_len - replaced_seq_len)

    compressed_tokens_attention.append(attention_mask)
# len(compressed_tokens[0]), len(compressed_tokens[1])
# compressed_tokens = torch.tensor(compressed_tokens)
# compressed_tokens_attention = torch.tensor(compressed_tokens_attention)
np.array(compressed_tokens).shape, np.array(
    replaced_original_tokens_batch
).shape, np.array(original_tokens["input_ids"]).shape

[33464, 6832, 11, 264, 1186, 2566, 315, 5662, 6832, 11, 702, 13791, 1506, 279, 18414, 315, 20443, 11229, 304, 3213, 1635, 13, 5542, 656, 59711, 9331, 311, 34549, 15712, 11, 1181, 8357, 525, 10454, 14756, 70767, 13, 1988, 1128, 6896, 374, 5538, 6832, 30, 1597, 1128, 3643, 432, 773, 7988, 30, 220]
4
[[33464, 6832, 11, 264, 1186, 2566, 315, 5662, 6832, 11, 702, 13791], [1506, 279, 18414, 315, 20443, 11229, 304, 3213, 1635, 13, 5542, 656], [59711, 9331, 311, 34549, 15712, 11, 1181, 8357, 525, 10454, 14756, 70767], [13, 1988, 1128, 6896, 374, 5538, 6832, 30, 1597, 1128, 3643, 432], [773, 7988, 30, 220]]
chunks_for_tokenization  {0}
=== replaced_original_tokens
<|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|><|object_ref_start|>ized the landscape of artificial intelligence in recent years. From self-driving cars to persona

((2, 98), (2, 122), (2, 122))

In [42]:
original_tokens_torch = torch.tensor(
    original_tokens["input_ids"],
    device="cuda",
)
replaced_tokens_torch = torch.tensor(
    replaced_original_tokens_batch,
    device="cuda",
)
compressed_tokens_torch = torch.tensor(
    compressed_tokens,
    device="cuda",
)
# torch.Size([2, 51, 896])
original_embeds = model.get_input_embeddings()(original_tokens_torch)
replaced_embeds = model.get_input_embeddings()(replaced_tokens_torch)
# torch.Size([2, 35, 896]) 51 - 3*4*2 + 4*2
compressed_embeds_template = model.get_input_embeddings()(compressed_tokens_torch)
compressed_embeds_template.shape, original_embeds.shape, replaced_embeds.shape

(torch.Size([2, 98, 896]),
 torch.Size([2, 122, 896]),
 torch.Size([2, 122, 896]))

In [32]:
compressed_embeds_template.dtype, original_embeds.dtype, replaced_embeds.dtype

(torch.bfloat16, torch.bfloat16, torch.bfloat16)

In [33]:
text_token_id

151646

In [43]:
tokens_for_compression_mask = replaced_tokens_torch == text_token_id
compressed_tokens_mask = compressed_tokens_torch == text_token_id
# original_tokens_torch[tokens_for_compression_mask].shape
original_embeds[tokens_for_compression_mask].shape, compressed_tokens_torch[
    compressed_tokens_mask
].shape

(torch.Size([48, 896]), torch.Size([16]))

In [44]:
embeds_for_compression = original_embeds[tokens_for_compression_mask].reshape(
    -1,
    window_size,
    original_embeds.shape[-1],
)
embeds_for_compression.shape

torch.Size([16, 3, 896])

In [45]:
pooled_embeds = embed_pooler_v2(embeds_for_compression)
# pooled_embeds = pooled_embeds.reshape(pooled_embeds.shape[0], -1)
pooled_embeds.shape, pooled_embeds.dtype

(torch.Size([16, 1, 896]), torch.bfloat16)

In [38]:
import torch

self = torch.tensor(
    [
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
    ]
)
mask = torch.tensor(
    [
        [0, 0, 0, 0, 1],
        [1, 1, 0, 1, 1],
    ],
    dtype=torch.bool,
)
source = torch.tensor(
    [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
)
self.masked_scatter_(mask, source)

tensor([[0, 0, 0, 0, 1],
        [2, 3, 0, 4, 5]])

In [52]:
compressed_embeds_template.shape

torch.Size([2, 98, 896])

In [74]:
compressed_tokens_mask.shape

torch.Size([2, 98])

In [75]:
compressed_embeds_template.shape

torch.Size([2, 98, 896])

In [76]:
compressed_tokens_mask.shape
# .expand_as(compressed_embeds_template).shape

torch.Size([2, 98])

In [None]:
compressed_tokens_mask.unsqueeze(-1).expand_as(compressed_embeds_template).shape

torch.Size([2, 98, 896])

In [83]:
compressed_embeds_template.dtype, pooled_embeds.dtype

(torch.bfloat16, torch.float32)

In [48]:
compressed_embeds_template.masked_scatter(
    compressed_tokens_mask.unsqueeze(-1).expand_as(compressed_embeds_template),
    pooled_embeds,
).shape

torch.Size([2, 98, 896])

In [50]:
compressed_embeds_template = compressed_embeds_template.masked_scatter_(
    compressed_tokens_mask.unsqueeze(-1).expand_as(compressed_embeds_template),
    pooled_embeds,
)

In [None]:
compressed_embeds_template[compressed_tokens_mask][0]

In [None]:
compressed_embeds_template[compressed_tokens_mask][0]

In [None]:
labels = compressed_tokens_torch.clone()
text_token_id = tokenizer.encode("<|object_ref_start|>")[0]
labels[labels == text_token_id] = -100
# labels

In [56]:
model_output_1 = model(
    inputs_embeds=compressed_embeds_template,
    labels=labels,
    output_hidden_states=True,
)
model_output_1.loss

tensor(5.9123, device='cuda:0')

### R1 Chat template

In [57]:
from transformers import Qwen2ForCausalLM, Qwen2Model, AutoTokenizer
import torch

torch.set_grad_enabled(False)
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model = Qwen2ForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)
model = model.eval()

config.json:   0%|          | 0.00/679 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.55G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

In [64]:
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "how many wings has a bird?"
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
print(text)
model_inputs = tokenizer([text], return_tensors="pt").to(device)

generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=1024)
generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


<｜begin▁of▁sentence｜>You are a helpful assistant.<｜User｜>how many wings has a bird?<｜Assistant｜><think>



"Okay, so I need to figure out how many wings a bird has. Hmm, I remember that birds are animals that can fly, but I'm not exactly sure about their wings. Let me think about what I know. I know that birds have wings, but do they have multiple wings? I think some birds have wings that are like flaps or maybe even wings with parts, like flaps and maybe a tail.\n\nWait, no, I think most birds have just one pair of wings. That makes sense because having more wings would complicate their flight. So, if I'm not mistaken, every bird has two wings: one upper and one lower. These wings are used for flapping to flap the wings up and down, which allows the bird to fly. \n\nLet me try to recall any specific examples. For instance, a bird like a sparrow has a pair of wings that are shaped like a V, one pointing up and one pointing down. Similarly, a crow has a similar structure with a V-shaped wing. I think other birds like eagles have wings that are more streamlined, but they still only have two w

In [65]:
print(text)

<｜begin▁of▁sentence｜>You are a helpful assistant.<｜User｜>how many wings has a bird?<｜Assistant｜><think>



In [63]:
print(tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0])

Alright, so I'm trying to figure out how many wings a bird has. I remember that birds are birds, so maybe they have wings. But how many exactly? I know some birds have wings, like the eagle, but I'm not sure about others. Let me think about different types of birds.

First, the eagle. I think they have two wings, but I'm not entirely certain. Maybe I should confirm that. Then there's the sparrow. I believe they have two wings too. What about a crow? I'm not sure if they have wings or just airways. Maybe they don't have wings at all.

There's also the penguin, but wait, penguins are birds too, right? But they don't have wings, do they? They have beaks. So that's a different category. Then there's the ostrich, I think they have two wings. How about the eagle? Yeah, two wings. So far, all the birds I can think of have two wings.

Wait, what about the peacock? I think they have two wings, but I'm not sure. Also, maybe some other birds like the osprey? I'm not sure, but I think they have tw

In [71]:
tokenizer

LlamaTokenizerFast(name_or_path='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B', vocab_size=151643, model_max_length=16384, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<｜begin▁of▁sentence｜>', 'eos_token': '<｜end▁of▁sentence｜>', 'pad_token': '<｜end▁of▁sentence｜>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	151643: AddedToken("<｜end▁of▁sentence｜>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	151644: AddedToken("<｜User｜>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),
	151645: AddedToken("<｜Assistant｜>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),
	151646: AddedToken("<｜begin▁of▁sentence｜>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	151647: AddedToken("<|EOT|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),
	151648: AddedToken("<think>", rstrip=False

In [None]:
tokenizer.encode("<|fim_pad|>", add_special_tokens=False)[0]

151662

### Generate Dataset

In [83]:
from huggingface_hub import InferenceClient

client = InferenceClient(f"http://{open('ip').read()}:1338")

user_prompt = "9.11 and 9.9 -- which is bigger? Let's think step by step."
output = client.chat.completions.create(
    model="tgi",
    messages=[
        {
            "role": "user",
            "content": user_prompt,
        },
    ],
    stream=False,
    max_tokens=10000,
    temperature=0.0,
)
result = output.choices[0].message.content
print("total_len", len(tokenizer.encode(result)))
print(result)

total_len 437
First, I need to compare the two numbers: 9.11 and 9.9.

To make the comparison easier, I'll align their decimal places by writing 9.9 as 9.90.

Now, I'll compare each corresponding digit from left to right.

Both numbers have 9 in the units place, so they are equal there.

Next, I'll look at the tenths place. In 9.11, the tenths digit is 1, while in 9.90, it's 9.

Since 9 is greater than 1, 9.90 is larger than 9.11.

Therefore, 9.9 is bigger than 9.11.
</think>

**Solution:**

To determine which number is larger between **9.11** and **9.9**, follow these steps:

1. **Align the Decimal Places:**
   
   To make the comparison easier, write both numbers with the same number of decimal places:
   
   \[
   9.11 \quad \text{and} \quad 9.90
   \]

2. **Compare the Numbers Digit by Digit:**
   
   - **Units Place:**
     
     Both numbers have **9** in the units place.
     
     \[
     9 \quad \text{vs} \quad 9
     \]
     
     They are equal in this place.

   - **Tenths 

In [None]:
import concurrent

from huggingface_hub import InferenceClient

client = InferenceClient(f"http://{open('ip').read()}:1338")


def gen_text(text):
    chat_completion = client.chat.completions.create(
        model="tgi",
        messages=[{"role": "user", "content": text}],
        temperature=0.0,
        max_tokens=5_000,
    )
    return chat_completion.choices[0].message.content


batch_size = 4
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]


def batch_generation(prompts):
    with concurrent.futures.ThreadPoolExecutor(max_workers=len(prompts)) as executor:
        prompts_results = list(executor.map(gen_text, prompts))
    return prompts_results


prompts_results = batch_generation(prompts)

for prompt, prompts_result in zip(prompts, prompts_results):
    print(f"Prompt: {prompt!r}, Generated text: {prompts_result!r}")

Prompt: 'Hello, my name is', Generated text: 'Okay, so I\'m trying to figure out what the user named is. They just said "Hello, my name is" and then stopped. Hmm, that\'s a bit confusing. Maybe they forgot to finish the sentence. I should probably ask them to complete it so I can help better. I don\'t want to assume anything about their identity or name. It\'s better to be safe than sorry. I\'ll let them know I\'m here to help once they provide their full name.\n</think>\n\nIt seems like your name is incomplete. Could you please provide your full name so I can assist you better?'
Prompt: 'The president of the United States is', Generated text: "Okay, so I need to figure out what the president of the United States is. Hmm, I'm not exactly sure, but I think the president is the head of the executive branch of the United States. I remember hearing that they're usually named after a person, like George W. Bush or Barack Obama. Maybe I should look up the current president to get the most ac

In [85]:
from datasets import load_dataset

dataset = load_dataset("Open-Orca/OpenOrca")
dataset = dataset["train"]
dataset = dataset.train_test_split(test_size=10_00, seed=42)
dataset = dataset["test"]
dataset

Dataset({
    features: ['id', 'system_prompt', 'question', 'response'],
    num_rows: 1000
})

In [None]:
from more_itertools import chunked
from tqdm.notebook import tqdm

batch_size = 32 * 2 * 2
questions = list(
    chunked(
        dataset["question"],
        batch_size,
    )
)
# 5 min 26 sec - 1000
correct_qa_pairs = []
for question_chunk in tqdm(questions):
    answers = batch_generation(question_chunk)
    for question, answer in zip(question_chunk, answers):
        if answer.count("</think>") == 1:
            correct_qa_pairs.append(
                [
                    question,
                    answer,
                ]
            )

  0%|          | 0/8 [00:00<?, ?it/s]

In [95]:
len(correct_qa_pairs), len(dataset["question"])

(905, 1000)

In [96]:
new_dataset = []
for question, answer in correct_qa_pairs:
    new_dataset.append(
        {
            "question": question,
            "answer": answer,
        }
    )
new_dataset[0]

{'question': 'NEW: Peterson to media on handcuffs, chains: "I got the bling. Can\'t complain" Drew Peterson arrested in the death of his third wife, Kathleen Savio. Renewed interest in Savio\'s death came after Peterson\'s fourth wife disappeared. Peterson, through his attorney, denies any wrongdoing in either case.\n\nWrite an article based on these highlights.',
 'answer': "Okay, so I need to write an article based on the highlights provided about Drew Peterson and his fourth wife, Kathleen Savio. The user has given me some specific points to include, so I should make sure to cover all of them.\n\nFirst, I should start by introducing Drew Peterson and his fourth wife, Kathleen Savio. I need to mention that he's been arrested in the deaths of his first and fourth wives. That's a significant event because it's the second time he's been involved in such a death, which could be a big deal.\n\nI should also note that the renewed interest in Savio's death came after his fourth wife disappe

In [98]:
from datasets import Dataset

new_dataset = Dataset.from_list(new_dataset)

In [None]:
new_dataset.push_to_hub("dim/open_orca_905_DeepSeek-R1-Distill-Qwen-1.5B")

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/dim/open_orca_905_DeepSeek-R1-Distill-Qwen-1.5B/commit/82cf73531b7b8daaee327e4813aa774abbc0fcc4', commit_message='Upload dataset', commit_description='', oid='82cf73531b7b8daaee327e4813aa774abbc0fcc4', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/dim/open_orca_905_DeepSeek-R1-Distill-Qwen-1.5B', endpoint='https://huggingface.co', repo_type='dataset', repo_id='dim/open_orca_905_DeepSeek-R1-Distill-Qwen-1.5B'), pr_revision=None, pr_num=None)

### Compress thinking

In [None]:
from datasets import load_dataset

dataset = load_dataset("dim/open_orca_905_DeepSeek-R1-Distill-Qwen-1.5B")
dataset = dataset["train"]
dataset

Dataset({
    features: ['question', 'answer'],
    num_rows: 905
})

In [None]:
think_end_id = tokenizer.encode("</think>")[0]
think_end_id

151646

In [None]:
messages = [
    # {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": dataset[0]["question"]},
    # {"role": "assistant", "content": result},
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
# print(text)
text

'<｜begin▁of▁sentence｜><｜User｜>NEW: Peterson to media on handcuffs, chains: "I got the bling. Can\'t complain" Drew Peterson arrested in the death of his third wife, Kathleen Savio. Renewed interest in Savio\'s death came after Peterson\'s fourth wife disappeared. Peterson, through his attorney, denies any wrongdoing in either case.\n\nWrite an article based on these highlights.<｜Assistant｜><think>\n'

In [132]:
tokenizer.decode(tokenizer.encode(text, add_special_tokens=False))

'<｜begin▁of▁sentence｜><｜User｜>NEW: Peterson to media on handcuffs, chains: "I got the bling. Can\'t complain" Drew Peterson arrested in the death of his third wife, Kathleen Savio. Renewed interest in Savio\'s death came after Peterson\'s fourth wife disappeared. Peterson, through his attorney, denies any wrongdoing in either case.\n\nWrite an article based on these highlights.<｜Assistant｜><think>\n'

In [None]:
tokenizer.encode("<｜begin▁of▁sentence｜><｜User｜>", add_special_tokens=False)

[151646, 151644]

In [None]:
# https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L150
messages = [
    {"role": "user", "content": dataset[0]["question"]},
    # {"role": "assistant", "content": dataset[0]["answer"]},
]
part_1 = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    # return_dict=True,
)
part_2 = tokenizer.encode(
    dataset[0]["answer"],
    add_special_tokens=False,
)
part_3 = tokenizer.encode(
    "<｜end▁of▁sentence｜>",
    add_special_tokens=False,
)
print(tokenizer.decode(part_1 + part_2 + part_3))
labels = len(part_1) * [-100] + part_2 + [-100]

<｜begin▁of▁sentence｜><｜User｜>NEW: Peterson to media on handcuffs, chains: "I got the bling. Can't complain" Drew Peterson arrested in the death of his third wife, Kathleen Savio. Renewed interest in Savio's death came after Peterson's fourth wife disappeared. Peterson, through his attorney, denies any wrongdoing in either case.

Write an article based on these highlights.<｜Assistant｜><think>
Okay, so I need to write an article based on the highlights provided about Drew Peterson and his fourth wife, Kathleen Savio. The user has given me some specific points to include, so I should make sure to cover all of them.

First, I should start by introducing Drew Peterson and his fourth wife, Kathleen Savio. I need to mention that he's been arrested in the deaths of his first and fourth wives. That's a significant event because it's the second time he's been involved in such a death, which could be a big deal.

I should also note that the renewed interest in Savio's death came after his fourth 

In [146]:
messages = [
    {"role": "user", "content": dataset[0]["question"]},
    {"role": "assistant", "content": dataset[0]["answer"]},
]

part_3 = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    # add_generation_prompt=False,
    # add_generation_prompt=True,
    # continue_final_message=True,
    # return_dict=True,
    # return_assistant_tokens_mask=True
)
part_3
print(tokenizer.decode(part_3))

<｜begin▁of▁sentence｜><｜User｜>NEW: Peterson to media on handcuffs, chains: "I got the bling. Can't complain" Drew Peterson arrested in the death of his third wife, Kathleen Savio. Renewed interest in Savio's death came after Peterson's fourth wife disappeared. Peterson, through his attorney, denies any wrongdoing in either case.

Write an article based on these highlights.<｜Assistant｜>

**Drew Peterson's Fourth Wife: A Legal Journey and Renewed Interest**

Drew Peterson, a controversial figure known for his controversial past, has recently come to light as the subject of legal proceedings involving his fourth wife, Kathleen Savio. Peterson, who has been arrested in the deaths of his first and fourth wives, has faced charges related to these tragic incidents. His attorney has denied any wrongdoing in either case, presenting a positive narrative for the legal community.

**Introduction to Drew Peterson and His Fourth Wife**

Drew Peterson, a controversial figure, has been the subject of l

### manual chat templating (single turn)

In [None]:
tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "question"},
        {"role": "assistant", "content": "answer"},
    ],
    tokenize=False,
)

'<｜begin▁of▁sentence｜><｜User｜>question<｜Assistant｜>answer<｜end▁of▁sentence｜>'

In [None]:
tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "question"},
    ],
    tokenize=False,
    add_generation_prompt=True,
)

'<｜begin▁of▁sentence｜><｜User｜>question<｜Assistant｜><think>\n'

In [160]:
content_compression_mask = []
part_1 = """<｜begin▁of▁sentence｜><｜User｜>"""
content_compression_mask += len(
    tokenizer.encode(
        part_1,
        add_special_tokens=False,
    )
) * [0]
print(content_compression_mask)
part_2 = 'NEW: Peterson to media on handcuffs, chains: "I got the bling. Can'
content_compression_mask += len(
    tokenizer.encode(
        part_2,
        add_special_tokens=False,
    )
) * [1]
print(content_compression_mask)
part_3 = "<｜Assistant｜><think>\n"
content_compression_mask += len(
    tokenizer.encode(
        part_3,
        add_special_tokens=False,
    )
) * [0]
print(content_compression_mask)
part_4 = "Peterson attorney that</think>Peterson is facing"
content_compression_mask += len(
    tokenizer.encode(
        part_4,
        add_special_tokens=False,
    )
) * [1]
print(content_compression_mask)
content_compression_mask += [0]
print(content_compression_mask)

[0, 0]
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
