## Loading the Data

In [3]:
from datasets import load_dataset
dataset = load_dataset('billingsmoore/tibetan-to-english-translation-dataset')

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
dataset = dataset['train'].train_test_split(.15)

In [5]:
dataset['train'][0]

{'tibetan': 'མི་བཟད་ཆུ་སྲིན་ལྟ་བུའི་སྡེ་ཚོགས་རྣམས།།',
 'phonetic': 'mi zé chusin tabü dé tsok nam',
 'english': 'Out of the turbulent sea of sixteenfold danger'}

## Load Unfinetuned Tokenizer, Model, and Data Collator

In [6]:
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM

checkpoint = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, device_map="cuda:0")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

## Add Tibetan to Tokenizer

The T5 tokenizer does not notably support the Tibetan script. So, we need to add it manually. Once the characters have been added to the tokenizer, the model needs to have its token embeddings resized to accomodate the added tokens. This is all pretty straightforward, as seen in the code below.

In [7]:
# Tibetan characters to add
tibetan_chars = [
    # Consonants
    "ཀ", "ཁ", "ག", "ང", "ཅ", "ཆ", "ཇ", "ཉ", "ཏ", "ཐ", "ད", "ན", "པ", "པ", "ཕ", "བ", "མ",
    "ཙ", "ཚ", "ཛ", "ཝ", "ཞ", "ཟ", "འ", "ཡ", "ར", "ལ", "ཤ", "ཥ", "ས", "ཧ", "ཨ",

    # Subjoined Consonants
    "ྐ", "ྑ", "ྒ", "ྒྷ", "ྔ", "ྕ", "ྖ", "ྗ", "྘", "ྙ", "ྚ", "ྛ", "ྜ", "ྜྷ", "ྞ", "ྟ",
    "ྠ", "ྡ", "ྡྷ", "ྣ", "ྤ", "ྥ", "ྦ", "ྦྷ", "ྨ", "ྩ", "ྪ", "ྫ", "ྫྷ", "ྭ", "ྮ", "ྯ",
    "ྰ", "ྱ", "ྲ", "ླ", "ྴ", "ྵ", "ྶ", "ྷ", "ྸ", "ྐྵ", "ྺ", "ྻ", "ྼ", "྽", "྾", "྿",

    # Vowels
    "ི", "ཱི", "ུ", "ཱུ", "ྲྀ", "ཷ", "ླྀ", "ཹ", "ེ", "ཻ", "ོ", "ཽ", "ཾ", "ཿ",

    # Other Marks and Symbols
    "འ", "ཡ", "ར", "ལ", "ཤ", "ཥ", "ས", "ཧ", "ཨ",

    # Additional Tibetan Characters
    "ཀྵ", "ཁྵ", "གྵ", "ངྵ", "ཅྵ", "ཆྵ", "ཇྵ", "ཉྵ", "ཏྵ", "ཐྵ", "དྵ", "ནྵ", "པྵ", 
    "པྵ", "ཕྵ", "བྵ", "མྵ", "ཙྵ", "ཚྵ", "ཛྵ", "ཝྵ", "ཞྵ", "ཟྵ", "འྵ", "ཡྵ", "རྵ", 
    "ལྵ", "ཤྵ", "ཥྵ", "སྵ", "ཧྵ", "ཨྵ", "པྪ", "པྫ", "པྫྷ", "པྭ", "པྮ", "པྯ", "པྰ", 
    "པྱ", "པྲ", "པླ", "པྴ", "པྵ", "པྶ", "པྷ", "པྸ", "པྐྵ", "པྺ", "པྻ", "པྼ", "པ྽", 
    "པ྾", "པ྿"
]


#'ཀཁགངཅཆཇཉཏཐདནཔཕབམཙཚཛཝཞཟའཡརལཤཥསཧཨ'

# Add the Tibetan characters to the tokenizer's vocabulary
new_tokens = [char for char in tibetan_chars if char not in tokenizer.get_vocab()]

# Add new tokens to the tokenizer
tokenizer.add_tokens(new_tokens)

# Resize model embeddings to accommodate the new vocabulary size
model.resize_token_embeddings(len(tokenizer))

Embedding(32245, 512)

## Preprocess Data

The dataset can now be tokenized for training.

