In [1]:
import os

os.environ["WANDB_PROJECT"] = "hidden_capacity_reasoning"
from transformers import Qwen2ForCausalLM, Qwen2Model, AutoTokenizer, BitsAndBytesConfig
import torch
from trl import (
    ModelConfig,
    ScriptArguments,
    SFTConfig,
    SFTTrainer,
    TrlParser,
    get_kbit_device_map,
)

from datasets import load_dataset
from tqdm import tqdm
from hidden_capacity_reasoning.utils import (
    generate_train_examples,
    pad_train_examples,
    tokenize_single_turn,
)
from datasets import Dataset
import gc
import types

# need for auto SFTTrainer patch(possible increase speed)
from unsloth import is_bfloat16_supported
from peft import (
    LoraConfig,
    PeftConfig,
    PeftModel,
    get_peft_model,
    prepare_model_for_kbit_training,
)
from hidden_capacity_reasoning.utils import (
    EOS_TOKEN_ID,
    TEXT_TOKEN_ID,
    WINDOW_SIZE,
    VISION_START,
    VISION_END,
    find_all_linear_names_v3,
)

import time
from datetime import datetime


from hidden_capacity_reasoning.models import (
    Qwen2ForCausalLMCompressionV1,
    Qwen2ModelEmbedPoolerV1,
    Qwen2ForCausalLMCompressionV2,
    Qwen2ModelEmbedPoolerV2,
)

# model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model_name = "my_r1_model"
model = Qwen2ForCausalLMCompressionV2.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
    attn_implementation="flash_attention_2",
)
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.model.requires_grad_(False)

# temp_model = Qwen2ModelEmbedPoolerV2.from_pretrained(
#     model_name,
#     attn_implementation="flash_attention_2",
#     torch_dtype=torch.bfloat16,
#     device_map={"": 0},
#     # quantization_config=BitsAndBytesConfig(load_in_4bit=True),
# )
# print(model.embed_pooler.load_state_dict(temp_model.state_dict()))
# temp_model = temp_model.cpu()
# del temp_model
# gc.collect()
# torch.cuda.empty_cache()

dataset = load_dataset("dim/open_orca_905_DeepSeek-R1-Distill-Qwen-1.5B")
dataset = dataset["train"]
dataset = dataset.train_test_split(test_size=10, seed=42)

# test pass
tokenize_single_turn(
    question=dataset["train"][0]["question"],
    answer=dataset["train"][0]["answer"],
    tokenizer=tokenizer,
)
train_examples = [
    tokenize_single_turn(tokenizer=tokenizer, **item)
    for item in tqdm(dataset["train"].to_list()[:3])
]

prepared_train_examples = []
for item in tqdm(train_examples):
    for example in generate_train_examples(
        dataset_batch=[item],
        window_size=WINDOW_SIZE,
    ):
        prepared_train_examples.append(example)

print(
    "max_len",
    max([len(item["original_tokens"]) for item in prepared_train_examples]),
)

new_dataset = Dataset.from_list(prepared_train_examples)
print(dataset)


def collate_fn(batch):
    padded_batch = pad_train_examples(
        train_examples=batch,
        tokenizer=tokenizer,
    )
    padded_batch = {
        "replaced_original_tokens": padded_batch["replaced_original_tokens"][
            "input_ids"
        ],
        "compressed_input_ids": padded_batch["compressed_input_ids"]["input_ids"],
        "original_tokens": padded_batch["original_tokens"]["input_ids"],
        "attention_mask": padded_batch["compressed_input_ids"]["attention_mask"],
        "labels": padded_batch["compressed_input_ids"]["input_ids"],
        "content_compression_mask": padded_batch["content_compression_mask"][
            "input_ids"
        ],
    }
    for key in padded_batch.keys():
        padded_batch[key] = torch.tensor(padded_batch[key])
    skip_ids = [
        TEXT_TOKEN_ID,
        EOS_TOKEN_ID,
        VISION_START,
        VISION_END,
    ]
    for skip_id in skip_ids:
        padded_batch["labels"][padded_batch["labels"] == skip_id] = -100

    # часть инпута от пользователя
    padded_batch["labels"][
        padded_batch["content_compression_mask"][:, padded_batch["labels"].shape[-1]]
        == 1
    ] = -100
    # print(padded_batch)
    return padded_batch


peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.0,
    bias="none",
    target_modules=find_all_linear_names_v3(model=model),
    modules_to_save=["embed_pooler.model.embed_tokens"],
)

formatted_date = datetime.fromtimestamp(time.time()).strftime("%Y_%m_%d_%H_%M_%S_%f")
model.embed_pooler = prepare_model_for_kbit_training(model.embed_pooler)
peft_model = get_peft_model(model, peft_config)
peft_model.print_trainable_parameters()

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=new_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    args=SFTConfig(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        warmup_steps=5,
        num_train_epochs=1,  # 90,  # Set this for 1 full training run.
        # num_train_epochs=90,  # Set this for 1 full training run.
        # max_steps=10000,
        learning_rate=1e-4,
        bf16=model.dtype == torch.bfloat16,
        # fp16=model.dtype == torch.float16,
        logging_steps=8,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir=f"outputs/{formatted_date}",
        report_to="wandb",
        # report_to="none",
        remove_unused_columns=False,
        dataset_kwargs={"skip_prepare_dataset": True},
        # gradient_checkpointing=True,
        save_steps=10000,
        run_name=formatted_date,
    ),
)
trainer.train()

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 3/3 [00:00<00:00, 871.33it/s]
100%|██████████| 3/3 [00:00<00:00,  6.19it/s]


