In [10]:
import os
import json
import numpy as np
import pandas as pd
import re
import string
from collections import Counter
from tqdm import tqdm

import torch
from transformers import (
    Trainer,
    TrainingArguments,
    pipeline,
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForQuestionAnswering,
    BitsAndBytesConfig,
    DataCollatorForLanguageModeling
)
from peft import (
    prepare_model_for_kbit_training, 
    LoraConfig, 
    TaskType,
    get_peft_model
)

from trl import DataCollatorForCompletionOnlyLM, SFTTrainer, SFTConfig
from datasets import load_dataset, Dataset, DatasetDict
from accelerate import Accelerator

# 모델 정의

In [4]:
repo = "beomi/OPEN-SOLAR-KO-10.7B"

accelerater = Accelerator()

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
        repo,
        torch_dtype="auto",
        attn_implementation="eager",
        quantization_config=quantization_config
)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj'],
    task_type=TaskType.CAUSAL_LM
)

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
tokenizer = AutoTokenizer.from_pretrained(repo)
model, tokenizer = accelerater.prepare(model, tokenizer)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
`low_cpu_mem_usage` was None, now set to True since model is quantized.
Downloading shards: 100%|██████████| 8/8 [59:27<00:00, 445.95s/it]
Loading checkpoint shards: 100%|██████████| 8/8 [00:42<00:00,  5.33s/it]


# 데이터셋 정의

In [23]:
train_dataset = load_dataset("csv", data_files="/home/jovyan/work/prj_data/open/train.csv")

def get_template(context, question, answer):
    return [
        {
            "role": "system",
            "content": "너는 주어진 Context에서 Question에 대한 Answer를 찾는 챗봇이야. Context에서 Answer가 될 수 있는 부분을 찾아서 그대로 적어줘. 단, Answer는 주관식이 아니라 단답형으로 적어야 해"
        },
        {
            "role": "user",
            "content": f"Context: {context}\nQuestion: {question}"
        },
        {
            "role": "assistant",
            "content": answer
        }
    ]
print(tokenizer.apply_chat_template(get_template('d', 'e', 'a'), tokenize=False))

INSTRUCTION_TEMPLATE = """
너는 주어진 Context에서 Question에 대한 Answer를 찾는 챗봇이야. Context에서 Answer가 될 수 있는 부분을 찾아서 그대로 적어줘. 단, Answer는 주관식이 아니라 단답형으로 적어야 해.

Context: {context}
Question: {question}"""

RESPONSE_TEMPLATE = "{answer}"

class QADataCollator(DataCollatorForLanguageModeling):
    def __init__(self, tokenizer, inst, resp, mlm=False):
        super().__init__(tokenizer=tokenizer, mlm=mlm)
    
    def __call__(self, examples):
        batch = []
        for example in examples:
            context = example['context']
            question = example['question']
            answer = example['answer']
            
            instruction = INSTRUCTION_TEMPLATE.format(context=context, question=question)
            response = RESPONSE_TEMPLATE.format(answer=answer)
            
            prompt = self.tokenizer.apply_chat_template([
                {"role": "user", "content": instruction},
                {"role": "assistant", "content": response}
            ], tokenize=False)
            
            encoded = self.tokenizer.encode(prompt, truncation=True, max_length=512)
            batch.append(encoded)
        
        return self.tokenizer.pad(
            {"input_ids": batch},
            padding=True,
            return_tensors="pt",
        )
    
data_collator = QADataCollator(tokenizer=tokenizer,
                              inst = INSTRUCTION_TEMPLATE,
                              resp = RESPONSE_TEMPLATE)

<s>[INST] <<SYS>>
너는 주어진 Context에서 Question에 대한 Answer를 찾는 챗봇이야. Context에서 Answer가 될 수 있는 부분을 찾아서 그대로 적어줘. 단, Answer는 주관식이 아니라 단답형으로 적어야 해
<</SYS>>

Context: d
Question: e [/INST] a </s>


# 학습

In [6]:
import wandb
wandb.login()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
huggingfac

True

In [26]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

args = SFTConfig(
    output_dir='my_model',
    eval_strategy='epoch',
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    num_train_epochs=1,
    warmup_ratio=0.1,
    do_train=True,
    do_eval=True,
    logging_strategy='steps',
    logging_dir='logs',
    logging_steps=1,
    save_steps=0.2,
    report_to="wandb",
    max_seq_length=4096
)

def formatting_func(example):
    return f"Context: {example['context']}\nQuestion: {example['question']}\nAnswer: {example['answer']}"

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset['train'],
    data_collator=data_collator,
    tokenizer=tokenizer,
    peft_config=lora_config,
    formatting_func=formatting_func
)

trainer.train(resume_from_checkpoint=False)

Map:   0%|          | 0/33716 [00:00<?, ? examples/s]


KeyError: None

---