In [8]:
def translation_preprocess_function(examples):

    # Prepare translation inputs and targets
    translation_inputs = ['Translate Tibetan to English: ' + example for example in examples['tibetan']]
    translation_targets = [example for example in examples['english']]
    
    # Tokenize translation inputs and targets
    translation_model_inputs = tokenizer(translation_inputs, text_target=translation_targets, 
                                         max_length=300, truncation=True, padding="max_length")
    
    
    return translation_model_inputs


In [9]:
def transliteration_preprocess_function(examples):
    # Prepare transliteration inputs and targets
    transliteration_inputs = ['Transliterate: ' + example for example in examples['tibetan']]
    transliteration_targets = [example for example in examples['phonetic']]
    
    # Tokenize transliteration inputs and targets
    transliteration_model_inputs = tokenizer(transliteration_inputs, text_target=transliteration_targets, 
                                             max_length=300, truncation=True, padding="max_length")
    
    return transliteration_model_inputs

In [10]:
translation_tokenized_dataset = dataset.map(translation_preprocess_function, batched=True)

Map: 100%|██████████| 69042/69042 [00:18<00:00, 3738.78 examples/s]
Map: 100%|██████████| 12184/12184 [00:03<00:00, 3372.88 examples/s]


In [11]:
transliteration_tokenized_dataset = dataset.map(transliteration_preprocess_function, batched=True)

Map:   0%|          | 0/69042 [00:00<?, ? examples/s]

Map: 100%|██████████| 69042/69042 [00:21<00:00, 3247.81 examples/s]
Map: 100%|██████████| 12184/12184 [00:03<00:00, 3189.78 examples/s]


In [12]:
from datasets import concatenate_datasets

tokenized_dataset = {}

tokenized_dataset['train'] = concatenate_datasets([translation_tokenized_dataset['train'], transliteration_tokenized_dataset['train']])
tokenized_dataset['test'] = concatenate_datasets([translation_tokenized_dataset['test'], transliteration_tokenized_dataset['test']])

## Train the Model

Finally, we can train the model. Note that the optimizer used is Adafactor. This is the optimizer that is preferred for translation tasks and for the T5 model in general. The transformers api includes a built in version of Adafactor, but I define it separately here so that we can optimize it with the 'accelerate' library.

In [13]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, Adafactor
from accelerate import Accelerator

accelerator = Accelerator()

optimizer = Adafactor(
    model.parameters(), 
    scale_parameter=True, 
    relative_step=False, 
    warmup_init=False, 
    lr=3e-4
)

model, optimizer = accelerator.prepare(model, optimizer)

In [16]:
import numpy as np
import evaluate

