In [None]:
from transformers import AutoTokenizer
import json, os
from datasets import Dataset

tokenizer = AutoTokenizer.from_pretrained("jhu-clsp/mmBERT-base", use_fast=True)

def build_sentence_pair_dataset(folder):
    """
    Reads all txt + json truth files and turns them into 
    sentence-pair samples with binary labels.
    """
    texts = []
    labels = []

    txt_files = sorted([f for f in os.listdir(folder) if f.endswith(".txt")])

    for fname in txt_files:
        problem_id = fname[:-4]
        with open(os.path.join(folder, fname), "r", encoding="utf8") as f:
            sentences = [s.strip() for s in f.readlines() if s.strip()]

        truth = json.load(open(os.path.join(folder, "truth-" + problem_id + ".json")))
        changes = truth["changes"]  # length = len(sentences) - 1

        # For each boundary, create: (S_i, S_(i+1)), label
        for i in range(len(changes)):
            pair_text = (sentences[i], sentences[i+1])
            label = changes[i]

            texts.append(pair_text)
            labels.append(label)

    return Dataset.from_dict({"text_pair": texts, "label": labels})

def tokenize_batch(batch):
    sent1 = [t[0] for t in batch["text_pair"]]
    sent2 = [t[1] for t in batch["text_pair"]]

    return tokenizer(
        sent1,
        sent2,
        padding=True,
        truncation=True,
        max_length=128,
    )

from datasets import concatenate_datasets

train_easy  = build_sentence_pair_dataset("Data/easy/train")
train_med   = build_sentence_pair_dataset("Data/medium/train")
train_hard  = build_sentence_pair_dataset("Data/hard/train")

valid_easy  = build_sentence_pair_dataset("Data/easy/validation")
valid_med   = build_sentence_pair_dataset("Data/medium/validation")
valid_hard  = build_sentence_pair_dataset("Data/hard/validation")

train_dataset = concatenate_datasets([train_easy, train_med, train_hard])
eval_dataset  = concatenate_datasets([valid_easy, valid_med, valid_hard])

train_dataset_tokenized = train_dataset.map(
    tokenize_batch,
    batched=True,
    remove_columns=["text_pair"]
)

eval_dataset_tokenized = eval_dataset.map(
    tokenize_batch,
    batched=True,
    remove_columns=["text_pair"]
)

train_dataset_tokenized.save_to_disk("Data/tokenized_train_dataset")
eval_dataset_tokenized.save_to_disk("Data/tokenized_eval_dataset")

In [None]:
print(train_dataset_tokenized[0])
print(eval_dataset_tokenized[0])