max_len 580
DatasetDict({
    train: Dataset({
        features: ['question', 'answer'],
        num_rows: 895
    })
    test: Dataset({
        features: ['question', 'answer'],
        num_rows: 10
    })
})
trainable params: 251,838,464 || all params: 3,572,640,768 || trainable%: 7.0491


  trainer = SFTTrainer(
ERROR:tornado.general:SEND Error: Host unreachable
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdimweb[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss
8,2.315
16,2.2429
24,1.6932
32,1.4449
40,1.6712
48,1.6086
56,1.5241
64,1.3814
72,1.583
80,1.5979


IndexError: index 576 is out of bounds for dimension 1 with size 575

In [None]:
model.save_pretrained("my_r1_model")

In [3]:
tokenizer.save_pretrained("my_r1_model")

('my_r1_model/tokenizer_config.json',
 'my_r1_model/special_tokens_map.json',
 'my_r1_model/tokenizer.json')

In [None]:
trainer.model

In [None]:
trainer.model.base_model.embed_pooler

In [2]:
for name, param in trainer.model.named_parameters():
    if param.requires_grad:
        print(f"Layer: {name}, Requires Gradient: {param.requires_grad}")

Layer: base_model.model.embed_pooler.model.embed_tokens.modules_to_save.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.q_proj.lora_A.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.q_proj.lora_B.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.k_proj.lora_A.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.k_proj.lora_B.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.v_proj.lora_A.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.v_proj.lora_B.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.o_proj.lora_A.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.o_proj.lora_B.

In [3]:
model = trainer.model

In [None]:
dataset

In [1]:
from peft import PeftModel
from hidden_capacity_reasoning.models import (
    Qwen2ForCausalLMCompressionV1,
    Qwen2ModelEmbedPoolerV1,
    Qwen2ForCausalLMCompressionV2,
    Qwen2ModelEmbedPoolerV2,
)
import torch

# model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model_name = "my_r1_model/"
model = Qwen2ForCausalLMCompressionV2.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
    attn_implementation="flash_attention_2",
)
model = PeftModel.from_pretrained(
    model,
    # "outputs/2025_03_17_00_02_13_701993/checkpoint-2970",
    # "outputs/2025_03_17_01_39_35_074194/checkpoint-210000",
    "outputs/2025_03_19_13_29_15_704211/checkpoint-140000",
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
from datasets import load_dataset

dataset = load_dataset("dim/open_orca_4475_DeepSeek-R1-Distill-Qwen-1.5B")
dataset = dataset["train"]
dataset = dataset.train_test_split(test_size=500, seed=42)

In [None]:
model = trainer.model

In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)
device = "cuda"
# prompt = "how many wings has a bird?"
prompt = dataset["test"].to_list()[:5][0]["question"]
messages = [
    # {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
print(text)
model_inputs = tokenizer([text], return_tensors="pt").to(device)

with torch.no_grad():
    # generated_ids = model.generate(
    #     model_inputs.input_ids,
    #     max_new_tokens=1,
    #     do_sample=False,
    # )
    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=1000,
        do_sample=False,
    )
generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


<｜begin▁of▁sentence｜><｜User｜>Here's a question: What do many people believe happens after you die?  Here are possible answers to this question: - stop moving - nothing - go to heaven - stop living - stop breathing  I believe the correct choice is "go to heaven", here's why:
Answer:<｜Assistant｜><think>



'Okay, so I\'m trying to figure out why I think the correct answer to the question "What do many people believe happens after you die?" is "go to heaven." The possible answers given are stop moving, nothing, go to heaven, stop living, and stop breathing. \n\nFirst, I remember hearing a lot about the afterlife in religious contexts, especially in Christianity. I think the idea is that after death, people go to heaven. But I\'m not entirely sure about all the details. Let me break this down.\n\nI know that in Christianity, the afterlife is a concept where people after death are sent to heaven. This is part of the Christian belief in a afterlife, which is different from the Buddhist or Hindu afterlife. In Buddhist and Hindu traditions, people go to the earth again after death, but that\'s not the same as heaven. So, if the question is about the afterlife, then heaven is the right answer.\n\nBut wait, the user mentioned that they believe the correct choice is "go to heaven," so I should fo

### Embed Generation

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)
device = "cuda"

generated_tokens = tokenizer.apply_chat_template(
    [
        # {"role": "user", "content": "how many wings has a bird?"},
        {"role": "user", "content": dataset["test"].to_list()[:5][0]["question"]},
    ],
    tokenize=True,
    add_generation_prompt=True,
)
with torch.no_grad():
    generated_tokens = torch.tensor(generated_tokens).unsqueeze(0).cuda()
    generated_embeds = model.get_input_embeddings()(generated_tokens)
    max_steps = 400
    for step in range(max_steps):
        logits = model(
            inputs_embeds=generated_embeds,
            # use_cache=False,
        ).logits
        top_token = logits.argmax(-1)[-1][-1]
        top_token_embed = model.get_input_embeddings()(top_token)
        # print(top)
        generated_tokens = torch.cat([generated_tokens, top_token.reshape(1, 1)], dim=1)
        generated_embeds = torch.cat(
            [generated_embeds, top_token_embed.reshape(1, 1, -1)], dim=1
        )
        print(step, tokenizer.decode(generated_tokens[-1]))
    # break
print(tokenizer.decode(generated_tokens[-1]))
# break
embeds_generation_tokens = generated_tokens[-1]

In [None]:
print(tokenizer.decode(generated_tokens[-1]))

In [None]:
# Okay, so I need to figure out whether the statement "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg." is a reasonable answer to the question: "What date did the American Civil War start?"
# Okay, so I need to figure out whether the statement "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg." is a reasonable answer to the question: "What date did the American Civil War start?"

In [None]:
from hidden_capacity_reasoning.utils import WINDOW_SIZE, VISION_START, VISION_END
from transformers.cache_utils import DynamicCache


