In [18]:
import os
import sys
import configparser
from dataclasses import dataclass

import datasets

from nltk import sent_tokenize
from datasets import load_from_disk
from tokenizers import Tokenizer
from transformers import BertModel, AutoTokenizer, AutoModel, BertTokenizerFast
from torch.utils.data import DataLoader

# Loading

In [4]:
file_name = "../tokenizers/tokenizer_2_1000.json"

base_tokenizer = Tokenizer.from_file(file_name)
tokenizer = BertTokenizerFast(tokenizer_object=base_tokenizer)

In [7]:
root_dir = r"C:\Users\onurg\.cache\huggingface\datasets"
data_dir = "updated_wiki40b"

path = os.path.join(root_dir, data_dir, "long_small_dataset")

In [8]:
dataset = load_from_disk(path)
dataset

Dataset({
    features: ['article_1', 'article_2', 'pair', 'wikidata_id'],
    num_rows: 1558009
})

In [9]:
config = configparser.ConfigParser()
config.read('config.ini')

[]

# Prototyping

In [10]:
small_dataset = dataset.select(range(100))

In [47]:
@dataclass
class Config:
    def __init__(self, max_sentence_len, max_document_len, per_device_train_batch_size):
      self.max_sentence_len = max_sentence_len
      self.max_document_len = max_document_len
      self.per_device_train_batch_size = per_device_train_batch_size

args = Config(128,32, 8)

In [12]:
def tokenize(example, tokenizer, args):
    # https://github.com/castorini/hedwig/blob/master/datasets/bert_processors/abstract_processor.py
    # https://github.com/abhishekkrthakur/bert-entity-extraction/blob/master/src/dataset.py

    def tokenize_helper(article, tokenizer, args):
        # TODO: comment
        sentences = [tokenizer.encode(sentence, add_special_tokens=False) for sentence in sent_tokenize(article)] 
        sentences = [sentence[:args.max_sentence_len - 2] for sentence in sentences]
        sentences = [[tokenizer.convert_tokens_to_ids("[CLS]")] + sentence + [tokenizer.convert_tokens_to_ids("[SEP]")] for sentence in sentences]

        sentence_lengths = [len(sentence) for sentence in sentences]
        mask = [[1]*sen_len for sen_len in sentence_lengths]

        return sentences, mask
 
    for i in range(1, 3):
        
        # example[f"article_{i}"] = [tokenizer.encode(sentence, 
        #                                             truncation=True,
        #                                             add_special_tokens=True,                                            
        #                                             max_length=args.max_sentence_len) for sentence in sent_tokenize(example[f"article_{i}"])]
                                                
        #print(example)

        example[f"article_{i}"], example[f"mask_{i}"] = tokenize_helper(example[f"article_{i}"], tokenizer, args)
        
    return example

In [14]:
new_dataset = small_dataset.map(tokenize, fn_kwargs={"tokenizer": tokenizer, "args": args})

100%|██████████| 100/100 [00:02<00:00, 34.85ex/s]


In [15]:
new_dataset

Dataset({
    features: ['article_1', 'article_2', 'pair', 'wikidata_id', 'mask_1', 'mask_2'],
    num_rows: 100
})

In [44]:
@dataclass
class CustomDataCollator:
    tokenizer: None
    max_sentence_len: int = 128
    max_document_len: int = 32
    return_tensors: str = "pt"
    
    def __call__(self, features: list) -> dict:
        print(len(features))



In [45]:
data_collator = CustomDataCollator(tokenizer=tokenizer, 
                                   max_sentence_len=args.max_sentence_len, 
                                   max_document_len=args.max_doc_len)

train_dataloader = DataLoader(
    new_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
)

In [46]:
for i in train_dataloader:
    print(i)

8
None
8
None
8
None
8
None
8
None
8
None
8
None
8
None
8
None
8
None
8
None
8
None
4
None
