In [1]:
import torch
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from huggingface_hub import snapshot_download

In [None]:
# # 스냅샷 로컬에 저장
# repo_id = "lemon-mint/gemma-ko-2b-instruct-v0.51"
# local_dir = "../model/gemma-ko-2b-instruct-v0.51"

# snapshot_download(repo_id=repo_id, local_dir=local_dir)

Fetching 11 files:   0%|          | 0/11 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/2.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.05G [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/657 [00:00<?, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

'C:\\workspace\\model\\gemma-ko-2b-instruct-v0.51'

In [None]:
local_dir = "../model/Qwen3-VL-2B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(local_dir, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

In [3]:
df = pd.read_excel("./data/emotion_data.xlsx")
df = df.fillna('')

In [4]:
data_list = []
for _, row in df.iterrows():
    data_list.append({
        'emotion': row['감정_소분류'],
        '사람문장1': row['사람문장1'],
        '시스템문장1': row['시스템문장1'],
        '사람문장2': row['사람문장2'],
        '시스템문장2': row['시스템문장2'],
        '사람문장3': row['사람문장3'],
        '시스템문장3': row['시스템문장3'],
    })

dataset = Dataset.from_dict({key: [d[key] for d in data_list] for key in data_list[0].keys()})
dataset[0]

{'emotion': '노여워하는',
 '사람문장1': '일은 왜 해도 해도 끝이 없을까? 화가 난다.',
 '시스템문장1': '많이 힘드시겠어요. 주위에 의논할 상대가 있나요?',
 '사람문장2': '그냥 내가 해결하는 게 나아. 남들한테 부담 주고 싶지도 않고.',
 '시스템문장2': '혼자 해결하기로 했군요. 혼자서 해결하기 힘들면 주위에 의논할 사람을 찾아보세요. ',
 '사람문장3': '',
 '시스템문장3': ''}

In [5]:
def formatting_prompts_func(examples):
    texts = []

    for i in range(len(examples['사람문장1'])):
        messages = []
        for turn in range(1, 4):
            user = examples[f'사람문장{turn}'][i].strip()
            assistant = examples[f'시스템문장{turn}'][i].strip()

            if user:
                messages.append({"role": "user", "content": user})
            if assistant:
                messages.append({"role": "assistant", "content": assistant})

        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )

        texts.append(text)

    return {"text": texts}

formatted_dataset = dataset.map(formatting_prompts_func, batched=True, remove_columns=dataset.column_names)
formatted_dataset[0]

Map:   0%|          | 0/51630 [00:00<?, ? examples/s]

{'text': '<|im_start|>user\n일은 왜 해도 해도 끝이 없을까? 화가 난다.<|im_end|>\n<|im_start|>assistant\n많이 힘드시겠어요. 주위에 의논할 상대가 있나요?<|im_end|>\n<|im_start|>user\n그냥 내가 해결하는 게 나아. 남들한테 부담 주고 싶지도 않고.<|im_end|>\n<|im_start|>assistant\n혼자 해결하기로 했군요. 혼자서 해결하기 힘들면 주위에 의논할 사람을 찾아보세요.<|im_end|>\n'}

In [6]:
def tokenize_function(examples):
    tokens = tokenizer(examples["text"], padding=True, return_tensors="pt")
    return tokens

tokenized_dataset = formatted_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_dataset[0]

Map:   0%|          | 0/51630 [00:00<?, ? examples/s]

{'input_ids': [151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  151645,
  15164

In [8]:
from transformers import Qwen3VLForConditionalGeneration

lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=[
        "q_proj", 
        "k_proj", 
        "v_proj"]
)
base_model = Qwen3VLForConditionalGeneration.from_pretrained(local_dir, dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
model = get_peft_model(base_model, lora_config)

In [9]:
split_dataset = tokenized_dataset.train_test_split(test_size=0.1)
train_ds = split_dataset["train"]
eval_ds = split_dataset["test"]

In [12]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

args=TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 500,
        # num_train_epochs=1,
        learning_rate = 2e-4, 
        bf16 = True,
        seed = 1234,
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        output_dir = "outputs",

        logging_steps=100,
        eval_strategy="steps",
        eval_steps=100,
        save_strategy="steps",
        save_steps=100,
        save_total_limit=3,
        load_best_model_at_end=True,
        metric_for_best_model="loss",
        greater_is_better=False
)

trainer = Trainer(
    model=model,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=data_collator,
    args=args
)

The model is already on multiple devices. Skipping the move to device specified in `args`.


In [13]:
trainer.train()

Step,Training Loss,Validation Loss


KeyboardInterrupt: 