In [1]:
import os

os.environ["WANDB_PROJECT"] = "hidden_capacity_reasoning"
from transformers import Qwen2ForCausalLM, Qwen2Model, AutoTokenizer, BitsAndBytesConfig
import torch
from trl import (
    ModelConfig,
    ScriptArguments,
    SFTConfig,
    SFTTrainer,
    TrlParser,
    get_kbit_device_map,
)

from datasets import load_dataset
from tqdm import tqdm
from hidden_capacity_reasoning.utils import (
    generate_train_examples,
    pad_train_examples,
    tokenize_single_turn,
)
from datasets import Dataset
import gc
import types

# need for auto SFTTrainer patch(possible increase speed)
from unsloth import is_bfloat16_supported
from peft import (
    LoraConfig,
    PeftConfig,
    PeftModel,
    get_peft_model,
    prepare_model_for_kbit_training,
)
from hidden_capacity_reasoning.utils import (
    EOS_TOKEN_ID,
    TEXT_TOKEN_ID,
    WINDOW_SIZE,
    # find_all_linear_names_v2,
)

import time
from datetime import datetime


from hidden_capacity_reasoning.models import (
    Qwen2ForCausalLMCompressionV1,
    Qwen2ModelEmbedPoolerV1,
    Qwen2ForCausalLMCompressionV2,
    Qwen2ModelEmbedPoolerV2,
)


# def main():
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
# model = Qwen2ForCausalLMCompressionV1.from_pretrained(
model = Qwen2ForCausalLMCompressionV2.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
    attn_implementation="flash_attention_2",
)
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.model.requires_grad_(False)

# temp_model = Qwen2ModelEmbedPoolerV1.from_pretrained(
temp_model = Qwen2ModelEmbedPoolerV2.from_pretrained(
    model_name,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)
print(model.embed_pooler.load_state_dict(temp_model.state_dict()))
temp_model = temp_model.cpu()
del temp_model
gc.collect()
torch.cuda.empty_cache()

dataset = load_dataset("dim/open_orca_905_DeepSeek-R1-Distill-Qwen-1.5B")
dataset = dataset["train"]
dataset = dataset.train_test_split(test_size=10, seed=42)

# test pass
tokenize_single_turn(
    question=dataset["train"][0]["question"],
    answer=dataset["train"][0]["answer"],
    tokenizer=tokenizer,
)
train_examples = [
    tokenize_single_turn(tokenizer=tokenizer, **item)
    for item in tqdm(dataset["train"].to_list()[:3])
]

prepared_train_examples = []
for item in tqdm(train_examples):
    for example in generate_train_examples(
        dataset_batch=[item],
        window_size=WINDOW_SIZE,
    ):
        prepared_train_examples.append(example)

print(
    "max_len",
    max([len(item["original_tokens"]) for item in prepared_train_examples]),
)

new_dataset = Dataset.from_list(prepared_train_examples)
print(new_dataset)


def collate_fn(batch):
    padded_batch = pad_train_examples(
        train_examples=batch,
        tokenizer=tokenizer,
    )
    padded_batch = {
        "replaced_original_tokens": padded_batch["replaced_original_tokens"][
            "input_ids"
        ],
        "compressed_input_ids": padded_batch["compressed_input_ids"]["input_ids"],
        "original_tokens": padded_batch["original_tokens"]["input_ids"],
        "attention_mask": padded_batch["compressed_input_ids"]["attention_mask"],
        "labels": padded_batch["compressed_input_ids"]["input_ids"],
    }
    for key in padded_batch.keys():
        padded_batch[key] = torch.tensor(padded_batch[key])
    skip_ids = [TEXT_TOKEN_ID, EOS_TOKEN_ID]
    for skip_id in skip_ids:
        padded_batch["labels"][padded_batch["labels"] == skip_id] = -100
    # print(padded_batch)
    return padded_batch


def find_all_linear_names_v2(model):
    lora_module_names = set()
    target_modules = set(
        [
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
            "embed_tokens",
        ]
    )
    for name, module in model.named_modules():
        if "embed_pooler" in name:
            names = name.split(".")[-1]
            if names in target_modules:
                lora_module_names.add(name)
        # if isinstance(module, torch.nn.Linear):
    return lora_module_names


peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.0,
    target_modules=find_all_linear_names_v2(model=model),
    modules_to_save=["embed_pooler.model.embed_tokens"],
)

