# Description


# Modules and Global Variables

In [1]:
from transformers import (
    AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, 
    DefaultDataCollator, DataCollatorWithPadding, 
    TrainingArguments, Trainer,
)

from transformers.optimization import (
    AdamW, get_linear_schedule_with_warmup,
    Adafactor, AdafactorSchedule,
)

import torch
import wandb

import datasets
import evaluate

from sklearn.metrics import accuracy_score, f1_score

import numpy as np
import pandas as pd

import os
import re
import random

import demoji

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(f'torch.__version__: {torch.__version__}')
print(f'torch.cuda.is_available(): {torch.cuda.is_available()}')
NGPU = torch.cuda.device_count()
print(f'NGPU: {NGPU}')
# NGPU = torch.cuda.device_count()
# if NGPU > 1:
#     model = torch.nn.DataParallel(model, device_ids=list(range(NGPU)))

torch.__version__: 1.12.1
torch.cuda.is_available(): True
NGPU: 4


In [3]:
### labels

ce_labels = ['True', 'False']
pc_labels = ['positive', 'negative', 'neutral']
pc_binary_labels = ['True', 'False']

labels = ce_labels

label2id = {k: i for i, k in enumerate(labels)}
id2label = {i: k for i, k in enumerate(labels)}
num_labels = len(labels)

print(label2id)
print(id2label)

{'True': 0, 'False': 1}
{0: 'True', 1: 'False'}


In [4]:
### paths and names

PROJECT_NAME = 'aspect_category_detection'
RUN_ID = 'uncleaned_v4'

DATA_V = 'uncleaned_v4'
DATA_T = 'ce' # ce or pc or pc_binary
AUGMENTATION = False
AUG_NAME = 'balanced'

model_checkpoint = 'snunlp/KR-ELECTRA-discriminator'

notebook_name = 'acd_binary_trainer.ipynb'

### fixed

model_name = re.sub(r'[/-]', r'_', model_checkpoint).lower()
run_name = f'{model_name}_{RUN_ID}'

ROOT_PATH = './'
SAVE_PATH = os.path.join(ROOT_PATH, 'training_results', run_name, 'acd')
NOTEBOOK_PATH = os.path.join('./', notebook_name)

augornot = f'_{AUG_NAME}' if AUGMENTATION is True else ''
TRAIN_DATA_PATH = os.path.join(ROOT_PATH, 'dataset', DATA_V, f'{DATA_T}_train{augornot}.csv')
EVAL_DATA_PATH = os.path.join(ROOT_PATH, 'dataset', DATA_V, f'{DATA_T}_dev.csv')

!mkdir -p {SAVE_PATH}

In [5]:
if os.path.exists(SAVE_PATH):
    print(f'{SAVE_PATH} exists.')
else:
    print(f'{SAVE_PATH} does not exist.')
if os.path.exists(NOTEBOOK_PATH):
    print(f'{NOTEBOOK_PATH} exists.')
else:
    print(f'{NOTEBOOK_PATH} does not exist.')
if os.path.exists(TRAIN_DATA_PATH):
    print(f'{TRAIN_DATA_PATH} exists.')
else:
    print(f'{TRAIN_DATA_PATH} does not exist.')
if os.path.exists(EVAL_DATA_PATH):
    print(f'{EVAL_DATA_PATH} exists.')
else:
    print(f'{EVAL_DATA_PATH} does not exist.')

./training_results/snunlp_kr_electra_discriminator_uncleaned_v4/acd exists.
./acd_binary_trainer.ipynb exists.
./dataset/uncleaned_v4/ce_train.csv exists.
./dataset/uncleaned_v4/ce_dev.csv exists.


In [6]:
### rest of training args

report_to="wandb"

fp16 = False

num_train_epochs = 10
batch_size = 25 * 2
gradient_accumulation_steps = 1

optim = 'adamw_torch' # 'adamw_hf'

learning_rate = 3e-6 / 8 * batch_size * 4 # 5e-5
weight_decay = 0.01 # 0
adam_epsilon = 1e-8

lr_scheduler_type = 'cosine'
warmup_ratio = 0

save_total_limit = 2

load_best_model_at_end = True
metric_for_best_model ='eval_loss'

save_strategy = "epoch"
evaluation_strategy = "epoch"

logging_strategy = "steps"
logging_first_step = True 
logging_steps = 500

# WandB Configuration

In [7]:
%env WANDB_PROJECT={PROJECT_NAME}
%env WANDB_NOTEBOOK_NAME={NOTEBOOK_PATH}
%env WANDB_LOG_MODEL=true
%env WANDB_WATCH=all
wandb.login()

env: WANDB_PROJECT=aspect_category_detection
env: WANDB_NOTEBOOK_NAME=./acd_binary_trainer.ipynb
env: WANDB_LOG_MODEL=true
env: WANDB_WATCH=all


