In [1]:
from hidden_capacity_reasoning.utils import tokenize_single_turn
from transformers import Qwen2ForCausalLM, Qwen2Model, AutoTokenizer, BitsAndBytesConfig
import torch
from trl import (
    ModelConfig,
    ScriptArguments,
    SFTConfig,
    SFTTrainer,
    TrlParser,
    get_kbit_device_map,
)


class Qwen2ModelEmbedPoolerV1(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config)
        self.model.embed_tokens = None
        self.lm_head = None
        self.post_init()

    def forward(self, input_embeds):
        # print(input_embeds.dtype)
        input_embeds = self.model(
            inputs_embeds=input_embeds,
            output_hidden_states=True,
        )[0]
        # print(input_embeds.dtype)
        input_embeds = input_embeds.sum(1) / torch.tensor(
            input_embeds.shape[1],
            device=input_embeds.device,
            dtype=input_embeds.dtype,
        )
        # print(input_embeds.dtype)
        input_embeds = input_embeds.unsqueeze(1)
        return input_embeds


class Qwen2ForCausalLMCompressionV1(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = torch.nn.Linear(
            config.hidden_size, config.vocab_size, bias=False
        )
        # print(config._name_or_path)
        self.embed_pooler = Qwen2ModelEmbedPoolerV1.from_pretrained(
            config._name_or_path,
        )

        self.post_init()
        # Initialize weights and apply final processing

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        past_key_values=None,
        inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        cache_position=None,
        logits_to_keep=0,
        **kwargs
    ):
        if "replaced_original_tokens" in kwargs:
            pass
        return super().forward(
            input_ids,
            attention_mask,
            position_ids,
            past_key_values,
            inputs_embeds,
            labels,
            use_cache,
            output_attentions,
            output_hidden_states,
            return_dict,
            cache_position,
            logits_to_keep,
            **kwargs
        )


# torch.set_grad_enabled(False)
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
# model_name = "./test_model/"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model = Qwen2ForCausalLMCompressionV1.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    device_map=get_kbit_device_map(),
)

# model = model.eval().cuda()
# model.model = model.embed_pooler.model
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "how many wings has a bird?"
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
print(text)
model_inputs = tokenizer([text], return_tensors="pt").to(device)

with torch.no_grad():
    # generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=1)
    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=1000,
        do_sample=False,
    )
generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
Some weights of Qwen2ForCausalLMCompressionV1 were not initialized from the model checkpoint at deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B and are newly initialized: ['embed_pooler.model.layers.0.input_layernorm.weight', 'embed_pooler.model.layers.0.mlp.down_proj.weight', 'embed_pooler.model.layers.0.mlp.gate_proj.weight', 'embed_pooler.model.layers.0.mlp.up_proj.weight', 'embed_pooler.model.layers.0.post_attention_layernorm.weight', 'embed_pooler.model.layers.0.self_attn.k_proj.bias', 'embed_pooler.model.layers.0.self_attn.k_proj.weight', 'embed_pooler.model.layers.0.self_attn.o_proj.weight', 'embed_pooler.model.layers.0.self_attn.q_proj.bias', 'embed_pooler.model.layers.0.self_attn.q_proj.weight', 'embed_pooler.model.layers.0.self_attn.v_proj.bias', 'embed_pooler.model.layers.0.self_attn.v_proj.weight', 'embed_pooler.model.layers.1.input_layernorm.weight', 'embed_pooler.model.lay

<｜begin▁of▁sentence｜>You are a helpful assistant.<｜User｜>how many wings has a bird?<｜Assistant｜><think>



'Okay, so I need to figure out how many wings a bird has. Let me start by recalling what I know about birds. I know that most birds have wings, but I\'m not exactly sure how many. Maybe I should think about different types of birds and their wing counts.\n\nFirst, let\'s consider common birds. The most common bird I can think of is the sparrow. I remember seeing sparrow wings, and I think they have two wings. So, maybe sparrow has two wings. But wait, I\'m not entirely sure. I should check if there are any exceptions or special cases.\n\nThen there are other birds like the ostrich. Wait, no, ostrich is a bird, but I think it\'s actually a mammal. No, wait, no, I\'m confusing. Ostrich is a type of bird, right? No, wait, no, I think it\'s a type of bird. Wait, no, I\'m getting confused. Let me think again. Ostrich is a type of bird, but I think it\'s actually a type of bird called a "flying bird." Wait, no, I\'m mixing up. Let me clarify. Ostrich is a type of bird, but I think it\'s actu

In [2]:
model

In [None]:
# если у модели есть ембединги они должны быть равны(по умолчанию я выставил None для экономии)
# несмотря на то что нам пишут что веса были newly initialized
model.model.embed_tokens.weight == model.embed_pooler.model.embed_tokens.weight

In [4]:
# tokenizer.save_pretrained("./test_model")

In [5]:
# model.save_pretrained("./test_model")

In [None]:
model

In [4]:
# model.save_pretrained('test_model')

In [2]:
from datasets import load_dataset
from tqdm import tqdm
from hidden_capacity_reasoning.utils import generate_train_examples, pad_train_examples

dataset = load_dataset("dim/open_orca_905_DeepSeek-R1-Distill-Qwen-1.5B")
dataset = dataset["train"]
dataset = dataset.train_test_split(test_size=5, seed=42)
dataset

# test
tokenize_single_turn(
    question=dataset["train"][0]["question"],
    answer=dataset["train"][0]["answer"],
    tokenizer=tokenizer,
)
train_examples = [
    tokenize_single_turn(tokenizer=tokenizer, **item)
    for item in tqdm(dataset["train"].to_list())
]


prepared_train_examples = []
for item in tqdm(train_examples):
    for example in generate_train_examples(
        dataset_batch=[item], tokenizer=tokenizer, window_size=4
    ):
        prepared_train_examples.append(example)

100%|██████████| 900/900 [00:01<00:00, 645.04it/s]
100%|██████████| 900/900 [00:05<00:00, 164.67it/s]


In [3]:
len(prepared_train_examples)

124217

In [None]:
# from more_itertools import chunked

# batch_size = 4
# train_examples_batches = [
#     pad_train_examples(
#         train_examples=item,
#         tokenizer=tokenizer,
#     )
#     for item in tqdm(
#         list(
#             chunked(
#                 prepared_train_examples,
#                 batch_size,
#             )
#         )
#     )
# ]

  0%|          | 0/31055 [00:00<?, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 31055/31055 [00:06<00:00, 4568.52it/s]


In [None]:
# train_examples_batches_clean = []
# for item in train_examples_batches:
#     train_examples_batches_clean.append(
#         {
#             "replaced_original_tokens": item["replaced_original_tokens"]["input_ids"],
#             "compressed_input_ids": item["compressed_input_ids"]["input_ids"],
#             "original_tokens": item["original_tokens"]["input_ids"],
#             "attention_mask": item["compressed_input_ids"]["attention_mask"],
#             "labels": item["compressed_input_ids"]["input_ids"],
#         }
#     )
# print(len(train_examples_batches_clean))
# # train_examples_batches_clean[0]

31055


### Use TRL

In [None]:
# from peft import PeftModel

# PeftModel.from_pretrained(model, "./test")

In [4]:
from datasets import Dataset

new_dataset = Dataset.from_list(prepared_train_examples)
new_dataset

Dataset({
    features: ['replaced_original_tokens', 'compressed_input_ids', 'original_tokens'],
    num_rows: 124217
})

In [None]:
# https://github.com/huggingface/trl/blob/main/examples/scripts/sft_video_llm.py
from trl import SFTTrainer, SFTConfig
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import is_bfloat16_supported
from peft import (
    LoraConfig,
    PeftConfig,
    PeftModel,
    get_peft_model,
    prepare_model_for_kbit_training,
)
from hidden_capacity_reasoning.utils import EOS_TOKEN_ID, TEXT_TOKEN_ID, WINDOW_SIZE


def find_all_linear_names_v2(model):
    lora_module_names = set()
    target_modules = set(
        [
            "q_proj",
            "k_proj",
            "v_proj",
            # "o_proj",
        ]
    )
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if "embed_pooler" in name:
                names = name.split(".")[-1]
                if names in target_modules:
                    lora_module_names.add(name)
    return lora_module_names


def collate_fn(batch):
    padded_batch = pad_train_examples(
        train_examples=batch,
        tokenizer=tokenizer,
    )
    padded_batch = {
        "replaced_original_tokens": padded_batch["replaced_original_tokens"][
            "input_ids"
        ],
        "compressed_input_ids": padded_batch["compressed_input_ids"]["input_ids"],
        "original_tokens": padded_batch["original_tokens"]["input_ids"],
        "attention_mask": padded_batch["compressed_input_ids"]["attention_mask"],
        "labels": padded_batch["compressed_input_ids"]["input_ids"],
    }
    for key in padded_batch.keys():
        padded_batch[key] = torch.tensor(padded_batch[key])
    skip_ids = [TEXT_TOKEN_ID, EOS_TOKEN_ID]
    for skip_id in skip_ids:
        padded_batch["labels"][padded_batch["labels"] == skip_id] = -100
    # print(padded_batch)
    return padded_batch


import types


def forward(
    self,
    input_ids=None,
    attention_mask=None,
    position_ids=None,
    past_key_values=None,
    inputs_embeds=None,
    labels=None,
    use_cache=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
    cache_position=None,
    logits_to_keep=0,
    **kwargs
):
    # print(kwargs[''])
    if "replaced_original_tokens" in kwargs:
        # original_tokens_torch = torch.tensor(
        #     kwargs["original_tokens"],
        #     device=self.model.device,
        # )
        # replaced_tokens_torch = torch.tensor(
        #     kwargs["replaced_original_tokens"],
        #     device=self.model.device,
        # )
        # compressed_tokens_torch = torch.tensor(
        #     kwargs["compressed_input_ids"],
        #     device=self.model.device,
        # )
        original_tokens_torch = kwargs["original_tokens"].to(self.model.device)
        replaced_tokens_torch = kwargs["replaced_original_tokens"].to(self.model.device)
        compressed_tokens_torch = kwargs["compressed_input_ids"].to(self.model.device)

        original_embeds = self.model.get_input_embeddings()(original_tokens_torch)
        # replaced_embeds = self.model.get_input_embeddings()(replaced_tokens_torch)
        compressed_embeds_template = self.model.get_input_embeddings()(
            compressed_tokens_torch
        )

        tokens_for_compression_mask = replaced_tokens_torch == TEXT_TOKEN_ID
        compressed_tokens_mask = compressed_tokens_torch == TEXT_TOKEN_ID
        embeds_for_compression = original_embeds[tokens_for_compression_mask].reshape(
            -1,
            WINDOW_SIZE,
            original_embeds.shape[-1],
        )
        pooled_embeds = self.embed_pooler(embeds_for_compression)
        pooled_embeds = pooled_embeds.to(compressed_embeds_template.dtype)
        compressed_embeds_template = compressed_embeds_template.masked_scatter_(
            compressed_tokens_mask.unsqueeze(-1).expand_as(compressed_embeds_template),
            pooled_embeds,
        )
        inputs_embeds = compressed_embeds_template

    return super(type(self), self).forward(
        input_ids,
        attention_mask,
        position_ids,
        past_key_values,
        inputs_embeds,
        labels,
        use_cache,
        output_attentions,
        output_hidden_states,
        return_dict,
        cache_position,
        logits_to_keep,
        **kwargs
    )


# model.forward = forward
model.forward = types.MethodType(forward, model)


peft_config = LoraConfig(
    r=2,
    lora_alpha=16,
    lora_dropout=0.0,
    target_modules=find_all_linear_names_v2(model=model),
)
model.model.requires_grad_(False)
max_seq_length = 5000
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=new_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
        warmup_steps=5,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps=60,
        learning_rate=2e-4,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir="outputs",
        report_to="none",  # Use this for WandB etc
        remove_unused_columns=False,
        dataset_kwargs={"skip_prepare_dataset": True},
    ),
)
trainer.train()

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


  trainer = SFTTrainer(
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss
1,2.621
2,2.9071
3,2.2246
4,4.3144
5,3.3574
6,2.0079
7,4.2379
8,4.736


OutOfMemoryError: CUDA out of memory. Tried to allocate 58.00 MiB. GPU 0 has a total capacity of 23.63 GiB of which 39.69 MiB is free. Process 2385140 has 23.48 GiB memory in use. Of the allocated memory 22.40 GiB is allocated by PyTorch, and 641.42 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Qwen2Model(
  (embed_tokens): Embedding(151936, 1536)
  (layers): ModuleList(
    (0-27): 28 x Qwen2DecoderLayer(
      (self_attn): Qwen2Attention(
        (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
        (k_proj): Linear(in_features=1536, out_features=256, bias=True)
        (v_proj): Linear(in_features=1536, out_features=256, bias=True)
        (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
      )
      (mlp): Qwen2MLP(
        (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
        (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
        (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
      (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
    )
  )
  (norm): Qwen2RMSNorm((1536,), eps=1e-06)
  (rotary_emb): Qwen2RotaryEmbedding()
)