formatted_date = datetime.fromtimestamp(time.time()).strftime("%Y_%m_%d_%H_%M_%S_%f")
peft_model = get_peft_model(model, peft_config)
peft_model.print_trainable_parameters()

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=new_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    args=SFTConfig(
        per_device_train_batch_size=8,
        gradient_accumulation_steps=8,
        warmup_steps=5,
        # num_train_epochs=1,#90,  # Set this for 1 full training run.
        num_train_epochs=90,  # Set this for 1 full training run.
        # max_steps=10000,
        learning_rate=1e-4,
        bf16=model.dtype == torch.bfloat16,
        fp16=model.dtype == torch.float16,
        logging_steps=8,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir=f"outputs/{formatted_date}",
        # report_to="wandb",
        report_to="none",
        remove_unused_columns=False,
        dataset_kwargs={"skip_prepare_dataset": True},
        gradient_checkpointing=True,
        save_steps=10000,
        run_name=formatted_date,
    ),
)
trainer.train()

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


Some weights of Qwen2ForCausalLMCompressionV2 were not initialized from the model checkpoint at deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B and are newly initialized: ['embed_pooler.model.embed_tokens.weight', 'embed_pooler.model.layers.0.input_layernorm.weight', 'embed_pooler.model.layers.0.mlp.down_proj.weight', 'embed_pooler.model.layers.0.mlp.gate_proj.weight', 'embed_pooler.model.layers.0.mlp.up_proj.weight', 'embed_pooler.model.layers.0.post_attention_layernorm.weight', 'embed_pooler.model.layers.0.self_attn.k_proj.bias', 'embed_pooler.model.layers.0.self_attn.k_proj.weight', 'embed_pooler.model.layers.0.self_attn.o_proj.weight', 'embed_pooler.model.layers.0.self_attn.q_proj.bias', 'embed_pooler.model.layers.0.self_attn.q_proj.weight', 'embed_pooler.model.layers.0.self_attn.v_proj.bias', 'embed_pooler.model.layers.0.self_attn.v_proj.weight', 'embed_pooler.model.layers.1.input_layernorm.weight', 'embed_pooler.model.layers.1.mlp.down_proj.weight', 'embed_pooler.model.layers.1.mlp.gat

<All keys matched successfully>


100%|██████████| 3/3 [00:00<00:00, 947.44it/s]
100%|██████████| 3/3 [00:00<00:00, 5245.07it/s]
  trainer = SFTTrainer(


max_len 575
Dataset({
    features: ['replaced_original_tokens', 'compressed_input_ids', 'original_tokens'],
    num_rows: 65
})
trainable params: 251,838,464 || all params: 3,572,640,768 || trainable%: 7.0491


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss
8,1.9188
16,1.6788
24,1.3441
32,1.248
40,1.5295
48,1.8469
56,1.5402
64,2.0789
72,2.006
80,1.5561


TrainOutput(global_step=90, training_loss=1.7037616570790608, metrics={'train_runtime': 223.3652, 'train_samples_per_second': 26.19, 'train_steps_per_second': 0.403, 'total_flos': 0.0, 'train_loss': 1.7037616570790608})

In [None]:
trainer.model

In [5]:
trainer.model.base_model.embed_pooler

Qwen2ModelEmbedPoolerV2(
  (model): Qwen2Model(
    (embed_tokens): ModulesToSaveWrapper(
      (original_module): Embedding(151936, 1536)
      (modules_to_save): ModuleDict(
        (default): Embedding(151936, 1536)
      )
    )
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): lora.Linear(
            (base_layer): Linear(in_features=1536, out_features=1536, bias=True)
            (lora_dropout): ModuleDict(
              (default): Identity()
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=1536, out_features=16, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=16, out_features=1536, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): lora.Linear(
            (base

In [2]:
for name, param in trainer.model.named_parameters():
    if param.requires_grad:
        print(f"Layer: {name}, Requires Gradient: {param.requires_grad}")

Layer: base_model.model.embed_pooler.model.embed_tokens.modules_to_save.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.q_proj.lora_A.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.q_proj.lora_B.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.k_proj.lora_A.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.k_proj.lora_B.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.v_proj.lora_A.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.v_proj.lora_B.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.o_proj.lora_A.default.weight, Requires Gradient: True
Layer: base_model.model.embed_pooler.model.layers.0.self_attn.o_proj.lora_B.

In [3]:
model = trainer.model

In [None]:
dataset

In [6]:
dataset = load_dataset("dim/open_orca_905_DeepSeek-R1-Distill-Qwen-1.5B")
dataset = dataset["train"]
dataset = dataset.train_test_split(test_size=10, seed=42)

In [3]:
model = trainer.model
# prompt = "how many wings has a bird?"
prompt = dataset["train"].to_list()[:5][0]["question"]
messages = [
    # {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
print(text)
model_inputs = tokenizer([text], return_tensors="pt").to(device)

with torch.no_grad():
    # generated_ids = model.generate(
    #     model_inputs.input_ids,
    #     max_new_tokens=1,
    #     do_sample=False,
    # )
    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=1000,
        do_sample=False,
    )
generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


<｜begin▁of▁sentence｜><｜User｜>Choose your answer from: (A). No (B). Yes
Given those answer options, answer the question: Question: what date did the american civil war start? Would "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg ." be a reasonable answer?
A:<｜Assistant｜><think>



'Okay, so I need to figure out whether the statement "In 1863, Robert E. Lee\'s Confederate incursion north ended at the Battle of Gettysburg." is a reasonable answer to the question "What date did the American Civil War start?" \n\nFirst, I should recall when the American Civil War actually began. I remember that the Civil War started in 1861 when states seceded from the Union to form the Confederate states. So, the start date is 1861.\n\nNow, looking at the statement, it mentions the Confederate incursion north in 1863. I know that the Union had a major battle in 1863 called the Battle of Gettysburg, which was a significant turning point. That battle was a major conflict between the Union and the Confederacy, and it\'s often cited as the turning point in the Civil War.\n\nSo, the statement is saying that the Confederate incursion ended at Gettysburg in 1863. That makes sense because Gettysburg was a key battle in the Civil War, and it\'s often associated with the end of the Confedera

### Embed Generation

In [5]:
generated_tokens = tokenizer.apply_chat_template(
    [
        # {"role": "user", "content": "how many wings has a bird?"},
        {"role": "user", "content": dataset["train"].to_list()[:5][0]["question"]},
    ],
    tokenize=True,
    add_generation_prompt=True,
)
with torch.no_grad():
    generated_tokens = torch.tensor(generated_tokens).unsqueeze(0).cuda()
    generated_embeds = model.get_input_embeddings()(generated_tokens)
    max_steps = 400
    for _ in range(max_steps):
        logits = model(
            inputs_embeds=generated_embeds,
        ).logits
        top_token = logits.argmax(-1)[-1][-1]
        top_token_embed = model.get_input_embeddings()(top_token)
        # print(top)
        generated_tokens = torch.cat([generated_tokens, top_token.reshape(1, 1)], dim=1)
        generated_embeds = torch.cat(
            [generated_embeds, top_token_embed.reshape(1, 1, -1)], dim=1
        )
    # break
print(tokenizer.decode(generated_tokens[-1]))
# break
embeds_generation_tokens = generated_tokens[-1]

<｜begin▁of▁sentence｜><｜User｜>Choose your answer from: (A). No (B). Yes
Given those answer options, answer the question: Question: what date did the american civil war start? Would "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg ." be a reasonable answer?
A:<｜Assistant｜><think>
Okay, so I need to figure out whether the statement "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg." is a reasonable answer to the question: "What date did the American Civil War start?" 

First, I should recall when the American Civil War actually started. I remember that the Civil War began in 1861 when states seceded from the Union to form the Confederate states. So the start date is 1861.

Now, looking at the statement, it mentions the year 1863. That's two years after the Civil War started. I think the key here is to check if the statement is accurate regarding the end of the Confederate incursion and the Battle of Gettysburg.



In [6]:
print(tokenizer.decode(generated_tokens[-1]))

<｜begin▁of▁sentence｜><｜User｜>Choose your answer from: (A). No (B). Yes
Given those answer options, answer the question: Question: what date did the american civil war start? Would "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg ." be a reasonable answer?
A:<｜Assistant｜><think>
Okay, so I need to figure out whether the statement "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg." is a reasonable answer to the question: "What date did the American Civil War start?" 

First, I should recall when the American Civil War actually started. I remember that the Civil War began in 1861 when states seceded from the Union to form the Confederate states. So the start date is 1861.

Now, looking at the statement, it mentions the year 1863. That's two years after the Civil War started. I think the key here is to check if the statement is accurate regarding the end of the Confederate incursion and the Battle of Gettysburg.



In [None]:
# Okay, so I need to figure out whether the statement "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg." is a reasonable answer to the question: "What date did the American Civil War start?"
# Okay, so I need to figure out whether the statement "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg." is a reasonable answer to the question: "What date did the American Civil War start?"

In [7]:
from hidden_capacity_reasoning.utils import WINDOW_SIZE

model = trainer.model
generated_tokens = tokenizer.apply_chat_template(
    [
        # {"role": "user", "content": "how many wings has a bird?"},
        {"role": "user", "content": dataset["train"].to_list()[:5][0]["question"]},
    ],
    tokenize=True,
    add_generation_prompt=True,
)

with torch.no_grad():
    generated_tokens = torch.tensor(generated_tokens).unsqueeze(0).cuda()
    generated_embeds = model.get_input_embeddings()(generated_tokens)
    temp_gen_size = 0
    compression_started = False
    window_size = WINDOW_SIZE  # + 1
    new_tokens = 4
    max_steps = (new_tokens + window_size) * 5
    print("generated_embeds", generated_embeds.shape)
    for step in range(max_steps):
        if temp_gen_size == window_size + new_tokens:
            print(
                "TOKENS FOR EMDED",
                tokenizer.decode(
                    generated_tokens[:, -(window_size + new_tokens) :][:, :WINDOW_SIZE]
                    .cpu()
                    .tolist()[0]
                ),
            )
            # tokenizer.decode(generated_tokens[:, : -window_size ].cpu().tolist()[0])
            if hasattr(model.base_model, "embed_pooler"):
                new_embeds_for_compression = (
                    model.base_model.embed_pooler.model.get_input_embeddings()(
                        generated_tokens[:, -(window_size + new_tokens) :][
                            :, :WINDOW_SIZE
                        ]
                    )
                )
                compressed_part = model.base_model.embed_pooler(
                    new_embeds_for_compression
                )
            else:
                compressed_part = model.embed_pooler(new_embeds_for_compression)

            # generated_embeds = torch.cat([generated_embeds, compressed_part], dim=1)
            generated_embeds = torch.cat(
                [
                    generated_embeds[:, : -(window_size + new_tokens)],
                    compressed_part,
                    # torch.randn(1, 1, 1536, device="cuda"),
                    generated_embeds[:, -new_tokens:],
                ],
                dim=1,
            )
            temp_gen_size = 1

        logits = model(
            inputs_embeds=generated_embeds,
        ).logits
        top_token = logits.argmax(-1)[-1][-1]
        top_token_embed = model.get_input_embeddings()(top_token)
        # print(top)
        generated_tokens = torch.cat([generated_tokens, top_token.reshape(1, 1)], dim=1)

        generated_embeds = torch.cat(
            [generated_embeds, top_token_embed.reshape(1, 1, -1)], dim=1
        )
        print(temp_gen_size, tokenizer.decode(generated_tokens[-1]))

        temp_gen_size += 1

    # print(tokenizer.decode(generated_tokens[-1]))

# break

generated_embeds torch.Size([1, 74, 1536])
0 <｜begin▁of▁sentence｜><｜User｜>Choose your answer from: (A). No (B). Yes
Given those answer options, answer the question: Question: what date did the american civil war start? Would "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg ." be a reasonable answer?
A:<｜Assistant｜><think>
Okay
1 <｜begin▁of▁sentence｜><｜User｜>Choose your answer from: (A). No (B). Yes
Given those answer options, answer the question: Question: what date did the american civil war start? Would "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg ." be a reasonable answer?
A:<｜Assistant｜><think>
Okay,
2 <｜begin▁of▁sentence｜><｜User｜>Choose your answer from: (A). No (B). Yes
Given those answer options, answer the question: Question: what date did the american civil war start? Would "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg ." be a reasonable answer?
A:<｜Assis

In [59]:
compressed_part.shape

torch.Size([1, 1, 1536])

In [9]:
generated_embeds.shape

torch.Size([1, 99, 1536])

In [33]:
torch.randn(1, 1, 1536).shape

torch.Size([1, 1, 1536])

In [10]:
generated_tokens.shape

torch.Size([1, 194])

In [8]:
print(
    """Okay, so I need to figure out whether the statement "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg." is a reasonable answer to the question "What date did the American Civil War start?" """
)
print(tokenizer.decode(generated_tokens[-1]))

Okay, so I need to figure out whether the statement "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg." is a reasonable answer to the question "What date did the American Civil War start?" 
<｜begin▁of▁sentence｜><｜User｜>Choose your answer from: (A). No (B). Yes
Given those answer options, answer the question: Question: what date did the american civil war start? Would "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg ." be a reasonable answer?
A:<｜Assistant｜><think>
Okay, so I need to figure out whether the statement "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg in 1863. That's a bit confusing. Lee was a Confederate general, so "in 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg. That's correct. The Battle of Gettysburg. That's the correct answer. It's a well-known event in American history. The correct answer is Yes. The correct


In [None]:
# Okay, so I need to figure out whether the statement "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg." is a reasonable answer to the question "What date did the American Civil War start?"
# Okay, so I need to figure out whether the statement "In 1863, Robert E. Lee's Confederate incursion north ended at the Battle of Gettysburg." as a possible answer to the question: What date did the American Civil War start?". Hmm, let me break this down.

In [None]:
from hidden_capacity_reasoning.utils import WINDOW_SIZE

model = trainer.model
generated_tokens = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": dataset["train"].to_list()[:5][0]["question"]},
    ],
    tokenize=True,
    add_generation_prompt=True,
)

generated_tokens = torch.tensor(generated_tokens).unsqueeze(0).cuda()
generated_embeds = model.get_input_embeddings()(generated_tokens)
max_steps = 80
temp_gen_size = 0
window_size = WINDOW_SIZE

print("generated_embeds", generated_embeds.shape)
for step in range(max_steps):
    # Run model on current embeddings to get next token
    logits = model(inputs_embeds=generated_embeds).logits
    top_token = logits.argmax(-1)[:, -1]
    top_token_embed = model.get_input_embeddings()(top_token.unsqueeze(0))

    # Add new token
    generated_tokens = torch.cat([generated_tokens, top_token.unsqueeze(1)], dim=1)
    generated_embeds = torch.cat([generated_embeds, top_token_embed], dim=1)

    # Increment counter for compression window
    temp_gen_size += 1

    # Print current token
    print(f"Step {step}, Token {temp_gen_size}: {tokenizer.decode(top_token.item())}")

    # Apply compression when window is filled
    if temp_gen_size == window_size:
        print(
            "TOKENS FOR COMPRESSION:",
            tokenizer.decode(generated_tokens[:, -window_size:-1].cpu().tolist()[0]),
        )

        # Get embeddings for window to compress
        new_embeds_for_compression = model.get_input_embeddings()(
            generated_tokens[:, -window_size:-1]
        )

        # Apply compression
        if hasattr(model.base_model, "embed_pooler"):
            compressed_part = model.base_model.embed_pooler(new_embeds_for_compression)
        else:
            compressed_part = model.embed_pooler(new_embeds_for_compression)

        # Replace window with compressed representation, keeping the latest token
        generated_embeds = torch.cat(
            [
                generated_embeds[:, :-window_size],
                compressed_part,
                generated_embeds[:, -1:],
            ],
            dim=1,
        )

        # Reset counter but account for the latest token
        temp_gen_size = 1

print("Final output:", tokenizer.decode(generated_tokens[0].cpu().tolist()))

generated_embeds torch.Size([1, 74, 1536])
Step 0, Token 1: Okay
Step 1, Token 2: ,
Step 2, Token 3:  so
Step 3, Token 4:  I
Step 4, Token 5:  need
Step 5, Token 6:  to
Step 6, Token 7:  figure
Step 7, Token 8:  out
Step 8, Token 9:  whether
Step 9, Token 10:  the
TOKENS FOR COMPRESSION: Okay, so I need to figure out whether
Step 10, Token 2:  user
Step 11, Token 3:  is
Step 12, Token 4:  asking
Step 13, Token 5:  whether
Step 14, Token 6:  the
Step 15, Token 7:  answer
Step 16, Token 8:  "
Step 17, Token 9: In
Step 18, Token 10:  
TOKENS FOR COMPRESSION:  the user is asking whether the answer "In
Step 19, Token 2:  I
Step 20, Token 3:  need
Step 21, Token 4:  to
Step 22, Token 5:  determine
Step 23, Token 6:  whether
Step 24, Token 7:  the
Step 25, Token 8:  statement
Step 26, Token 9:  "
Step 27, Token 10: In
TOKENS FOR COMPRESSION:   I need to determine whether the statement "
Step 28, Token 2:  
Step 29, Token 3: 1
Step 30, Token 4: 8
Step 31, Token 5: 6
Step 32, Token 6: 3
Step 33

In [85]:
from hidden_capacity_reasoning.utils import WINDOW_SIZE
import torch

model = trainer.model
generated_tokens = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": dataset["train"].to_list()[:5][0]["question"]},
    ],
    tokenize=True,
    add_generation_prompt=True,
)

generated_tokens = torch.tensor(generated_tokens).unsqueeze(0).cuda()
generated_embeds = model.get_input_embeddings()(generated_tokens)
max_steps = 100
temp_gen_size = 0
window_size = WINDOW_SIZE + 1  # +1 to account for keeping the latest token
total_tokens = []

print("generated_embeds", generated_embeds.shape)
for step in range(max_steps):
    # Run model forward pass
    with torch.no_grad():
        logits = model(inputs_embeds=generated_embeds).logits
    next_token = logits[:, -1, :].argmax(dim=-1)
    total_tokens.append(next_token.item())

    # Get embedding for the new token
    next_token_embed = model.get_input_embeddings()(next_token.unsqueeze(1))

    # Add new token to sequence
    generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(1)], dim=1)
    generated_embeds = torch.cat([generated_embeds, next_token_embed], dim=1)

    # Print current token
    token_text = tokenizer.decode(next_token.item())
    print(f"Step {step}, Token {temp_gen_size+1}: {token_text}")
    temp_gen_size += 1

    # Apply compression when window is full
    if temp_gen_size == window_size:
        # Show exactly WINDOW_SIZE tokens being compressed (not including the latest one)
        window_tokens = generated_tokens[:, -window_size:][:, :WINDOW_SIZE]
        window_text = tokenizer.decode(window_tokens.cpu().tolist()[0])
        print(f"TOKENS FOR COMPRESSION: {window_text}")

        # Get embeddings for compression window - exactly WINDOW_SIZE tokens
        # window_embeds = model.get_input_embeddings()(window_tokens)
        window_embeds = model.base_model.embed_pooler.get_input_embeddings()(
            window_tokens
        )

        # Apply compression
        if hasattr(model.base_model, "embed_pooler"):
            compressed_embed = model.base_model.embed_pooler(window_embeds)
        else:
            compressed_embed = model.embed_pooler(window_embeds)

        # Replace window with compressed representation, preserving latest token
        generated_embeds = torch.cat(
            [
                generated_embeds[:, :-window_size],  # Tokens before window
                compressed_embed,  # Compressed window (WINDOW_SIZE tokens)
                generated_embeds[:, -1:],  # Latest token
            ],
            dim=1,
        )

        # Reset counter but keep the latest token
        temp_gen_size = 1

    # Optional stopping condition
    if next_token.item() == tokenizer.eos_token_id:
        break

