In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float32,
    device_map={"": 0},
    attn_implementation="sdpa",
)
# model = model.eval()
model.requires_grad_(False)

from datasets import load_dataset

dataset = load_dataset(
    # "dim/hendrycks_math_train_12k_DeepSeek-R1-Distill-Qwen-1.5B_max_len_4096"
    # "dim/hendrycks_math_test_500_DeepSeek-R1-Distill-Qwen-1.5B_max_len_4096_greedy"
    # "dim/hendrycks_math_train_1k_DeepSeek-R1-Distill-Qwen-1.5B_max_len_4096_greedy"
    "dim/hendrycks_math_test_500_DeepSeek-R1-Distill-Qwen-1.5B_max_len_4096_greedy"
)

dataset = dataset["train"].train_test_split(
    # test_size=250,
    test_size=350,
    # test_size=999,
    # test_size=1,
    seed=42,
)
dataset = dataset["test"].filter(lambda x: x["model_answer"].count("</think>") == 1)

from lm_eval.tasks.hendrycks_math.utils import strip_string, remove_boxed, is_equiv
from hidden_capacity_reasoning.evaluation.math_500.utils import (
    dataset_answer_filter,
    model_answer_filter,
)

correct_dataset = []

for pos, item in enumerate(dataset):
    try:
        answer = dataset_answer_filter(item["answer"])
        model_answer = model_answer_filter(item["model_answer"])
        # print(answer, model_answer)
        # break
        if is_equiv(answer, model_answer):
            correct_dataset.append(item)
    except:
        pass

print(len(dataset), len(correct_dataset), len(correct_dataset) / len(dataset))

correct_dataset = correct_dataset[:30]
len(correct_dataset)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


'NoneType' object has no attribute 'group'
'NoneType' object has no attribute 'group'
'NoneType' object has no attribute 'group'
'NoneType' object has no attribute 'group'
'NoneType' object has no attribute 'group'
'NoneType' object has no attribute 'group'
'NoneType' object has no attribute 'group'
'NoneType' object has no attribute 'group'
'NoneType' object has no attribute 'group'
'NoneType' object has no attribute 'group'
'NoneType' object has no attribute 'group'
'NoneType' object has no attribute 'group'
'NoneType' object has no attribute 'group'
224 202 0.9017857142857143


30

## test time train generation (single train)

In [28]:
import torch

from lm_eval.tasks.hendrycks_math.utils import strip_string, remove_boxed, is_equiv
from hidden_capacity_reasoning.evaluation.math_500.utils import (
    dataset_answer_filter,
    model_answer_filter,
)
from tqdm.notebook import tqdm
from tqdm import tqdm as text_tqdm
from hidden_capacity_reasoning.utils import tokenize_single_turn

torch.manual_seed(0)
base_prompt = open(
    "hidden_capacity_reasoning/evaluation/math_500/math_500_prompt"
).read()

max_new_tokens = 400
compression_tokens = 8 * 2

evaluation_dataset = []
correct_items = 0
model.generation_config.pad_token_id = tokenizer.pad_token_id

