## Load Data and Tokenizer

In [1]:
from transformers import BertTokenizer
from datasets import load_dataset

ds = load_dataset('openpecha/tagged_cleaned_MT_v1.0.3')

tokenizer = BertTokenizer.from_pretrained('tibetan_tokenizer')

ds['train'][0]

README.md:   0%|          | 0.00/479 [00:00<?, ?B/s]

train-00000-of-00003.parquet:   0%|          | 0.00/167M [00:00<?, ?B/s]

train-00001-of-00003.parquet:   0%|          | 0.00/172M [00:00<?, ?B/s]

train-00002-of-00003.parquet:   0%|          | 0.00/174M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1429192 [00:00<?, ? examples/s]

{'Source': 'ཐུབ་པས་རྟག་ཏུ་དེ་བཞིན་སྤྱད།།',
 'Target': 'The aspirant should move in such a way at all times.',
 'File_Name': 'TM2382',
 'Machine Aligned': True,
 '__index_level_0__': 0,
 'Tag': 'Prophecies, Rituals'}

In [2]:
def condition(example):
    return example['Tag'] != ''

ds = ds.filter(condition)

Filter:   0%|          | 0/1429192 [00:00<?, ? examples/s]

In [3]:
ds

DatasetDict({
    train: Dataset({
        features: ['Source', 'Target', 'File_Name', 'Machine Aligned', '__index_level_0__', 'Tag'],
        num_rows: 709531
    })
})

## Preprocess Data

In [4]:
all_tags = list(set(ds['train']['Tag']))

# Create a label-to-index mapping
label2id = {label: idx for idx, label in enumerate(all_tags)}
id2label = {idx: label for label, idx in label2id.items()}

# Save label mappings for future use
import json
with open("op_label_mapping.json", "w") as f:
    json.dump(label2id, f)


In [5]:
def preprocess(examples):
    tokens = tokenizer(examples["Source"], padding="max_length", truncation=True, max_length=128)
    tokens["labels"] = [label2id[label] for label in examples["Tag"]]    
    return tokens

encoded_dataset = ds.map(preprocess, batched=True)


Map:   0%|          | 0/709531 [00:00<?, ? examples/s]

In [6]:
encoded_dataset['train'][0]

