<a href="https://colab.research.google.com/github/kesperinc/LLaMA_usage_example/blob/master/Document_Classification_Using_KR_SBERT_via_Transformers_(new%2C_including_data_preprocessing%2C_last_update_2022_05_03).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Document Classification Using KR-SBERT via Transformers

*SNU NLP Laboratory*

In this tutorial, we will show how to apply our pre-trained KoRean S-BERT model to a document classification task, using HuggingFace's `transformers` library.

## 0. Preparation

## Libraries

First, you need to install the following libraries.

In [None]:
!pip install -U transformers sentence-transformers kss

### BNC dataset

Then the Balanced News Corpus for a sentiment classification task.

Download and unzip this file.

👇👇👇👇👇

In [None]:
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1Lg2jL89n3lqkKCulAnk4WwmI8G1hNfIA' -O BalancedNewsCorpusShuffled.zip
!unzip BalancedNewsCorpusShuffled.zip

## 1. Setting on Python

Now we can import all of the required libraries.

In [None]:
import torch
import pandas as pd
import numpy as np

# For Transformer models
from transformers import BertForSequenceClassification, Trainer, TrainingArguments
from sentence_transformers import SentenceTransformer

# For train/dev/test datasets
from torch.utils.data import Dataset
from torch.utils.data import random_split
from torch.nn.functional import pad

# For evaluation
from torch import manual_seed
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix

Let us load a `SentenceTransformer` model for sentence embddings and a `BertForSequenceClassification` for classification.

In [None]:
sbert_model_name = 'snunlp/KR-SBERT-V40K-klueNLI-augSTS'
sbert_model = SentenceTransformer(sbert_model_name)
# config = sbert_model._first_module().auto_model.config # for bert token embeddings
from transformers import BertConfig
config = BertConfig()
config.num_labels=9
config.max_position_embeddings = sbert_model.max_seq_length
model = BertForSequenceClassification(config)
model.main_input_name = 'inputs_embeds'
max_seq_length = sbert_model.max_seq_length
manual_seed(1234)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

## 2. Building the BNC datasets

We define a new `Dataset` class loading the Balanced News Corpus dataset for the `BertForSequenceClassification`model.

In [None]:
import re

def clean(text:str):
  # https://github.com/YongWookHa/kor-text-preprocess/blob/master/src/clean.py
  not_used = re.compile('[^ .?!/@$%~|0-9|ㄱ-ㅣ가-힣]+')
  dup_space = re.compile('[ \t]+')  # white space duplicate
  dup_stop = re.compile('[\.]+')  # full stop duplicate

  cleaned = not_used.sub('', text.strip())
  cleaned = dup_space.sub(' ', cleaned)
  cleaned = dup_stop.sub('.', cleaned)

  return cleaned

In [None]:
# from kss import split_sentences # Sentence segmentation for the Korean Language
# sent_tokenize = split_sentences

import nltk
nltk.download('punkt')
from nltk import sent_tokenize

In [None]:
def get_sentence_embeddings(text:str, cls_token='[CLS]', sep_token='[SEP]', padding=True, truncate=True, max_len=128):
  sentences = [cls_token] + sent_tokenize(text) + [sep_token]
  embeddings = sbert_model.encode(sentences, convert_to_tensor=True)
  d = sbert_model.get_sentence_embedding_dimension()
  n = len(sentences)

  seq_len = n

  if padding:
    seq_len = max(n, max_len)

  if truncate:
    seq_len = min(seq_len, max_len)

  output = torch.zeros((seq_len, d), dtype=torch.float32).to(device)
  for i in range(min(n, seq_len)):
    output[i] = embeddings[i]

  return output

In [None]:
class BNCDataset(Dataset):

    labels = ['IT/과학', '경제', '문화', '미용/건강', '사회', '생활', '스포츠', '연예', '정치']

    def __init__(self, data_file='BalancedNewsCorpus_train.csv'):
        data = pd.read_csv(data_file)
        self.text = data['News'].apply(lambda text: text.replace('<p>', '\n').replace('</p>', '\n'))
        self.text = self.text.apply(clean).tolist()
        self.label = data['Topic'].apply(lambda label: self.labels.index(label)).tolist()

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        text = self.text[idx]
        label = torch.tensor(self.label[idx]).to(device)
        feature = {'inputs_embeds': get_sentence_embeddings(text), 'labels': label}
        return feature

Load the BNC dataset files we have downloaded.

In [None]:
train_dataset = BNCDataset('BalancedNewsCorpus_train.csv')
test_dataset = BNCDataset('BalancedNewsCorpus_test.csv')

In [None]:
train_dataset, val_dataset = random_split(train_dataset, [8100, 900], generator=manual_seed(1234))

## 3. Training

In [None]:
args = TrainingArguments(
    output_dir="./bnc-results",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    # eval_steps=10,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    num_train_epochs=1,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    dataloader_pin_memory=True, # False for GPU
)

We will evaluate our classifier using Accuracy, F1, Precision, and Recall scores. This should be defined as the following.

In [None]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    print(confusion_matrix(labels, preds))
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall,
    }

Instantiate the `Trainer`.

In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

Let's train!

In [None]:
trainer.train()

***** Running training *****
  Num examples = 8100
  Num Epochs = 1
  Instantaneous batch size per device = 128
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 1
  Total optimization steps = 64


## 4. Evaluation

In [None]:
trainer.evaluate(test_dataset)