<a href="https://colab.research.google.com/github/bbang3/korean-text-augmentation/blob/klue/experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
%cd /content/drive/MyDrive/korean-text-augmentation

/content/drive/MyDrive/korean-text-augmentation


# Import

In [4]:
!pip install transformers[pytorch]
!pip install datasets
!pip install accelerate
!pip install evaluate
!pip install wandb

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.15.0 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.6
Collecting accelerate
  Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)
[2K     

In [5]:
import os
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm

from transformers import BertModel, BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
import evaluate

import wandb

In [6]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [7]:
os.environ["WANDB_PROJECT"] = "KLUE - KRBERT"

In [None]:
train_path = 'data/train_mlm_2.csv'
val_path = 'data/val_low.csv'
test_path = 'data/test.csv'
dataset = load_dataset('csv', data_files={'train': train_path, 'validation': val_path, 'test': test_path})

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()
del model

In [11]:
num_labels = 7
tokenizer = BertTokenizer.from_pretrained("snunlp/KR-Medium")
model = BertForSequenceClassification.from_pretrained("snunlp/KR-Medium", num_labels=num_labels).to("cuda")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at snunlp/KR-Medium and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
def tokenization(example):
    return tokenizer(example["title"], padding="max_length", truncation=True)

tokenized_dataset = dataset.map(tokenization, batched=True, remove_columns=["title"])
tokenized_dataset.set_format("torch")

Map:   0%|          | 0/840 [00:00<?, ? examples/s]

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/70 [00:00<?, ? examples/s]

Map:   0%|          | 0/7000 [00:00<?, ? examples/s]

In [13]:
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [31]:
training_args = TrainingArguments(
    output_dir='./checkpoints/mlm',
    num_train_epochs=4,
    warmup_steps=100,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    seed=42,
    evaluation_strategy='epoch',
    logging_strategy='epoch',
    save_strategy='epoch',
    report_to='wandb',
    run_name='mlm_3',
    save_total_limit=2,
    metric_for_best_model='accuracy',
    greater_is_better=True,
    load_best_model_at_end=True
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['validation'],
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
wandb.finish()

# Load model

In [34]:
trainer.state.best_model_checkpoint

'./checkpoints/mlm/checkpoint-53'

In [35]:
num_labels = 7
tokenizer = BertTokenizer.from_pretrained("snunlp/KR-Medium")
ft_model = BertForSequenceClassification.from_pretrained(trainer.state.best_model_checkpoint, num_labels=num_labels).to("cuda")
# ft_model = BertForSequenceClassification.from_pretrained("checkpoints/bt_2/checkpoint-848", num_labels=num_labels).to("cuda")

In [36]:
training_args = TrainingArguments(
    output_dir='./checkpoints/tmp',
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    metric_for_best_model='accuracy',
    greater_is_better=True,
    seed=42,
    evaluation_strategy='no',
    logging_strategy='epoch',
    save_strategy='no',
    report_to='wandb',
    run_name='test-mlm'
)

trainer = Trainer(
    model=ft_model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['test'],
    compute_metrics=compute_metrics
)

In [None]:
trainer.evaluate(eval_dataset=tokenized_dataset['test'])

In [None]:
wandb.finish()

In [42]:
predictions = trainer.predict(test_dataset=tokenized_dataset['validation'])
predictions = np.argmax(predictions.predictions, axis=-1)
predictions

array([5, 6, 2, 4, 0, 6, 2, 0, 0, 4, 4, 0, 4, 0, 3, 6, 5, 4, 0, 4, 5, 6,
       5, 4, 0, 2, 5, 4, 1, 0, 0, 1, 1, 3, 6, 5, 2, 4, 1, 4, 1, 6, 3, 4,
       0, 3, 3, 2, 5, 5, 6, 2, 2, 4, 0, 5, 1, 2, 4, 1, 1, 0, 2, 6, 6, 3,
       5, 6, 1, 0])