In [None]:
!pip install git+https://github.com/cosmoquester/transformers-tf-finetune.git

In [None]:
import argparse
import json
import random
import sys
import urllib.request

import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModel

from transformers_tf_finetune.losses import PearsonCorrelationLoss
from transformers_tf_finetune.metrics import (
    BinaryF1Score,
    PearsonCorrelationMetric,
    SpearmanCorrelationMetric,
    pearson_correlation_coefficient,
    spearman_correlation_coefficient,
)
from transformers_tf_finetune.models import SemanticTextualSimailarityWrapper
from transformers_tf_finetune.utils import LRScheduler, get_device_strategy, get_logger, path_join, set_random_seed

# Config

In [None]:
#: transformers pretrained path
pretrained_model = "cosmoquester/bart-ko-small"
#: pretrained tokenizer fast pretrained path
pretrained_tokenizer = "cosmoquester/bart-ko-small"
#: load from pytorch weight
from_pytorch = False
#: use huggingface credential for private model
use_auth_token = ""

train_dataset_path = "https://raw.githubusercontent.com/KLUE-benchmark/KLUE/main/klue_benchmark/klue-sts-v1.1/klue-sts-v1.1_train.json"
dev_dataset_path = "https://raw.githubusercontent.com/KLUE-benchmark/KLUE/main/klue_benchmark/klue-sts-v1.1/klue-sts-v1.1_dev.json"
#: output directory to save log and model checkpoints, should be GCS path with TPU
output_path = None

#: training params
epochs = 3
learning_rate = 5e-5
min_learning_rate = 1e-5
warmup_rate = 0.06
warmup_steps = None
batch_size = 128
dev_batch_size = 512
num_valid_dataset = 2000
tensorboard_update_freq = 1

#: device to use (TPU or GPU or CPU)
device = "TPU"
#: Use mixed precision FP16
mixed_precision = False
#: Set random seed
seed = None

In [None]:
if output_path is not None and output_path.startswith("gs://"):
  from google.colab import auth
  auth.authenticate_user()

In [None]:
def load_dataset(dataset_path: str, tokenizer: AutoTokenizer, shuffle: bool = False) -> tf.data.Dataset:
    """
    Load KLUE STS dataset from local file or web

    :param dataset_path: local file path or file uri
    :param tokenizer: PreTrainedTokenizer for tokenizing
    :param shuffle: whether shuffling lines or not
    :returns: KLUE STS dataset, number of dataset
    """
    if dataset_path.startswith("https://"):
        with urllib.request.urlopen(dataset_path) as response:
            data = response.read().decode("utf-8")
    else:
        with open(dataset_path) as f:
            data = f.read()
    examples = json.loads(data)
    if shuffle:
        random.shuffle(examples)

    start_token = tokenizer.bos_token or tokenizer.cls_token
    end_token = tokenizer.eos_token or tokenizer.sep_token

    sentences1 = []
    sentences2 = []
    normalized_labels = []
    for example in examples:
        sentence1 = start_token + example["sentence1"] + end_token
        sentence2 = start_token + example["sentence2"] + end_token

        sentences1.append(sentence1)
        sentences2.append(sentence2)
        normalized_labels.append(float(example["labels"]["real-label"]) / 5.0)

    tokens1 = tokenizer(
        sentences1,
        padding=True,
        return_tensors="tf",
        return_token_type_ids=False,
        return_attention_mask=True,
    )
    tokens2 = tokenizer(
        sentences2,
        padding=True,
        return_tensors="tf",
        return_token_type_ids=False,
        return_attention_mask=True,
    )

    dataset = tf.data.Dataset.from_tensor_slices(((dict(tokens1), dict(tokens2)), normalized_labels))
    return dataset

In [None]:
if seed:
    set_random_seed(seed)

In [None]:
strategy = get_device_strategy(device)

# Mixed Precision

In [None]:
with strategy.scope():
    if mixed_precision:
        mixed_type = "mixed_bfloat16" if device == "TPU" else "mixed_float16"
        policy = tf.keras.mixed_precision.experimental.Policy(mixed_type)
        tf.keras.mixed_precision.experimental.set_policy(policy)

# Load Dataset

In [None]:
with strategy.scope():
    tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer, use_auth_token=use_auth_token)

    dataset = load_dataset(train_dataset_path, tokenizer, True)
    train_dataset = dataset.skip(num_valid_dataset).batch(batch_size)
    valid_dataset = dataset.take(num_valid_dataset).batch(dev_batch_size)
    dev_dataset = load_dataset(dev_dataset_path, tokenizer).batch(dev_batch_size)

# Load Model

In [None]:
with strategy.scope():
    model = TFAutoModel.from_pretrained(
        pretrained_model, use_auth_token=use_auth_token, from_pt=from_pytorch
    )
    model_sts = SemanticTextualSimailarityWrapper(model=model, embedding_dropout=0.1)

# Model Compile

In [None]:
with strategy.scope():
    model_sts.compile(
        optimizer=tf.keras.optimizers.Adam(
            LRScheduler(
                len(train_dataset) * epochs,
                learning_rate,
                min_learning_rate,
                warmup_rate,
                warmup_steps,
            ),
        ),
        loss=[PearsonCorrelationLoss(), tf.keras.losses.MeanSquaredError()],
        loss_weights=[0.25, 0.75],
        metrics=[
            BinaryF1Score(),
            PearsonCorrelationMetric(name="pearson_coef"),
            SpearmanCorrelationMetric(name="spearman_coef"),
        ],
    )

# Model Training

In [None]:
with strategy.scope():
    model_sts.fit(
        train_dataset,
        validation_data=valid_dataset,
        epochs=epochs,
        callbacks=[
            tf.keras.callbacks.ModelCheckpoint(
                path_join(output_path, "best_model.ckpt"),
                save_weights_only=True,
                save_best_only=True,
                monitor="val_pearson_coef",
                mode="max",
                verbose=1,
            ),
            tf.keras.callbacks.TensorBoard(
                path_join(output_path, "logs"), update_freq=tensorboard_update_freq
            ),
        ] if output_path is not None else None,
    )

# Model Evaluate

In [None]:
with strategy.scope():
    preds = []
    labels = []
    f1 = BinaryF1Score()
    for inputs, label in dev_dataset:
        pred = model_sts(inputs, training=False)
        preds.extend(pred.numpy())
        labels.extend(label.numpy())
        f1.update_state(label, pred)

    pearson_score = pearson_correlation_coefficient(labels, preds)
    spearman_score = spearman_correlation_coefficient(labels, preds)
    print(
        f"Dev F1 Score: {f1.result():.4f}, "
        f"Dev Pearson: {pearson_score:.4f}, "
        f"Dev Spearman: {spearman_score:.4f}"
    )