{'Source': 'ཐུབ་པས་རྟག་ཏུ་དེ་བཞིན་སྤྱད།།',
 'Target': 'The aspirant should move in such a way at all times.',
 'File_Name': 'TM2382',
 'Machine Aligned': True,
 '__index_level_0__': 0,
 'Tag': 'Prophecies, Rituals',
 'input_ids': [2,
  2820,
  9,
  339,
  9,
  1918,
  9,
  40,
  9,
  42,
  9,
  3925,
  9,
  1321,
  11,
  11,
  3,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 'token_type_ids': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,

In [7]:
encoded_dataset = encoded_dataset.remove_columns(['Source', 'Target', 'File_Name', 'Machine Aligned', '__index_level_0__', 'Tag'])

In [8]:
encoded_dataset = encoded_dataset['train'].train_test_split(.15)

## Train Model

In [9]:
from transformers import BertTokenizer, BertForSequenceClassification

# Load tokenizer and model
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=len(label2id))

# Resize embeddings to match the new tokenizer
model.resize_token_embeddings(len(tokenizer))

# Move model to GPU
model = model.to('cuda:0')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def compute_metrics(eval_pred):
    predictions, references = eval_pred
    
    # Get predicted class indices
    predictions = np.argmax(predictions, axis=1)
    
    # Compute metrics
    accuracy = accuracy_score(references, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(references, predictions, average="weighted")
    
    return {
        "accuracy": accuracy,
        "f1": f1,
        "precision": precision,
        "recall": recall
    }


In [11]:
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback

# Define training arguments
training_args = TrainingArguments(
    output_dir="op-bert-classifier",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=100,  # Set a maximum number of epochs
    weight_decay=0.01,
    eval_strategy="epoch",  # Evaluate at the end of every epoch
    save_strategy="epoch",  # Save the model at the end of every epoch
    load_best_model_at_end=True,  # Load the best model after training
    metric_for_best_model="accuracy",  # Metric to monitor
    greater_is_better=True,  # Higher accuracy is better
    logging_dir="./logs"
)

# Add the EarlyStoppingCallback
early_stopping = EarlyStoppingCallback(
    early_stopping_patience=3  # Stop training if the metric does not improve for 3 evaluation steps
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping]  # Add the early stopping callback
)

# Start training
trainer.train()

  trainer = Trainer(
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbillingsmoore[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/3769400 [00:00<?, ?it/s]

{'loss': 2.0894, 'grad_norm': 3.4238860607147217, 'learning_rate': 1.9997347057887198e-05, 'epoch': 0.01}
{'loss': 2.0178, 'grad_norm': 3.957597017288208, 'learning_rate': 1.9994694115774395e-05, 'epoch': 0.03}
{'loss': 1.9869, 'grad_norm': 7.211489677429199, 'learning_rate': 1.999204117366159e-05, 'epoch': 0.04}
{'loss': 1.9831, 'grad_norm': 4.002837657928467, 'learning_rate': 1.9989388231548788e-05, 'epoch': 0.05}
{'loss': 1.9785, 'grad_norm': 5.943065643310547, 'learning_rate': 1.9986735289435985e-05, 'epoch': 0.07}
{'loss': 1.9586, 'grad_norm': 4.799134254455566, 'learning_rate': 1.9984082347323185e-05, 'epoch': 0.08}
{'loss': 1.9361, 'grad_norm': 4.495475769042969, 'learning_rate': 1.9981429405210378e-05, 'epoch': 0.09}
{'loss': 1.9416, 'grad_norm': 4.601923942565918, 'learning_rate': 1.9978776463097578e-05, 'epoch': 0.11}
{'loss': 1.9231, 'grad_norm': 5.0125932693481445, 'learning_rate': 1.9976123520984775e-05, 'epoch': 0.12}
{'loss': 1.9063, 'grad_norm': 5.557928085327148, 'lear

  0%|          | 0/6652 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.6856188774108887, 'eval_accuracy': 0.4737010241473269, 'eval_f1': 0.3758530511519662, 'eval_precision': 0.3867557938476894, 'eval_recall': 0.4737010241473269, 'eval_runtime': 244.7458, 'eval_samples_per_second': 434.859, 'eval_steps_per_second': 27.179, 'epoch': 1.0}
{'loss': 1.6937, 'grad_norm': 5.6133246421813965, 'learning_rate': 1.9798376399426966e-05, 'epoch': 1.01}
{'loss': 1.6543, 'grad_norm': 7.418537139892578, 'learning_rate': 1.9795723457314163e-05, 'epoch': 1.02}
{'loss': 1.6715, 'grad_norm': 12.934019088745117, 'learning_rate': 1.979307051520136e-05, 'epoch': 1.03}
{'loss': 1.6425, 'grad_norm': 10.267112731933594, 'learning_rate': 1.9790417573088556e-05, 'epoch': 1.05}
{'loss': 1.6743, 'grad_norm': 13.30710220336914, 'learning_rate': 1.9787764630975753e-05, 'epoch': 1.06}
{'loss': 1.6789, 'grad_norm': 7.715205669403076, 'learning_rate': 1.978511168886295e-05, 'epoch': 1.07}
{'loss': 1.6758, 'grad_norm': 19.948389053344727, 'learning_rate': 1.978245874675015e

  0%|          | 0/6652 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.6049892902374268, 'eval_accuracy': 0.4882270036643803, 'eval_f1': 0.405253809937267, 'eval_precision': 0.560281023326854, 'eval_recall': 0.4882270036643803, 'eval_runtime': 244.5271, 'eval_samples_per_second': 435.248, 'eval_steps_per_second': 27.204, 'epoch': 2.0}
{'loss': 1.617, 'grad_norm': 9.327131271362305, 'learning_rate': 1.9599405740966734e-05, 'epoch': 2.0}
{'loss': 1.5788, 'grad_norm': 11.49383544921875, 'learning_rate': 1.959675279885393e-05, 'epoch': 2.02}
{'loss': 1.5538, 'grad_norm': 8.606086730957031, 'learning_rate': 1.9594099856741127e-05, 'epoch': 2.03}
{'loss': 1.5896, 'grad_norm': 10.014422416687012, 'learning_rate': 1.9591446914628324e-05, 'epoch': 2.04}
{'loss': 1.595, 'grad_norm': 12.307848930358887, 'learning_rate': 1.958879397251552e-05, 'epoch': 2.06}
{'loss': 1.5752, 'grad_norm': 5.049760818481445, 'learning_rate': 1.958614103040272e-05, 'epoch': 2.07}
{'loss': 1.6009, 'grad_norm': 7.472815990447998, 'learning_rate': 1.9583488088289914e-05, 'e

  0%|          | 0/6652 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.5505106449127197, 'eval_accuracy': 0.5000657709292493, 'eval_f1': 0.4200792618764342, 'eval_precision': 0.45923000457985064, 'eval_recall': 0.5000657709292493, 'eval_runtime': 245.4038, 'eval_samples_per_second': 433.693, 'eval_steps_per_second': 27.106, 'epoch': 3.0}
{'loss': 1.5262, 'grad_norm': 7.8854827880859375, 'learning_rate': 1.93977821403937e-05, 'epoch': 3.01}
{'loss': 1.5391, 'grad_norm': 7.9411773681640625, 'learning_rate': 1.9395129198280895e-05, 'epoch': 3.02}
{'loss': 1.5134, 'grad_norm': 12.898775100708008, 'learning_rate': 1.939247625616809e-05, 'epoch': 3.04}
{'loss': 1.52, 'grad_norm': 15.852009773254395, 'learning_rate': 1.9389823314055288e-05, 'epoch': 3.05}
{'loss': 1.5336, 'grad_norm': 6.002817153930664, 'learning_rate': 1.9387170371942485e-05, 'epoch': 3.06}
{'loss': 1.5321, 'grad_norm': 7.17263650894165, 'learning_rate': 1.9384517429829685e-05, 'epoch': 3.08}
{'loss': 1.5358, 'grad_norm': 13.205915451049805, 'learning_rate': 1.9381864487716878e-

  0%|          | 0/6652 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.5359846353530884, 'eval_accuracy': 0.5048482570703748, 'eval_f1': 0.4286645608356498, 'eval_precision': 0.4530202554835685, 'eval_recall': 0.5048482570703748, 'eval_runtime': 245.4139, 'eval_samples_per_second': 433.676, 'eval_steps_per_second': 27.105, 'epoch': 4.0}
{'loss': 1.5175, 'grad_norm': 5.8107147216796875, 'learning_rate': 1.9198811481933466e-05, 'epoch': 4.01}
{'loss': 1.4811, 'grad_norm': 7.159397125244141, 'learning_rate': 1.9196158539820663e-05, 'epoch': 4.02}
{'loss': 1.4908, 'grad_norm': 11.363373756408691, 'learning_rate': 1.919350559770786e-05, 'epoch': 4.03}
{'loss': 1.4662, 'grad_norm': 4.984918117523193, 'learning_rate': 1.9190852655595056e-05, 'epoch': 4.05}
{'loss': 1.4782, 'grad_norm': 7.230827331542969, 'learning_rate': 1.9188199713482253e-05, 'epoch': 4.06}
{'loss': 1.4841, 'grad_norm': 10.438305854797363, 'learning_rate': 1.918554677136945e-05, 'epoch': 4.07}
{'loss': 1.4827, 'grad_norm': 7.910827159881592, 'learning_rate': 1.918289382925665e-

  0%|          | 0/6652 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.5355561971664429, 'eval_accuracy': 0.5057502583857935, 'eval_f1': 0.4283294248101745, 'eval_precision': 0.45048258825009446, 'eval_recall': 0.5057502583857935, 'eval_runtime': 245.3104, 'eval_samples_per_second': 433.859, 'eval_steps_per_second': 27.117, 'epoch': 5.0}
{'loss': 1.4994, 'grad_norm': 16.956466674804688, 'learning_rate': 1.899984082347323e-05, 'epoch': 5.0}
{'loss': 1.4373, 'grad_norm': 14.953500747680664, 'learning_rate': 1.899718788136043e-05, 'epoch': 5.01}
{'loss': 1.4345, 'grad_norm': 7.3859710693359375, 'learning_rate': 1.8994534939247627e-05, 'epoch': 5.03}
{'loss': 1.433, 'grad_norm': 8.295027732849121, 'learning_rate': 1.8991881997134824e-05, 'epoch': 5.04}
{'loss': 1.4191, 'grad_norm': 12.642525672912598, 'learning_rate': 1.898922905502202e-05, 'epoch': 5.05}
{'loss': 1.3998, 'grad_norm': 9.791640281677246, 'learning_rate': 1.8986576112909217e-05, 'epoch': 5.07}
{'loss': 1.4303, 'grad_norm': 12.365324974060059, 'learning_rate': 1.8983923170796414e

  0%|          | 0/6652 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.5265285968780518, 'eval_accuracy': 0.503833505590529, 'eval_f1': 0.439446464109943, 'eval_precision': 0.4585073792525974, 'eval_recall': 0.503833505590529, 'eval_runtime': 244.8355, 'eval_samples_per_second': 434.7, 'eval_steps_per_second': 27.169, 'epoch': 6.0}
{'loss': 1.3958, 'grad_norm': 8.822230339050293, 'learning_rate': 1.87982172229002e-05, 'epoch': 6.01}
{'loss': 1.3965, 'grad_norm': 16.697267532348633, 'learning_rate': 1.8795564280787395e-05, 'epoch': 6.02}
{'loss': 1.3967, 'grad_norm': 9.632466316223145, 'learning_rate': 1.8792911338674592e-05, 'epoch': 6.04}
{'loss': 1.3897, 'grad_norm': 17.311708450317383, 'learning_rate': 1.879025839656179e-05, 'epoch': 6.05}
{'loss': 1.4029, 'grad_norm': 11.993690490722656, 'learning_rate': 1.8787605454448985e-05, 'epoch': 6.06}
{'loss': 1.415, 'grad_norm': 12.110703468322754, 'learning_rate': 1.8784952512336182e-05, 'epoch': 6.08}
{'loss': 1.3956, 'grad_norm': 6.113706588745117, 'learning_rate': 1.878229957022338e-05, 'e

  0%|          | 0/6652 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.5374624729156494, 'eval_accuracy': 0.5025838579347929, 'eval_f1': 0.44102470088096163, 'eval_precision': 0.45335571038925093, 'eval_recall': 0.5025838579347929, 'eval_runtime': 247.3336, 'eval_samples_per_second': 430.31, 'eval_steps_per_second': 26.895, 'epoch': 7.0}
{'loss': 1.4114, 'grad_norm': 10.325528144836426, 'learning_rate': 1.8599246564439966e-05, 'epoch': 7.0}
{'loss': 1.3222, 'grad_norm': 17.04103660583496, 'learning_rate': 1.8596593622327163e-05, 'epoch': 7.02}
{'loss': 1.3343, 'grad_norm': 20.90623664855957, 'learning_rate': 1.859394068021436e-05, 'epoch': 7.03}
{'loss': 1.3391, 'grad_norm': 26.000619888305664, 'learning_rate': 1.8591287738101556e-05, 'epoch': 7.04}
{'loss': 1.34, 'grad_norm': 17.28074836730957, 'learning_rate': 1.8588634795988753e-05, 'epoch': 7.06}
{'loss': 1.3477, 'grad_norm': 10.698341369628906, 'learning_rate': 1.858598185387595e-05, 'epoch': 7.07}
{'loss': 1.362, 'grad_norm': 18.605119705200195, 'learning_rate': 1.8583328911763146e-0

  0%|          | 0/6652 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.5359383821487427, 'eval_accuracy': 0.5066146763130697, 'eval_f1': 0.44149182723058805, 'eval_precision': 0.4550629188322734, 'eval_recall': 0.5066146763130697, 'eval_runtime': 247.6434, 'eval_samples_per_second': 429.771, 'eval_steps_per_second': 26.861, 'epoch': 8.0}
{'loss': 1.2903, 'grad_norm': 13.124061584472656, 'learning_rate': 1.839762296386693e-05, 'epoch': 8.01}
{'loss': 1.2967, 'grad_norm': 17.717111587524414, 'learning_rate': 1.8394970021754127e-05, 'epoch': 8.03}
{'loss': 1.3039, 'grad_norm': 17.602914810180664, 'learning_rate': 1.8392317079641324e-05, 'epoch': 8.04}
{'loss': 1.2774, 'grad_norm': 11.426953315734863, 'learning_rate': 1.838966413752852e-05, 'epoch': 8.05}
{'loss': 1.2972, 'grad_norm': 14.420337677001953, 'learning_rate': 1.8387011195415717e-05, 'epoch': 8.06}
{'loss': 1.2955, 'grad_norm': 11.884708404541016, 'learning_rate': 1.8384358253302914e-05, 'epoch': 8.08}
{'loss': 1.2973, 'grad_norm': 12.409870147705078, 'learning_rate': 1.838170531119

  0%|          | 0/6652 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.5630689859390259, 'eval_accuracy': 0.5062576341257164, 'eval_f1': 0.4388699899312852, 'eval_precision': 0.46137715986810796, 'eval_recall': 0.5062576341257164, 'eval_runtime': 246.1207, 'eval_samples_per_second': 432.43, 'eval_steps_per_second': 27.027, 'epoch': 9.0}
{'loss': 1.2872, 'grad_norm': 18.207101821899414, 'learning_rate': 1.8198652305406695e-05, 'epoch': 9.01}
{'loss': 1.2414, 'grad_norm': 14.786429405212402, 'learning_rate': 1.8195999363293895e-05, 'epoch': 9.02}
{'loss': 1.2366, 'grad_norm': 18.78693962097168, 'learning_rate': 1.8193346421181092e-05, 'epoch': 9.03}
{'loss': 1.2545, 'grad_norm': 14.792014122009277, 'learning_rate': 1.819069347906829e-05, 'epoch': 9.05}
{'loss': 1.2503, 'grad_norm': 10.622474670410156, 'learning_rate': 1.8188040536955485e-05, 'epoch': 9.06}
{'loss': 1.2141, 'grad_norm': 25.130523681640625, 'learning_rate': 1.8185387594842682e-05, 'epoch': 9.07}
{'loss': 1.2564, 'grad_norm': 19.40913200378418, 'learning_rate': 1.81827346527298

  0%|          | 0/6652 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.592553734779358, 'eval_accuracy': 0.495950389927652, 'eval_f1': 0.43995122769298595, 'eval_precision': 0.4455531113799034, 'eval_recall': 0.495950389927652, 'eval_runtime': 245.9537, 'eval_samples_per_second': 432.724, 'eval_steps_per_second': 27.046, 'epoch': 10.0}
{'loss': 1.2609, 'grad_norm': 20.011045455932617, 'learning_rate': 1.7999681646946466e-05, 'epoch': 10.0}
{'loss': 1.1717, 'grad_norm': 22.799793243408203, 'learning_rate': 1.799702870483366e-05, 'epoch': 10.01}
{'loss': 1.188, 'grad_norm': 13.32534122467041, 'learning_rate': 1.799437576272086e-05, 'epoch': 10.03}
{'loss': 1.1677, 'grad_norm': 20.454381942749023, 'learning_rate': 1.7991722820608056e-05, 'epoch': 10.04}
{'loss': 1.2074, 'grad_norm': 27.273958206176758, 'learning_rate': 1.7989069878495253e-05, 'epoch': 10.05}
{'loss': 1.1823, 'grad_norm': 22.1530818939209, 'learning_rate': 1.798641693638245e-05, 'epoch': 10.07}
{'loss': 1.1558, 'grad_norm': 26.75925064086914, 'learning_rate': 1.798376399426964

  0%|          | 0/6652 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 1.6183372735977173, 'eval_accuracy': 0.49539603495255097, 'eval_f1': 0.4422287990438464, 'eval_precision': 0.4446891482403874, 'eval_recall': 0.49539603495255097, 'eval_runtime': 246.0357, 'eval_samples_per_second': 432.579, 'eval_steps_per_second': 27.037, 'epoch': 11.0}
{'train_runtime': 50323.5068, 'train_samples_per_second': 1198.448, 'train_steps_per_second': 74.903, 'train_loss': 1.4578601861428213, 'epoch': 11.0}


TrainOutput(global_step=414634, training_loss=1.4578601861428213, metrics={'train_runtime': 50323.5068, 'train_samples_per_second': 1198.448, 'train_steps_per_second': 74.903, 'total_flos': 4.364592657569042e+17, 'train_loss': 1.4578601861428213, 'epoch': 11.0})