def _crop_past_key_values(model, past_key_values, max_length):
    """Crops the past key values up to a certain maximum length."""
    new_past = []
    if model.config.is_encoder_decoder:
        for idx in range(len(past_key_values)):
            new_past.append(
                (
                    past_key_values[idx][0][:, :, :max_length, :],
                    past_key_values[idx][1][:, :, :max_length, :],
                    past_key_values[idx][2],
                    past_key_values[idx][3],
                )
            )
        past_key_values = tuple(new_past)
    # gptbigcode is special and stores kv in shape (batch_size, seq_len, dim), if it's a multi_query model
    elif "gptbigcode" in model.__class__.__name__.lower() or (
        model.config.architectures is not None
        and "gptbigcode" in model.config.architectures[0].lower()
    ):
        if model.config.multi_query:
            for idx in range(len(past_key_values)):
                past_key_values[idx] = past_key_values[idx][:, :max_length, :]
        else:
            for idx in range(len(past_key_values)):
                past_key_values[idx] = past_key_values[idx][:, :, :max_length, :]
    elif isinstance(past_key_values, DynamicCache):
        past_key_values.crop(max_length)
    elif past_key_values is not None:
        for idx in range(len(past_key_values)):
            if past_key_values[idx] != ([], []):
                new_past.append(
                    (
                        past_key_values[idx][0][:, :, :max_length, :],
                        past_key_values[idx][1][:, :, :max_length, :],
                    )
                )
            else:
                new_past.append((past_key_values[idx][0], past_key_values[idx][1]))
        past_key_values = tuple(new_past)
    return past_key_values


# model = trainer.model
generated_tokens = tokenizer.apply_chat_template(
    [
        # {"role": "user", "content": "how many wings has a bird?"},
        {"role": "user", "content": dataset["test"].to_list()[:5][0]["question"]},
    ],
    tokenize=True,
    add_generation_prompt=True,
)

with torch.no_grad(), torch.autocast(device_type="cuda"):
    start_embed = model.base_model.embed_pooler.model.get_input_embeddings()(
        torch.tensor([[VISION_START]], device="cuda")
    )
    end_embed = model.base_model.embed_pooler.model.get_input_embeddings()(
        torch.tensor([[VISION_END]], device="cuda")
    )
    generated_tokens = torch.tensor(generated_tokens).unsqueeze(0).cuda()
    generated_embeds = model.get_input_embeddings()(generated_tokens)
    temp_gen_size = 0
    window_size = WINDOW_SIZE  # + 1
    # new_tokens = 4
    new_tokens = 1
    generation_started = False
    max_steps = (new_tokens + window_size) * 15
    past_key_values_big = None
    print("generated_embeds", generated_embeds.shape)
    for step in range(max_steps):
        if temp_gen_size == window_size + new_tokens:
            # print(
            #     "TOKENS FOR EMDED",
            #     tokenizer.decode(
            #         generated_tokens[:, -(window_size + new_tokens) :][:, :WINDOW_SIZE]
            #         .cpu()
            #         .tolist()[0]
            #     ),
            # )
            # tokenizer.decode(generated_tokens[:, : -window_size ].cpu().tolist()[0])
            if hasattr(model.base_model, "embed_pooler"):
                new_embeds_for_compression = (
                    model.base_model.embed_pooler.model.get_input_embeddings()(
                        generated_tokens[:, -(window_size + new_tokens) :][
                            :, :WINDOW_SIZE
                        ]
                    )
                ).to(torch.bfloat16)
                compressed_part = model.base_model.embed_pooler(
                    new_embeds_for_compression
                )
            else:
                compressed_part = model.embed_pooler(new_embeds_for_compression)
            # gen_embeds_prev = generated_tokens.shape[1]
            if generation_started:
                # past_key_values_big = _crop_past_key_values(
                #     model=model,
                #     past_key_values=past_key_values_big,
                #     max_length=generated_embeds.shape[1] - new_tokens - 2,
                # )
                generated_embeds = torch.cat(
                    [
                        generated_embeds[:, : -(window_size + new_tokens + 1)],
                        # generated_embeds[:, : -(window_size + new_tokens)],
                        compressed_part,
                        # torch.randn(1, 1, 1536, device="cuda"),
                        end_embed,
                        generated_embeds[:, -new_tokens:],
                    ],
                    dim=1,
                )
            else:
                # past_key_values_big = _crop_past_key_values(
                #     model=model,
                #     past_key_values=past_key_values_big,
                #     max_length=generated_embeds.shape[1] - new_tokens - 3,
                # )
                generated_embeds = torch.cat(
                    [
                        generated_embeds[:, : -(window_size + new_tokens)],
                        start_embed,
                        # torch.randn(1, 1, 1536, device="cuda"),
                        compressed_part,
                        end_embed,
                        generated_embeds[:, -new_tokens:],
                    ],
                    dim=1,
                )
                generation_started = True
            past_key_values_big = _crop_past_key_values(
                model=model,
                past_key_values=past_key_values_big,
                max_length=generated_embeds.shape[1] - new_tokens - 2,
            )
            temp_gen_size = 1

        outputs = model(
            inputs_embeds=generated_embeds,
            past_key_values=past_key_values_big,
            # use_cache=False,
        )
        logits = outputs.logits
        past_key_values_big = outputs.past_key_values
        top_token = logits.argmax(-1)[-1][-1]
        top_token_embed = model.get_input_embeddings()(top_token)
        # print(top)
        generated_tokens = torch.cat([generated_tokens, top_token.reshape(1, 1)], dim=1)

        generated_embeds = torch.cat(
            [generated_embeds, top_token_embed.reshape(1, 1, -1)], dim=1
        )
        # print(temp_gen_size, tokenizer.decode(generated_tokens[-1]))

        temp_gen_size += 1

print(tokenizer.decode(generated_tokens[-1]))

# break

generated_embeds torch.Size([1, 62, 1536])
<｜begin▁of▁sentence｜><｜User｜>Here's a question: What do many people believe happens after you die?  Here are possible answers to this question: - stop moving - nothing - go to heaven - stop living - stop breathing  I believe the correct choice is "go to heaven", here's why:
Answer:<｜Assistant｜><think>
Okay, so I'm trying to figure out the answer is correct. Let me think about it again. I think the user is trying to figure out the correct answer to the question they're asking. They provided a list of answers, and I need to heaven, but I'm not sure if I'm sure if I'm on the right track. Let me break it down step by step. The question is about what people believe happens after you die. I know that when someone dies, they believe that people often believe in something called the afterlife, which is the belief that after you die, you don't need to move or anything else. So, the correct answer is "go to heaven", which is the correct answer. The othe

# я не знаю, может быть я неправильно управляюсь с KV-cache

In [18]:
generated_tokens.shape

torch.Size([1, 227])

