In [None]:
! pip install datasets transformers

In [None]:
from datasets import load_dataset

import glob
import pickle
import re 
from termcolor import colored
from transformers import AutoModelForMaskedLM, AutoTokenizer
from transformers import AutoTokenizer
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
import torch
import math


from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np

In [None]:
# import the data-----------------------------------------------------
from google.colab import drive
drive.mount('/content/drive')

## 0.2 val & 3 epochs 

In [None]:
val_path = '.../Data/all_poetry_train_verse.csv'
train_path = '.../Data/all_poetry_val_verse.csv'

In [None]:
dataset_poetry = load_dataset('csv', data_files={'train': train_path,
                                                'test': val_path})

In [None]:
model_checkpoint = 'HooshvareLab/bert-fa-zwnj-base'
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, 
                                          use_fast=True)

model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)

In [None]:
model_checkpoint = 'HooshvareLab/bert-fa-zwnj-base'
#After training
ModelPath = '.../Pretrained Models/Pretrained on verses/BERT_0.15_Verse/BERT_model_all_poems.hpt'
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
model.load_state_dict(torch.load(ModelPath, map_location='cuda'))
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

In [None]:
training_args = TrainingArguments( 
    "test-clm",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    weight_decay=0.01, 
    load_best_model_at_end=True
)

In [None]:
def tokenize_function(examples):
    return tokenizer(examples['poetry'])


In [None]:
tokenized_datasets = dataset_poetry.map(
    tokenize_function, 
    batched=True, 
    num_proc=4,
    batch_size=512)


tokenized_datasets["train"][1]

In [None]:
tokenized_datasets

In [None]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,
                                                mlm_probability=0.15)

In [None]:
trainer = Trainer(
    model=model, 
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=data_collator, 
)

In [None]:
trainer.train()

In [None]:
eval_results = trainer.evaluate()

In [None]:
ModelPath = '.../BERT_model_all_poems.hpt'
torch.save(model.state_dict(), ModelPath)

In [None]:
import math
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

# 0.1 val & 5 epochs


In [None]:
val_path = '.../Data/all_poetry_train_verse_10p_Val.csv'
train_path = '.../Data/all_poetry_val_verse_90p_Train.csv'

In [None]:
dataset_poetry = load_dataset('csv', data_files={'train': train_path,
                                                'test': val_path})

In [None]:
training_args = TrainingArguments( 
    "test-clm",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    weight_decay=0.01,  
    load_best_model_at_end=True,
    num_train_epochs=5
)

In [None]:
def tokenize_function(examples):
    return tokenizer(examples['poetry'])


tokenized_datasets = dataset_poetry.map(
    tokenize_function, 
    batched=True, 
    num_proc=4,
    batch_size=128)


tokenized_datasets["train"][1]

In [None]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,
                                                mlm_probability=0.15)

trainer = Trainer(
    model=model, 
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=data_collator, 
)

In [None]:
trainer.train()

In [None]:
tokenizer.save_pretrained('.../Pretrained Models/Pretrained on verses/BERT_0.15_Verse_5epochs/')
model.save_pretrained('.../Pretrained Models/Pretrained on verses/BERT_0.15_Verse_5epochs/')

In [None]:
eval_results = trainer.evaluate()

In [None]:
eval_results

In [None]:
import math
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")