In [70]:
import os
import sys
import configparser
from dataclasses import dataclass

import datasets
import torch

from nltk import sent_tokenize
from datasets import load_from_disk
from tokenizers import Tokenizer
from transformers import BertModel, AutoTokenizer, AutoModel, BertTokenizerFast
from torch.utils.data import DataLoader

# Loading

In [2]:
file_name = "../tokenizers/tokenizer_2_1000.json"

base_tokenizer = Tokenizer.from_file(file_name)
tokenizer = BertTokenizerFast(tokenizer_object=base_tokenizer)

In [6]:
config = configparser.ConfigParser()
config.read('../config.ini')

root_dir = r"C:\Users\onurg\.cache\huggingface\datasets"
data_dir = "updated_wiki40b"

path = os.path.join(root_dir, data_dir, "long_small_dataset")

dataset = load_from_disk(path)
dataset

# Prototyping

In [12]:
small_dataset = dataset.select(range(100))

In [138]:
@dataclass
class Config:
    def __init__(self, max_sentence_len, max_document_len, per_device_train_batch_size):
      self.max_sentence_len = max_sentence_len
      self.max_document_len = max_document_len
      self.per_device_train_batch_size = per_device_train_batch_size

args = Config(128,64, 8)

In [141]:
def tokenize(example, tokenizer, args):
    # https://github.com/castorini/hedwig/blob/master/datasets/bert_processors/abstract_processor.py
    # https://github.com/abhishekkrthakur/bert-entity-extraction/blob/master/src/dataset.py

    def tokenize_helper(article, tokenizer, args):
        # TODO: comment
        sentences = [tokenizer.encode(sentence, add_special_tokens=False) for sentence in sent_tokenize(article)] 
        sentences = [sentence[:args.max_sentence_len - 2] for sentence in sentences]
        sentences = [[tokenizer.convert_tokens_to_ids("[CLS]")] + sentence + [tokenizer.convert_tokens_to_ids("[SEP]")] for sentence in sentences]

        sentence_lengths = [len(sentence) for sentence in sentences]
        # TODO: check for attention_mask ID
        mask = [[1]*sen_len for sen_len in sentence_lengths]

        return sentences, mask
 
    # TODO: make "end" dynamic
    for i in range(1, 3):
        
        # example[f"article_{i}"] = [tokenizer.encode(sentence, 
        #                                             truncation=True,
        #                                             add_special_tokens=True,                                            
        #                                             max_length=args.max_sentence_len) for sentence in sent_tokenize(example[f"article_{i}"])]

        example[f"article_{i}"], example[f"mask_{i}"] = tokenize_helper(example[f"article_{i}"], tokenizer, args)
        
    return example

In [142]:
new_dataset = small_dataset.map(tokenize, fn_kwargs={"tokenizer": tokenizer, "args": args})

100%|██████████| 100/100 [00:02<00:00, 44.69ex/s]


## Data Colator

In [170]:
@dataclass
class CustomDataCollator:
    """ A data collator which can be used for dynamic padding, when each instance of a batch is a 
    list of lists. Each sentence is a list and each document (instance of a batch) contains multiple 
    sentences.
    """
    tokenizer: None
    max_sentence_len: int = 128
    max_document_len: int = 32
    return_tensors: str = "pt"
    

    def __call__(self, features: list) -> dict:
        batch = {}

        # TODO: make article number dynamic
        for article_number in range(1, 3):
            batch_sentences = list()
            batch_masks = list()
            
            sen_len_article = [len(sentence) for instance in features for sentence in instance[f"article_{article_number}"]]
            sen_len_mask = [len(sentence) for instance in features for sentence in instance[f"mask_{article_number}"]]
            
            assert sen_len_article == sen_len_mask, (
                f"There is a mismatch for article_{article_number} and mask_{article_number}."
                )
            
            sen_len = min(self.max_sentence_len, max(sen_len_article))
            
            doc_len_article = [len(instance[f"mask_{article_number}"]) for instance in features]
            doc_len = min(self.max_document_len, max(doc_len_article))

            for id in range(len(features)):
                sentences, masks = self.pad_sentence(sen_len, features[id], article_number)
                self.pad_document(sentences, masks, doc_len)

                batch_sentences.append(sentences)
                batch_masks.append(masks)
      
            # TODO: decide on dtype for tensor, torch.int/torch.long?
            batch[f"article_1": torch.tensor(batch_sentences, dtype=torch.int64)]
            batch[f"mask_1": torch.tensor(batch_masks, dtype=torch.int64)]
        
        return batch

    def pad_sentence(self, sen_len: int, feature: dict, i: int) -> tuple():
        sentences = [sentence + [self.tokenizer.convert_tokens_to_ids("[PAD]")] * (sen_len - len(sentence))  for sentence in feature[f"article_{i}"]]
        # TODO: check for attention_mask ID
        masks = [sentence + [0] * (sen_len - len(sentence))  for sentence in feature[f"mask_{i}"]]

        return sentences, masks

    def pad_document(self, sentences: list, masks: list, doc_len: int):
        mask_padding_array = [0 for i0 in range(len(masks[0]))]
        sentence_padding_array = [self.tokenizer.convert_tokens_to_ids("[PAD]") for i0 in range(len(sentences[0]))]

        if len(sentences) < doc_len:
            sentences += [sentence_padding_array for difference in range(doc_len - len(sentences))]
            masks += [mask_padding_array for difference in range(doc_len - len(masks))]
        elif len(sentences) > doc_len:
            sentences = sentences[: doc_len]
            masks = masks[: doc_len]

In [171]:
data_collator = CustomDataCollator(tokenizer=tokenizer, 
                                   max_sentence_len=args.max_sentence_len, 
                                   max_document_len=args.max_document_len)

train_dataloader = DataLoader(
    new_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
)

In [178]:
for i in train_dataloader:
    print(i)

36
36
64
64


SystemExit: 