In [1]:
import os

os.environ["WANDB_PROJECT"] = "hidden_capacity_reasoning"
from transformers import Qwen2ForCausalLM, Qwen2Model, AutoTokenizer, BitsAndBytesConfig
import torch
from trl import (
    ModelConfig,
    ScriptArguments,
    SFTConfig,
    SFTTrainer,
    TrlParser,
    get_kbit_device_map,
)


class Qwen2ModelEmbedPoolerV1(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config)
        self.model.embed_tokens = None
        self.lm_head = None
        self.post_init()

    def forward(self, input_embeds):
        # print(input_embeds.dtype)
        input_embeds = self.model(
            inputs_embeds=input_embeds,
            output_hidden_states=True,
        )[0]
        # print(input_embeds.dtype)
        input_embeds = input_embeds.sum(1) / torch.tensor(
            input_embeds.shape[1],
            device=input_embeds.device,
            dtype=input_embeds.dtype,
        )
        # print(input_embeds.dtype)
        input_embeds = input_embeds.unsqueeze(1)
        return input_embeds


class Qwen2ForCausalLMCompressionV1(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = torch.nn.Linear(
            config.hidden_size, config.vocab_size, bias=False
        )
        # print(config._name_or_path)
        self.embed_pooler = Qwen2ModelEmbedPoolerV1.from_pretrained(
            config._name_or_path,
            config=config,
            attn_implementation="flash_attention_2",
        )

        self.post_init()
        # Initialize weights and apply final processing

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        past_key_values=None,
        inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        cache_position=None,
        logits_to_keep=0,
        **kwargs
    ):
        if "replaced_original_tokens" in kwargs:
            pass
        return super().forward(
            input_ids,
            attention_mask,
            position_ids,
            past_key_values,
            inputs_embeds,
            labels,
            use_cache,
            output_attentions,
            output_hidden_states,
            return_dict,
            cache_position,
            logits_to_keep,
            **kwargs
        )


# torch.set_grad_enabled(False)
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
# model_name = "./test_model/"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model = Qwen2ForCausalLMCompressionV1.from_pretrained(
    model_name,
    # torch_dtype=torch.float16,
    torch_dtype=torch.bfloat16,
    # quantization_config=bnb_config,
    device_map=get_kbit_device_map(),
    attn_implementation="flash_attention_2",
)

# model = model.eval().cuda()
# model.model = model.embed_pooler.model
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.model.requires_grad_(False)

prompt = "how many wings has a bird?"
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
print(text)
model_inputs = tokenizer([text], return_tensors="pt").to(device)

with torch.no_grad():
    generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=1)
    # generated_ids = model.generate(
    #     model_inputs.input_ids,
    #     max_new_tokens=1000,
    #     do_sample=False,
    # )
generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response

