## Loading the Data

In [1]:
from datasets import load_from_disk

dataset = load_from_disk('rat-poc-ds-w-context')

In [2]:
dataset['train'][0]

{'Source': 'འཇིག་ལས་འདས་པའི་གང་འདུལ་ལོ།།',
 'Target': 'Taming with transcendent beings.',
 'File_Name': 'TM3076',
 'Machine Aligned': False,
 '__index_level_0__': 1176089,
 'Tag': 'Intrinsic Existence, Conventional Existence',
 'context': ['འགྲོ་ཀུན་སྒྲིབ་པ་གཉིས་སྤངས་ཏེ།\xa0། -> May all beings conquer the two obscurations',
  'དགེ་བས་མཁའ་མཉམ་ལུས་ཅན་མ་ལུས་པ།། ཐེག་མཆོག་གོ་གྱོན་ཤེས་རབ་མཚོན་ཐོགས་ནས།། བདུད་བཞིའི་དགྲ་སྡེ་མ་ལུས་ཀུན་བཅོམ་སྟེ།། སྐུ་གསུམ་ནོར་བུའི་ཁྲི་ལ་འཁོད་གྱུར་ཅིག། -> Through this virtue, may all embodied beings throughout space without exception, Put on the armor of the Supreme Vehicle and having raised the weapon of wisdom, May they overcome all without exception of the host of enemies which are the four demons And be set on the jeweled throne of the three bodies.',
  'སྐྱེ་འགག་ཡོད་མེད་ལ་སོགས་པའི་དམིགས་པ་དང་འཛིན་པའི་ཡུལ་ལས་འདས་པའི་རིག་སྟོང་སྤྲོས་བྲལ་མཉམ་པ་ཉིད་ཀྱི་ཁོར་ཡུག་ཡིན་ཏེ། -> Phenomena therefore transcend all objects of reference and clinging, such as origin and cessat

In [3]:
dataset['test'][0]

{'Source': ' དབང་ཤེས་ནི་རྟགས་ལས་དཔག་མི་དགོས་པར་མངོན་སུམ་དུ་ངེས་པའི་ཕྱིར་རོ།།',
 'Target': ' Sense cognitions need not infer from signs but can ascertain things directly.',
 'File_Name': 'TM0713',
 'Machine Aligned': False,
 '__index_level_0__': 168915,
 'Tag': 'Prophecies, Rituals',
 'context': ['འགྲོ་ཀུན་སྒྲིབ་པ་གཉིས་སྤངས་ཏེ།\xa0། -> May all beings conquer the two obscurations',
  'དགེ་བས་མཁའ་མཉམ་ལུས་ཅན་མ་ལུས་པ།། ཐེག་མཆོག་གོ་གྱོན་ཤེས་རབ་མཚོན་ཐོགས་ནས།། བདུད་བཞིའི་དགྲ་སྡེ་མ་ལུས་ཀུན་བཅོམ་སྟེ།། སྐུ་གསུམ་ནོར་བུའི་ཁྲི་ལ་འཁོད་གྱུར་ཅིག། -> Through this virtue, may all embodied beings throughout space without exception, Put on the armor of the Supreme Vehicle and having raised the weapon of wisdom, May they overcome all without exception of the host of enemies which are the four demons And be set on the jeweled throne of the three bodies.',
  'སྐྱེ་འགག་ཡོད་མེད་ལ་སོགས་པའི་དམིགས་པ་དང་འཛིན་པའི་ཡུལ་ལས་འདས་པའི་རིག་སྟོང་སྤྲོས་བྲལ་མཉམ་པ་ཉིད་ཀྱི་ཁོར་ཡུག་ཡིན་ཏེ། -> Phenomena therefore transcend all obj

## Load Tokenizer, Model, and Data Collator

In [4]:
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM

checkpoint = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, device_map="cuda:0")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [5]:
# Generate a list of all Tibetan Unicode characters (U+0F00 to U+0FFF)
tibetan_chars = [chr(codepoint) for codepoint in range(0x0F00, 0x0FFF)]

# Add the Tibetan characters to the tokenizer's vocabulary
new_tokens = [char for char in tibetan_chars if char not in tokenizer.get_vocab()]

# Add new tokens to the tokenizer
tokenizer.add_tokens(new_tokens)

# Resize model embeddings to accommodate the new vocabulary size
model.resize_token_embeddings(len(tokenizer))

Embedding(32355, 512)

In [6]:
enc = tokenizer.encode(dataset['train'][0]['Source'])
dec = tokenizer.decode(enc)
dec

'འཇིག་ལས་འདས་པའི་གང་འདུལ་ལོ།།</s>'

## Preprocess Data

The dataset can now be tokenized for training.

In [8]:
def preprocess_with_multiple_contexts(example):
    input_text = (
        f"translate Tibetan to English: {example['Source']} Context: "
        + " ".join([f"{i+1}. {context}" for i, context in enumerate(example['context'])])
    )
    target_text = example['Target']

    return {
        "input_ids": tokenizer(input_text, padding="max_length", truncation=True, max_length=512).input_ids,
        "labels": tokenizer(target_text, padding="max_length", truncation=True, max_length=512).input_ids,
    }

# Apply preprocessing
dataset_with_contexts = dataset.map(
    preprocess_with_multiple_contexts,
    batched=False,  # Process one example at a time
    #remove_columns=dataset.column_names,  # Remove old columns to update the schema
)



Map:   0%|          | 0/45000 [00:00<?, ? examples/s]

## Define Metric

In [9]:
import numpy as np
import evaluate

# Load BLEU and CHRF metrics
bleu_metric = evaluate.load("sacrebleu")
chrf_metric = evaluate.load("chrf")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    
    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Postprocess text
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    # Compute BLEU score
    bleu_result = bleu_metric.compute(predictions=decoded_preds, references=decoded_labels)
    bleu_score = bleu_result["score"]

    # Compute CHRF score
    chrf_result = chrf_metric.compute(predictions=decoded_preds, references=decoded_labels)
    chrf_score = chrf_result["score"]

    # Compute generation length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    avg_gen_len = np.mean(prediction_lens)

    # Return rounded results
    return {
        "bleu": round(bleu_score, 4),
        "chrf": round(chrf_score, 4),
        "gen_len": round(avg_gen_len, 4),
    }

## Train the Model

In [10]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, Adafactor
from accelerate import Accelerator

accelerator = Accelerator()

optimizer = Adafactor(
    model.parameters(), 
    scale_parameter=True, 
    relative_step=False, 
    warmup_init=False, 
    lr=3e-4
)

model, optimizer = accelerator.prepare(model, optimizer)

In [11]:
training_args = Seq2SeqTrainingArguments(
    output_dir=f"rat-poc-single-context",
    auto_find_batch_size=True,
    predict_with_generate=True,
    fp16=False, #check this
    push_to_hub=False,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    num_train_epochs=3
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset_with_contexts['train'],
    eval_dataset=dataset_with_contexts['test'],
    tokenizer=tokenizer,
    optimizers=(optimizer, None),
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbillingsmoore[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/16875 [00:00<?, ?it/s]

  0%|          | 0/33750 [00:00<?, ?it/s]

{'loss': 0.6475, 'grad_norm': 0.1861901581287384, 'learning_rate': 0.0002955555555555555, 'epoch': 0.04}
{'loss': 0.5436, 'grad_norm': 0.2835758626461029, 'learning_rate': 0.0002911111111111111, 'epoch': 0.09}
{'loss': 0.489, 'grad_norm': 0.2705881595611572, 'learning_rate': 0.0002866666666666667, 'epoch': 0.13}
{'loss': 0.5044, 'grad_norm': 0.23340293765068054, 'learning_rate': 0.00028222222222222223, 'epoch': 0.18}
{'loss': 0.4828, 'grad_norm': 0.3132828176021576, 'learning_rate': 0.0002777777777777778, 'epoch': 0.22}
{'loss': 0.4916, 'grad_norm': 0.11836644262075424, 'learning_rate': 0.00027333333333333333, 'epoch': 0.27}
{'loss': 0.4816, 'grad_norm': 0.17120428383350372, 'learning_rate': 0.0002688888888888889, 'epoch': 0.31}
{'loss': 0.4909, 'grad_norm': 0.2121029943227768, 'learning_rate': 0.00026444444444444443, 'epoch': 0.36}
{'loss': 0.486, 'grad_norm': 0.4586763381958008, 'learning_rate': 0.00026, 'epoch': 0.4}
{'loss': 0.5026, 'grad_norm': 0.39106279611587524, 'learning_rate'



  0%|          | 0/625 [00:00<?, ?it/s]

{'eval_loss': 0.4654269516468048, 'eval_bleu': 0.0333, 'eval_chrf': 5.4894, 'eval_gen_len': 16.98, 'eval_runtime': 184.2989, 'eval_samples_per_second': 27.13, 'eval_steps_per_second': 3.391, 'epoch': 1.0}
{'loss': 0.458, 'grad_norm': 0.13817499577999115, 'learning_rate': 0.00019777777777777776, 'epoch': 1.02}
{'loss': 0.47, 'grad_norm': 0.21915124356746674, 'learning_rate': 0.00019333333333333333, 'epoch': 1.07}
{'loss': 0.4623, 'grad_norm': 0.1733018159866333, 'learning_rate': 0.00018888888888888888, 'epoch': 1.11}
{'loss': 0.4255, 'grad_norm': 0.32128384709358215, 'learning_rate': 0.00018444444444444443, 'epoch': 1.16}
{'loss': 0.4476, 'grad_norm': 0.15475602447986603, 'learning_rate': 0.00017999999999999998, 'epoch': 1.2}
{'loss': 0.4598, 'grad_norm': 0.2829039394855499, 'learning_rate': 0.00017555555555555553, 'epoch': 1.24}
{'loss': 0.442, 'grad_norm': 0.1287246197462082, 'learning_rate': 0.0001711111111111111, 'epoch': 1.29}
{'loss': 0.4483, 'grad_norm': 0.2688215374946594, 'lear



  0%|          | 0/625 [00:00<?, ?it/s]

{'eval_loss': 0.45450741052627563, 'eval_bleu': 0.0241, 'eval_chrf': 5.5786, 'eval_gen_len': 16.8316, 'eval_runtime': 184.9053, 'eval_samples_per_second': 27.041, 'eval_steps_per_second': 3.38, 'epoch': 2.0}
{'loss': 0.4528, 'grad_norm': 0.2827492654323578, 'learning_rate': 9.555555555555555e-05, 'epoch': 2.04}
{'loss': 0.4588, 'grad_norm': 0.18485760688781738, 'learning_rate': 9.11111111111111e-05, 'epoch': 2.09}
{'loss': 0.4614, 'grad_norm': 0.23722995817661285, 'learning_rate': 8.666666666666665e-05, 'epoch': 2.13}
{'loss': 0.44, 'grad_norm': 0.2688820958137512, 'learning_rate': 8.222222222222222e-05, 'epoch': 2.18}
{'loss': 0.4641, 'grad_norm': 0.11892783641815186, 'learning_rate': 7.777777777777777e-05, 'epoch': 2.22}
{'loss': 0.4424, 'grad_norm': 0.20135915279388428, 'learning_rate': 7.333333333333332e-05, 'epoch': 2.27}
{'loss': 0.4433, 'grad_norm': 0.782459557056427, 'learning_rate': 6.888888888888888e-05, 'epoch': 2.31}
{'loss': 0.454, 'grad_norm': 0.14318709075450897, 'learni



  0%|          | 0/625 [00:00<?, ?it/s]

{'eval_loss': 0.4517495036125183, 'eval_bleu': 0.0262, 'eval_chrf': 5.678, 'eval_gen_len': 16.7604, 'eval_runtime': 185.1364, 'eval_samples_per_second': 27.007, 'eval_steps_per_second': 3.376, 'epoch': 3.0}


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


{'train_runtime': 10737.2834, 'train_samples_per_second': 12.573, 'train_steps_per_second': 3.143, 'train_loss': 0.46500030698423034, 'epoch': 3.0}


TrainOutput(global_step=33750, training_loss=0.46500030698423034, metrics={'train_runtime': 10737.2834, 'train_samples_per_second': 12.573, 'train_steps_per_second': 3.143, 'total_flos': 1.827114319872e+16, 'train_loss': 0.46500030698423034, 'epoch': 3.0})