In [1]:
%cd /content/drive/MyDrive/goorm_project3
%pwd

/content/drive/MyDrive/goorm_project3


'/content/drive/MyDrive/goorm_project3'

In [2]:
!pip install transformers sentencepiece wandb;



In [3]:
#%%
from typing import Dict, List
import csv

from transformers import (
    EncoderDecoderModel,
    GPT2Tokenizer as BaseGPT2Tokenizer,
    ElectraTokenizer,
    BertTokenizer,
    PreTrainedTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Trainer,

    PreTrainedTokenizerFast,
    GPT2LMHeadModel
)

from lib.tokenization_kobert import KoBertTokenizer

import torch

import wandb

In [4]:
wandb.login()

wandb.init(project = 'test_project', entity = 'chohs1221')

# wandb.config.learning_rate = args.learning_rate
# wandb.config.epochs = args.epochs
# wandb.config.batch_size = args.batch_size

[34m[1mwandb[0m: Currently logged in as: [33mchohs1221[0m (use `wandb login --relogin` to force relogin)


In [5]:
#%%
class GPT2Tokenizer(BaseGPT2Tokenizer):
    def build_inputs_with_special_tokens(self, token_ids: List[int], _) -> List[int]:
        return token_ids + [self.eos_token_id]

In [6]:
#%%
enc_tokenizer = KoBertTokenizer.from_pretrained('monologg/kobert')
dec_tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'KoBertTokenizer'.


In [7]:
# %%
model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    'monologg/kobert',
    'distilgpt2',    
    pad_token_id=dec_tokenizer.bos_token_id
)
model.config.decoder_start_token_id = dec_tokenizer.bos_token_id

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['transformer.h.5.crossattention.masked_bias', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.5.crossattention.bias', 'transformer.h.2.crossattention.c_proj.weight', 'transformer.h.2.crossattention.bias', 'transformer.h.4.crossattention.masked_bias', 'transformer.h.5.ln_cross_attn.weight', 'transformer.h.2.ln_cross_attn.weight', 'transformer.h.5.crossattention.c_proj.weight', 'transformer.h.3.ln_cross_attn.weight', 'transformer.h.5.crossattention.c_proj.bias', 'transformer.h.4.crossattention.c_proj.weight', 'transformer.h.0.crossattention.bias', 'transformer.h.4.ln_cross_attn.weight', 'transformer.h.2.crossattention.q_attn.weight', 'transformer.h.1.crossattention.bias', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.3.crossattention.masked_bias', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.0.ln_cross_attn.weight', 'tr

In [8]:
#%%
class PairedDataset:
    def __init__(self, data, enc_tokenizer=enc_tokenizer, dec_tokenizer=dec_tokenizer):
        self.data = data

        self.enc_tokenizer = enc_tokenizer
        self.dec_tokenizer = dec_tokenizer

    @classmethod
    def loads(cls, *file_names):
        data = []
        for file_name in file_names:
            with open(file_name, 'r', encoding='cp949') as fd:
                data += [row[1:] for row in csv.reader(fd)]
        
        return cls(data)
    
    @classmethod
    def split(cls, datasets, ratio = 0.1):
        valid_length = int(len(datasets) * ratio)
        train = [datasets[i] for i in range(len(datasets) - valid_length)]
        valid = [datasets[i] for i in range(valid_length, len(datasets))]

        return cls(train), cls(valid)

    def __getitem__(self, index: int) -> List[str]:
        return self.data[index]

    def __len__(self):
        return len(self.data)

In [9]:
#%%
dataset = PairedDataset.loads('./data/kor2en.csv')
train_dataset_, valid_dataset_ = PairedDataset.split(dataset)
print(train_dataset_[0])
print(valid_dataset_[0])

['또 져버린 것 같아 넌 화가 나 보여 아른대는 Game over over over', 'I think I lost again You look mad In a blur, game over over over']
['어쩌면 난 너를 쉽게 잊을지 몰라. 혹시 너 아닌 다른 기억도 지워진다면. ', "Maybe I'll easily forget you Maybe I'll erase all the other memories"]


In [10]:
#%%
class TokenizeDataset:
    def __init__(self, dataset, enc_tokenizer, dec_tokenizer):
        self.dataset = dataset
        self.enc_tokenizer = enc_tokenizer
        self.dec_tokenizer = dec_tokenizer
    
    def __getitem__(self, index: int):
        src, trg = self.dataset[index]
        input = self.enc_tokenizer(src, return_attention_mask=False, return_token_type_ids=False, padding='max_length', truncation = True, max_length = 512)
        input['labels'] = self.dec_tokenizer(trg, return_attention_mask=False)['input_ids']

        return input
    
    def __len__(self):
        return len(self.dataset)

In [11]:
train_dataset = TokenizeDataset(train_dataset_, enc_tokenizer, dec_tokenizer)
valid_dataset = TokenizeDataset(valid_dataset_, enc_tokenizer, dec_tokenizer)
print(train_dataset[0])
print(enc_tokenizer.convert_ids_to_tokens(train_dataset[0]['input_ids']))
print(dec_tokenizer.convert_ids_to_tokens(train_dataset[0]['labels']))

{'input_ids': [2, 1861, 517, 7245, 6325, 905, 832, 517, 5695, 5112, 5330, 1370, 2376, 3093, 6115, 5808, 5760, 650, 373, 389, 517, 427, 454, 517, 427, 454, 517, 427, 454, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [12]:
# %%
collator = DataCollatorForSeq2Seq(enc_tokenizer, model, max_length = 512)

arguments = Seq2SeqTrainingArguments(
    output_dir='dump',
    do_train=True,
    do_eval=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=10,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_ratio=0.1,
    gradient_accumulation_steps=1,
    save_total_limit=5,
    dataloader_num_workers=1,
    fp16=True,
    load_best_model_at_end=True,
    report_to='wandb',
    run_name='test1'

)

trainer = Trainer(
    model,
    arguments,
    data_collator=collator,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset
)

Using amp half precision backend


In [13]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
# %%
trainer.train()

model.save_pretrained("dump/best_model")

***** Running training *****
  Num examples = 15107
  Num Epochs = 10
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 37770
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
  "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. "


Epoch,Training Loss,Validation Loss


In [None]:
wandb.finish()