<a href="https://colab.research.google.com/github/jwengr/KoDeBERTa/blob/main/lit_deberta_colab_tpu_pretrain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
! pip install --quiet transformers datasets

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m720.6/720.6 kB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m91.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.6/485.6 kB[0m [31m45.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m46.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m27.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m106.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m82.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m13.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━

In [13]:
cd /content/drive/MyDrive/KoDeBERTa

/content/drive/MyDrive/KoDeBERTa
/content/drive/MyDrive/KoDeBERTa


In [None]:
import os
import tensorflow as tf

from datasets import load_dataset
from tokenizers import Tokenizer

from Model.DebertaV3.TFDebertaV3 import TFDebertaV3ForPretraining
from Data.DataCollator import DataCollatorForHFUnigramSpanMLM

In [None]:
data_path = 'gs://your_bucket/your_text_file.txt'
model_name = 'microsoft/deberta-v3-xsmall'
tokenizer_path = 'tokenizer.json'
mask_token = '[MASK]'
pad_token = '[PAD]'
learning_rate = 1e-4
warmup_steps = 10000
total_steps = 1000000
batch_size = 32
log_per_steps = 100
log_dir = 'logs'
save_per_steps = 10000
save_dir = 'checkpoints'
max_length = 512
mask_prob = 0.15
pretrained_model_path = None

In [None]:
tokenizer = Tokenizer.from_file(tokenizer_path)
mask_id = tokenizer.get_vocab()[mask_token]
pad_id = tokenizer.get_vocab()[pad_token]


In [None]:
!echo $COLAB_TPU_ADDR

In [None]:
TPU_PATH = f"grpc://{os.environ['COLAB_TPU_ADDR']}"

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=TPU_PATH)
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)

In [None]:
ds = load_dataset("text", data_files={"train": data_path})['train']
ds = ds.to_tf_dataset(
    batch_size=batch_size, 
    shuffle=False,
    collate_fn = DataCollatorForHFUnigramSpanMLM(tokenizer, truncation_argument={'max_length':max_length}, mask_prob=mask_prob)
)

In [None]:
with strategy.scope():
    if pretrained_model_path:
        model = tf.keras.models.load_model(pretrained_model_path)
    else:
        model = TFDebertaV3ForPretraining(
            model_name=model_name,
            mask_id=mask_id,
            pad_id=pad_id,
            learning_rate=learning_rate,
            warmup_steps=warmup_steps,
            total_steps=total_steps
        )

    training_loss_generator = tf.keras.metrics.Mean('training_loss_generator', dtype=tf.float32)
    training_loss_discriminator = tf.keras.metrics.Mean('training_loss_discriminator', dtype=tf.float32)

    per_replica_batch_size = batch_size // strategy.num_replicas_in_sync
    train_dataset = strategy.experimental_distribute_datasets_from_function(
        lambda _: ds
    )


In [None]:
@tf.function
def train_multiple_steps(iterator, steps):
    def step_fn(inputs):
        masked_ids, attention_mask, label_ids = inputs
        loss_generator, loss_discriminator = model(masked_ids=masked_ids, attention_mask=attention_mask, label_ids=label_ids)
        training_loss_generator.update_state(loss_generator * strategy.num_replicas_in_sync)
        training_loss_discriminator.update_state(loss_discriminator * strategy.num_replicas_in_sync)

    for _ in tf.range(steps):
        strategy.run(step_fn, args=(next(iterator),))

In [None]:
train_summary_writer = tf.summary.create_file_writer(log_dir)


In [None]:
train_iterator = iter(train_dataset)
for step in range(0, total_steps):
    if step % log_per_steps == 0:
        train_multiple_steps(train_iterator, log_per_steps)
        if step % save_per_steps == 0:
            model.save(f"{save_dir}/loss_g={training_loss_generator.result()}-loss_d={training_loss_discriminator.result()}-step={step}.h5")
        with train_summary_writer.as_default():
            tf.summary.scalar('training_loss_generator', training_loss_generator.result(), step=step)
            tf.summary.scalar('training_loss_discriminator', training_loss_discriminator.result(), step=step)
        training_loss_generator.reset_states()
        training_loss_discriminator.reset_states()
    