## Load Data and Tokenizer

In [1]:
from transformers import BertTokenizer
from datasets import load_dataset

ds = load_dataset('openpecha/tagged_cleaned_MT_v1.0.3')

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

ds['train'][0]

{'Source': 'ཐུབ་པས་རྟག་ཏུ་དེ་བཞིན་སྤྱད།།',
 'Target': 'The aspirant should move in such a way at all times.',
 'File_Name': 'TM2382',
 'Machine Aligned': True,
 '__index_level_0__': 0,
 'Tag': 'Prophecies, Rituals'}

## Preprocess Data

### Remove Blank Tags

In [2]:
def condition(example):
    return example['Tag'] != ''

ds = ds.filter(condition)

In [3]:
ds

DatasetDict({
    train: Dataset({
        features: ['Source', 'Target', 'File_Name', 'Machine Aligned', '__index_level_0__', 'Tag'],
        num_rows: 1163105
    })
    test: Dataset({
        features: ['Source', 'Target', 'File_Name', 'Machine Aligned', '__index_level_0__', 'Tag'],
        num_rows: 0
    })
})

In [4]:
ds['train'][0]

{'Source': 'ཐུབ་པས་རྟག་ཏུ་དེ་བཞིན་སྤྱད།།',
 'Target': 'The aspirant should move in such a way at all times.',
 'File_Name': 'TM2382',
 'Machine Aligned': True,
 '__index_level_0__': 0,
 'Tag': 'Prophecies, Rituals'}

### Collapse Buddhist Lables into One

In [5]:
buddhist_labels = ['Mantras',
                    'Dzogchen',
                    'Astrology',
                    'Monastery',
                    'Mahamudra',
                    'Mind',
                    'Meditation',
                    'Self, Logic, Aggregates',
                    'Tantra',
                    'Emptiness',
                    'Dreams',
                    'Education, Teaching',
                    'Ethics, Enlightenment, Wisdom',
                    'Prophecies, Rituals',
                    'Lama',
                    'Samsara, Nirvana',
                    'Milarepa, Realization, Biography',
                    'Kayas',
                    'Intrinsic Existence, Conventional Existence',
                    'Time, Causality, Perception',
                    'Natural State',
                    'Karma, Consequences',
                    'Dharma']

In [6]:
def collapse_labels(example):
    if example['Tag'] in buddhist_labels:
        example['Tag'] = 'Buddhist'
    return example

# Apply the function to the dataset
ds = ds.map(collapse_labels)

In [7]:
ds

DatasetDict({
    train: Dataset({
        features: ['Source', 'Target', 'File_Name', 'Machine Aligned', '__index_level_0__', 'Tag'],
        num_rows: 1163105
    })
    test: Dataset({
        features: ['Source', 'Target', 'File_Name', 'Machine Aligned', '__index_level_0__', 'Tag'],
        num_rows: 0
    })
})

### Convert Labels to Id Numbers

In [8]:
all_tags = list(set(ds['train']['Tag']))

# Create a label-to-index mapping
label2id = {label: idx for idx, label in enumerate(all_tags)}
id2label = {idx: label for label, idx in label2id.items()}

# Save label mappings for future use
import json
with open("simple_op_label_mapping.json", "w") as f:
    json.dump(label2id, f)


In [9]:
all_tags

['Journalism',
 'History, Politics, Law',
 'Business',
 'Fiction',
 'Science & Medicine',
 'Buddhist',
 'Language & Culture']

In [10]:
def preprocess(examples):
    tokens = tokenizer(examples["Target"], padding="max_length", truncation=True, max_length=128)
    tokens["labels"] = [label2id[label] for label in examples["Tag"]]    
    return tokens

encoded_dataset = ds.map(preprocess, batched=True)


Map:   0%|          | 0/1163105 [00:00<?, ? examples/s]

In [11]:
encoded_dataset = encoded_dataset.remove_columns(['Source', 'Target', 'File_Name', 'Machine Aligned', '__index_level_0__', 'Tag'])

In [12]:
encoded_dataset = encoded_dataset['train'].train_test_split(.15)

## Train Model

In [13]:
from transformers import BertForSequenceClassification

# Load tokenizer and model
model = BertForSequenceClassification.from_pretrained("bert-base-cased", num_labels=len(label2id))

# Resize embeddings to match the new tokenizer
model.resize_token_embeddings(len(tokenizer))

