In [1]:
import os
import json
import numpy as np
import pandas as pd
import re
import string
from collections import Counter
from tqdm import tqdm

import torch
from transformers import (
    pipeline,
    AutoTokenizer,
    AutoModelForCausalLM,
)
from datasets import load_dataset, Dataset, DatasetDict
from accelerate import Accelerator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 파일 경로에서 데이터를 읽어옴
file_path = './content/train.csv'
train_data = pd.read_csv(file_path)
from datasets import load_dataset

ds = load_dataset("wikimedia/wikipedia", "20231101.ab")
# 데이터를 셔플하고 인덱스를 재설정
train_data = train_data.sample(frac=1).reset_index(drop=True)

# 검증 데이터와 학습 데이터로 분할
val_data = train_data[:10]
train_data = train_data[10:]

# 검증 데이터에서 질문과 답변 컬럼을 선택
val_label_df = val_data[['question', 'answer']]

# 학습 데이터를 datasets의 Dataset으로 변환
train_dataset = Dataset.from_pandas(train_data)
val_dataset = Dataset.from_pandas(val_label_df)

In [3]:
# 필요한 라이브러리 불러오기
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer
import torch

# torch_dtype 설정
torch_dtype = torch.float16  # 예시로 torch.float16 사용. 필요에 따라 변경 가능

In [4]:
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)


model = AutoModelForCausalLM.from_pretrained(
    "beomi/Llama-3-Open-Ko-8B",
    quantization_config=quant_config,
    device_map={"": 0}
)
model.config.use_cache = False
model.config.pretraining_tp = 1

Loading checkpoint shards: 100%|██████████| 6/6 [00:12<00:00,  2.16s/it]


In [5]:
tokenizer = AutoTokenizer.from_pretrained(
              "beomi/Llama-3-Open-Ko-8B",
              trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
from peft import LoraConfig

peft_params = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)

In [7]:
training_params = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,  # epoch는 1로 설정
    max_steps=10000,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_8bit",
    warmup_steps=150,  # warmup_steps을 절대값으로 설정
    learning_rate=2e-4,
    fp16=True,
    logging_steps=100,
    push_to_hub=False,
    report_to='tensorboard',
)

In [8]:
trainer = SFTTrainer(
    model=model,
    args=training_params,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    peft_config=peft_params,
    dataset_text_field="question",  # 여기에 적절한 필드 이름을 지정
    # 필요한 추가적인 파라미터들
)

trainer.train()


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Map:   0%|          | 0/33706 [00:00<?, ? examples/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Map: 100%|██████████| 33706/33706 [00:00<00:00, 73550.21 examples/s]
Map: 100%|██████████| 10/10 [00:00<00:00, 2966.69 examples/s]
max_steps is given, it will override any value given in num_train_epochs


Step,Training Loss
100,3.2715
200,2.7406
300,2.783
400,2.6761
500,2.799
600,2.6508
700,2.6597
800,2.6451
900,2.5994
1000,2.6181


KeyboardInterrupt: 

In [None]:
# 저장된 체크포인트 경로
checkpoint_path = "./results/checkpoint-10000"

# 모델 로드
model = YourModelClass.from_pretrained(checkpoint_path)

# SFTTrainer 설정 (동일하게 유지)
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    peft_config=peft_params,
    dataset_text_field="question",
    # 필요한 추가적인 파라미터들
)

# 학습 재개
trainer.train(resume_from_checkpoint=checkpoint_path)


In [None]:
trainer.save_model('./models/20240704')