print("Full generated text:", tokenizer.decode(total_tokens))

generated_embeds torch.Size([1, 74, 1536])
Step 0, Token 1: Okay
Step 1, Token 2: ,
Step 2, Token 3:  so
Step 3, Token 4:  I
Step 4, Token 5:  need
Step 5, Token 6:  to
Step 6, Token 7:  figure
Step 7, Token 8:  out
Step 8, Token 9:  whether
Step 9, Token 10:  the
Step 10, Token 11:  statement
TOKENS FOR COMPRESSION: Okay, so I need to figure out whether the
Step 11, Token 2: ,
Step 12, Token 3: 我现在
Step 13, Token 4: 要
Step 14, Token 5: 解决
Step 15, Token 6: 的问题
Step 16, Token 7: 是
Step 17, Token 8: 判断
Step 18, Token 9: 给
Step 19, Token 10: 定
Step 20, Token 11: 的
TOKENS FOR COMPRESSION:  statement,我现在要解决的问题是判断给定
Step 21, Token 2: 用户
Step 22, Token 3: 问
Step 23, Token 4: 的是
Step 24, Token 5: ，
Step 25, Token 6: 给
Step 26, Token 7: 定
Step 27, Token 8: 的
Step 28, Token 9: 选项
Step 29, Token 10: 是
Step 30, Token 11: A
TOKENS FOR COMPRESSION: 的用户问的是，给定的选项是
Step 31, Token 2: 的问题
Step 32, Token 3: 是
Step 33, Token 4: 关于
Step 34, Token 5: 美国
Step 35, Token 6: 的
Step 36, Token 7: 独立
Step 37, Toke