Some weights of Qwen2ForCausalLMCompressionV1 were not initialized from the model checkpoint at deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B and are newly initialized: ['embed_pooler.model.layers.0.input_layernorm.weight', 'embed_pooler.model.layers.0.mlp.down_proj.weight', 'embed_pooler.model.layers.0.mlp.gate_proj.weight', 'embed_pooler.model.layers.0.mlp.up_proj.weight', 'embed_pooler.model.layers.0.post_attention_layernorm.weight', 'embed_pooler.model.layers.0.self_attn.k_proj.bias', 'embed_pooler.model.layers.0.self_attn.k_proj.weight', 'embed_pooler.model.layers.0.self_attn.o_proj.weight', 'embed_pooler.model.layers.0.self_attn.q_proj.bias', 'embed_pooler.model.layers.0.self_attn.q_proj.weight', 'embed_pooler.model.layers.0.self_attn.v_proj.bias', 'embed_pooler.model.layers.0.self_attn.v_proj.weight', 'embed_pooler.model.layers.1.input_layernorm.weight', 'embed_pooler.model.layers.1.mlp.down_proj.weight', 'embed_pooler.model.layers.1.mlp.gate_proj.weight', 'embed_pooler.model.layers

<｜begin▁of▁sentence｜>You are a helpful assistant.<｜User｜>how many wings has a bird?<｜Assistant｜><think>



'Okay'

In [2]:
model.embed_pooler

Qwen2ModelEmbedPoolerV1(
  (model): Qwen2Model(
    (embed_tokens): None
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((1536,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmb

In [None]:
import gc

temp_model = Qwen2ModelEmbedPoolerV1.from_pretrained(
    model_name,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)
print(model.embed_pooler.load_state_dict(temp_model.state_dict()))
temp_model = temp_model.cpu()
del temp_model
gc.collect()
torch.cuda.empty_cache()

In [None]:
# если у модели есть ембединги они должны быть равны(по умолчанию я выставил None для экономии)
# несмотря на то что нам пишут что веса были newly initialized
# model.model.embed_tokens.weight == model.embed_pooler.model.embed_tokens.weight

In [3]:
from datasets import load_dataset
from tqdm import tqdm
from hidden_capacity_reasoning.utils import (
    generate_train_examples,
    pad_train_examples,
    tokenize_single_turn,
)

dataset = load_dataset("dim/open_orca_905_DeepSeek-R1-Distill-Qwen-1.5B")
dataset = dataset["train"]
dataset = dataset.train_test_split(test_size=5, seed=42)
dataset

# test
tokenize_single_turn(
    question=dataset["train"][0]["question"],
    answer=dataset["train"][0]["answer"],
    tokenizer=tokenizer,
)
train_examples = [
    tokenize_single_turn(tokenizer=tokenizer, **item)
    for item in tqdm(dataset["train"].to_list())
]


prepared_train_examples = []
for item in tqdm(train_examples):
    for example in generate_train_examples(
        dataset_batch=[item], tokenizer=tokenizer, window_size=4
    ):
        prepared_train_examples.append(example)

print(
    "max_len", max([len(item["original_tokens"]) for item in prepared_train_examples])
)
from datasets import Dataset

new_dataset = Dataset.from_list(prepared_train_examples)
new_dataset

100%|██████████| 900/900 [00:01<00:00, 631.96it/s]
100%|██████████| 900/900 [00:04<00:00, 186.62it/s]


Dataset({
    features: ['replaced_original_tokens', 'compressed_input_ids', 'original_tokens'],
    num_rows: 124217
})

In [3]:
# from more_itertools import chunked

# batch_size = 4
# train_examples_batches = [
#     pad_train_examples(
#         train_examples=item,
#         tokenizer=tokenizer,
#     )
#     for item in tqdm(
#         list(
#             chunked(
#                 prepared_train_examples,
#                 batch_size,
#             )
#         )
#     )
# ]

In [4]:
# train_examples_batches_clean = []
# for item in train_examples_batches:
#     train_examples_batches_clean.append(
#         {
#             "replaced_original_tokens": item["replaced_original_tokens"]["input_ids"],
#             "compressed_input_ids": item["compressed_input_ids"]["input_ids"],
#             "original_tokens": item["original_tokens"]["input_ids"],
#             "attention_mask": item["compressed_input_ids"]["attention_mask"],
#             "labels": item["compressed_input_ids"]["input_ids"],
#         }
#     )
# print(len(train_examples_batches_clean))
# # train_examples_batches_clean[0]

### Test one Pass (for debug only)

In [4]:
from hidden_capacity_reasoning.utils import EOS_TOKEN_ID, TEXT_TOKEN_ID, WINDOW_SIZE

kwargs = {}
for k in prepared_train_examples[0].keys():
    kwargs[k] = torch.tensor(
        prepared_train_examples[0][k],
        device="cuda",
    ).unsqueeze(0)

# kwargs

original_tokens_torch = kwargs["original_tokens"].to(model.device)
replaced_tokens_torch = kwargs["replaced_original_tokens"].to(model.device)
compressed_tokens_torch = kwargs["compressed_input_ids"].to(model.device)

original_embeds = model.get_input_embeddings()(original_tokens_torch)
# replaced_embeds = self.model.get_input_embeddings()(replaced_tokens_torch)
compressed_embeds_template = model.get_input_embeddings()(compressed_tokens_torch)

tokens_for_compression_mask = replaced_tokens_torch == TEXT_TOKEN_ID
compressed_tokens_mask = compressed_tokens_torch == TEXT_TOKEN_ID
embeds_for_compression = original_embeds[tokens_for_compression_mask].reshape(
    -1,
    WINDOW_SIZE,
    original_embeds.shape[-1],
)
pooled_embeds = model.embed_pooler(embeds_for_compression)
pooled_embeds = pooled_embeds.to(compressed_embeds_template.dtype)
compressed_embeds_template = compressed_embeds_template.masked_scatter(
    compressed_tokens_mask.unsqueeze(-1).expand_as(compressed_embeds_template),
    pooled_embeds,
)
pooled_embeds

tensor([[[-0.1230,  1.2969, -0.4004,  ..., -1.5781,  1.8594,  1.7734]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsqueezeBackward0>)

In [5]:
# model.get_input_embeddings().requires_grad_(True)

In [36]:
import time
from datetime import datetime

formatted_date = datetime.fromtimestamp(time.time()).strftime("%Y_%m_%d_%H_%M_%S_%f")
formatted_date

'2025_03_13_14_23_53_931398'

### Use TRL

In [None]:
# https://github.com/huggingface/trl/blob/main/examples/scripts/sft_video_llm.py
from trl import SFTTrainer, SFTConfig
from transformers import TrainingArguments, DataCollatorForSeq2Seq

from unsloth import is_bfloat16_supported
from peft import (
    LoraConfig,
    PeftConfig,
    PeftModel,
    get_peft_model,
    prepare_model_for_kbit_training,
)
from hidden_capacity_reasoning.utils import EOS_TOKEN_ID, TEXT_TOKEN_ID, WINDOW_SIZE


def find_all_linear_names_v2(model):
    lora_module_names = set()
    target_modules = set(
        [
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
        ]
    )
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if "embed_pooler" in name:
                names = name.split(".")[-1]
                if names in target_modules:
                    lora_module_names.add(name)
    return lora_module_names


def collate_fn(batch):
    padded_batch = pad_train_examples(
        train_examples=batch,
        tokenizer=tokenizer,
    )
    padded_batch = {
        "replaced_original_tokens": padded_batch["replaced_original_tokens"][
            "input_ids"
        ],
        "compressed_input_ids": padded_batch["compressed_input_ids"]["input_ids"],
        "original_tokens": padded_batch["original_tokens"]["input_ids"],
        "attention_mask": padded_batch["compressed_input_ids"]["attention_mask"],
        "labels": padded_batch["compressed_input_ids"]["input_ids"],
    }
    for key in padded_batch.keys():
        padded_batch[key] = torch.tensor(padded_batch[key])
    skip_ids = [TEXT_TOKEN_ID, EOS_TOKEN_ID]
    for skip_id in skip_ids:
        padded_batch["labels"][padded_batch["labels"] == skip_id] = -100
    # print(padded_batch)
    return padded_batch


import types


def forward(
    self,
    input_ids=None,
    attention_mask=None,
    position_ids=None,
    past_key_values=None,
    inputs_embeds=None,
    labels=None,
    use_cache=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
    cache_position=None,
    logits_to_keep=0,
    **kwargs,
):
    if "replaced_original_tokens" in kwargs:
        original_tokens_torch = kwargs["original_tokens"].to(self.model.device)
        replaced_tokens_torch = kwargs["replaced_original_tokens"].to(self.model.device)
        compressed_tokens_torch = kwargs["compressed_input_ids"].to(self.model.device)

        original_embeds = self.model.get_input_embeddings()(original_tokens_torch)
        # replaced_embeds = self.model.get_input_embeddings()(replaced_tokens_torch)
        compressed_embeds_template = self.model.get_input_embeddings()(
            compressed_tokens_torch
        )

        tokens_for_compression_mask = replaced_tokens_torch == TEXT_TOKEN_ID
        compressed_tokens_mask = compressed_tokens_torch == TEXT_TOKEN_ID
        embeds_for_compression = original_embeds[tokens_for_compression_mask].reshape(
            -1,
            WINDOW_SIZE,
            original_embeds.shape[-1],
        )
        pooled_embeds = self.embed_pooler(embeds_for_compression)
        pooled_embeds = pooled_embeds.to(compressed_embeds_template.dtype)
        compressed_embeds_template = compressed_embeds_template.masked_scatter(
            compressed_tokens_mask.unsqueeze(-1).expand_as(compressed_embeds_template),
            pooled_embeds,
        )
        inputs_embeds = compressed_embeds_template

    return super(type(self), self).forward(
        input_ids,
        attention_mask,
        position_ids,
        past_key_values,
        inputs_embeds,
        labels,
        use_cache,
        output_attentions,
        output_hidden_states,
        return_dict,
        cache_position,
        logits_to_keep,
        **kwargs,
    )


# model.forward = forward
model.forward = types.MethodType(forward, model)


peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.0,
    target_modules=find_all_linear_names_v2(model=model),
)
import time
from datetime import datetime

formatted_date = datetime.fromtimestamp(time.time()).strftime("%Y_%m_%d_%H_%M_%S_%f")

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=new_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=2,
        warmup_steps=5,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps=10000,
        learning_rate=2e-4,
        # fp16=not is_bfloat16_supported(),
        # bf16=is_bfloat16_supported(),
        bf16=model.dtype == torch.bfloat16,
        fp16=model.dtype == torch.float16,
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir=f"outputs/{formatted_date}",
        # report_to="none",  # Use this for WandB etc
        report_to="wandb",  # Use this for WandB etc
        remove_unused_columns=False,
        dataset_kwargs={"skip_prepare_dataset": True},
        gradient_checkpointing=True,
    ),
)
trainer.train()

In [7]:
model.dtype

torch.float16

In [11]:
for name, param in trainer.model.named_parameters():
    if param.requires_grad:
        print(f"Layer: {name}, Requires Gradient: {param.requires_grad}")

Layer: embed_pooler.model.layers.0.self_attn.q_proj.lora_A.default.weight, Requires Gradient: True
Layer: embed_pooler.model.layers.0.self_attn.q_proj.lora_B.default.weight, Requires Gradient: True
Layer: embed_pooler.model.layers.0.self_attn.k_proj.lora_A.default.weight, Requires Gradient: True
Layer: embed_pooler.model.layers.0.self_attn.k_proj.lora_B.default.weight, Requires Gradient: True
Layer: embed_pooler.model.layers.0.self_attn.v_proj.lora_A.default.weight, Requires Gradient: True
Layer: embed_pooler.model.layers.0.self_attn.v_proj.lora_B.default.weight, Requires Gradient: True
Layer: embed_pooler.model.layers.1.self_attn.q_proj.lora_A.default.weight, Requires Gradient: True
Layer: embed_pooler.model.layers.1.self_attn.q_proj.lora_B.default.weight, Requires Gradient: True
Layer: embed_pooler.model.layers.1.self_attn.k_proj.lora_A.default.weight, Requires Gradient: True
Layer: embed_pooler.model.layers.1.self_attn.k_proj.lora_B.default.weight, Requires Gradient: True
Layer: emb

### Inference

In [1]:
from hidden_capacity_reasoning.models import Qwen2ForCausalLMCompressionV1
from transformers import AutoTokenizer
import torch

model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model = Qwen2ForCausalLMCompressionV1.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2",
)
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.model.requires_grad_(False)
None

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


Some weights of Qwen2ForCausalLMCompressionV1 were not initialized from the model checkpoint at deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B and are newly initialized: ['embed_pooler.model.layers.0.input_layernorm.weight', 'embed_pooler.model.layers.0.mlp.down_proj.weight', 'embed_pooler.model.layers.0.mlp.gate_proj.weight', 'embed_pooler.model.layers.0.mlp.up_proj.weight', 'embed_pooler.model.layers.0.post_attention_layernorm.weight', 'embed_pooler.model.layers.0.self_attn.k_proj.bias', 'embed_pooler.model.layers.0.self_attn.k_proj.weight', 'embed_pooler.model.layers.0.self_attn.o_proj.weight', 'embed_pooler.model.layers.0.self_attn.q_proj.bias', 'embed_pooler.model.layers.0.self_attn.q_proj.weight', 'embed_pooler.model.layers.0.self_attn.v_proj.bias', 'embed_pooler.model.layers.0.self_attn.v_proj.weight', 'embed_pooler.model.layers.1.input_layernorm.weight', 'embed_pooler.model.layers.1.mlp.down_proj.weight', 'embed_pooler.model.layers.1.mlp.gate_proj.weight', 'embed_pooler.model.layers

In [2]:
from peft import PeftModel

lora_name = "outputs/2025_03_13_17_26_24_256272/checkpoint-95800"
model = PeftModel.from_pretrained(model, lora_name)

In [None]:
model

In [4]:
prompt = "how many wings has a bird?"
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
print(text)
model_inputs = tokenizer([text], return_tensors="pt").to(device)

with torch.no_grad():
    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=1,
        do_sample=False,
    )
    # generated_ids = model.generate(
    #     model_inputs.input_ids,
    #     max_new_tokens=1000,
    #     do_sample=False,
    # )
generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


<｜begin▁of▁sentence｜>You are a helpful assistant.<｜User｜>how many wings has a bird?<｜Assistant｜><think>



'Okay'

### default embed generation

In [5]:
generated_tokens = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "2+2*10"},
    ],
    tokenize=True,
    add_generation_prompt=True,
)
generated_tokens = torch.tensor(generated_tokens).unsqueeze(0).cuda()
generated_embeds = model.get_input_embeddings()(generated_tokens)
max_steps = 400
for _ in range(max_steps):
    logits = model(
        inputs_embeds=generated_embeds,
    ).logits
    top_token = logits.argmax(-1)[-1][-1]
    top_token_embed = model.get_input_embeddings()(top_token)
    # print(top)
    generated_tokens = torch.cat([generated_tokens, top_token.reshape(1, 1)], dim=1)
    generated_embeds = torch.cat(
        [generated_embeds, top_token_embed.reshape(1, 1, -1)], dim=1
    )
    # break