In [35]:
new_embeds_for_compression.shape

torch.Size([1, 10, 1536])

In [34]:
[1, 2, 3, 4, 5, 6, 7, 8][:-2]

[1, 2, 3, 4, 5, 6]

In [None]:
print(dataset["test"][0]["answer"])

In [29]:
generated_embeds.shape

torch.Size([1, 100, 1536])

In [28]:
generated_tokens.shape

torch.Size([1, 227])

In [20]:
print(
    """
Okay, so I'm trying to figure out why I think the correct answer to the question "What do many people believe happens after you die?" is "go to heaven." Let me break this down step by step.

First, I know that when people talk about what happens after death, they often refer to the afterlife. The afterlife is a concept that's been discussed in various religious traditions, including Christianity, Islam, and Buddhism. Each of these religions has its own beliefs about what happens after death.

In Christianity, the afterlife is often depicted as a place where people are separated from their afterlife. It's usually associated with the concept of heaven and hell. I remember hearing that in the Bible, there's a story about a man who was given a soul after death, but he was separated from his soul. This is sometimes referred to as the "man and woman of the last supper." So, in this context, people might believe that after death, they are separated from their soul, which is why they might think they go to heaven.

In Islam, the concept of the afterlife is also significant. The Quran, which is the holy book of Islam, mentions that after death, people are separated from their bodies and are sent to heaven. This is part of the concept of the "finality of the afterlife," where people are sent to heaven after suffering for all their lives. So, in Islam, the belief is that after death, people are sent to heaven.

In Buddhism, the concept of the afterlife is also well-documented. The Buddhist tradition teaches that after death, people are separated from their bodies and are sent to heaven. This is part of the "end of the world" concept, where the end of the world is considered to be the end of this life, and heaven is the place where people are sent after suffering.

So, putting this all together, the common belief among many people is that after death, people are sent to heaven. This is because each of these religions has a consistent belief in the afterlife and the separation of the soul from the body, leading to the concept of heaven as the place where this separation occurs.

I also think about the other options provided: stop moving, nothing, go to heaven, stop living, and stop breathing. The first option, stop moving, might refer to the idea that after death, people stop moving, but that's more of a common belief among some people rather than a specific religious belief. The second option, nothing, doesn't make much sense in this context. The third option, go to heaven, is the most consistent with the religious beliefs I've considered. The fourth option, stop living, might refer to the idea that after death, people stop living, but that's not as strong a belief as going to heaven. The fifth option, stop breathing, is more about the physical state after death, which isn't as central to the religious concepts.

I also recall that in some religious contexts, heaven is associated with the afterlife, so that reinforces the idea that going to heaven is the correct belief. Additionally, the term "heaven" is a common symbol in many religious texts and is often associated with the afterlife, making it a strong candidate for the correct answer.

In summary, considering the common beliefs in Christianity, Islam, and Buddhism, the belief that many people hold after death is that they go to heaven. This is because each of these traditions has a consistent and strong belief in the afterlife and the separation of the soul from the body, leading to the concept of heaven as the place where this separation occurs.
</think>

The correct answer is that many people believe after death they go to heaven. This belief is rooted in various religious traditions, including Christianity, Islam, and Buddhism, each of which has a consistent and strong belief in the afterlife and the separation of the soul from the body, leading to the concept of heaven as the place where this separation occurs.
    
"""
)
print(tokenizer.decode(generated_tokens[-1]))


Okay, so I'm trying to figure out why I think the correct answer to the question "What do many people believe happens after you die?" is "go to heaven." Let me break this down step by step.

First, I know that when people talk about what happens after death, they often refer to the afterlife. The afterlife is a concept that's been discussed in various religious traditions, including Christianity, Islam, and Buddhism. Each of these religions has its own beliefs about what happens after death.

In Christianity, the afterlife is often depicted as a place where people are separated from their afterlife. It's usually associated with the concept of heaven and hell. I remember hearing that in the Bible, there's a story about a man who was given a soul after death, but he was separated from his soul. This is sometimes referred to as the "man and woman of the last supper." So, in this context, people might believe that after death, they are separated from their soul, which is why they might th

In [None]:
oririnal = """
Okay, so I'm trying to figure out why I think the correct answer to the question "What do many people believe happens after you die?" is "go to heaven." Let me break this down step by step.

First, I know that when people talk about what happens after death, they often refer to the afterlife. The afterlife is a concept that's been discussed in various religious traditions, including Christianity, Islam, and Buddhism. Each of these religions has its own beliefs about what happens after death.

In Christianity, the afterlife is often depicted as a place where people are separated from their afterlife. It's usually associated with the concept of heaven and hell. I remember hearing that in the Bible, there's a story about a man who was given a soul after death, but he was separated from his soul. This is sometimes referred to as the "man and woman of the last supper." So, in this context, people might believe that after death, they are separated from their soul, which is why they might think they go to heaven.

In Islam, the concept of the afterlife is also significant. The Quran, which is the holy book of Islam, mentions that after death, people are separated from their bodies and are sent to heaven. This is part of the concept of the "finality of the afterlife," where people are sent to heaven after suffering for all their lives. So, in Islam, the belief is that after death, people are sent to heaven.

In Buddhism, the concept of
"""

compress_1_cut_cache = """
Okay, so I'm trying to figure out the answer is correct. Let me think about it again. I think the user is trying to figure out the correct answer to the question they're asking. They provided a list of answers, including "go to heaven, nothing, stop moving, nothing, stop moving - nothing - go to heaven - stop living - stop breathing - go to heaven - stop living - Here are the possible answers to this question: stop moving - nothing - go to heaven - stop breathing -. The user is trying to figure out the correct choice is "go to heaven", here's the correct answer. But I'm not entirely sure if that's the correct answer. I think I should clarify that I'm confident in my answer. 

So, I think the correct answer is that the correct answer is
"""
compress_2_cut_cache = """
Okay, so I'm trying to figure out the answer is correct. Let me think about it again. I think the user is trying to figure out the correct answer to the question they're asking. They provided a list of answers, and they're asking for feedback on why they chose "go to heaven." They provided four options: stop moving, nothing, go to heaven - stop moving - stop breathing. They're asking why they chose "go to heaven".

Okay, I think I understand now. The user is asking about what happens after you die. They're trying to figure out the correct answer. Let me break it down step by step. Maybe they're trying to figure out the correct answer is correct. I think I've got it now, I'll try to figure out why the correct answer is. 

"""

