<a href="https://colab.research.google.com/github/jeongukjae/distilkobert-tfhub-examples/blob/main/distilkobert_kornli_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training DistilKoBERT on KorNLI dataset

* kornli dataset: https://jeongukjae.github.io/tfds-korean/datasets/kornli.html
* distilkobert
    * encoder: https://tfhub.dev/jeongukjae/distilkobert_cased_L-3_H-768_A-12/1
    * preprocessor: https://tfhub.dev/jeongukjae/distilkobert_cased_preprocess/1

## Install pacakges

In [1]:
!pip install -q \
    tensorflow-text \
    tfds-korean

## Prepare environments

In [2]:
import math

import tensorflow as tf
import tensorflow_text as text
import tensorflow_hub as hub
import tensorflow_datasets as tfds

import tfds_korean.kornli

In [3]:
import os

os.environ["TFHUB_MODEL_LOAD_FORMAT"] = "UNCOMPRESSED"

In [4]:
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)

INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.


INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.


INFO:tensorflow:Initializing the TPU system: grpc://10.0.67.138:8470


INFO:tensorflow:Initializing the TPU system: grpc://10.0.67.138:8470


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


## Set up hyperparameters and build models

In [5]:
BATCH_SIZE = 64
LEARNING_RATE = 1e-4
EPOCHS = 5
WARMUP_RATE = 0.05

In [6]:
def create_preprocessing_model():
    preprocessor = hub.load("https://tfhub.dev/jeongukjae/distilkobert_cased_preprocess/1")
    tokenize = hub.KerasLayer(preprocessor.tokenize)
    bert_pack_inputs = hub.KerasLayer(preprocessor.bert_pack_inputs)

    text_inputs = [
        tf.keras.Input([], dtype=tf.string),
        tf.keras.Input([], dtype=tf.string),
    ]
    tokens = [tokenize(item) for item in text_inputs]
    model_inputs = bert_pack_inputs(tokens)
    return tf.keras.Model(text_inputs, model_inputs)


def create_model():
    encoder = hub.KerasLayer("https://tfhub.dev/jeongukjae/distilkobert_cased_L-3_H-768_A-12/1", trainable=True)
    inputs = {
        "input_word_ids": tf.keras.Input([None], dtype=tf.int32, name="input_word_ids"),
        "input_mask": tf.keras.Input([None], dtype=tf.int32, name="input_mask"),
    }
    logit = encoder(inputs)['pooled_output']
    logit = tf.keras.layers.Dropout(0.1)(logit)
    pred = tf.keras.layers.Dense(3)(logit)
    model = tf.keras.Model(inputs, pred)
    model.summary()
    return model

In [7]:
class BertScheduler(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, rate, warmup_ratio, total_steps, name=None):
        super().__init__()

        self.rate = rate
        self.warmup_ratio = warmup_ratio
        self.total_steps = float(total_steps)
        self.warmup_steps = warmup_ratio * total_steps
        self.name = name

    def __call__(self, step):
        with tf.name_scope("BertScheduler"):
            total_steps = tf.convert_to_tensor(self.total_steps, name="total_steps")
            warmup_steps = tf.convert_to_tensor(self.warmup_steps, name="warmup_steps")

            current_step = step + 1.0

            return self.rate * tf.cond(
                current_step < warmup_steps,
                lambda: self.warmup(current_step, warmup_steps),
                lambda: self.decay(current_step, total_steps, warmup_steps),
            )

    @tf.function
    def warmup(self, step, warmup_steps):
        return step / tf.math.maximum(tf.constant(1.0), warmup_steps)

    @tf.function
    def decay(self, step, total_steps, warmup_steps):
        return tf.math.maximum(
            tf.constant(0.0), (total_steps - step) / tf.math.maximum(tf.constant(1.0), total_steps - warmup_steps)
        )

## Prepare datasets

In [8]:
def get_dataset(preprocessor, batch_size):
    with tf.device('/job:localhost'):
        # batch_size=-1 is a way to load the dataset into memory
        in_memory_ds = tfds.load("kornli", batch_size=-1, shuffle_files=True)

    tfds_info = tfds.builder("kornli").info
    train_ds = tf.data.Dataset.from_tensor_slices(in_memory_ds['mnli_train'])
    dev_ds = tf.data.Dataset.from_tensor_slices(in_memory_ds['xnli_dev'])
    test_ds = tf.data.Dataset.from_tensor_slices(in_memory_ds['xnli_test'])
    num_examples = tfds_info.splits['mnli_train'].num_examples

    train_ds = (
        train_ds
        .shuffle(num_examples, reshuffle_each_iteration=True)
        .batch(batch_size, drop_remainder=True)
        .map(lambda x: (preprocessor([x['sentence1'], x['sentence2']]), x['gold_label']), num_parallel_calls=tf.data.AUTOTUNE)
    )
    dev_ds = (
        dev_ds
        .batch(batch_size)
        .map(lambda x: (preprocessor([x['sentence1'], x['sentence2']]), x['gold_label']), num_parallel_calls=tf.data.AUTOTUNE)
    )
    test_ds = (
        test_ds
        .batch(batch_size)
        .map(lambda x: (preprocessor([x['sentence1'], x['sentence2']]), x['gold_label']), num_parallel_calls=tf.data.AUTOTUNE)
    )
    return (train_ds, dev_ds, test_ds), num_examples

## Run train

In [9]:
preprocessor = create_preprocessing_model()


with strategy.scope():
    (train_ds, dev_ds, test_ds), num_examples = get_dataset(preprocessor, BATCH_SIZE)
    print("Element spec:", train_ds.element_spec)
    print("Num examples:", num_examples)
    steps_per_epoch = math.ceil(num_examples / BATCH_SIZE)
    print("steps per epoch:", steps_per_epoch)
    num_train_steps = steps_per_epoch * EPOCHS
    print("total num steps:", num_train_steps)

    model = create_model()
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=tf.keras.optimizers.Adam(learning_rate=BertScheduler(LEARNING_RATE, WARMUP_RATE, num_train_steps)),
        metrics=['acc']
    )

    model.fit(
        train_ds,
        epochs=EPOCHS,
        validation_data=dev_ds,
    )
    model.evaluate(test_ds)

Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


Element spec: ({'input_word_ids': TensorSpec(shape=(64, 128), dtype=tf.int32, name=None), 'input_mask': TensorSpec(shape=(64, 128), dtype=tf.int32, name=None)}, TensorSpec(shape=(64,), dtype=tf.int64, name=None))
Num examples: 392702
steps per epoch: 6136
total num steps: 30680
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_mask (InputLayer)        [(None, None)]       0           []                               
                                                                                                  
 input_word_ids (InputLayer)    [(None, None)]       0           []                               
                                                                                                  
 keras_layer_2 (KerasLayer)     {'encoder_outputs':  27803904    ['input_mask[0][0]',             
           

  "shape. This may consume a large amount of memory." % value)


Instructions for updating:
use `experimental_local_results` instead.


Instructions for updating:
use `experimental_local_results` instead.


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


## Save Model

In [10]:
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save("distilkobert_kornli", include_optimizer=False, options=save_options)



INFO:tensorflow:Assets written to: distilkobert_kornli/assets


INFO:tensorflow:Assets written to: distilkobert_kornli/assets
