In [36]:
import jax

jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [35]:
!free -mh

              total        used        free      shared  buff/cache   available
Mem:          334Gi       6.0Gi       319Gi       1.0Mi       8.8Gi       326Gi
Swap:            0B          0B          0B


In [3]:
language = "si"

In [4]:
model_config = "distilgpt2"

In [5]:
model_dir = "../models/" + model_config + f"-pretrained-{language}"

In [6]:
from pathlib import Path

Path(model_dir).mkdir(parents=True, exist_ok=True)

In [7]:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(model_config)

Downloading:   0%|          | 0.00/762 [00:00<?, ?B/s]

In [8]:
config.save_pretrained(f"{model_dir}")

In [9]:
from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
from pathlib import Path

In [10]:
raw_dataset = load_dataset("oscar", f"unshuffled_deduplicated_{language}")

Downloading and preparing dataset oscar/unshuffled_deduplicated_si (download: 167.48 MiB, generated: 802.48 MiB, post-processed: Unknown size, total: 969.95 MiB) to /home/Keshan/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_si/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2...


Downloading:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/176M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

Dataset oscar downloaded and prepared to /home/Keshan/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_si/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2. Subsequent calls will reuse this data.


In [11]:
tokenizer = ByteLevelBPETokenizer()

In [12]:
def batch_iterator(batch_size=1000):
    for i in range(0, len(raw_dataset), batch_size):
        yield raw_dataset["train"][i: i + batch_size]["text"]

In [13]:
tokenizer.train_from_iterator(batch_iterator(), vocab_size=config.vocab_size, min_frequency=2, special_tokens=[
    "<s>",
    "<pad>",
    "</s>",
    "<unk>",
    "<mask>",
])






In [14]:
tokenizer.save(f"{model_dir}/tokenizer.json")

In [37]:
max_seq_length = 512

In [38]:
raw_dataset["train"] = load_dataset("oscar", f"unshuffled_deduplicated_{language}", split="train[5%:]")



In [39]:
raw_dataset["validation"] = load_dataset("oscar", f"unshuffled_deduplicated_{language}", split="train[:5%]")



In [40]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(f"{model_dir}")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [41]:
def tokenize_function(examples):
    return tokenizer(examples["text"])

In [42]:
tokenized_datasets = raw_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=raw_dataset["train"].column_names)



In [43]:
def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // max_seq_length) * max_seq_length
    result = {
        k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [44]:
tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)

https://symbolize.stripped_domain/r/?trace=7fa4697e7f99,7fa4695fc20f&map= 
*** SIGTERM received by PID 123988 (TID 123988) on cpu 71 from PID 121832; stack trace: ***
PC: @     0x7fa4697e7f99  (unknown)  munmap
    @     0x7fa457ca9800        976  (unknown)
    @     0x7fa4695fc210  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7fa4697e7f99,7fa457ca97ff,7fa4695fc20f&map=2a762cd764e70bc90ae4c7f9747c08d7:7fa44ad67000-7fa457fe8280 
E0708 06:22:01.272365  123988 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.
E0708 06:22:01.279528  123988 process_state.cc:771] RAW: Raising signal 15 with default behavior
https://symbolize.stripped_domain/r/?trace=7fa4699c860a,7fa4695fc20f&map= 
*** SIGTERM received by PID 124272 (TID 124272) on cpu 24 from PID 121832; stack trace: ***
PC: @     0x7fa4699c860a  (unknown)  (unknown)
    @     0x7fa457ca9800        976  (unknown)
    @     0x7fa4695fc210  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace

In [57]:
import jax
import optax
import flax
import jax.numpy as jnp
import math

from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard

import numpy as np

from tqdm.auto import tqdm

In [46]:
per_device_batch_size = 64
num_epochs = 10
training_seed = 0
learning_rate = 3e-4

total_batch_size = per_device_batch_size * jax.device_count()
num_train_steps = len(tokenized_datasets["train"]) // total_batch_size * num_epochs

In [47]:
from transformers import FlaxAutoModelForCausalLM

model = FlaxAutoModelForCausalLM.from_config(config, seed=training_seed, dtype=jnp.dtype("bfloat16"))

In [48]:
linear_decay_lr_schedule_fn = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)

In [49]:
adamw = optax.adamw(learning_rate=linear_decay_lr_schedule_fn, b1=0.9, b2=0.98, eps=1e-8, weight_decay=0.01)

In [50]:
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)

In [51]:
def data_loader(rng, dataset, batch_size, shuffle=False):
    steps_per_epoch = len(dataset) // batch_size

    if shuffle:
        batch_idx = jax.random.permutation(rng, len(dataset))
    else:
        batch_idx = jnp.arange(len(dataset))

    batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))

    for idx in batch_idx:
        batch = dataset[idx]
        batch = {k: jnp.array(v) for k, v in batch.items()}

        batch = shard(batch)

        yield batch

