In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
import torch
%env TOKENIZERS_PARALLELISM=False

# Загрузка модели и токенизатора
model_name = "Qwen/Qwen2.5-Math-1.5B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map='cuda'
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Загрузка датасета
ds = load_dataset("ScalableMath/Lean-STaR-plus")
ds = ds['train'].train_test_split(test_size=0.01, seed=42)
ds_train = ds['train']
ds_test = ds['test']

env: TOKENIZERS_PARALLELISM=False


Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [2]:
# Форматирование данных для обучения
def format_instruction(example):
    return {
        "text": f"{example['input']}\n{example['output']}"
    }

formatted_ds = ds_train.map(format_instruction)

# Токенизация датасета
def tokenize_function(examples):
    result = tokenizer(
        examples["text"],
        truncation=True, 
        max_length=2048,
        padding=True,
    )
    result["labels"] = result["input_ids"].copy()
    return result

tokenized_ds = formatted_ds.map(tokenize_function, batched=True, remove_columns=["text"])

In [3]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, 
    mlm=False, 
)

training_args = TrainingArguments(
    output_dir="./qwen-math-finetuned",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    save_steps=1000,
    save_total_limit=20,
    prediction_loss_only=True,
    #fp16=True,
    bf16 = True,
    optim='adamw_8bit',
    learning_rate=1e-5,
    weight_decay=0.01,
    warmup_steps=500,
    logging_dir="./logs",
    logging_steps=500,
    dataloader_num_workers=1,
    group_by_length=True,
)

# Инициализация Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds,
    data_collator=data_collator,
)

# Обучение модели
trainer.train()

# Сохранение модели
model.save_pretrained("./qwen-math-finetuned")
tokenizer.save_pretrained("./qwen-math-finetuned")

Step,Training Loss
500,2.1227
1000,0.9844
1500,0.8768
2000,0.8212
2500,0.7976
3000,0.7836
3500,0.7694
4000,0.7582
4500,0.7499
5000,0.7492


('./qwen-math-finetuned/tokenizer_config.json',
 './qwen-math-finetuned/special_tokens_map.json',
 './qwen-math-finetuned/vocab.json',
 './qwen-math-finetuned/merges.txt',
 './qwen-math-finetuned/added_tokens.json',
 './qwen-math-finetuned/tokenizer.json')

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

local_model_path = "/mfrolova/qwen-math-finetuned/checkpoint-11000"

model = AutoModelForCausalLM.from_pretrained(
    local_model_path,
    torch_dtype=torch.bfloat16,
    device_map='cuda',
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(
    local_model_path,
    trust_remote_code=True
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

input_text = """<|im_start|>user
My LEAN 4 state is:
import Mathlib.Data.Matrix.Basic
import Aesop
import Mathlib.Tactic.NormNum
import Mathlib.Tactic.RewriteSearch
import AutoSolver


@[simp] def x: List ℚ :=  [6, 2, 9]
@[simp] def e_1: List ℚ :=  [5, 0, 4]
@[simp] def e_2: List ℚ :=  [5, -1, 0]
@[simp] def e_3: List ℚ :=  [-1, 0, 4]

@[simp] def x1: ℚ := 73/24
@[simp] def x2: ℚ := -2
@[simp] def x3: ℚ := -19/24
Please write down the reasoning that leads to the possible next tactic and then predict the tactic to help me prove the corectness of the system.<|im_end|> \
<|im_start|>assistant"""

inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

with torch.no_grad():
    outputs = model.generate(
        inputs.input_ids,
        max_new_tokens=512,
    )

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(generated_text)