# Move model to GPU
model = model.to('cuda:0')

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def compute_metrics(eval_pred):
    predictions, references = eval_pred
    
    # Get predicted class indices
    predictions = np.argmax(predictions, axis=1)
    
    # Compute metrics
    accuracy = accuracy_score(references, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(references, predictions, average="weighted")
    
    return {
        "accuracy": accuracy,
        "f1": f1,
        "precision": precision,
        "recall": recall
    }


In [15]:
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback

# Define training arguments
training_args = TrainingArguments(
    output_dir="en-col-op-bert-classifier",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=100,  # Set a maximum number of epochs
    weight_decay=0.01,
    eval_strategy="epoch",  # Evaluate at the end of every epoch
    save_strategy="epoch",  # Save the model at the end of every epoch
    load_best_model_at_end=True,  # Load the best model after training
    metric_for_best_model="accuracy",  # Metric to monitor
    greater_is_better=True,  # Higher accuracy is better
    logging_dir="./logs"
)

# Add the EarlyStoppingCallback
early_stopping = EarlyStoppingCallback(
    early_stopping_patience=3  # Stop training if the metric does not improve for 3 evaluation steps
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping]  # Add the early stopping callback
)

# Start training
trainer.train()

  trainer = Trainer(
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbillingsmoore[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/6179000 [00:00<?, ?it/s]

{'loss': 0.6483, 'grad_norm': 6.397555828094482, 'learning_rate': 1.9998381615148082e-05, 'epoch': 0.01}
{'loss': 0.5102, 'grad_norm': 11.611562728881836, 'learning_rate': 1.9996763230296166e-05, 'epoch': 0.02}
{'loss': 0.5057, 'grad_norm': 5.505033016204834, 'learning_rate': 1.9995144845444247e-05, 'epoch': 0.02}
{'loss': 0.4955, 'grad_norm': 5.837775707244873, 'learning_rate': 1.999352646059233e-05, 'epoch': 0.03}
{'loss': 0.481, 'grad_norm': 6.1027750968933105, 'learning_rate': 1.9991908075740415e-05, 'epoch': 0.04}
{'loss': 0.4827, 'grad_norm': 3.4290521144866943, 'learning_rate': 1.9990289690888496e-05, 'epoch': 0.05}
{'loss': 0.477, 'grad_norm': 6.998585224151611, 'learning_rate': 1.9988671306036576e-05, 'epoch': 0.06}
{'loss': 0.4776, 'grad_norm': 4.99961519241333, 'learning_rate': 1.998705292118466e-05, 'epoch': 0.06}
{'loss': 0.4714, 'grad_norm': 8.222979545593262, 'learning_rate': 1.998543453633274e-05, 'epoch': 0.07}
{'loss': 0.4774, 'grad_norm': 1.7104167938232422, 'learnin

  0%|          | 0/10905 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 0.4207659065723419, 'eval_accuracy': 0.850268820285901, 'eval_f1': 0.8411203626221172, 'eval_precision': 0.844679736197194, 'eval_recall': 0.850268820285901, 'eval_runtime': 405.7856, 'eval_samples_per_second': 429.946, 'eval_steps_per_second': 26.874, 'epoch': 1.0}
{'loss': 0.4095, 'grad_norm': 5.7561354637146, 'learning_rate': 1.9799320278362196e-05, 'epoch': 1.0}
{'loss': 0.3882, 'grad_norm': 3.9415371417999268, 'learning_rate': 1.979770189351028e-05, 'epoch': 1.01}
{'loss': 0.3695, 'grad_norm': 2.2126357555389404, 'learning_rate': 1.979608350865836e-05, 'epoch': 1.02}
{'loss': 0.3921, 'grad_norm': 7.058658599853516, 'learning_rate': 1.9794465123806445e-05, 'epoch': 1.03}
{'loss': 0.39, 'grad_norm': 3.2204318046569824, 'learning_rate': 1.9792846738954525e-05, 'epoch': 1.04}
{'loss': 0.3899, 'grad_norm': 5.074982643127441, 'learning_rate': 1.9791228354102606e-05, 'epoch': 1.04}
{'loss': 0.3672, 'grad_norm': 6.697738170623779, 'learning_rate': 1.978960996925069e-05, 'epo

  0%|          | 0/10905 [00:00<?, ?it/s]

{'eval_loss': 0.4310838282108307, 'eval_accuracy': 0.8525615306134147, 'eval_f1': 0.8450336886101782, 'eval_precision': 0.8480871413328013, 'eval_recall': 0.8525615306134147, 'eval_runtime': 406.4293, 'eval_samples_per_second': 429.265, 'eval_steps_per_second': 26.831, 'epoch': 2.0}
{'loss': 0.341, 'grad_norm': 4.577843189239502, 'learning_rate': 1.959864055672439e-05, 'epoch': 2.01}
{'loss': 0.3302, 'grad_norm': 3.8836381435394287, 'learning_rate': 1.9597022171872474e-05, 'epoch': 2.01}
{'loss': 0.3301, 'grad_norm': 5.883388996124268, 'learning_rate': 1.9595403787020555e-05, 'epoch': 2.02}
{'loss': 0.3353, 'grad_norm': 2.8376076221466064, 'learning_rate': 1.9593785402168635e-05, 'epoch': 2.03}
{'loss': 0.3227, 'grad_norm': 4.838154315948486, 'learning_rate': 1.959216701731672e-05, 'epoch': 2.04}
{'loss': 0.328, 'grad_norm': 8.006811141967773, 'learning_rate': 1.95905486324648e-05, 'epoch': 2.05}
{'loss': 0.3237, 'grad_norm': 2.5016961097717285, 'learning_rate': 1.9588930247612884e-05,

  0%|          | 0/10905 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 0.44673487544059753, 'eval_accuracy': 0.8526417754748776, 'eval_f1': 0.8443285599436753, 'eval_precision': 0.848795194497865, 'eval_recall': 0.8526417754748776, 'eval_runtime': 406.4284, 'eval_samples_per_second': 429.266, 'eval_steps_per_second': 26.831, 'epoch': 3.0}
{'loss': 0.3487, 'grad_norm': 5.341777801513672, 'learning_rate': 1.9399579219938504e-05, 'epoch': 3.0}
{'loss': 0.2825, 'grad_norm': 13.408658027648926, 'learning_rate': 1.9397960835086584e-05, 'epoch': 3.01}
{'loss': 0.2755, 'grad_norm': 3.749850273132324, 'learning_rate': 1.9396342450234665e-05, 'epoch': 3.02}
{'loss': 0.2808, 'grad_norm': 6.097711086273193, 'learning_rate': 1.939472406538275e-05, 'epoch': 3.03}
{'loss': 0.28, 'grad_norm': 3.202685594558716, 'learning_rate': 1.9393105680530833e-05, 'epoch': 3.03}
{'loss': 0.2649, 'grad_norm': 16.06536293029785, 'learning_rate': 1.9391487295678914e-05, 'epoch': 3.04}
{'loss': 0.2975, 'grad_norm': 8.506192207336426, 'learning_rate': 1.9389868910826998e-05,

  0%|          | 0/10905 [00:00<?, ?it/s]

{'eval_loss': 0.4988257586956024, 'eval_accuracy': 0.8475232996687033, 'eval_f1': 0.8407941748243749, 'eval_precision': 0.8400164388835006, 'eval_recall': 0.8475232996687033, 'eval_runtime': 406.5037, 'eval_samples_per_second': 429.187, 'eval_steps_per_second': 26.826, 'epoch': 4.0}
{'loss': 0.2566, 'grad_norm': 5.924985408782959, 'learning_rate': 1.9198899498300698e-05, 'epoch': 4.01}
{'loss': 0.2516, 'grad_norm': 0.8387424349784851, 'learning_rate': 1.919728111344878e-05, 'epoch': 4.01}
{'loss': 0.2337, 'grad_norm': 1.2511589527130127, 'learning_rate': 1.9195662728596863e-05, 'epoch': 4.02}
{'loss': 0.2427, 'grad_norm': 26.29286003112793, 'learning_rate': 1.9194044343744943e-05, 'epoch': 4.03}
{'loss': 0.2392, 'grad_norm': 2.1782076358795166, 'learning_rate': 1.9192425958893027e-05, 'epoch': 4.04}
{'loss': 0.2484, 'grad_norm': 10.332441329956055, 'learning_rate': 1.9190807574041108e-05, 'epoch': 4.05}
{'loss': 0.2576, 'grad_norm': 15.544615745544434, 'learning_rate': 1.91891891891891

  0%|          | 0/10905 [00:00<?, ?it/s]

{'eval_loss': 0.5627232193946838, 'eval_accuracy': 0.8459986473009068, 'eval_f1': 0.8391218228897186, 'eval_precision': 0.8415472118425453, 'eval_recall': 0.8459986473009068, 'eval_runtime': 406.5015, 'eval_samples_per_second': 429.189, 'eval_steps_per_second': 26.826, 'epoch': 5.0}
{'loss': 0.2723, 'grad_norm': 19.75212860107422, 'learning_rate': 1.899983816151481e-05, 'epoch': 5.0}
{'loss': 0.2115, 'grad_norm': 2.251190423965454, 'learning_rate': 1.8998219776662892e-05, 'epoch': 5.01}
{'loss': 0.2325, 'grad_norm': 19.01184844970703, 'learning_rate': 1.8996601391810973e-05, 'epoch': 5.02}
{'loss': 0.223, 'grad_norm': 6.405007362365723, 'learning_rate': 1.8994983006959057e-05, 'epoch': 5.03}
{'loss': 0.2145, 'grad_norm': 5.565252304077148, 'learning_rate': 1.8993364622107138e-05, 'epoch': 5.03}
{'loss': 0.2129, 'grad_norm': 19.0717716217041, 'learning_rate': 1.8991746237255222e-05, 'epoch': 5.04}
{'loss': 0.2218, 'grad_norm': 0.9892017245292664, 'learning_rate': 1.8990127852403302e-05,

KeyboardInterrupt: 