In [1]:
from datasets import load_dataset
import os
import torch
from transformers import BertTokenizer

TEST_FOLDER_PATH = "data/test_data_processed"
TRAIN_FOLDER_PATH = "data/train_data_processed"

def load_test_data():
    """
    Helper function to load test data into a dictionary
    Pass **output_dict to BERT model
    """
    input_ids = torch.load(os.path.join(TEST_FOLDER_PATH, 'test_input_ids.pt'))
    attention_mask = torch.load(os.path.join(TEST_FOLDER_PATH, 'test_attention_mask.pt'))
    token_type_ids = torch.load(os.path.join(TEST_FOLDER_PATH, 'test_token_type_ids.pt'))
    # labels = torch.load(os.path.join(TEST_FOLDER_PATH, 'test_labels.pt'))
    
    return {
        'input_ids':input_ids, 
        'attention_mask': attention_mask, 
        'token_type_ids': token_type_ids, 
        # 'labels': labels
    }

def load_test_target_words():
    """ Helper function to load ground truth target words for testing """
    with open(os.path.join(TEST_FOLDER_PATH, 'test_last_words.txt'), 'r') as f:
        target_words = f.read().splitlines()
    return target_words

def load_train_eval_data(train_type, train_prop=.75):
    """
    Helper function to load and split train/eval data as TensorDatasets

    In the datasets, 
        1st col = `input_ids` (to be masked by a data_collator)
        2nd col = `token_type_ids`
        3rd col = `attention_mask`
        4th col = `labels` (ground truths)
    
    Can be passed as values for transformers.Trainer(train_dataset=xxx, val_dataset=xxx).

    Args:
        train_type (str): Type of training (determines relevant data file name).
        train_prop (float): Proportion of training data out of the entire dataset.
                            Val data will take up (1-`train_prop`) of the original dataset. 
    Returns:
        train_set (Dataset): Training set in Hugging Face dataset format
        val_set (Dataset): Training set in Hugging Face dataset format
    """
    # Load processed "prev verse + next verse" texts
    full_texts = load_dataset("text", data_files=os.path.join(TRAIN_FOLDER_PATH, f"{train_type}_train_lines.txt"))

    # Tokenize the processed texts
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    tokenized_texts = full_texts.map(
        lambda x: tokenizer(x["text"], max_length=60, padding='max_length', truncation=True), 
        batched=True, 
        remove_columns=["text"]
    )

    # Duplicate the `input_ids` column (will be masked) to create `labels` (ground truths)
    def add_labels(x):
        x['labels'] = x['input_ids'].copy()
        return x
    labeled_texts = tokenized_texts.map(add_labels, batched=True)

    # Make the dataset compatible with torch formats (i.e., convert to tensors)
    torch_texts = labeled_texts.with_format("torch")

    # Do train-val split and return
    split_texts = torch_texts['train'].train_test_split(train_size=train_prop)
    return split_texts['train'], split_texts['test']



In [4]:
import torch
from transformers import BertTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments

# Replace these lists with your actual dataset
# sentences = ["This is an example sentence from my domain.", "Another example sentence from my domain."]
from utils import load_train_eval_data

# Among the original training data, 75% is used for training, 25% is used for validation
train_prop = 0.75

# Set the fine-tuning type 
train_type = "sonnets" # One of ["sonnets", "shakestrain", "poems"]

train_set, val_set = load_train_eval_data(train_type, train_prop)

# Load the tokenizer and the pre-trained model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")

# Set up the data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

# Set up the training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    logging_dir="./logs",
    logging_steps=1,
    save_steps=100,
    seed=42,
    learning_rate=2e-5,
    weight_decay=0.01,
)


# Create the Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=val_set,
    data_collator=data_collator,
)

# Fine-tune the model
trainer.train()

torch.save(model, "model.pth")

Found cached dataset text (C:/Users/Shaobo Liang/.cache/huggingface/datasets/text/default-f8e57489f1a94737/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached processed dataset at C:\Users\Shaobo Liang\.cache\huggingface\datasets\text\default-f8e57489f1a94737\0.0.0\cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2\cache-937f1637ed85bf6f.arrow
Loading cached processed dataset at C:\Users\Shaobo Liang\.cache\huggingface\datasets\text\default-f8e57489f1a94737\0.0.0\cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2\cache-99df46b8f2a33063.arrow
Loading cached split indices for dataset at C:\Users\Shaobo Liang\.cache\huggingface\datasets\text\default-f8e57489f1a94737\0.0.0\cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2\cache-fa80f177049e6cb9.arrow and C:\Users\Shaobo Liang\.cache\huggingface\datasets\text\default-f8e57489f1a94737\0.0.0\cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2\cache-ea871f09751aea53.arrow
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.w

Epoch,Training Loss,Validation Loss
1,4.5849,3.717442
2,3.4951,3.563942
3,3.7313,3.527672


In [9]:
output_dir = "./saved_model"
model.save_pretrained(output_dir)

In [14]:
from test import test_main

# model_path = "model.pth"
model_path = "./results/checkpoint-1000/"
# model_path = "./saved_model/"

k = 5 # I want top 5 test metrics

results_dict = test_main(model_path, k)

print(f"Top {k} accuracy is {results_dict['accuracy']}.")
print(f"Top {k} cosine similarity score is {results_dict['cos_sim']}.")
# print(f"Top {k} rhyming score is {results_dict["rhyme"]}.") # To be implemented

Top 5 accuracy is 0.13941018766756033.
Top 5 cosine similarity score is 0.42201218008995056.
