In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
import evaluate

In [2]:
# Параметры
model_name = "google/gemma-2b-it"
dataset_name = "tatsu-lab/alpaca"  # можно заменить на локальный путь или huggingface dataset id

In [3]:
# 1. Загрузка модели и токенизатора
tokenizer = AutoTokenizer.from_pretrained(model_name)

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_enable_fp32_cpu_offload=True,
    # llm_int8_threshold=6.0
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
# 2. Настройка LoRA
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # в зависимости от модели
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(model, lora_config)

In [5]:
# 3. Подготовка данных
def formatting_func(example):
    prompt = f"<bos><start_of_turn>system\n{example['instruction']}<end_of_turn>\n<start_of_turn>user\n{example['input']}<end_of_turn>\n<start_of_turn>model\n"
    full_text = prompt + example["output"]
    return {"text": full_text}

dataset = load_dataset(dataset_name, split="train")
split_dataset = dataset.train_test_split(test_size=0.01, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]

dataset = dataset.map(formatting_func)

In [6]:
metric = evaluate.load("perplexity", module_type="metric")

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = logits.argmax(-1)

    # Поскольку perplexity принимает списки строк — декодируем
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Вычисляем perplexity
    results = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return {"perplexity": results["perplexity"]}

In [7]:
# 4. Настройка обучения
training_args = SFTConfig(
    output_dir="./results",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    max_seq_length=512,
    packing=False,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=50,
    save_steps=500,
    learning_rate=2e-4,
    num_train_epochs=3,
    # fp16=True,
    save_total_limit=2,
    report_to="none"
)

In [8]:
# 5. Тренировка
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=training_args,
    compute_metrics=None,
)

Truncating train dataset:   0%|          | 0/51481 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/521 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
trainer.train()

Step,Training Loss,Validation Loss
50,1.4338,1.339693
100,1.2728,1.275633
150,1.2895,1.210593
200,1.2688,1.209696
250,1.2342,1.200176
300,1.2029,1.197046
350,1.1767,1.190623
400,1.2588,1.190975
450,1.1932,1.190723
500,1.2529,1.184475