compress_5 = """
Okay, so I'm trying to figure out why I think the correct answer is "go to heaven." The user provided a list of possible answers: stop moving, nothing, go to heaven, stop breathing, and stop living. I believe the correct answer is to go to heaven, but I need to understand why.

First, I know that when you die, it's a pretty big deal. I remember hearing that when you die, you don't just stop moving or anything like that. You don't just go to heaven on your own, you need to stop breathing. I think you need to stop moving, but that's not the main reason people think they go to heaven. I think it's more about the process of dying and how it affects your body and mind.

I remember hearing that when you die, your body goes through a process called death. I think that process involves moving through different stages. One of those stages is going to heaven. I think that's when your mind, body, and spirit come together to go to

"""
compress_1 = """
Okay, so I'm trying to figure out why I should explain why I think the correct answer is "go to heaven" is the correct answer. Let me try to break this down step by step.

First, I think about what happens when you die. When you die, it's a bit of course, but I'm not entirely sure how that's the correct belief. I know that dying can be a bit confusing because I've heard people say. I mean, I know that dying can have to think about why I'd think that. Maybe it's because I've heard people say that dying is a sign of good luck, but I'm not sure why that's the case.

I remember hearing that dying can bring good luck, but I think it's because of the brain chemistry. When your brain is damaged

"""


random = """
Okay, so I'm trying to figure out why I think the correct answer is "go to heaven." Let me try to break this down. So, the question is asking what people believe happens after you die. The options are stop moving, nothing, go to heaven, stop breathing, etc. So, the question is, what do many people believe happens after you die? The options are stop moving, nothing, go to heaven, stop living, stop breathing, etc. But I think the correct answer is "go to heaven". But I'm not entirely sure why I think that's correct. Let me think through this.

First, I know that when you die, you don't really "move" anymore, but I'm not sure. Maybe it's about not being able to move, which is a physical state. But I'm not sure if "move" is the right word. Maybe it's about the body's function. I think when you die, you don't really move anymore, right? You're just there,
"""

### Предсказание с изначально сжатыми токенами

In [3]:
from transformers import AutoTokenizer
from hidden_capacity_reasoning.utils import WINDOW_SIZE, VISION_START, VISION_END
import torch

torch.manual_seed(0)
# torch.manual_seed(42)

tokenizer = AutoTokenizer.from_pretrained(model_name)
device = "cuda"
dataset_number = 3
example = dataset["test"].to_list()[:5][dataset_number]

generated_tokens = tokenizer.apply_chat_template(
    [
        # {"role": "user", "content": "how many wings has a bird?"},
        {"role": "user", "content": example["question"]},
    ],
    tokenize=True,
    add_generation_prompt=True,
)

with torch.no_grad():
    start_embed = model.base_model.embed_pooler.model.get_input_embeddings()(
        torch.tensor([[VISION_START]], device="cuda")
    )
    end_embed = model.base_model.embed_pooler.model.get_input_embeddings()(
        torch.tensor([[VISION_END]], device="cuda")
    )
    generated_tokens = torch.tensor(generated_tokens).unsqueeze(0).cuda()
    generated_embeds = model.get_input_embeddings()(generated_tokens)
    windows_amount = 50
    next_true_tokens = tokenizer.encode(
        example["answer"],
        add_special_tokens=False,
    )[: WINDOW_SIZE * windows_amount]
    next_true_tokens = torch.tensor(next_true_tokens, device="cuda").unsqueeze(0)

    new_embeds_for_compression = (
        model.base_model.embed_pooler.model.get_input_embeddings()(next_true_tokens)
    ).to(torch.bfloat16)

    compressed_part = model.base_model.embed_pooler(new_embeds_for_compression)
    new_embeds_for_compression = new_embeds_for_compression.reshape(
        windows_amount, WINDOW_SIZE, -1
    )
    compressed_part = model.base_model.embed_pooler(new_embeds_for_compression)
    compressed_part = compressed_part.reshape(1, windows_amount, -1)
    # compressed_part = torch.rand_like(compressed_part)
    # start_embed = torch.rand_like(start_embed)
    # end_embed = torch.rand_like(end_embed)
    generated_embeds = torch.cat(
        [
            generated_embeds,
            start_embed,
            compressed_part,
            end_embed,
        ],
        dim=1,
    )
    generated_tokens = torch.cat(
        [
            generated_tokens,
            next_true_tokens,
        ],
        dim=1,
    )
    print("COMPRESSED PART", tokenizer.decode(next_true_tokens[-1]))
    max_steps = 400
    for step in range(max_steps):
        logits = model(
            inputs_embeds=generated_embeds,
            # use_cache=False,
        ).logits
        top_token = logits.argmax(-1)[-1][-1]
        top_token_embed = model.get_input_embeddings()(top_token)
        # print(top)
        generated_tokens = torch.cat([generated_tokens, top_token.reshape(1, 1)], dim=1)
        generated_embeds = torch.cat(
            [generated_embeds, top_token_embed.reshape(1, 1, -1)], dim=1
        )
        print(step, tokenizer.decode(generated_tokens[-1]))
    # break
print(tokenizer.decode(generated_tokens[-1]))
# break
embeds_generation_tokens = generated_tokens[-1]

COMPRESSED PART Okay, so I need to figure out if the answer provided is correct to the question "What had he just finished doing when he saw the tractor." The answer given is "Feeding the chickens and ducks." 

Let me break this down. The story starts with Joe going out to the field to feed the horses and cows. That's the first action he did. Then, after finishing, he saw the tractor. So, the tractor was in the field, and he noticed it. 