metric = evaluate.load("sacrebleu")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [12]:
import wandb

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbillingsmoore[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [17]:
training_args = Seq2SeqTrainingArguments(
    output_dir=f"dual-task-poc",
    auto_find_batch_size=True,
    predict_with_generate=True,
    fp16=False, #check this
    push_to_hub=False,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    num_train_epochs=3
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['test'],
    tokenizer=tokenizer,
    optimizers=(optimizer, None),
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbillingsmoore[0m. Use [1m`wandb login --relogin`[0m to force relogin


  1%|          | 501/51783 [01:37<2:40:20,  5.33it/s]

{'loss': 0.4351, 'grad_norm': 0.1823342889547348, 'learning_rate': 0.0002971032964486414, 'epoch': 0.03}


  2%|▏         | 1001/51783 [03:11<2:44:55,  5.13it/s]

{'loss': 0.22, 'grad_norm': 0.3234434127807617, 'learning_rate': 0.00029420659289728285, 'epoch': 0.06}


  3%|▎         | 1501/51783 [04:44<2:44:34,  5.09it/s]

{'loss': 0.2091, 'grad_norm': 0.15238432586193085, 'learning_rate': 0.0002913098893459243, 'epoch': 0.09}


  4%|▍         | 2001/51783 [06:17<2:33:50,  5.39it/s]

{'loss': 0.1979, 'grad_norm': 0.11686129122972488, 'learning_rate': 0.00028841318579456577, 'epoch': 0.12}


  5%|▍         | 2501/51783 [07:51<2:32:28,  5.39it/s]

{'loss': 0.196, 'grad_norm': 0.15118508040905, 'learning_rate': 0.0002855164822432072, 'epoch': 0.14}


  6%|▌         | 3001/51783 [09:24<2:29:23,  5.44it/s]

{'loss': 0.1924, 'grad_norm': 0.13265873491764069, 'learning_rate': 0.00028261977869184864, 'epoch': 0.17}


  7%|▋         | 3501/51783 [10:57<2:33:10,  5.25it/s]

{'loss': 0.1855, 'grad_norm': 0.12466584146022797, 'learning_rate': 0.0002797230751404901, 'epoch': 0.2}


  8%|▊         | 4001/51783 [12:30<2:29:13,  5.34it/s]

{'loss': 0.1815, 'grad_norm': 0.26544079184532166, 'learning_rate': 0.0002768263715891315, 'epoch': 0.23}


  9%|▊         | 4501/51783 [14:04<2:29:40,  5.26it/s]

{'loss': 0.1712, 'grad_norm': 0.3441460430622101, 'learning_rate': 0.00027392966803777295, 'epoch': 0.26}


 10%|▉         | 5001/51783 [15:37<2:27:34,  5.28it/s]

{'loss': 0.1571, 'grad_norm': 0.18080788850784302, 'learning_rate': 0.00027103296448641444, 'epoch': 0.29}


 11%|█         | 5500/51783 [17:10<2:22:37,  5.41it/s]

{'loss': 0.1446, 'grad_norm': 0.2512515187263489, 'learning_rate': 0.0002681362609350559, 'epoch': 0.32}


 12%|█▏        | 6001/51783 [18:44<2:24:11,  5.29it/s]

{'loss': 0.1362, 'grad_norm': 0.1613493710756302, 'learning_rate': 0.0002652395573836973, 'epoch': 0.35}


 13%|█▎        | 6501/51783 [20:17<2:19:57,  5.39it/s]

{'loss': 0.1313, 'grad_norm': 0.26526325941085815, 'learning_rate': 0.00026234285383233875, 'epoch': 0.38}


 14%|█▎        | 7001/51783 [21:50<2:23:54,  5.19it/s]

{'loss': 0.1292, 'grad_norm': 0.5206232666969299, 'learning_rate': 0.0002594461502809802, 'epoch': 0.41}


 14%|█▍        | 7501/51783 [23:24<2:15:40,  5.44it/s]

{'loss': 0.1198, 'grad_norm': 0.12677748501300812, 'learning_rate': 0.0002565494467296217, 'epoch': 0.43}


 15%|█▌        | 8001/51783 [24:57<2:15:11,  5.40it/s]

{'loss': 0.1136, 'grad_norm': 0.1538192331790924, 'learning_rate': 0.0002536527431782631, 'epoch': 0.46}


 16%|█▋        | 8500/51783 [26:30<2:15:33,  5.32it/s]

{'loss': 0.1115, 'grad_norm': 0.2646123468875885, 'learning_rate': 0.00025075603962690455, 'epoch': 0.49}


 17%|█▋        | 9001/51783 [28:03<2:14:18,  5.31it/s]

{'loss': 0.1128, 'grad_norm': 0.27708113193511963, 'learning_rate': 0.000247859336075546, 'epoch': 0.52}


 18%|█▊        | 9501/51783 [29:36<2:13:48,  5.27it/s]

{'loss': 0.1061, 'grad_norm': 0.2413697987794876, 'learning_rate': 0.0002449626325241875, 'epoch': 0.55}


 19%|█▉        | 10001/51783 [31:09<2:12:44,  5.25it/s]

{'loss': 0.1092, 'grad_norm': 0.11853358149528503, 'learning_rate': 0.00024206592897282889, 'epoch': 0.58}


 20%|██        | 10501/51783 [32:42<2:07:57,  5.38it/s]

{'loss': 0.0997, 'grad_norm': 0.1577751636505127, 'learning_rate': 0.00023916922542147032, 'epoch': 0.61}


 21%|██        | 11001/51783 [34:15<2:09:18,  5.26it/s]

{'loss': 0.1, 'grad_norm': 0.20955780148506165, 'learning_rate': 0.0002362725218701118, 'epoch': 0.64}


 22%|██▏       | 11501/51783 [35:48<2:03:46,  5.42it/s]

{'loss': 0.1003, 'grad_norm': 0.43981611728668213, 'learning_rate': 0.00023337581831875325, 'epoch': 0.67}


 23%|██▎       | 12001/51783 [37:22<2:03:28,  5.37it/s]

{'loss': 0.0997, 'grad_norm': 0.11100732535123825, 'learning_rate': 0.00023047911476739468, 'epoch': 0.7}


 24%|██▍       | 12501/51783 [38:54<2:02:23,  5.35it/s]

{'loss': 0.1035, 'grad_norm': 0.13933542370796204, 'learning_rate': 0.00022758241121603612, 'epoch': 0.72}


 25%|██▌       | 13001/51783 [40:27<2:01:54,  5.30it/s]

{'loss': 0.0957, 'grad_norm': 0.10604268312454224, 'learning_rate': 0.00022468570766467758, 'epoch': 0.75}


 26%|██▌       | 13501/51783 [42:00<2:00:13,  5.31it/s]

{'loss': 0.0972, 'grad_norm': 0.17232361435890198, 'learning_rate': 0.00022178900411331902, 'epoch': 0.78}


 27%|██▋       | 14001/51783 [43:32<2:00:25,  5.23it/s]

{'loss': 0.0954, 'grad_norm': 0.18954706192016602, 'learning_rate': 0.00021889230056196046, 'epoch': 0.81}


 28%|██▊       | 14501/51783 [45:05<1:57:40,  5.28it/s]

{'loss': 0.0963, 'grad_norm': 0.16822417080402374, 'learning_rate': 0.00021599559701060192, 'epoch': 0.84}


 29%|██▉       | 15001/51783 [46:39<1:55:30,  5.31it/s]

{'loss': 0.0939, 'grad_norm': 0.12941144406795502, 'learning_rate': 0.00021309889345924338, 'epoch': 0.87}


 30%|██▉       | 15501/51783 [48:11<1:54:37,  5.28it/s]

{'loss': 0.0957, 'grad_norm': 0.1256210058927536, 'learning_rate': 0.00021020218990788482, 'epoch': 0.9}


 31%|███       | 16001/51783 [49:44<1:52:21,  5.31it/s]

{'loss': 0.0961, 'grad_norm': 0.11581950634717941, 'learning_rate': 0.00020730548635652625, 'epoch': 0.93}


 32%|███▏      | 16501/51783 [51:17<1:47:58,  5.45it/s]

{'loss': 0.0902, 'grad_norm': 0.12283802032470703, 'learning_rate': 0.0002044087828051677, 'epoch': 0.96}


 33%|███▎      | 17001/51783 [52:49<1:46:10,  5.46it/s]

{'loss': 0.0931, 'grad_norm': 0.09389176964759827, 'learning_rate': 0.00020151207925380913, 'epoch': 0.98}


                                                       
 33%|███▎      | 17261/51783 [1:07:50<1:40:25,  5.73it/s]

{'eval_loss': 0.08057115226984024, 'eval_bleu': 22.3588, 'eval_gen_len': 17.314, 'eval_runtime': 852.7097, 'eval_samples_per_second': 28.577, 'eval_steps_per_second': 3.572, 'epoch': 1.0}


 34%|███▍      | 17501/51783 [1:08:35<1:48:00,  5.29it/s]    

{'loss': 0.0908, 'grad_norm': 0.23076477646827698, 'learning_rate': 0.00019861537570245062, 'epoch': 1.01}


 35%|███▍      | 18001/51783 [1:10:08<1:44:55,  5.37it/s]

{'loss': 0.0891, 'grad_norm': 0.13912959396839142, 'learning_rate': 0.00019571867215109205, 'epoch': 1.04}


 36%|███▌      | 18501/51783 [1:11:41<1:44:30,  5.31it/s]

{'loss': 0.0879, 'grad_norm': 0.12529756128787994, 'learning_rate': 0.0001928219685997335, 'epoch': 1.07}


 37%|███▋      | 19001/51783 [1:13:14<1:41:09,  5.40it/s]

{'loss': 0.0914, 'grad_norm': 0.1795681118965149, 'learning_rate': 0.00018992526504837493, 'epoch': 1.1}


 38%|███▊      | 19501/51783 [1:14:48<1:40:46,  5.34it/s]

{'loss': 0.0876, 'grad_norm': 0.15455688536167145, 'learning_rate': 0.00018702856149701636, 'epoch': 1.13}


 39%|███▊      | 20001/51783 [1:16:21<1:38:09,  5.40it/s]

{'loss': 0.0846, 'grad_norm': 0.1650020033121109, 'learning_rate': 0.00018413185794565783, 'epoch': 1.16}


 40%|███▉      | 20501/51783 [1:17:54<1:37:57,  5.32it/s]

{'loss': 0.0872, 'grad_norm': 0.10488549619913101, 'learning_rate': 0.0001812351543942993, 'epoch': 1.19}


 41%|████      | 21001/51783 [1:19:29<1:40:34,  5.10it/s]

{'loss': 0.0871, 'grad_norm': 0.14840812981128693, 'learning_rate': 0.00017833845084294072, 'epoch': 1.22}


 42%|████▏     | 21501/51783 [1:21:05<1:33:51,  5.38it/s]

{'loss': 0.0859, 'grad_norm': 0.10379995405673981, 'learning_rate': 0.00017544174729158216, 'epoch': 1.25}


 42%|████▏     | 22001/51783 [1:22:38<1:32:26,  5.37it/s]

{'loss': 0.088, 'grad_norm': 0.11183345317840576, 'learning_rate': 0.0001725450437402236, 'epoch': 1.27}


 43%|████▎     | 22501/51783 [1:24:11<1:30:17,  5.40it/s]

{'loss': 0.0857, 'grad_norm': 0.09153014421463013, 'learning_rate': 0.00016964834018886506, 'epoch': 1.3}


 44%|████▍     | 23001/51783 [1:25:44<1:30:33,  5.30it/s]

{'loss': 0.0835, 'grad_norm': 0.07523860782384872, 'learning_rate': 0.0001667516366375065, 'epoch': 1.33}


 45%|████▌     | 23501/51783 [1:27:17<1:31:55,  5.13it/s]

{'loss': 0.0819, 'grad_norm': 0.10975798219442368, 'learning_rate': 0.00016385493308614796, 'epoch': 1.36}


 46%|████▋     | 24001/51783 [1:28:50<1:28:47,  5.22it/s]

{'loss': 0.083, 'grad_norm': 0.13227905333042145, 'learning_rate': 0.0001609582295347894, 'epoch': 1.39}


 47%|████▋     | 24501/51783 [1:30:24<1:25:42,  5.30it/s]

{'loss': 0.0843, 'grad_norm': 0.11836561560630798, 'learning_rate': 0.00015806152598343086, 'epoch': 1.42}


 48%|████▊     | 25001/51783 [1:31:57<1:23:05,  5.37it/s]

{'loss': 0.0825, 'grad_norm': 0.10410425066947937, 'learning_rate': 0.0001551648224320723, 'epoch': 1.45}


 49%|████▉     | 25501/51783 [1:33:30<1:22:39,  5.30it/s]

{'loss': 0.0865, 'grad_norm': 0.1000574603676796, 'learning_rate': 0.00015226811888071373, 'epoch': 1.48}


 50%|█████     | 26001/51783 [1:35:02<1:20:39,  5.33it/s]

{'loss': 0.0823, 'grad_norm': 0.15024282038211823, 'learning_rate': 0.0001493714153293552, 'epoch': 1.51}


 51%|█████     | 26501/51783 [1:36:35<1:18:17,  5.38it/s]

{'loss': 0.0818, 'grad_norm': 0.1603843718767166, 'learning_rate': 0.00014647471177799663, 'epoch': 1.54}


 52%|█████▏    | 27001/51783 [1:38:08<1:17:13,  5.35it/s]

{'loss': 0.0847, 'grad_norm': 0.14846032857894897, 'learning_rate': 0.00014357800822663807, 'epoch': 1.56}


 53%|█████▎    | 27501/51783 [1:39:41<1:17:00,  5.26it/s]

{'loss': 0.0806, 'grad_norm': 0.15820053219795227, 'learning_rate': 0.00014068130467527953, 'epoch': 1.59}


 54%|█████▍    | 28000/51783 [1:41:13<1:15:09,  5.27it/s]

{'loss': 0.0795, 'grad_norm': 0.20389413833618164, 'learning_rate': 0.00013778460112392097, 'epoch': 1.62}


 55%|█████▌    | 28501/51783 [1:42:47<1:11:37,  5.42it/s]

{'loss': 0.0815, 'grad_norm': 0.12447873502969742, 'learning_rate': 0.0001348878975725624, 'epoch': 1.65}


 56%|█████▌    | 29001/51783 [1:44:21<1:12:07,  5.26it/s]

{'loss': 0.0821, 'grad_norm': 0.13548676669597626, 'learning_rate': 0.00013199119402120387, 'epoch': 1.68}


 57%|█████▋    | 29501/51783 [1:45:54<1:08:55,  5.39it/s]

{'loss': 0.0815, 'grad_norm': 0.0979003980755806, 'learning_rate': 0.0001290944904698453, 'epoch': 1.71}


 58%|█████▊    | 30001/51783 [1:47:28<1:08:28,  5.30it/s]

{'loss': 0.0818, 'grad_norm': 0.08415436744689941, 'learning_rate': 0.00012619778691848674, 'epoch': 1.74}


 59%|█████▉    | 30501/51783 [1:49:01<1:07:13,  5.28it/s]

{'loss': 0.0814, 'grad_norm': 0.12154505401849747, 'learning_rate': 0.0001233010833671282, 'epoch': 1.77}


 60%|█████▉    | 31001/51783 [1:50:35<1:05:01,  5.33it/s]

{'loss': 0.0827, 'grad_norm': 0.21466487646102905, 'learning_rate': 0.00012040437981576965, 'epoch': 1.8}


 61%|██████    | 31501/51783 [1:52:08<1:03:33,  5.32it/s]

{'loss': 0.0836, 'grad_norm': 0.10766518861055374, 'learning_rate': 0.00011750767626441109, 'epoch': 1.82}


 62%|██████▏   | 32001/51783 [1:53:41<1:02:51,  5.25it/s]

{'loss': 0.0786, 'grad_norm': 0.1834632158279419, 'learning_rate': 0.00011461097271305252, 'epoch': 1.85}


 63%|██████▎   | 32501/51783 [1:55:13<1:00:26,  5.32it/s]

{'loss': 0.0842, 'grad_norm': 0.08560586720705032, 'learning_rate': 0.00011171426916169399, 'epoch': 1.88}


 64%|██████▎   | 33001/51783 [1:56:46<59:34,  5.25it/s]  

{'loss': 0.0818, 'grad_norm': 0.11757726967334747, 'learning_rate': 0.00010881756561033542, 'epoch': 1.91}


 65%|██████▍   | 33501/51783 [1:58:18<56:19,  5.41it/s]  

{'loss': 0.0801, 'grad_norm': 0.11863401532173157, 'learning_rate': 0.00010592086205897687, 'epoch': 1.94}


 66%|██████▌   | 34001/51783 [1:59:50<54:34,  5.43it/s]  

{'loss': 0.0805, 'grad_norm': 0.1046622097492218, 'learning_rate': 0.00010302415850761832, 'epoch': 1.97}


 67%|██████▋   | 34501/51783 [2:01:23<53:51,  5.35it/s]  

{'loss': 0.0774, 'grad_norm': 0.1220642551779747, 'learning_rate': 0.00010012745495625977, 'epoch': 2.0}


                                                       
 67%|██████▋   | 34522/51783 [2:15:46<48:27,  5.94it/s]

{'eval_loss': 0.07104559987783432, 'eval_bleu': 25.3204, 'eval_gen_len': 17.1703, 'eval_runtime': 858.6616, 'eval_samples_per_second': 28.379, 'eval_steps_per_second': 3.547, 'epoch': 2.0}


 68%|██████▊   | 35001/51783 [2:17:15<51:29,  5.43it/s]      

{'loss': 0.0779, 'grad_norm': 0.29106685519218445, 'learning_rate': 9.723075140490121e-05, 'epoch': 2.03}


 69%|██████▊   | 35501/51783 [2:18:49<50:22,  5.39it/s]

{'loss': 0.079, 'grad_norm': 0.08156785368919373, 'learning_rate': 9.433404785354267e-05, 'epoch': 2.06}


 70%|██████▉   | 36001/51783 [2:20:22<49:36,  5.30it/s]

{'loss': 0.0774, 'grad_norm': 0.14186909794807434, 'learning_rate': 9.14373443021841e-05, 'epoch': 2.09}


 70%|███████   | 36501/51783 [2:21:55<51:19,  4.96it/s]

{'loss': 0.0771, 'grad_norm': 0.08791491389274597, 'learning_rate': 8.854064075082554e-05, 'epoch': 2.11}


 71%|███████▏  | 37001/51783 [2:23:28<46:17,  5.32it/s]

{'loss': 0.0799, 'grad_norm': 0.09943348169326782, 'learning_rate': 8.5643937199467e-05, 'epoch': 2.14}


 72%|███████▏  | 37501/51783 [2:25:01<43:58,  5.41it/s]

{'loss': 0.0762, 'grad_norm': 0.12157502770423889, 'learning_rate': 8.274723364810844e-05, 'epoch': 2.17}


 73%|███████▎  | 38001/51783 [2:26:35<43:28,  5.28it/s]

{'loss': 0.0766, 'grad_norm': 0.13884931802749634, 'learning_rate': 7.985053009674989e-05, 'epoch': 2.2}


 74%|███████▍  | 38501/51783 [2:28:08<41:23,  5.35it/s]

{'loss': 0.0771, 'grad_norm': 0.08467958867549896, 'learning_rate': 7.695382654539134e-05, 'epoch': 2.23}


 75%|███████▌  | 39001/51783 [2:29:41<40:24,  5.27it/s]

{'loss': 0.0769, 'grad_norm': 0.1039431095123291, 'learning_rate': 7.405712299403279e-05, 'epoch': 2.26}


 76%|███████▋  | 39501/51783 [2:31:14<38:07,  5.37it/s]

{'loss': 0.0759, 'grad_norm': 0.14011473953723907, 'learning_rate': 7.116041944267423e-05, 'epoch': 2.29}


 77%|███████▋  | 40001/51783 [2:32:47<37:10,  5.28it/s]

{'loss': 0.08, 'grad_norm': 0.14876465499401093, 'learning_rate': 6.826371589131568e-05, 'epoch': 2.32}


 78%|███████▊  | 40501/51783 [2:34:20<35:42,  5.27it/s]

{'loss': 0.0767, 'grad_norm': 0.10424850136041641, 'learning_rate': 6.536701233995713e-05, 'epoch': 2.35}


 79%|███████▉  | 41001/51783 [2:35:54<33:57,  5.29it/s]

{'loss': 0.0761, 'grad_norm': 0.5227623581886292, 'learning_rate': 6.247030878859856e-05, 'epoch': 2.38}


 80%|████████  | 41501/51783 [2:37:27<31:44,  5.40it/s]

{'loss': 0.076, 'grad_norm': 0.1364286094903946, 'learning_rate': 5.957360523724002e-05, 'epoch': 2.4}


 81%|████████  | 42001/51783 [2:39:00<31:11,  5.23it/s]

{'loss': 0.0744, 'grad_norm': 0.08967331796884537, 'learning_rate': 5.667690168588146e-05, 'epoch': 2.43}


 82%|████████▏ | 42501/51783 [2:40:34<28:53,  5.36it/s]

{'loss': 0.0764, 'grad_norm': 0.1169663816690445, 'learning_rate': 5.3780198134522905e-05, 'epoch': 2.46}


 83%|████████▎ | 43001/51783 [2:42:06<27:37,  5.30it/s]

{'loss': 0.0737, 'grad_norm': 0.08818582445383072, 'learning_rate': 5.0883494583164354e-05, 'epoch': 2.49}


 84%|████████▍ | 43501/51783 [2:43:39<25:35,  5.39it/s]

{'loss': 0.0792, 'grad_norm': 0.10218051820993423, 'learning_rate': 4.7986791031805804e-05, 'epoch': 2.52}


 85%|████████▍ | 44001/51783 [2:45:13<24:31,  5.29it/s]

{'loss': 0.0751, 'grad_norm': 0.3725349009037018, 'learning_rate': 4.509008748044725e-05, 'epoch': 2.55}


 86%|████████▌ | 44501/51783 [2:46:51<22:48,  5.32it/s]

{'loss': 0.0766, 'grad_norm': 0.12443254888057709, 'learning_rate': 4.2193383929088697e-05, 'epoch': 2.58}


 87%|████████▋ | 45001/51783 [2:48:24<21:04,  5.36it/s]

{'loss': 0.0776, 'grad_norm': 0.12859289348125458, 'learning_rate': 3.929668037773013e-05, 'epoch': 2.61}


 88%|████████▊ | 45501/51783 [2:49:58<19:27,  5.38it/s]

{'loss': 0.077, 'grad_norm': 0.11626347154378891, 'learning_rate': 3.639997682637158e-05, 'epoch': 2.64}


 89%|████████▉ | 46001/51783 [2:51:31<17:54,  5.38it/s]

{'loss': 0.0752, 'grad_norm': 0.30511674284935, 'learning_rate': 3.350327327501303e-05, 'epoch': 2.66}


 90%|████████▉ | 46501/51783 [2:53:04<16:46,  5.25it/s]

{'loss': 0.0759, 'grad_norm': 0.11528491973876953, 'learning_rate': 3.060656972365448e-05, 'epoch': 2.69}


 91%|█████████ | 47001/51783 [2:54:38<16:00,  4.98it/s]

{'loss': 0.0767, 'grad_norm': 0.33734071254730225, 'learning_rate': 2.7709866172295925e-05, 'epoch': 2.72}


 92%|█████████▏| 47501/51783 [2:56:11<13:19,  5.35it/s]

{'loss': 0.0746, 'grad_norm': 0.12107425928115845, 'learning_rate': 2.481316262093737e-05, 'epoch': 2.75}


 93%|█████████▎| 48001/51783 [2:57:43<11:40,  5.40it/s]

{'loss': 0.0765, 'grad_norm': 0.11096466332674026, 'learning_rate': 2.191645906957882e-05, 'epoch': 2.78}


 94%|█████████▎| 48501/51783 [2:59:16<10:06,  5.41it/s]

{'loss': 0.076, 'grad_norm': 0.0968431606888771, 'learning_rate': 1.9019755518220264e-05, 'epoch': 2.81}


 95%|█████████▍| 49001/51783 [3:00:49<08:38,  5.36it/s]

{'loss': 0.0714, 'grad_norm': 0.1084090992808342, 'learning_rate': 1.612305196686171e-05, 'epoch': 2.84}


 96%|█████████▌| 49501/51783 [3:02:22<07:10,  5.30it/s]

{'loss': 0.0746, 'grad_norm': 0.09340263903141022, 'learning_rate': 1.3226348415503156e-05, 'epoch': 2.87}


 97%|█████████▋| 50001/51783 [3:03:55<05:30,  5.40it/s]

{'loss': 0.0745, 'grad_norm': 0.11812326312065125, 'learning_rate': 1.0329644864144602e-05, 'epoch': 2.9}


 98%|█████████▊| 50501/51783 [3:05:28<03:56,  5.42it/s]

{'loss': 0.0777, 'grad_norm': 0.12189219892024994, 'learning_rate': 7.432941312786049e-06, 'epoch': 2.93}


 98%|█████████▊| 51001/51783 [3:07:00<02:24,  5.41it/s]

{'loss': 0.0737, 'grad_norm': 0.12548935413360596, 'learning_rate': 4.536237761427495e-06, 'epoch': 2.95}


 99%|█████████▉| 51500/51783 [3:08:33<00:51,  5.45it/s]

{'loss': 0.0736, 'grad_norm': 0.12164194136857986, 'learning_rate': 1.6395342100689414e-06, 'epoch': 2.98}


                                                       
100%|██████████| 51783/51783 [3:23:44<00:00,  5.80it/s]

{'eval_loss': 0.06824082881212234, 'eval_bleu': 26.1597, 'eval_gen_len': 17.0385, 'eval_runtime': 857.5151, 'eval_samples_per_second': 28.417, 'eval_steps_per_second': 3.552, 'epoch': 3.0}


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].
100%|██████████| 51783/51783 [3:23:47<00:00,  4.24it/s]

{'train_runtime': 12228.3308, 'train_samples_per_second': 33.876, 'train_steps_per_second': 4.235, 'train_loss': 0.09931060706740599, 'epoch': 3.0}





TrainOutput(global_step=51783, training_loss=0.09931060706740599, metrics={'train_runtime': 12228.3308, 'train_samples_per_second': 33.876, 'train_steps_per_second': 4.235, 'total_flos': 3.28509444980736e+16, 'train_loss': 0.09931060706740599, 'epoch': 3.0})