for dataset_pos in tqdm(range(len(correct_dataset))):
    tokenized_turn = tokenize_single_turn(
        question=base_prompt.format(question=correct_dataset[dataset_pos]["problem"]),
        answer=correct_dataset[dataset_pos]["model_answer"],
        tokenizer=tokenizer,
    )
    for key in tokenized_turn.keys():
        tokenized_turn[key] = torch.tensor(tokenized_turn[key])

    device = "cuda"

    content_compression_mask = tokenized_turn["content_compression_mask"]

    input_part_end = (content_compression_mask == 0).nonzero()[-3][0]
    # get only question part
    question_input_ids = (
        tokenized_turn["input_ids"][: int(input_part_end) + 1].unsqueeze(0).cuda()
    )
    # print(tokenizer.decode(question_input_ids[-1]))

    ########
    ######## generate first part of tokens
    ########
    with torch.no_grad():

        # input_ids = torch.tensor(question_input_ids).cuda()
        input_ids_embeds = model.get_input_embeddings()(question_input_ids)

        inputs_embeds = torch.cat(
            [
                input_ids_embeds,
            ],
            dim=1,
        )
        generated_ids_new = model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=torch.ones(
                inputs_embeds.shape[:2],
                device="cuda",
            ).long(),
            max_new_tokens=max_new_tokens,
            do_sample=False,
        )
        # break
    generated_result = tokenizer.decode(generated_ids_new[-1])
    # print(generated_result)

    ########
    ######## get original language loss
    ########
    labels = torch.cat(
        [
            question_input_ids.cuda(),
            generated_ids_new.cuda(),
        ],
        dim=1,
    )
    # print(tokenizer.decode(labels[-1]))

    question_content_mask = content_compression_mask[: int(input_part_end) + 1].clone()
    question_content_mask[question_content_mask == 0] = 4
    question_content_mask[question_content_mask == 1] = 0
    question_content_mask[question_content_mask == 4] = 1
    train_content_mask_new = torch.cat(
        [
            question_content_mask,
            torch.ones(
                generated_ids_new.shape[1],
            ),
        ]
    ).long()
    # print(question_content_mask)

    generated_embeds = model.get_input_embeddings()(generated_ids_new)
    new_input_embeds = torch.cat(
        [
            input_ids_embeds,
            generated_embeds,
        ],
        dim=1,
    )
    labels[:, train_content_mask_new == 0] = -100

    with torch.no_grad():
        original_loss = model(
            inputs_embeds=new_input_embeds,
            labels=labels,
        ).loss
    print("original_loss", original_loss)
    ########
    ######## generate compress embeddings
    ########

    compression_tensor = torch.nn.Parameter(
        torch.rand_like(
            new_input_embeds[:, :compression_tokens, :],
        )
        * model.get_input_embeddings().weight.data.std(),
        requires_grad=True,
    )
    compressed_inputs_embeds = torch.cat(
        [
            input_ids_embeds.detach(),
            compression_tensor,
            generated_embeds[:, -(max_new_tokens // 2) :, :].detach(),
        ],
        dim=1,
    )
    question_labels = question_input_ids.clone()
    question_labels[0][question_content_mask == 0] = -100
    question_labels = question_labels.cuda()
    compressed_part = torch.ones(compression_tensor.shape[:2]).long().cuda()

    compressed_labels = torch.cat(
        [
            question_labels,
            compressed_part,
            generated_ids_new[:, -(max_new_tokens // 2) :],
        ],
        dim=-1,
    )

    ########
    ######## train
    ########
    epoch_amount = 100

    optimizer = torch.optim.Adam([compression_tensor], lr=0.1)
    acclumulation_steps = 1
    for epoch in range(epoch_amount):
        compressed_inputs_embeds = torch.cat(
            [
                input_ids_embeds.detach(),
                compression_tensor,
                generated_embeds[:, -(max_new_tokens // 2) :, :].detach(),
            ],
            dim=1,
        )
        compression_loss = model(
            inputs_embeds=compressed_inputs_embeds,
            labels=compressed_labels,
        ).loss
        compression_loss.backward()
        if (epoch + 1) % acclumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        # print(epoch, compression_loss)
        # if compression_loss.item() <= (original_loss.item() - 0.01):
        #     break
        if (compression_loss.item() - 0.02) <= original_loss.item():
            break
        # if compression_loss.item() <= original_loss.item():
        #     break
    print("compression_loss", compression_loss)
    ########
    ######## evaluate
    ########
    with torch.no_grad():

        compressed_inputs_embeds = torch.cat(
            [
                input_ids_embeds.detach(),
                # compression_tensor,
                torch.rand_like(compression_tensor).cuda(),
                generated_embeds[:, -(max_new_tokens // 2) :, :].detach(),
            ],
            dim=1,
        )
        generated_ids_compressed = model.generate(
            inputs_embeds=compressed_inputs_embeds,
            attention_mask=torch.ones(
                compressed_inputs_embeds.shape[:2],
                device="cuda",
            ).long(),
            max_new_tokens=4096,
            do_sample=False,
        )
        # break
    generated_result = tokenizer.decode(generated_ids_compressed[-1])
    gold_answer = correct_dataset[dataset_pos]["answer"]
    answer = dataset_answer_filter(gold_answer)
    model_answer = model_answer_filter(generated_result)
    if is_equiv(answer, model_answer):
        correct_items += 1
        print("CORRECT")
    else:
        print("WRONG", gold_answer)
        print(generated_result)
    compressed_total_len = (
        compression_tensor.shape[1]
        + generated_embeds[:, -(max_new_tokens // 2) :, :].shape[1]
        + generated_ids_compressed.shape[1]
    )
    original_total_len = len(
        tokenizer.encode(
            correct_dataset[dataset_pos]["model_answer"],
            add_special_tokens=False,
        )
    )
    print(
        f"сгенерированно={generated_ids_compressed.shape[1]}, вопрос+сжатые+сгенерированные={compressed_total_len} оригинальная генерация={original_total_len}"
    )
    evaluation_dataset.append(
        {
            "original_total_len": original_total_len,
            "compressed_total_len": compressed_total_len,
        }
    )
    # break

  0%|          | 0/30 [00:00<?, ?it/s]

original_loss tensor(0.2270, device='cuda:0')
compression_loss tensor(0.2568, device='cuda:0', grad_fn=<NllLossBackward0>)
CORRECT
сгенерированно=1737, вопрос+сжатые+сгенерированные=1953 оригинальная генерация=1959
original_loss tensor(0.2759, device='cuda:0')
compression_loss tensor(0.2916, device='cuda:0', grad_fn=<NllLossBackward0>)
CORRECT
сгенерированно=625, вопрос+сжатые+сгенерированные=841 оригинальная генерация=1125
original_loss tensor(0.2294, device='cuda:0')
compression_loss tensor(0.2493, device='cuda:0', grad_fn=<NllLossBackward0>)
CORRECT
сгенерированно=1173, вопрос+сжатые+сгенерированные=1389 оригинальная генерация=1548
original_loss tensor(0.2595, device='cuda:0')
compression_loss tensor(0.2987, device='cuda:0', grad_fn=<NllLossBackward0>)
CORRECT
сгенерированно=1459, вопрос+сжатые+сгенерированные=1675 оригинальная генерация=3960
original_loss tensor(0.2939, device='cuda:0')
compression_loss tensor(0.3127, device='cuda:0', grad_fn=<NllLossBackward0>)
CORRECT
сгенерирова

In [30]:
len(correct_dataset) / len(dataset), correct_items / len(dataset), correct_items / len(
    correct_dataset
)

(0.13392857142857142, 0.12946428571428573, 0.9666666666666667)

In [32]:
original_total_len = 0
compressed_total_len = 0
for item in evaluation_dataset:
    original_total_len += item["original_total_len"]
    compressed_total_len += item["compressed_total_len"]
original_total_len, compressed_total_len, compressed_total_len / original_total_len

(56056, 41238, 0.7356571999429142)

In [None]:
# (56056, 38735, 0.6910054231482803) - обучение с 8 токенами, до того как лосс не станет оригинальным, все ответы правильные (0.9666666666666667) просто ошибка парсинга
# (56056, 37788, 0.6741116026830313) - - обучение с 8 токенами, до того как лосс не станет оригинальным - 0.01, все ответы правильные (0.9666666666666667) просто ошибка парсинга
# (56056, 41238, 0.7356571999429142) - никакого обучения, просто вставка рандомных 8 векторов, точность 0.9666666666666667