In [None]:
!pip install git+https://github.com/cosmoquester/transformers-tf-finetune.git

In [None]:
import csv
import random
import urllib.request

import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

from transformers_tf_finetune.losses import SparseCategoricalCrossentropy
from transformers_tf_finetune.metrics import SparseCategoricalAccuracy
from transformers_tf_finetune.utils import LRScheduler, get_device_strategy, path_join, set_random_seed

# Config

In [None]:
#: transformers pretrained path
pretrained_model = "cosmoquester/bart-ko-small"
#: pretrained tokenizer fast pretrained path
pretrained_tokenizer = "cosmoquester/bart-ko-small"
#: load from pytorch weight
from_pytorch = False
#: use huggingface credential for private model
use_auth_token = None

dataset_path = "https://raw.githubusercontent.com/songys/Chatbot_data/master/ChatbotData.csv"
#: output directory to save log and model checkpoints, should be GCS path with TPU
output_path = None

max_sequence_length = 128
#: "beam size, use greedy search if this is zero"
beam_size = 0

#: training params
epochs = 2
learning_rate = 1e-4
min_learning_rate = 1e-5
warmup_rate = 0.06
warmup_steps = None
batch_size = 16
dev_batch_size = 256
num_dev_dataset = 128 # should be multipes of 8 with TPU
tensorboard_update_freq = 1

#: device to use (TPU or GPU or CPU)
device = "TPU"
#: Use mixed precision FP16
mixed_precision = False
#: Set random seed
seed = None

In [None]:
if output_path is not None and output_path.startswith("gs://"):
  from google.colab import auth
  auth.authenticate_user()

In [None]:
def load_dataset(dataset_path: str, tokenizer: AutoTokenizer, shuffle: bool = False) -> tf.data.Dataset:
    """
    Load Chatbot Conversation dataset from local file or web

    :param dataset_path: local file path or file uri
    :param tokenizer: PreTrainedTokenizer for tokenizing
    :param shuffle: whether shuffling lines or not
    :returns: conversation dataset
    """
    if dataset_path.startswith("https://"):
        with urllib.request.urlopen(dataset_path) as response:
            data = response.read().decode("utf-8")
    else:
        with open(dataset_path) as f:
            data = f.read()
    lines = data.splitlines()[1:]
    if shuffle:
        random.shuffle(lines)

    bos = tokenizer.bos_token or tokenizer.cls_token or ""
    eos = tokenizer.eos_token or tokenizer.sep_token

    questions = []
    answers = []
    for question, answer, _ in csv.reader(lines):
        questions.append(bos + question + eos)
        answers.append(bos + answer + eos)

    max_length = max(len(text) for text in questions + answers)
    inputs = tokenizer(
        questions,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="tf",
        return_token_type_ids=False,
        return_attention_mask=True,
    )

    target_tokens = tokenizer(
        answers,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="tf",
        return_token_type_ids=False,
        return_attention_mask=False,
    )["input_ids"]

    dataset = tf.data.Dataset.from_tensor_slices(
        ({**inputs, "decoder_input_ids": target_tokens[:, :-1]}, target_tokens[:, 1:])
    )
    return dataset

In [None]:
if seed:
    set_random_seed(seed)

In [None]:
strategy = get_device_strategy(device)

# Mixed Precision

In [None]:
with strategy.scope():
    if mixed_precision:
        mixed_type = "mixed_bfloat16" if device == "TPU" else "mixed_float16"
        policy = tf.keras.mixed_precision.experimental.Policy(mixed_type)
        tf.keras.mixed_precision.experimental.set_policy(policy)

# Load Dataset

In [None]:
with strategy.scope():
    tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer, use_auth_token=use_auth_token)

    dataset = load_dataset(dataset_path, tokenizer, True)
    train_dataset = dataset.skip(num_dev_dataset).batch(batch_size)
    dev_dataset = dataset.take(num_dev_dataset).batch(dev_batch_size)

# Load Model

In [None]:
with strategy.scope():
    model = TFAutoModelForSeq2SeqLM.from_pretrained(
        pretrained_model, use_auth_token=use_auth_token, from_pt=from_pytorch, use_cache=False
    )

# Model Compile

In [None]:
with strategy.scope():
    model.compile(
        optimizer=tf.optimizers.Adam(
            LRScheduler(
                len(train_dataset) * epochs,
                learning_rate,
                min_learning_rate,
                warmup_rate,
                warmup_steps,
            )
        ),
        loss=SparseCategoricalCrossentropy(from_logits=True, ignore_index=tokenizer.pad_token_id),
        metrics=SparseCategoricalAccuracy(ignore_index=tokenizer.pad_token_id, name="accuracy"),
    )

# Model Training

In [None]:
with strategy.scope():
    model.fit(
        train_dataset,
        validation_data=dev_dataset,
        epochs=epochs,
        callbacks=[
            tf.keras.callbacks.ModelCheckpoint(
                path_join(output_path, "best_model.ckpt"),
                save_weights_only=True,
                save_best_only=True,
                monitor="val_accuracy",
                mode="max",
                verbose=1,
            ),
            tf.keras.callbacks.TensorBoard(
                path_join(output_path, "logs"), update_freq=tensorboard_update_freq
            ),
        ] if output_path is not None else None,
    )

# Model Evaluate

In [None]:
with strategy.scope():
    loss, accuracy = model.evaluate(dev_dataset)

# Prediction

In [None]:
# Predict is Not Supported with TPU. Use GPU after training and saving model
with strategy.scope():
    input_tokens = []
    predict_tokens = []

    for batch, _ in strategy.experimental_distribute_dataset(dev_dataset):
        output = strategy.run(
            model.generate, kwargs={"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}
        )

        input_tokens.extend(strategy.gather(batch["input_ids"], axis=0).numpy())
        predict_tokens.extend(output.numpy())

    input_sentences = tokenizer.batch_decode(input_tokens, skip_special_tokens=True)
    predict_sentences = tokenizer.batch_decode(predict_tokens, skip_special_tokens=True)
    for question, answer in zip(input_sentences, predict_sentences):
        print(f"Q: {question} A: {answer}")