In [None]:
from hidden_capacity_reasoning.utils import WINDOW_SIZE
import torch

model = trainer.model
generated_tokens = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": dataset["train"].to_list()[:5][0]["question"]},
    ],
    tokenize=True,
    add_generation_prompt=True,
)

generated_tokens = torch.tensor(generated_tokens).unsqueeze(0).cuda()
generated_embeds = model.get_input_embeddings()(generated_tokens)
max_steps = 100
temp_gen_size = 0
window_size = WINDOW_SIZE
total_tokens = []

print("generated_embeds", generated_embeds.shape)
for step in range(max_steps):
    # Run model forward pass
    with torch.no_grad():
        logits = model(inputs_embeds=generated_embeds).logits
    next_token = logits[:, -1, :].argmax(dim=-1)
    total_tokens.append(next_token.item())

    # Get embedding for the new token
    next_token_embed = model.get_input_embeddings()(next_token.unsqueeze(1))

    # Add new token to sequence
    generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(1)], dim=1)
    generated_embeds = torch.cat([generated_embeds, next_token_embed], dim=1)

    # Print current token
    token_text = tokenizer.decode(next_token.item())
    print(f"Step {step}, Token {temp_gen_size+1}: {token_text}")
    temp_gen_size += 1

    # Apply compression when window is full
    if temp_gen_size == window_size:
        # Get window tokens for compression (exactly WINDOW_SIZE tokens)
        window_tokens = generated_tokens[:, -window_size:]
        window_text = tokenizer.decode(window_tokens.cpu().tolist()[0])
        print(f"TOKENS FOR COMPRESSION: {window_text}")

        # Get embeddings for compression window
        window_embeds = model.get_input_embeddings()(window_tokens)

        # Apply compression
        if hasattr(model.base_model, "embed_pooler"):
            compressed_embed = model.base_model.embed_pooler(window_embeds)
        else:
            compressed_embed = model.embed_pooler(window_embeds)

        # Replace entire window with just the compressed representation
        generated_embeds = torch.cat(
            [
                generated_embeds[:, :-window_size],  # Tokens before window
                compressed_embed,  # Compressed window
            ],
            dim=1,
        )

        # Reset counter to 0 since we're not keeping any tokens separate
        temp_gen_size = 0

    # Optional stopping condition
    if next_token.item() == tokenizer.eos_token_id:
        break