# print(tokenizer.decode(generated_tokens[-1]))
# break
embeds_generation_tokens = generated_tokens[-1]

#### generation with compression

In [6]:
model.base_model.embed_pooler  # .embed_pooler_v3

Qwen2ModelEmbedPoolerV1(
  (model): Qwen2Model(
    (embed_tokens): None
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): lora.Linear(
            (base_layer): Linear(in_features=1536, out_features=1536, bias=True)
            (lora_dropout): ModuleDict(
              (default): Identity()
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=1536, out_features=16, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=16, out_features=1536, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): lora.Linear(
            (base_layer): Linear(in_features=1536, out_features=256, bias=True)
            (lora_dropout): ModuleDict(
              (default): Identity()
            )
       

In [7]:
hasattr(model.base_model, "embed_pooler")

True

In [15]:
from datasets import load_dataset

dataset = load_dataset("dim/open_orca_4475_DeepSeek-R1-Distill-Qwen-1.5B")
dataset = dataset["train"]
dataset = dataset.train_test_split(test_size=1000, seed=42)

In [None]:
dataset["train"][0]

{'question': 'Question: Title: Rajah Medium Curry Powder Product review: The fulfillment was efficient and prompt. Having curry powder in a tin is far better than having a packet in a box. The 2 Pack is great as the sealed content remains fresh. This company has a great range of products and I look forward to supporting them in the future! Would you say this review depicts the product in a flattering or unflattering light?\nAnswer:',
 'answer': "Okay, so I need to figure out whether the review is flattering or unflattering about the Rajah Medium Curry Powder product. Let me start by reading the review again carefully.\n\nThe user says the product's fulfillment was efficient and prompt. That's a positive point. Then they mention having curry powder in a tin is better than having it in a packet in a box. Hmm, that's interesting. I wonder why they think a tin is better. Maybe because it's more convenient or easier to use? I'm not sure if that's a positive or negative thing.\n\nNext, they 

In [17]:
from hidden_capacity_reasoning.utils import WINDOW_SIZE

generated_tokens = tokenizer.apply_chat_template(
    [
        # {"role": "user", "content": "2+2*10"},
        # {"role": "user", "content": "how many wings has a bird?"},
        {"role": "user", "content": dataset["train"][0]["question"]},
    ],
    tokenize=True,
    add_generation_prompt=True,
)

generated_tokens = torch.tensor(generated_tokens).unsqueeze(0).cuda()
generated_embeds = model.get_input_embeddings()(generated_tokens)
max_steps = 50
temp_gen_size = 0
compression_started = False
window_size = WINDOW_SIZE 

for _ in range(max_steps):
    if temp_gen_size == window_size:
        new_embeds_for_compression = generated_embeds[:, -window_size:]
        if hasattr(model.base_model, "embed_pooler"):
            compressed_part = model.base_model.embed_pooler(new_embeds_for_compression)
        else:
            compressed_part = model.embed_pooler(new_embeds_for_compression)

        generated_embeds = torch.cat([generated_embeds, compressed_part], dim=1)
        temp_gen_size = 1

    logits = model(
        inputs_embeds=generated_embeds,
    ).logits
    top_token = logits.argmax(-1)[-1][-1]
    top_token_embed = model.get_input_embeddings()(top_token)
    # print(top)
    generated_tokens = torch.cat([generated_tokens, top_token.reshape(1, 1)], dim=1)

    generated_embeds = torch.cat(
        [generated_embeds, top_token_embed.reshape(1, 1, -1)], dim=1
    )
    print(temp_gen_size, tokenizer.decode(generated_tokens[-1]))
    temp_gen_size += 1

# print(tokenizer.decode(generated_tokens[-1]))

# break

0 <｜begin▁of▁sentence｜><｜User｜>Question: Title: Rajah Medium Curry Powder Product review: The fulfillment was efficient and prompt. Having curry powder in a tin is far better than having a packet in a box. The 2 Pack is great as the sealed content remains fresh. This company has a great range of products and I look forward to supporting them in the future! Would you say this review depicts the product in a flattering or unflattering light?
Answer:<｜Assistant｜><think>
Okay
1 <｜begin▁of▁sentence｜><｜User｜>Question: Title: Rajah Medium Curry Powder Product review: The fulfillment was efficient and prompt. Having curry powder in a tin is far better than having a packet in a box. The 2 Pack is great as the sealed content remains fresh. This company has a great range of products and I look forward to supporting them in the future! Would you say this review depicts the product in a flattering or unflattering light?
Answer:<｜Assistant｜><think>
Okay,
2 <｜begin▁of▁sentence｜><｜User｜>Question: Titl