# Урок 4: Адаптация генеративных моделей (LLM)

**Задача:** Файнтюнинг Mistral-7B под стиль общения с 4-bit квантизацией (для Colab T4 GPU).

In [None]:
!pip install -q transformers datasets peft accelerate bitsandbytes trl

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# 1. 4-bit квантизация для Colab T4
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

model_id = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
)
model = prepare_model_for_kbit_training(model)

In [None]:
# 2. LoRA для генеративных моделей
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
# 3. Подготовка датасета (Alpaca-style)
from datasets import load_dataset

# Пример: alpaca на русском или другой instruction dataset
dataset = load_dataset("tatsu-lab/alpaca", split="train").select(range(1000))

def format_instruction(sample):
    return f"""### Instruction:
{sample['instruction']}

### Input:
{sample['input']}

### Response:
{sample['output']}"""

dataset = dataset.map(lambda x: {"text": format_instruction(x)}, remove_columns=dataset.column_names)

In [None]:
# 4. Обучение через SFTTrainer
from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    processing_class=tokenizer,
    args=SFTConfig(
        max_length=512,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        warmup_steps=50,
        max_steps=500,
        learning_rate=2e-4,
        bf16=True,
        logging_steps=25,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
)

trainer.train()