print("Full generated text:", tokenizer.decode(total_tokens))

generated_embeds torch.Size([1, 74, 1536])
Step 0, Token 1: Okay
Step 1, Token 2: ,
Step 2, Token 3:  so
Step 3, Token 4:  I
Step 4, Token 5:  need
Step 5, Token 6:  to
Step 6, Token 7:  figure
Step 7, Token 8:  out
Step 8, Token 9:  whether
Step 9, Token 10:  the
TOKENS FOR COMPRESSION: Okay, so I need to figure out whether the
Step 10, Token 1: ,
Step 11, Token 2:  �
Step 12, Token 3: �
Step 13, Token 4: 需要
Step 14, Token 5: 判断
Step 15, Token 6: 给
Step 16, Token 7: 定
Step 17, Token 8: 的
Step 18, Token 9: 陈述
Step 19, Token 10: 是否
TOKENS FOR COMPRESSION: , 我需要判断给定的陈述是否
Step 20, Token 1: ，
Step 21, Token 2: 我
Step 22, Token 3: 需要
Step 23, Token 4: 判断
Step 24, Token 5: 给
Step 25, Token 6: 定
Step 26, Token 7: 的
Step 27, Token 8: 陈述
Step 28, Token 9: 是否
Step 29, Token 10: 是
TOKENS FOR COMPRESSION: ，我需要判断给定的陈述是否是
Step 30, Token 1: 问题
Step 31, Token 2: 是要
Step 32, Token 3: 判断
Step 33, Token 4: 给
Step 34, Token 5: 定
Step 35, Token 6: 的
Step 36, Token 7: 陈述
Step 37, Token 8: 是否
Step 38, Token 