# 04. Transfer-learning
* 전처리를 완료한 txt 파일로 KoGPT2를 `transfer-learning`하는 단계
* 현재 모델이 안정적인 결과물을 만드는지 샘플을 확인해본다.
* 학교에서 빌린 GPU 서버로 진행

### transfer learning.py

In [None]:
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import GPT2LMHeadModel
from transformers import Trainer, TrainingArguments
from transformers import PreTrainedTokenizerFast

def load_dataset(file_path, tokenizer, block_size = 128):
    dataset = TextDataset(
        tokenizer = tokenizer,
        file_path = file_path,
        block_size = block_size
    )
    return dataset

def load_data_collator(tokenizer, mlm = False):
    data_collator = DataCollatorForLanguageModeling(
        tokenizer = tokenizer,
        mlm = mlm,
    )
    return data_collator

def train(train_file_path, model_name, output_dir, overwrite_output_dir, 
        per_device_train_batch_size, num_train_epochs, save_steps):
    tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name,
                bos_token = '</s>', eos_token = '</s>', unk_token = '<unk>',
                pad_token = '<pad>', mask_token = '<mask>')
    train_dataset = load_dataset(train_file_path, tokenizer)
    data_collator = load_data_collator(tokenizer)

    tokenizer.save_pretrained(output_dir, legacy_format = False)

    model = GPT2LMHeadModel.from_pretrained(model_name)

    model.save_pretrained(output_dir)

    training_args = TrainingArguments(
        output_dir = output_dir,
        overwrite_output_dir = overwrite_output_dir,
        per_device_train_batch_size = per_device_train_batch_size,
        num_train_epochs = num_train_epochs,
    )

    trainer = Trainer(
        model = model,
        args = training_args,
        data_collator = data_collator,
        train_dataset = train_dataset,
    )

    trainer.train()
    trainer.save_model()

train_file_path = '/root/dora/shortdata.txt'
model_name = 'skt/kogpt2-base-v2'
output_dir = '/root/dora/checkpoint'
overwrite_output_dir = False
per_device_train_batch_size = 8
num_train_epochs = 5.0
save_steps = 500

train(
    train_file_path = train_file_path,
    model_name = model_name,
    output_dir = output_dir,
    overwrite_output_dir = overwrite_output_dir,
    per_device_train_batch_size = per_device_train_batch_size,
    num_train_epochs = num_train_epochs,
    save_steps = save_steps
)