Joe
0 <｜begin▁of▁sentence｜><｜User｜>Joe's parents are farmers and they have a huge farm with cows, chickens, and ducks. Joe loves the farm and all the things he gets to play around and play on. One day, Joe's father told him not to get near a tractor that was sitting in the field. His father was worried that Joe would climb on it and hurt himself. Joe went out to the field and was feeding the horses and cows. When he was done, he saw the tractor his father told him not to get near. He knew that climbing on the tractor wouldn't hurt anything, so he did.

In [4]:
print(tokenizer.decode(generated_tokens[-1]))

<｜begin▁of▁sentence｜><｜User｜>Joe's parents are farmers and they have a huge farm with cows, chickens, and ducks. Joe loves the farm and all the things he gets to play around and play on. One day, Joe's father told him not to get near a tractor that was sitting in the field. His father was worried that Joe would climb on it and hurt himself. Joe went out to the field and was feeding the horses and cows. When he was done, he saw the tractor his father told him not to get near. He knew that climbing on the tractor wouldn't hurt anything, so he did. He climbed on to the seat and sat there. Then, he pretended he was his father and pretended that he was driving the tractor. Joe's father saw him playing on the tractor and called for him. Joe heard his father calling for him and got off the tractor really fast. When he did that, he fell off and hurt his arm. Joe was in pain and his father came running to check on him and picked him up and sat him on a bench and asked him why he did that. Joe l

In [6]:
generated_embeds.shape

torch.Size([1, 514, 1536])

In [5]:
generated_tokens.shape

torch.Size([1, 562])

In [6]:
print(example["answer"])

Okay, so I need to figure out if the answer provided is correct to the question "What had he just finished doing when he saw the tractor." The answer given is "Feeding the chickens and ducks." 

Let me break this down. The story starts with Joe going out to the field to feed the horses and cows. That's the first action he did. Then, after finishing, he saw the tractor. So, the tractor was in the field, and he noticed it. 

Joe went to feed the animals, saw the tractor, and then decided to climb on it. He did that pretending to be driving it, which probably made him feel good because he didn't hurt himself. Then, he got off the tractor and fell, causing him pain. His father came to pick him up and rode him on the tractor. 

So, the sequence is: feed the animals, see the tractor, climb on it, ride it, and then ride with him. The question is about what he did when he saw the tractor. The tractor was in the field, so he was feeding the animals, saw the tractor, and then decided to ride it.

In [None]:
oririnal = """
Okay, so I'm trying to figure out why I think the correct answer to the question "What do many people believe happens after you die?" is "go to heaven." Let me break this down step by step.

First, I know that when people talk about what happens after death, they often refer to the afterlife. The afterlife is a concept that's been discussed in various religious traditions, including Christianity, Islam, and Buddhism. Each of these religions has its own beliefs about what happens after death.

In Christianity, the afterlife is often depicted as a place where people are separated from their afterlife. It's usually associated with the concept of heaven and hell. I remember hearing that in the Bible, there's a story about a man who was given a soul after death, but he was separated from his soul. This is sometimes referred to as the "man and woman of the last supper." So, in this context, people might believe that after death, they are separated from their soul, which is why they might think they go to heaven.

In Islam, the concept of the afterlife is also significant. The Quran, which is the holy book of Islam, mentions that after death, people are separated from their bodies and are sent to heaven. This is part of the concept of the "finality of the afterlife," where people are sent to heaven after suffering for all their lives. So, in Islam, the belief is that after death, people are sent to heaven.

In Buddhism, the concept of the afterlife is also well-documented. The Buddhist tradition teaches that after death, people are separated from their bodies and are sent to heaven. This is part of the "end of the world" concept, where the end of the world is considered to be the end of this life, and heaven is the place where people are sent after suffering.

So, putting this all together, the common belief among many people is that after death, people are sent to heaven. This is because each of these religions has a consistent belief in the afterlife and the separation of the soul from the body, leading to the concept of heaven as the place where this separation occurs.

I also think about the other options provided: stop moving, nothing, go to heaven, stop living, and stop breathing. The first option, stop moving, might refer to the idea that after death, people stop moving, but that's more of a common belief among some people rather than a specific religious belief. The second option, nothing, doesn't make much sense in this context. The third option, go to heaven, is the most consistent with the religious beliefs I've considered. The fourth option, stop living, might refer to the idea that after death, people stop living, but that's not as strong a belief as going to heaven. The fifth option, stop breathing, is more about the physical state after death, which isn't as central to the religious concepts.

I also recall that in some religious contexts, heaven is associated with the afterlife, so that reinforces the idea that going to heaven is the correct belief. Additionally, the term "heaven" is a common symbol in many religious texts and is often associated with the afterlife, making it a strong candidate for the correct answer.

In summary, considering the common beliefs in Christianity, Islam, and Buddhism, the belief that many people hold after death is that they go to heaven. This is because each of these traditions has a consistent and strong belief in the afterlife and the separation of the soul from the body, leading to the concept of heaven as the place where this separation occurs.
</think>

The correct answer is that many people believe after death they go to heaven. This belief is rooted in various religious traditions, including Christianity, Islam, and Buddhism, each of which has a consistent and strong belief in the afterlife and the separation of the soul from the body, leading to the concept of heaven as the place where this separation occurs.
"""