In [52]:
def train_step(state, batch, dropout_rng):
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_fn(params):
        labels = batch.pop("labels")
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        
        loss = optax.softmax_cross_entropy(logits[..., :-1, :], onehot(labels[..., 1:], logits.shape[-1])).mean()
        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)

    metrics = jax.lax.pmean(
        {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
    )

    return new_state, metrics, new_dropout_rng

In [53]:
parallel_train_step = jax.pmap(train_step, "batch")

In [54]:
def eval_step(params, batch):
    labels = batch.pop("labels")

    logits = model(**batch, params=params, train=False)[0]

    loss = optax.softmax_cross_entropy(logits[..., :-1, :], onehot(labels[..., 1:], logits.shape[-1])).mean()

    # summarize metrics
    metrics = {"loss": loss, "perplexity": jnp.exp(loss)}
    metrics = jax.lax.pmean(metrics, axis_name="batch")
    return metrics

In [55]:
parallel_eval_step = jax.pmap(eval_step, "batch")
state = flax.jax_utils.replicate(state)

rng = jax.random.PRNGKey(training_seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())

In [56]:
for epoch in tqdm(range(1, num_epochs + 1), desc=f"Epoch ...", position=0, leave=True):
    rng, input_rng = jax.random.split(rng)

    # -- Train --
    train_loader = data_loader(input_rng, tokenized_datasets["train"], total_batch_size, shuffle=True)
    with tqdm(total=len(tokenized_datasets["train"]) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
        for model_inputs in train_loader:
            # Model forward
            state, train_metric, dropout_rngs = parallel_train_step(state, model_inputs, dropout_rngs)

            progress_bar_train.update(1)

        progress_bar_train.write(
              f"Train... ({epoch}/{num_epochs} | Loss: {round(train_metric['loss'].mean(), 3)}, Learning Rate: {round(train_metric['learning_rate'].mean(), 6)})"
        )

    # -- Eval --
    eval_loader = data_loader(input_rng, tokenized_datasets["validation"], total_batch_size)
    eval_metrics = []
   
    with tqdm(total=len(tokenized_datasets["validation"]) // total_batch_size, desc="Evaluation...", leave=False) as progress_bar_eval:
        for model_inputs in eval_loader:
            # Model forward
            eval_metric = parallel_eval_step(state.params, model_inputs)
            eval_metrics.append(eval_metric)

            progress_bar_eval.update(1)
 
        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
        progress_bar_eval.write(
            f"Eval... ({epoch}/{num_epochs} | Loss: {eval_metrics['loss']} | Perplexity: {eval_metrics['perplexity']})"
        )

Epoch ...:   0%|          | 0/10 [00:00<?, ?it/s]

Training...:   0%|          | 0/758 [00:00<?, ?it/s]

Train... (1/10 | Loss: 2.509000062942505, Learning Rate: 0.0002699999895412475)


Evaluation...:   0%|          | 0/40 [00:00<?, ?it/s]

Eval... (1/10 | Loss: 2.480623960494995 | Perplexity: 12.450997352600098)


Training...:   0%|          | 0/758 [00:00<?, ?it/s]

Train... (2/10 | Loss: 2.1080000400543213, Learning Rate: 0.00023999999393709004)


Evaluation...:   0%|          | 0/40 [00:00<?, ?it/s]

Eval... (2/10 | Loss: 2.1132447719573975 | Perplexity: 8.5897855758667)


Training...:   0%|          | 0/758 [00:00<?, ?it/s]

Train... (3/10 | Loss: 2.005000114440918, Learning Rate: 0.0002099999983329326)


Evaluation...:   0%|          | 0/40 [00:00<?, ?it/s]

Eval... (3/10 | Loss: 1.9953031539916992 | Perplexity: 7.633769512176514)


Training...:   0%|          | 0/758 [00:00<?, ?it/s]

Train... (4/10 | Loss: 1.942000150680542, Learning Rate: 0.00018000000272877514)


Evaluation...:   0%|          | 0/40 [00:00<?, ?it/s]

Eval... (4/10 | Loss: 1.9307587146759033 | Perplexity: 7.157529354095459)


Training...:   0%|          | 0/758 [00:00<?, ?it/s]

Train... (5/10 | Loss: 1.8970000743865967, Learning Rate: 0.00014999999257270247)


Evaluation...:   0%|          | 0/40 [00:00<?, ?it/s]

Eval... (5/10 | Loss: 1.8898165225982666 | Perplexity: 6.869416236877441)


Training...:   0%|          | 0/758 [00:00<?, ?it/s]

Train... (6/10 | Loss: 1.852000117301941, Learning Rate: 0.00011999999696854502)


Evaluation...:   0%|          | 0/40 [00:00<?, ?it/s]

Eval... (6/10 | Loss: 1.862762451171875 | Perplexity: 6.687985897064209)


Training...:   0%|          | 0/758 [00:00<?, ?it/s]

Train... (7/10 | Loss: 1.8250000476837158, Learning Rate: 9.000000136438757e-05)


Evaluation...:   0%|          | 0/40 [00:00<?, ?it/s]

Eval... (7/10 | Loss: 1.8416788578033447 | Perplexity: 6.548841953277588)


Training...:   0%|          | 0/758 [00:00<?, ?it/s]

Train... (8/10 | Loss: 1.8100000619888306, Learning Rate: 5.999999848427251e-05)


Evaluation...:   0%|          | 0/40 [00:00<?, ?it/s]

Eval... (8/10 | Loss: 1.8258028030395508 | Perplexity: 6.445103645324707)


Training...:   0%|          | 0/758 [00:00<?, ?it/s]

Train... (9/10 | Loss: 1.811000108718872, Learning Rate: 2.9999999242136255e-05)


Evaluation...:   0%|          | 0/40 [00:00<?, ?it/s]

Eval... (9/10 | Loss: 1.8161194324493408 | Perplexity: 6.384854793548584)


Training...:   0%|          | 0/758 [00:00<?, ?it/s]

Train... (10/10 | Loss: 1.7760001420974731, Learning Rate: 0.0)


Evaluation...:   0%|          | 0/40 [00:00<?, ?it/s]

Eval... (10/10 | Loss: 1.8110542297363281 | Perplexity: 6.354654788970947)