[34m[1mwandb[0m: Currently logged in as: [33mdotsnangles[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# Load Model, Tokenizer, and Collator

In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint, label2id=label2id, id2label=id2label, num_labels=num_labels
)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Some weights of the model checkpoint at snunlp/KR-ELECTRA-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense.weight']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at snunlp/KR-ELECTRA-discriminator and are newly initialized: ['classifier.dense.weight', 'classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
train_path = f'./dataset/{DATA_V}/raw_train.csv'
dev_path = f'./dataset/{DATA_V}/raw_dev.csv'
test_path = f'./dataset/{DATA_V}/raw_test.csv'
train = pd.read_csv(train_path)
dev = pd.read_csv(dev_path)
test = pd.read_csv(test_path)

### new
entity_property_pair = [
    '본품#가격', '본품#다양성', '본품#디자인', '본품#인지도', '본품#일반', '본품#편의성', '본품#품질',
    '브랜드#가격', '브랜드#디자인', '브랜드#인지도', '브랜드#일반', '브랜드#품질',
    '제품 전체#가격', '제품 전체#다양성', '제품 전체#디자인', '제품 전체#인지도', '제품 전체#일반', '제품 전체#편의성', '제품 전체#품질',
    '패키지/구성품#가격', '패키지/구성품#다양성', '패키지/구성품#디자인', '패키지/구성품#일반', '패키지/구성품#편의성', '패키지/구성품#품질'
]
special_tokens = ['&name&', '&affiliation&', '&social-security-num&', '&tel-num&', '&card-num&', '&bank-account&', '&num&', '&online-account&']
emojis = pd.concat([train.sentence_form, dev.sentence_form, test.sentence_form], ignore_index=True, verify_integrity=True).to_frame()
emojis = list(set(demoji.findall(' '.join(emojis.sentence_form.to_list())).keys()))
ep_labels = pd.Series(entity_property_pair, name='sentence_form', copy=True)

tokens2add = special_tokens + emojis

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
print(len(tokenizer))
tokenizer_train_data = pd.concat([train.sentence_form, dev.sentence_form, test.sentence_form], ignore_index=True, verify_integrity=True).to_frame().drop_duplicates()
tokenizer_train_data = tokenizer_train_data.sentence_form.to_list()
new_tokenizer = tokenizer.train_new_from_iterator(tokenizer_train_data, vocab_size=1)
new_tokens = set(list(new_tokenizer.vocab.keys()) + tokens2add) - set(tokenizer.vocab.keys())
tokenizer.add_tokens(list(new_tokens))
print(len(new_tokenizer))
print(len(tokenizer))
model.resize_token_embeddings(len(tokenizer))

30000





3018
30111


Embedding(30111, 768)

In [10]:
print(len(new_tokens))
print(new_tokens)

111
{'읒', '🙋🏻', '💡', '쓩', 'ᴜ', '##◍', '⏰', '💄', '##ᵕ', '##ᴜ', '##ˇ', '##💆', 'ɴ', '🧚\u200d♀️', '🙏🏻', 'ғ', '‼️', '⁉️', '🙋\u200d♀️', '챳', '☝️', '👠', '✔️', 'ɢ', 'ʀ', 'ᴛ', '🍷', '##💄', '👦🏼', '##ɢ', '🙋🏻\u200d♀️', '☺️', '죱', '&social-security-num&', '##👠', '##ɴ', '〰️', '👉🏻', '&affiliation&', '##🚗', '##💇', '❣️', '🐄', '😯', '✌️', '◍', '##㉦', '##쫜', '🤡', '##🤡', '💬', 'ᴍ', 'ɪ', 'ᴠ', '뿤', '👋🏻', '##읒', '💪🏻', '&bank-account&', '##ᴠ', '🙌🏻', '&num&', '🚗', '👌🏻', '💆🏻\u200d♀️', '❤️', '&tel-num&', '##ꈍ', 'ᴡ', '🙆🏻', '##ᴛ', '👩\u200d👦', '쨕', '🍼', '👨\u200d👧', '쫜', '##ᴡ', '##🥤', '&online-account&', '&card-num&', '😺', '✌🏻', 'ꈍ', '💆\u200d♀️', '💆', '🏃\u200d♀️', '##ᴍ', '☝🏻', '🕺', 'ᴘ', 'ʜ', '🥤', '♥️', '🤘🏻', '👏🏻', '##뜌', '㉦', '&name&', '💇🏼\u200d♀️', '##ᴘ', '##쨕', '🙆\u200d♂️', '##죱', '##➕', '##ʀ', '➕', '💇', '##ɪ', '뜌', 'ˇ', 'ᵕ'}


In [11]:
model.config.label2id, model.config.id2label, model.num_labels

({'True': 0, 'False': 1}, {0: 'True', 1: 'False'}, 2)

In [12]:
# entity_property_pair = [
#     '본품#가격', '본품#다양성', '본품#디자인', '본품#인지도', '본품#일반', '본품#편의성', '본품#품질',
#     '브랜드#가격', '브랜드#디자인', '브랜드#인지도', '브랜드#일반', '브랜드#품질',
#     '제품 전체#가격', '제품 전체#다양성', '제품 전체#디자인', '제품 전체#인지도', '제품 전체#일반', '제품 전체#편의성', '제품 전체#품질',
#     '패키지/구성품#가격', '패키지/구성품#다양성', '패키지/구성품#디자인', '패키지/구성품#일반', '패키지/구성품#편의성', '패키지/구성품#품질'
# ]
# polarity_id_to_name = ['positive', 'negative', 'neutral']
# tokenizer_tester = []
# for pair in entity_property_pair:
#     for polarity in polarity_id_to_name:
#         tokenizer_tester.append('#'.join([pair, polarity]))
# for e in tokenizer_tester:
#     print(tokenizer.decode(tokenizer.encode(e)))
# for e in tokenizer_tester:
#     print(tokenizer.encode(e))

# Define Metric

In [13]:
accuracy_metric = evaluate.load('accuracy')
f1_metric = evaluate.load('f1')

In [14]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    accuracy = accuracy_metric.compute(references=labels, predictions=predictions)['accuracy']
    f1_true, f1_false = tuple(f1_metric.compute(references=labels, predictions=predictions, average=None, labels=[0,1])['f1'])
    f1_macro = f1_metric.compute(references=labels, predictions=predictions, average='macro')['f1']
    f1_micro = f1_metric.compute(references=labels, predictions=predictions, average='micro')['f1']
    
    return {'accuracy': accuracy, 'f1_true': f1_true, 'f1_false': f1_false, 'f1_macro': f1_macro, 'f1_micro': f1_micro}

# Load Data

In [15]:
def preprocess_function(examples):
    return tokenizer(examples["form"], examples["pair"], truncation=True)

In [16]:
train_dataset = pd.read_csv(TRAIN_DATA_PATH)
eval_dataset = pd.read_csv(EVAL_DATA_PATH)
# train_dataset = pd.concat([train_dataset, eval_dataset])
train_dataset = datasets.Dataset.from_pandas(train_dataset) #.shuffle(seed=42)
eval_dataset = datasets.Dataset.from_pandas(eval_dataset) #.shuffle(seed=42)
train_dataset = train_dataset.map(preprocess_function, batched=False)
eval_dataset = eval_dataset.map(preprocess_function, batched=False)

  0%|          | 0/67900 [00:00<?, ?ex/s]

  1%|          | 495/67900 [00:00<00:13, 4946.20ex/s]

  1%|▏         | 990/67900 [00:00<00:31, 2104.07ex/s]

  2%|▏         | 1282/67900 [00:00<00:29, 2293.97ex/s]

  3%|▎         | 1790/67900 [00:00<00:21, 3048.77ex/s]

  3%|▎         | 2236/67900 [00:00<00:19, 3441.52ex/s]

  4%|▍         | 2691/67900 [00:00<00:17, 3756.65ex/s]

  5%|▍         | 3110/67900 [00:00<00:16, 3834.83ex/s]

  5%|▌         | 3622/67900 [00:01<00:15, 4203.82ex/s]

  6%|▌         | 4067/67900 [00:01<00:15, 4152.50ex/s]

  7%|▋         | 4569/67900 [00:01<00:14, 4401.27ex/s]

  7%|▋         | 5023/67900 [00:01<00:15, 4055.34ex/s]

  8%|▊         | 5565/67900 [00:01<00:14, 4430.96ex/s]

  9%|▉         | 6022/67900 [00:01<00:13, 4439.08ex/s]

 10%|▉         | 6514/67900 [00:01<00:13, 4574.01ex/s]

 10%|█         | 7000/67900 [00:01<00:13, 4450.68ex/s]

 11%|█         | 7451/67900 [00:01<00:13, 4447.57ex/s]

 12%|█▏        | 7945/67900 [00:02<00:13, 4589.04ex/s]

 12%|█▏        | 8421/67900 [00:02<00:12, 4638.24ex/s]

 13%|█▎        | 8893/67900 [00:02<00:12, 4658.84ex/s]

 14%|█▍        | 9361/67900 [00:02<00:13, 4403.11ex/s]

 14%|█▍        | 9818/67900 [00:02<00:13, 4447.31ex/s]

 15%|█▌        | 10266/67900 [00:02<00:13, 4178.91ex/s]

 16%|█▌        | 10697/67900 [00:02<00:13, 4213.40ex/s]

 16%|█▋        | 11122/67900 [00:02<00:13, 4107.16ex/s]

 17%|█▋        | 11585/67900 [00:02<00:13, 4253.85ex/s]

 18%|█▊        | 12014/67900 [00:02<00:13, 4262.18ex/s]

 18%|█▊        | 12454/67900 [00:03<00:12, 4298.91ex/s]

 19%|█▉        | 12886/67900 [00:03<00:12, 4266.26ex/s]

 20%|█▉        | 13314/67900 [00:03<00:13, 4104.16ex/s]

 20%|██        | 13730/67900 [00:03<00:13, 4117.11ex/s]

 21%|██        | 14144/67900 [00:03<00:13, 3982.53ex/s]

 22%|██▏       | 14607/67900 [00:03<00:12, 4167.74ex/s]

 22%|██▏       | 15026/67900 [00:03<00:12, 4090.39ex/s]

 23%|██▎       | 15464/67900 [00:03<00:12, 4172.45ex/s]

 23%|██▎       | 15905/67900 [00:03<00:12, 4240.80ex/s]

 24%|██▍       | 16331/67900 [00:04<00:12, 4059.50ex/s]

 25%|██▍       | 16828/67900 [00:04<00:11, 4319.58ex/s]

 25%|██▌       | 17263/67900 [00:04<00:11, 4286.51ex/s]

 26%|██▌       | 17778/67900 [00:04<00:11, 4537.19ex/s]

 27%|██▋       | 18234/67900 [00:04<00:11, 4457.29ex/s]

 28%|██▊       | 18682/67900 [00:04<00:11, 4462.35ex/s]

 28%|██▊       | 19130/67900 [00:04<00:11, 4164.26ex/s]

 29%|██▉       | 19576/67900 [00:04<00:11, 4245.43ex/s]

 29%|██▉       | 20005/67900 [00:04<00:11, 4084.28ex/s]

 30%|███       | 20531/67900 [00:04<00:10, 4414.35ex/s]

 31%|███       | 21003/67900 [00:05<00:10, 4501.79ex/s]

 32%|███▏      | 21473/67900 [00:05<00:10, 4556.66ex/s]

 32%|███▏      | 21932/67900 [00:05<00:10, 4502.88ex/s]

 33%|███▎      | 22385/67900 [00:05<00:10, 4381.16ex/s]

 34%|███▎      | 22829/67900 [00:05<00:10, 4394.48ex/s]

 34%|███▍      | 23270/67900 [00:05<00:10, 4262.92ex/s]

 35%|███▌      | 23775/67900 [00:05<00:09, 4488.51ex/s]

 36%|███▌      | 24226/67900 [00:05<00:10, 4363.37ex/s]

 36%|███▋      | 24723/67900 [00:05<00:09, 4537.16ex/s]

 37%|███▋      | 25179/67900 [00:06<00:09, 4435.86ex/s]

 38%|███▊      | 25625/67900 [00:06<00:09, 4358.71ex/s]

 38%|███▊      | 26063/67900 [00:06<00:10, 3856.84ex/s]

 39%|███▉      | 26573/67900 [00:06<00:09, 4187.37ex/s]

 40%|███▉      | 27018/67900 [00:06<00:09, 4258.51ex/s]

 40%|████      | 27455/67900 [00:06<00:09, 4288.84ex/s]

 41%|████      | 27963/67900 [00:06<00:08, 4514.27ex/s]

 42%|████▏     | 28420/67900 [00:06<00:09, 4268.31ex/s]

 43%|████▎     | 28858/67900 [00:06<00:09, 4297.23ex/s]

 43%|████▎     | 29293/67900 [00:06<00:09, 4072.24ex/s]

 44%|████▎     | 29706/67900 [00:07<00:09, 4083.18ex/s]

 44%|████▍     | 30118/67900 [00:07<00:11, 3299.06ex/s]

 45%|████▌     | 30596/67900 [00:07<00:10, 3664.15ex/s]

 46%|████▌     | 31001/67900 [00:07<00:09, 3759.50ex/s]

 46%|████▋     | 31459/67900 [00:07<00:09, 3980.93ex/s]

 47%|████▋     | 31969/67900 [00:07<00:08, 4293.05ex/s]

 48%|████▊     | 32412/67900 [00:07<00:08, 4188.17ex/s]

 48%|████▊     | 32903/67900 [00:07<00:07, 4390.48ex/s]

 49%|████▉     | 33351/67900 [00:08<00:08, 4300.31ex/s]

 50%|████▉     | 33822/67900 [00:08<00:07, 4415.30ex/s]

 50%|█████     | 34269/67900 [00:08<00:08, 4172.15ex/s]

 51%|█████     | 34692/67900 [00:08<00:09, 3625.51ex/s]

 52%|█████▏    | 35075/67900 [00:08<00:08, 3676.42ex/s]

 52%|█████▏    | 35455/67900 [00:08<00:08, 3626.07ex/s]

 53%|█████▎    | 35826/67900 [00:08<00:09, 3517.80ex/s]

 53%|█████▎    | 36184/67900 [00:08<00:10, 3140.19ex/s]

 54%|█████▍    | 36680/67900 [00:08<00:08, 3604.03ex/s]

 55%|█████▍    | 37178/67900 [00:09<00:07, 3973.58ex/s]

 56%|█████▌    | 37746/67900 [00:09<00:06, 4445.16ex/s]

 56%|█████▋    | 38241/67900 [00:09<00:06, 4587.84ex/s]

 57%|█████▋    | 38795/67900 [00:09<00:05, 4861.91ex/s]

 58%|█████▊    | 39290/67900 [00:09<00:05, 4807.45ex/s]

 59%|█████▊    | 39810/67900 [00:09<00:05, 4920.56ex/s]

 59%|█████▉    | 40307/67900 [00:09<00:05, 4688.56ex/s]

 60%|██████    | 40871/67900 [00:09<00:05, 4957.26ex/s]

 61%|██████    | 41372/67900 [00:09<00:05, 4561.96ex/s]

 62%|██████▏   | 41838/67900 [00:10<00:05, 4477.60ex/s]

 62%|██████▏   | 42292/67900 [00:10<00:05, 4322.48ex/s]

 63%|██████▎   | 42756/67900 [00:10<00:05, 4408.79ex/s]

 64%|██████▎   | 43201/67900 [00:10<00:05, 4284.15ex/s]

 64%|██████▍   | 43677/67900 [00:10<00:05, 4415.87ex/s]

 65%|██████▍   | 44122/67900 [00:10<00:05, 4361.46ex/s]

 66%|██████▌   | 44619/67900 [00:10<00:05, 4535.25ex/s]

 66%|██████▋   | 45075/67900 [00:10<00:05, 4385.38ex/s]

 67%|██████▋   | 45548/67900 [00:10<00:04, 4483.51ex/s]

 68%|██████▊   | 46000/67900 [00:10<00:04, 4468.01ex/s]

 68%|██████▊   | 46500/67900 [00:11<00:04, 4622.97ex/s]

 69%|██████▉   | 47000/67900 [00:11<00:04, 4581.20ex/s]

 70%|██████▉   | 47460/67900 [00:11<00:04, 4581.78ex/s]

 71%|███████   | 47934/67900 [00:11<00:04, 4626.43ex/s]

 71%|███████▏  | 48398/67900 [00:11<00:04, 4402.33ex/s]

 72%|███████▏  | 48885/67900 [00:11<00:04, 4533.28ex/s]

 73%|███████▎  | 49341/67900 [00:11<00:04, 4435.84ex/s]

 73%|███████▎  | 49846/67900 [00:11<00:03, 4611.07ex/s]

 74%|███████▍  | 50310/67900 [00:11<00:03, 4575.82ex/s]

 75%|███████▍  | 50786/67900 [00:12<00:03, 4629.22ex/s]

 75%|███████▌  | 51251/67900 [00:12<00:03, 4425.86ex/s]

 76%|███████▌  | 51697/67900 [00:12<00:03, 4431.83ex/s]

 77%|███████▋  | 52142/67900 [00:12<00:03, 4378.46ex/s]

 78%|███████▊  | 52657/67900 [00:12<00:03, 4599.93ex/s]

 78%|███████▊  | 53119/67900 [00:12<00:03, 4304.85ex/s]

 79%|███████▉  | 53554/67900 [00:12<00:03, 3947.75ex/s]

 80%|███████▉  | 53994/67900 [00:12<00:03, 4068.66ex/s]

 80%|████████  | 54447/67900 [00:12<00:03, 4193.28ex/s]

 81%|████████  | 54873/67900 [00:13<00:03, 3725.45ex/s]

 81%|████████▏ | 55259/67900 [00:13<00:03, 3481.36ex/s]

 82%|████████▏ | 55704/67900 [00:13<00:03, 3731.25ex/s]

 83%|████████▎ | 56089/67900 [00:13<00:03, 3631.95ex/s]

 83%|████████▎ | 56569/67900 [00:13<00:02, 3948.30ex/s]

 84%|████████▍ | 57000/67900 [00:13<00:02, 3878.45ex/s]

 85%|████████▍ | 57395/67900 [00:13<00:02, 3882.95ex/s]

 85%|████████▌ | 57828/67900 [00:13<00:02, 4006.21ex/s]

 86%|████████▌ | 58233/67900 [00:13<00:02, 3952.86ex/s]

 87%|████████▋ | 58734/67900 [00:13<00:02, 4254.57ex/s]

 87%|████████▋ | 59163/67900 [00:14<00:02, 4156.49ex/s]

 88%|████████▊ | 59667/67900 [00:14<00:01, 4411.13ex/s]

 89%|████████▊ | 60111/67900 [00:14<00:01, 4380.72ex/s]

 89%|████████▉ | 60607/67900 [00:14<00:01, 4549.14ex/s]

 90%|████████▉ | 61064/67900 [00:14<00:01, 4470.72ex/s]

 91%|█████████ | 61617/67900 [00:14<00:01, 4779.06ex/s]

 91%|█████████▏| 62097/67900 [00:14<00:01, 4606.50ex/s]

 92%|█████████▏| 62583/67900 [00:14<00:01, 4678.03ex/s]

 93%|█████████▎| 63053/67900 [00:14<00:01, 4554.08ex/s]

 94%|█████████▎| 63562/67900 [00:15<00:00, 4708.47ex/s]

 94%|█████████▍| 64035/67900 [00:15<00:00, 4532.03ex/s]

 95%|█████████▌| 64533/67900 [00:15<00:00, 4659.47ex/s]

 96%|█████████▌| 65002/67900 [00:15<00:00, 4619.20ex/s]

 96%|█████████▋| 65466/67900 [00:15<00:00, 4578.14ex/s]

 97%|█████████▋| 65925/67900 [00:15<00:00, 4418.39ex/s]

 98%|█████████▊| 66369/67900 [00:15<00:00, 4084.39ex/s]

 98%|█████████▊| 66783/67900 [00:15<00:00, 3719.29ex/s]

 99%|█████████▉| 67201/67900 [00:15<00:00, 3839.60ex/s]

100%|█████████▉| 67766/67900 [00:16<00:00, 4332.63ex/s]

100%|██████████| 67900/67900 [00:16<00:00, 4227.16ex/s]




  0%|          | 0/65675 [00:00<?, ?ex/s]

  0%|          | 121/65675 [00:00<01:55, 565.87ex/s]

  1%|          | 626/65675 [00:00<00:27, 2360.30ex/s]

  2%|▏         | 1021/65675 [00:00<00:22, 2937.84ex/s]

  2%|▏         | 1527/65675 [00:00<00:17, 3663.33ex/s]

  3%|▎         | 2000/65675 [00:00<00:16, 3920.32ex/s]

  4%|▍         | 2472/65675 [00:00<00:15, 4171.20ex/s]

  5%|▍         | 2962/65675 [00:00<00:14, 4395.79ex/s]

  5%|▌         | 3419/65675 [00:00<00:14, 4351.66ex/s]

  6%|▌         | 3877/65675 [00:01<00:13, 4418.93ex/s]

  7%|▋         | 4328/65675 [00:01<00:14, 4258.25ex/s]

  7%|▋         | 4811/65675 [00:01<00:13, 4421.48ex/s]

  8%|▊         | 5259/65675 [00:01<00:13, 4368.30ex/s]

  9%|▊         | 5744/65675 [00:01<00:13, 4507.44ex/s]

  9%|▉         | 6198/65675 [00:01<00:13, 4420.04ex/s]

 10%|█         | 6646/65675 [00:01<00:13, 4436.83ex/s]

 11%|█         | 7092/65675 [00:01<00:13, 4273.42ex/s]

 12%|█▏        | 7564/65675 [00:01<00:13, 4399.89ex/s]

 12%|█▏        | 8013/65675 [00:01<00:13, 4424.47ex/s]

 13%|█▎        | 8546/65675 [00:02<00:12, 4688.99ex/s]

 14%|█▎        | 9017/65675 [00:02<00:12, 4478.51ex/s]

 14%|█▍        | 9468/65675 [00:02<00:12, 4401.59ex/s]

 15%|█▌        | 9911/65675 [00:02<00:13, 4277.00ex/s]

 16%|█▌        | 10341/65675 [00:02<00:15, 3462.48ex/s]

 16%|█▋        | 10826/65675 [00:02<00:14, 3804.31ex/s]

 17%|█▋        | 11232/65675 [00:02<00:14, 3774.77ex/s]

 18%|█▊        | 11681/65675 [00:02<00:13, 3958.80ex/s]

 18%|█▊        | 12091/65675 [00:03<00:13, 3906.14ex/s]

 19%|█▉        | 12509/65675 [00:03<00:13, 3980.87ex/s]

 20%|█▉        | 12952/65675 [00:03<00:12, 4108.15ex/s]

 20%|██        | 13369/65675 [00:03<00:13, 3914.86ex/s]

 21%|██        | 13843/65675 [00:03<00:12, 4144.76ex/s]

 22%|██▏       | 14263/65675 [00:03<00:15, 3378.25ex/s]

 22%|██▏       | 14707/65675 [00:03<00:13, 3641.42ex/s]

 23%|██▎       | 15095/65675 [00:03<00:13, 3637.74ex/s]

 24%|██▎       | 15564/65675 [00:03<00:12, 3919.83ex/s]

 24%|██▍       | 16026/65675 [00:04<00:12, 4113.81ex/s]

 25%|██▌       | 16512/65675 [00:04<00:11, 4325.61ex/s]

 26%|██▌       | 16954/65675 [00:04<00:11, 4301.85ex/s]

 26%|██▋       | 17391/65675 [00:04<00:11, 4068.01ex/s]

 27%|██▋       | 17838/65675 [00:04<00:11, 4180.45ex/s]

 28%|██▊       | 18262/65675 [00:04<00:11, 4020.53ex/s]

 29%|██▊       | 18723/65675 [00:04<00:11, 4184.09ex/s]

 29%|██▉       | 19146/65675 [00:04<00:11, 4096.38ex/s]

 30%|██▉       | 19687/65675 [00:04<00:10, 4469.18ex/s]

 31%|███       | 20138/65675 [00:04<00:10, 4449.45ex/s]

 31%|███▏      | 20586/65675 [00:05<00:10, 4428.89ex/s]

 32%|███▏      | 21031/65675 [00:05<00:10, 4331.87ex/s]

 33%|███▎      | 21466/65675 [00:05<00:10, 4179.61ex/s]

 33%|███▎      | 21886/65675 [00:05<00:11, 3779.69ex/s]

 34%|███▍      | 22309/65675 [00:05<00:11, 3900.17ex/s]

 35%|███▍      | 22794/65675 [00:05<00:10, 4163.32ex/s]

 35%|███▌      | 23257/65675 [00:05<00:09, 4294.78ex/s]

 36%|███▌      | 23779/65675 [00:05<00:09, 4559.77ex/s]

 37%|███▋      | 24240/65675 [00:05<00:09, 4344.67ex/s]

 38%|███▊      | 24680/65675 [00:06<00:09, 4337.02ex/s]

 38%|███▊      | 25118/65675 [00:06<00:09, 4144.75ex/s]

 39%|███▉      | 25581/65675 [00:06<00:09, 4279.75ex/s]

 40%|███▉      | 26013/65675 [00:06<00:10, 3932.60ex/s]

 40%|████      | 26440/65675 [00:06<00:09, 4024.51ex/s]

 41%|████      | 26849/65675 [00:06<00:09, 4001.19ex/s]

 41%|████▏     | 27254/65675 [00:06<00:10, 3667.48ex/s]

 42%|████▏     | 27667/65675 [00:06<00:10, 3790.99ex/s]

 43%|████▎     | 28053/65675 [00:06<00:09, 3806.99ex/s]

 43%|████▎     | 28535/65675 [00:07<00:09, 4093.79ex/s]

 44%|████▍     | 29000/65675 [00:07<00:08, 4157.71ex/s]

 45%|████▍     | 29489/65675 [00:07<00:08, 4367.70ex/s]

 46%|████▌     | 29956/65675 [00:07<00:08, 4453.63ex/s]

 46%|████▋     | 30404/65675 [00:07<00:08, 4293.27ex/s]

 47%|████▋     | 30866/65675 [00:07<00:07, 4387.14ex/s]

 48%|████▊     | 31308/65675 [00:07<00:07, 4352.03ex/s]

 48%|████▊     | 31776/65675 [00:07<00:07, 4447.35ex/s]

 49%|████▉     | 32223/65675 [00:07<00:08, 4170.50ex/s]

 50%|████▉     | 32673/65675 [00:07<00:07, 4260.96ex/s]

 50%|█████     | 33103/65675 [00:08<00:07, 4093.61ex/s]

 51%|█████     | 33571/65675 [00:08<00:07, 4258.99ex/s]

 52%|█████▏    | 34001/65675 [00:08<00:07, 4244.14ex/s]

 52%|█████▏    | 34436/65675 [00:08<00:07, 4274.14ex/s]

 53%|█████▎    | 34866/65675 [00:08<00:07, 4276.82ex/s]

 54%|█████▎    | 35295/65675 [00:08<00:07, 3812.68ex/s]

 54%|█████▍    | 35687/65675 [00:08<00:07, 3768.45ex/s]

 55%|█████▌    | 36188/65675 [00:08<00:07, 4110.10ex/s]

 56%|█████▌    | 36743/65675 [00:08<00:06, 4516.04ex/s]

 57%|█████▋    | 37203/65675 [00:09<00:06, 4221.55ex/s]

 57%|█████▋    | 37655/65675 [00:09<00:06, 4302.42ex/s]

 58%|█████▊    | 38092/65675 [00:09<00:06, 4220.49ex/s]

 59%|█████▊    | 38519/65675 [00:09<00:06, 4219.22ex/s]

 59%|█████▉    | 38945/65675 [00:09<00:06, 3836.67ex/s]

 60%|█████▉    | 39337/65675 [00:09<00:07, 3761.82ex/s]

 61%|██████    | 39857/65675 [00:09<00:06, 4155.08ex/s]

 61%|██████▏   | 40280/65675 [00:09<00:06, 4169.07ex/s]

 62%|██████▏   | 40780/65675 [00:09<00:05, 4406.43ex/s]

 63%|██████▎   | 41226/65675 [00:10<00:05, 4316.76ex/s]

 64%|██████▎   | 41731/65675 [00:10<00:05, 4527.90ex/s]

 64%|██████▍   | 42188/65675 [00:10<00:05, 4537.30ex/s]

 65%|██████▌   | 42734/65675 [00:10<00:04, 4806.93ex/s]

 66%|██████▌   | 43217/65675 [00:10<00:04, 4573.05ex/s]

 67%|██████▋   | 43687/65675 [00:10<00:04, 4609.04ex/s]

 67%|██████▋   | 44151/65675 [00:10<00:04, 4350.41ex/s]

 68%|██████▊   | 44647/65675 [00:10<00:04, 4520.51ex/s]

 69%|██████▊   | 45104/65675 [00:10<00:04, 4335.11ex/s]

 70%|██████▉   | 45659/65675 [00:10<00:04, 4675.92ex/s]

 70%|███████   | 46132/65675 [00:11<00:04, 4644.15ex/s]

 71%|███████   | 46600/65675 [00:11<00:04, 4501.37ex/s]

 72%|███████▏  | 47054/65675 [00:11<00:04, 4393.57ex/s]

 72%|███████▏  | 47518/65675 [00:11<00:04, 4462.17ex/s]

 73%|███████▎  | 48000/65675 [00:11<00:04, 4417.69ex/s]

 74%|███████▍  | 48488/65675 [00:11<00:03, 4548.28ex/s]

 75%|███████▍  | 48999/65675 [00:11<00:03, 4708.06ex/s]

 75%|███████▌  | 49472/65675 [00:11<00:03, 4550.53ex/s]

 76%|███████▌  | 49949/65675 [00:11<00:03, 4612.69ex/s]

 77%|███████▋  | 50413/65675 [00:12<00:03, 4339.75ex/s]

 78%|███████▊  | 50935/65675 [00:12<00:03, 4585.58ex/s]

 78%|███████▊  | 51399/65675 [00:12<00:03, 4437.29ex/s]

 79%|███████▉  | 51918/65675 [00:12<00:02, 4647.33ex/s]

 80%|███████▉  | 52387/65675 [00:12<00:03, 4404.37ex/s]

 80%|████████  | 52849/65675 [00:12<00:02, 4463.78ex/s]

 81%|████████  | 53300/65675 [00:12<00:02, 4327.48ex/s]

 82%|████████▏ | 53778/65675 [00:12<00:02, 4453.85ex/s]

 83%|████████▎ | 54227/65675 [00:12<00:02, 4266.04ex/s]

 83%|████████▎ | 54712/65675 [00:13<00:02, 4429.98ex/s]

 84%|████████▍ | 55159/65675 [00:13<00:02, 4286.73ex/s]

 85%|████████▍ | 55591/65675 [00:13<00:02, 4281.61ex/s]

 85%|████████▌ | 56022/65675 [00:13<00:02, 4063.50ex/s]

 86%|████████▌ | 56458/65675 [00:13<00:02, 4145.93ex/s]

 87%|████████▋ | 56963/65675 [00:13<00:01, 4403.85ex/s]

 87%|████████▋ | 57407/65675 [00:13<00:01, 4252.71ex/s]

 88%|████████▊ | 57930/65675 [00:13<00:01, 4530.23ex/s]

 89%|████████▉ | 58400/65675 [00:13<00:01, 4578.78ex/s]

 90%|████████▉ | 58914/65675 [00:13<00:01, 4740.91ex/s]

 90%|█████████ | 59391/65675 [00:14<00:01, 4470.26ex/s]

 91%|█████████ | 59845/65675 [00:14<00:01, 4489.40ex/s]

 92%|█████████▏| 60298/65675 [00:14<00:01, 4349.09ex/s]

 92%|█████████▏| 60736/65675 [00:14<00:01, 4295.49ex/s]

 93%|█████████▎| 61168/65675 [00:14<00:01, 4197.59ex/s]

 94%|█████████▍| 61667/65675 [00:14<00:00, 4421.44ex/s]

 95%|█████████▍| 62112/65675 [00:14<00:00, 4186.77ex/s]

 95%|█████████▌| 62535/65675 [00:14<00:00, 4192.06ex/s]

 96%|█████████▌| 62957/65675 [00:14<00:00, 4074.00ex/s]

 96%|█████████▋| 63367/65675 [00:15<00:00, 3907.10ex/s]

 97%|█████████▋| 63878/65675 [00:15<00:00, 4243.04ex/s]

 98%|█████████▊| 64310/65675 [00:15<00:00, 4263.19ex/s]

 99%|█████████▉| 64924/65675 [00:15<00:00, 4805.02ex/s]

100%|█████████▉| 65409/65675 [00:15<00:00, 4665.94ex/s]

100%|██████████| 65675/65675 [00:15<00:00, 4226.14ex/s]




In [17]:
len(train_dataset), len(eval_dataset)

(67900, 65675)

In [18]:
k = random.randrange(len(train_dataset))
print(tokenizer.decode(train_dataset['input_ids'][k]), train_dataset['labels'][k])
k = random.randrange(len(eval_dataset))
print(tokenizer.decode(eval_dataset['input_ids'][k]), eval_dataset['labels'][k])

[CLS] 에어가 터지면 걸을 때 민망한 소리가 난다는 말을 들은 적도 있고, 나이키 에어는 별로 내구성이 좋지 않다는 소리도 많이 들어서 그 점도 플러스가 됐다. [SEP] 제품 전체 # 다양성 [SEP] 1


[CLS] 김정문알로에 1위 아이템! [SEP] 본품 # 인지도 [SEP] 1


# Load Trainer

In [19]:
args = TrainingArguments(
    output_dir=run_name,
    run_name=run_name,
    report_to=report_to,

    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,

    optim=optim,

    learning_rate=learning_rate,
    weight_decay=weight_decay,
    adam_epsilon=adam_epsilon,

    lr_scheduler_type=lr_scheduler_type,
    warmup_ratio=warmup_ratio,

    save_total_limit=save_total_limit,

    load_best_model_at_end=load_best_model_at_end,
    metric_for_best_model=metric_for_best_model,
    
    save_strategy=save_strategy,
    evaluation_strategy=evaluation_strategy,

    logging_strategy=logging_strategy,
    logging_first_step=logging_first_step, 
    logging_steps=logging_steps,
    
    fp16=fp16,
)

In [20]:
# es = EarlyStoppingCallback(early_stopping_patience=early_stopping_patience)

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
    # callbacks=[es],
)

# Run Trainer

In [None]:
trainer.train()
wandb.finish()

In [None]:
keep = [
    'added_tokens.json',
    'config.json',
    'pytorch_model.bin',
    'special_tokens_map.json',
    'tokenizer.json',
    'tokenizer_config.json',
    'vocab.txt'
]

ckpts = os.listdir(run_name)
for ckpt in ckpts:
    ckpt = os.path.join(run_name, ckpt)
    for item in os.listdir(ckpt):
        if item not in keep:
            os.remove(os.path.join(ckpt, item))

!mv wandb {run_name} {SAVE_PATH}/