generated_1 = """
Okay, so I'm trying to figure out why the correct answer to the question "What do many people believe happens after you die?" The user provided four possible answers: stop moving, nothing, go to heaven, stop living, and stop breathing. They concluded that the correct choice is "go to heaven" because they believe that's what many people believe happens after death.

Hmm, let me think about this. I know that in many religious beliefs, particularly in Christianity, there are concepts of heaven and hell. For example, in Christianity, after death, people are supposed to go to heaven. This is part of the afterlife concept, where after living their lives, people are sent to heaven. So, stopping breathing and moving might be part of that process, but the main belief is going to heaven.

But wait, the user mentioned that they believe the correct choice is "go to heaven." So, why isn't that the answer? Maybe because the question is asking for the belief, not the action. So, the belief is that after death, you go to heaven. The action of stopping breathing and moving might be part of the process, but the belief itself is going to heaven.

Let me consider other possibilities. In some philosophies, like Buddhism, there's the concept of the afterlife, but I'm not sure if that's what the user is referring to. In Hinduism, there's also the afterlife, but again, that's more about the experience rather than the belief. The user's conclusion was that go to heaven is the correct answer, so I guess the reasoning is that many people believe that after death, they go to heaven.

I should also think about other beliefs. For example, in some cultures, after death, people are supposed to stop breathing and live in peace, which is similar to going to heaven. But in other beliefs, like in some religious communities, people might live in a state of peace but not necessarily go to heaven. So, the user's answer might be based on
"""
generated_2 = """
Okay, so I'm trying to figure out why I think the correct answer to the question "What after you die?" The options given are stop moving, nothing, go to heaven, stop living, and stop breathing. The user believes the correct answer is "go to heaven," but I need to understand why.

First, I'll think about what I know about death and the afterlife. I remember hearing about the afterlife in religious contexts, like in Christianity, where people are supposed to go to heaven after dying. But I'm not entirely sure if that's the only belief or if there are other perspectives.

I also recall that in some cultures, especially in Hinduism and Buddhism, there are different afterlife experiences. In Hinduism, for example, there's the concept of the "Karma World," where one can die and then die again, but there are different outcomes based on actions. However, I'm not sure if that's related to the question's options.

In Buddhism, there's the concept of the "Bodhisattva," who is the mediator between the physical and the afterlife. The user's question doesn't mention anything about the afterlife in a religious context, so maybe the answer is more about the general belief rather than a specific religious belief.

The user's reasoning is that "go to heaven" is the correct answer. I'll try to break down why that might be the case. In many religious perspectives, after death, people are supposed to go to heaven. This is often seen as a symbolic way to end their life and move on to a higher level of existence. It's a common belief among people who have gone through similar experiences.

I also think about the idea of the afterlife in terms of personal choice. Some people might choose to live a life of peace and contentment, while others might choose to die and then go to heaven. This doesn't necessarily mean that everyone goes to heaven; it's more about the individual's choice and the belief system they're part of.

Another angle is the concept of the "after
"""
generated_5 = """
COMPRESSED PART Okay, so I'm trying to figure out why I think the correct answer to the question "What do many people believe happens after you die?" is "go to heaven." Let me break this down step by step.
Okay, so I'm trying to figure out why I think the correct answer to the question "What do many people believe happens after you die?" is "go to heaven." Let me break this down step by step.

First, I know that when, which is the belief that after death, people will return to the afterlife. The afterlife is often associated with the concept of heaven, where people go after they die. This is a common belief among many people, especially those who are religious or have a strong spiritual background.

The options given are: stop moving, nothing, go to heaven, stop living, and stop breathing. Let's go through each one to see why I think "go to heaven" is the correct answer.

1. **Stop moving**: This doesn't make much sense. Moving is a physical act, and after death, you don't move anymore. So stopping movement doesn't make sense in this context.

2. **Nothing**: This is too vague. It doesn't specify what happens after death, which is why it's not a specific answer.

3. **Go to heaven**: This aligns with the religious belief I mentioned earlier. After death, people believe they will go to heaven, where they will find their afterlife. This is a well-established belief among many people.

4. **Stop living**: This is similar to stopping movement. Living is a continuous process, and stopping living doesn't make sense after death.

5. **Stop breathing**: This is also about stopping a physical act, which isn't relevant after death.

So, putting it all together, the correct answer is "go to heaven" because it directly addresses the belief about afterlife and is a well-supported belief among many people.
</think>

The correct answer is:

**Go to heaven**

This is because the belief that after death, people will go to heaven is a well-established religious belief. It aligns with the concept of the afterlife, where individuals are believed to return to heaven after dying. The other options either lack specificity or are not relevant to the belief in afterlife.<｜end▁of▁sentence｜><｜begin▁of▁sentence｜>

Okay, so I need to figure out why the correct answer is "go to heaven." Let me break it down step
"""
# 5 сжатые части это рандом, токены начала и конца идут с обучаемой модели
generated_5_random_end_start_no_random = """
COMPRESSED PART Okay, so I'm trying to figure out why I think the correct answer to the question "What do many people believe happens after you die?" is "go to heaven." Let me break this down step by step.
Okay, so I'm trying to figure out why I think the correct answer to the question "What do many people believe happens after you die?" is "go to heaven." Let me break this down step by step.

First, I know that when question is about what people believe happens after they die. The possible answers are stop moving, nothing, go to heaven, stop living, and stop breathing. The user thinks the correct answer is "go to heaven," but I need to figure out why.

First, I should consider the common beliefs people have about dying. I know that in many cultures, especially in Christianity, dying is seen as a natural part of life. The user is probably thinking of the Last Supper, where Jesus and his disciples die, and then the disciples return to the Father. This is a common belief system.

In that context, dying is often associated with the afterlife, which is a concept rooted in religious beliefs. The afterlife is a place where people are separated from their bodies, and they go to heaven. So, after death, people believe they go to heaven.

I should also think about other beliefs. For example, in some religious traditions, dying is seen as a necessary part of life, and after death, people are transferred to heaven. In other beliefs, like some Eastern philosophies, dying might be associated with the afterlife, but I'm not sure if that's the case here.

Another angle is the idea of the afterlife. In Christian theology, the afterlife is a place where people are separated from their bodies, and they go to heaven. This is a common belief, especially in Christian communities. The user's reasoning seems to align with this.

I should also consider if there are any other plausible answers. "Stop moving" could be a literal interpretation, but it's more about the physical state after death rather than the belief about the afterlife. "Nothing" doesn't make much sense in this context. "Stop living" is similar to stopping moving, which is more about the physical aspect rather than the spiritual afterlife. "Stop breathing" is about the physical state, not the belief about the afterlife.

So, putting it all together, the most plausible
"""
# 5 сжатые части это рандом, токены начала и конца тоже рандом
generated_5_random_end_start_no_random = """
COMPRESSED PART Okay, so I'm trying to figure out why I think the correct answer to the question "What do many people believe happens after you die?" is "go to heaven." Let me break this down step by step.
Okay, so I'm trying to figure out why I think the correct answer to the question "What do many people believe happens after you die?" is "go to heaven." Let me break this down step by step.

First, I know that when, which is the belief that after death, people will return to the afterlife. The afterlife is often associated with the concept of heaven, where people go after they die. This is a common belief among many people, especially those who are religious or have a strong spiritual background.

The options given are: stop moving, nothing, go to heaven, stop living, and stop breathing. Let's go through each one to see why I think "go to heaven" is the correct answer.

1. **Stop moving**: This doesn't make much sense. Moving is a physical act, and after death, you don't move anymore. So stopping movement doesn't make sense in this context.

2. **Nothing**: This is too vague. It doesn't specify what happens after death, which is why it's not a specific answer.

3. **Go to heaven**: This aligns with the religious belief I mentioned earlier. After death, people believe they will go to heaven, where they will find their afterlife. This is a well-established belief among many people.

4. **Stop living**: This is similar to stopping movement. Living is a continuous process, and stopping living doesn't make sense after death.

5. **Stop breathing**: This is also about stopping a physical act, which isn't relevant after death.

So, putting it all together, the correct answer is "go to heaven" because it directly addresses the belief about afterlife and is a well-supported belief among many people.
</think>

The correct answer is:

**Go to heaven**

This is because the belief that after death, people will go to heaven is a well-established religious belief. It aligns with the concept of the afterlife, where individuals are believed to return to heaven after dying. The other options either lack specificity or are not relevant to the belief in afterlife.<｜end▁of▁sentence｜><｜begin▁of▁sentence｜>

Okay, so I need to figure out why the correct answer is "go to heaven." Let me break it down step
"""

