In [None]:
# notebooks/02_train_lora.ipynb

# !pip install torch transformers peft accelerate datasets

from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset

# Base model
model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# LoRA config
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1
)
model = get_peft_model(model, lora_config)

# Example dataset (replace with rubric/student answers later)
dataset = load_dataset("imdb", split="train[:1%]")

def tokenize_fn(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=256)

tokenized = dataset.map(tokenize_fn, batched=True)

# Training args
training_args = TrainingArguments(
    output_dir="../checkpoints",
    per_device_train_batch_size=2,
    num_train_epochs=1,
    logging_steps=10,
    save_strategy="epoch"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized
)

trainer.train()
model.save_pretrained("../checkpoints/lora")
print("✅ LoRA adapter saved at ../checkpoints/lora")
