# Further Pre-training Models on Transliterated Bangla

This sample notebook further pre-trains mBERT but can be replaced by any pre-trained language model from HuggingFace. Please modify the `model_name` and `tokenizer` to further pre-train a different model.

## Import Necessary Libraries

In [1]:
import torch
import pandas as pd

import warnings
warnings.filterwarnings('ignore')

from transformers import (AutoModel,AutoModelForMaskedLM,
                          AutoTokenizer, LineByLineTextDataset,
                          DataCollatorForLanguageModeling,
                          Trainer, TrainingArguments)

## Load models and tokenizers from HuggingFace

In [None]:
model_name = 'google-bert/bert-base-multilingual-cased'
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-multilingual-cased')
tokenizer.save_pretrained('./mbert');

## Load the Dataset

In [None]:
train_dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path="/kaggle/input/banglish/banglish.txt", #mention train text file here
    block_size=256)

valid_dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path="/kaggle/input/banglish/banglish.txt", #mention valid text file here
    block_size=256)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

## Configure the Trainer

In [None]:
training_args = TrainingArguments(
    output_dir="./mBERTcheckPoint", #select model path for checkpoint
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    evaluation_strategy= 'no',
    do_eval=False,
    save_total_limit=1,
    greater_is_better=False,
    load_best_model_at_end =True,
    prediction_loss_only=True,
    report_to = "none",
    )

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset)

## Start Further Pre-training

In [None]:
trainer.train()
trainer.save_model(f'./mbert')