In [None]:
oririnal = """
Okay, so I need to figure out if the answer provided is correct to the question "What had he just finished doing when he saw the tractor." The answer given is "Feeding the chickens and ducks." 

Let me break this down. The story starts with Joe going out to the field to feed the horses and cows. That's the first action he did. Then, after finishing, he saw the tractor. So, the tractor was in the field, and he noticed it. 

Joe went to feed the animals, saw the tractor, and then decided to climb on it. He did that pretending to be driving it, which probably made him feel good because he didn't hurt himself. Then, he got off the tractor and fell, causing him pain. His father came to pick him up and rode him on the tractor. 

So, the sequence is: feed the animals, see the tractor, climb on it, ride it, and then ride with him. The question is about what he did when he saw the tractor. The tractor was in the field, so he was feeding the animals, saw the tractor, and then decided to ride it. 

The answer given is feeding the chickens and ducks, which is correct because that's what he was doing before seeing the tractor. The tractor was just a distraction at that point. So, the answer is accurate to the question.
</think>

The answer provided is correct. When Joe saw the tractor, he was feeding the chickens and ducks. The tractor was a distraction at that moment, and the key action was feeding the animals.

**Answer:** Yes, the answer is correct.
"""
# COMPRESSED PART Okay, so I need to figure out if the answer provided is correct to the question "What had he just finished doing when he saw the tractor." The answer given is "Feeding the chickens and ducks."
compressed_50 = """
Okay, so I need to figure out if the answer provided is correct to the question "What had he just finished doing when he saw the tractor." The answer given is "Feeding the chickens and ducks." 

Let me break this down. The story starts with Joe going out to the field to feed the horses and cows. That's the first action he did. Then, after finishing, he saw the tractor. So, the tractor was in the field, and he noticed it. 

Joe is trying to figure out what he did last time he saw the tractor. The answer says he fed the chickens and ducks. That makes sense because that's what he was doing when he was feeding them. 

But wait, the question is asking what he "had just finished doing" when he saw the tractor. So, the last action he took before seeing the tractor was feeding the chickens and ducks. That's the correct answer because that's what he was doing when he was feeding them. 

I don't see any other action mentioned in the story that would lead to seeing the tractor. Joe was feeding the horses and cows, then saw the tractor, and then proceeded to feed them. So, the answer is indeed correct.
</think>

The answer is correct. Joe finished feeding the chickens and ducks when he saw the tractor. 

**Answer:** Feeding the chickens and ducks.
"""

In [17]:
from transformers import AutoTokenizer
from hidden_capacity_reasoning.utils import WINDOW_SIZE, VISION_START, VISION_END
import torch

tokenizer = AutoTokenizer.from_pretrained(model_name)
device = "cuda"

example = dataset["test"].to_list()[:5][0]

generated_tokens = tokenizer.apply_chat_template(
    [
        # {"role": "user", "content": "how many wings has a bird?"},
        {"role": "user", "content": example["question"]},
    ],
    tokenize=True,
    add_generation_prompt=True,
)

start_embed = model.base_model.embed_pooler.model.get_input_embeddings()(
    torch.tensor([[VISION_START]], device="cuda")
)
end_embed = model.base_model.embed_pooler.model.get_input_embeddings()(
    torch.tensor([[VISION_END]], device="cuda")
)
generated_tokens = torch.tensor(generated_tokens).unsqueeze(0).cuda()
generated_embeds = model.get_input_embeddings()(generated_tokens)

windows_amount = 2
next_true_tokens = tokenizer.encode(
    example["answer"],
    add_special_tokens=False,
)[: WINDOW_SIZE * windows_amount]
next_true_tokens = torch.tensor(next_true_tokens, device="cuda").unsqueeze(0)

new_embeds_for_compression = (
    model.base_model.embed_pooler.model.get_input_embeddings()(next_true_tokens)
).to(torch.bfloat16)
new_embeds_for_compression = new_embeds_for_compression.reshape(
    windows_amount, WINDOW_SIZE, -1
)
compressed_part = model.base_model.embed_pooler(new_embeds_for_compression)
compressed_part = compressed_part.reshape(1, windows_amount, -1)
generated_embeds = torch.cat(
    [
        generated_embeds,
        start_embed,
        compressed_part,
        end_embed,
    ],
    dim=1,
)

In [9]:
compressed_part.shape

torch.Size([1, 1, 1536])

In [13]:
compressed_part.shape

torch.Size([2, 1, 1536])

In [16]:
compressed_part.reshape(1, windows_amount, -1).shape

torch.Size([1, 2, 1536])

In [None]:
compressed_part