In [3]:
from dataclasses import dataclass
import os
from pathlib import Path
import re

import numpy as np

from datasets import load_dataset
from transformers import AutoTokenizer, PreTrainedTokenizer

## Dataset loading

In [4]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'

In [None]:
# ds = load_dataset('wikipedia', '20200501.en',  beam_runner='DirectRunner', cache_dir=str(DATA_PATH))
ds = load_dataset('wikipedia', '20220301.en', cache_dir=str(DATA_PATH))

Downloading data:   0%|          | 0/41 [00:00<?, ?files/s]

train-00000-of-00041.parquet:   1%|          | 10.5M/1.05G [00:00<?, ?B/s]

In [4]:
list(ds.keys())

['train']

In [5]:
dst = ds['train']

In [6]:
len(dst)

6078422

In [7]:
dst[6078421]

{'title': 'Overload (novel)',
 'text': 'Overload (1979) is a novel by Arthur Hailey, concerning the electricity production industry in California and the activities of the employees and others involved with Golden State Power and Light, a fictional California public service company. The plot follows many of the issues of the day, including race relations, corporate politics, business ethics, terrorism and journalism. (Hailey would later explore (television) journalism in another novel, The Evening News.)\n\nPlot Synopsis\nThe novel is described from the point of view of vice-president of Golden State Power and Light, Nimrod "Nim" Goldman, who, despite being married, tends to be somewhat of a Lothario and has many extramarital affairs. The geographic area of service of the fictional electric utility, Golden State Power and Light, matches the actual Northern California footprint of the real-life Pacific Gas and Electric Company.\n\nGolden State Power and Light is a public utility, supply

## Tokenization

In [11]:
pretrained_model_name = 'bert-base-uncased'
tkz = AutoTokenizer.from_pretrained(pretrained_model_name)
tkz

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [13]:
item = dst[10]
item

{'title': 'Ollombo District',
 'text': 'Ollombo is a district in the Plateaux  Department of Republic of the Congo.\n\nReferences \n\nCategory:Plateaux Department (Republic of the Congo)\nCategory:Districts of the Republic of the Congo'}

In [14]:
title, text = item['title'], item['text']
print(f'{title}. Text len: {len(text)}')
toks = tkz.tokenize(text)
print(toks)

Ollombo District. Text len: 190
['ol', '##lom', '##bo', 'is', 'a', 'district', 'in', 'the', 'plateau', '##x', 'department', 'of', 'republic', 'of', 'the', 'congo', '.', 'references', 'category', ':', 'plateau', '##x', 'department', '(', 'republic', 'of', 'the', 'congo', ')', 'category', ':', 'districts', 'of', 'the', 'republic', 'of', 'the', 'congo']


In [26]:
res1 = tkz(text, is_split_into_words=False)
print(res1)

{'input_ids': [101, 19330, 21297, 5092, 2003, 1037, 2212, 1999, 1996, 9814, 2595, 2533, 1997, 3072, 1997, 1996, 9030, 1012, 7604, 4696, 1024, 9814, 2595, 2533, 1006, 3072, 1997, 1996, 9030, 1007, 4696, 1024, 4733, 1997, 1996, 3072, 1997, 1996, 9030, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [27]:
res2 = tkz(text.split(), is_split_into_words=True)
print(res2)

{'input_ids': [101, 19330, 21297, 5092, 2003, 1037, 2212, 1999, 1996, 9814, 2595, 2533, 1997, 3072, 1997, 1996, 9030, 1012, 7604, 4696, 1024, 9814, 2595, 2533, 1006, 3072, 1997, 1996, 9030, 1007, 4696, 1024, 4733, 1997, 1996, 3072, 1997, 1996, 9030, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [28]:
res1 == res2

True

In [37]:
res3 = tkz(text.split(), add_special_tokens=False)
print(text.split())
res3


['Ollombo', 'is', 'a', 'district', 'in', 'the', 'Plateaux', 'Department', 'of', 'Republic', 'of', 'the', 'Congo.', 'References', 'Category:Plateaux', 'Department', '(Republic', 'of', 'the', 'Congo)', 'Category:Districts', 'of', 'the', 'Republic', 'of', 'the', 'Congo']


{'input_ids': [[19330, 21297, 5092], [2003], [1037], [2212], [1999], [1996], [9814, 2595], [2533], [1997], [3072], [1997], [1996], [9030, 1012], [7604], [4696, 1024, 9814, 2595], [2533], [1006, 3072], [1997], [1996], [9030, 1007], [4696, 1024, 4733], [1997], [1996], [3072], [1997], [1996], [9030]], 'token_type_ids': [[0, 0, 0], [0], [0], [0], [0], [0], [0, 0], [0], [0], [0], [0], [0], [0, 0], [0], [0, 0, 0, 0], [0], [0, 0], [0], [0], [0, 0], [0, 0, 0], [0], [0], [0], [0], [0], [0]], 'attention_mask': [[1, 1, 1], [1], [1], [1], [1], [1], [1, 1], [1], [1], [1], [1], [1], [1, 1], [1], [1, 1, 1, 1], [1], [1, 1], [1], [1], [1, 1], [1, 1, 1], [1], [1], [1], [1], [1], [1]]}

In [34]:
words3 = tkz.batch_decode(res3.input_ids)
words3

['ollombo',
 'is',
 'a',
 'district',
 'in',
 'the',
 'plateaux',
 'department',
 'of',
 'republic',
 'of',
 'the',
 'congo.',
 'references',
 'category : plateaux',
 'department',
 '( republic',
 'of',
 'the',
 'congo )',
 'category : districts',
 'of',
 'the',
 'republic',
 'of',
 'the',
 'congo']

In [36]:
s1 = tkz.decode(res1.input_ids, skip_special_tokens=True)
s2 = ' '.join(words3)
print(s2)
s1 == s2

ollombo is a district in the plateaux department of republic of the congo. references category : plateaux department ( republic of the congo ) category : districts of the republic of the congo


True

In [None]:
tkz.tokenize()

In [41]:
pat = re.compile(r'\b')
parts = pat.split(text)
''.join(parts) == text

True

In [47]:
parts_nospace = [p.lower().strip() for p in parts]
parts_nospace = filter(None, parts_nospace)
parts_nospace = list(parts_nospace)
parts_nospace

['ollombo',
 'is',
 'a',
 'district',
 'in',
 'the',
 'plateaux',
 'department',
 'of',
 'republic',
 'of',
 'the',
 'congo',
 '.',
 'references',
 'category',
 ':',
 'plateaux',
 'department',
 '(',
 'republic',
 'of',
 'the',
 'congo',
 ')',
 'category',
 ':',
 'districts',
 'of',
 'the',
 'republic',
 'of',
 'the',
 'congo']

In [49]:
s2 = ' '.join([p.lower() for p in parts_nospace])
s1 == s2

False

In [50]:
s1

'ollombo is a district in the plateaux department of republic of the congo. references category : plateaux department ( republic of the congo ) category : districts of the republic of the congo'

In [51]:
s2

'ollombo is a district in the plateaux department of republic of the congo . references category : plateaux department ( republic of the congo ) category : districts of the republic of the congo'

In [59]:
s = s1 + ' abc## ## ####xxx'
toks_ids = tkz(s, add_special_tokens=False).input_ids
toks_str = tkz.convert_ids_to_tokens(toks_ids)
assert len(toks_ids) == len(toks_str)
ids_to_str = [(tid, tis) for tid, tis in zip(toks_ids, toks_str)]
ids_to_str

[(19330, 'ol'),
 (21297, '##lom'),
 (5092, '##bo'),
 (2003, 'is'),
 (1037, 'a'),
 (2212, 'district'),
 (1999, 'in'),
 (1996, 'the'),
 (9814, 'plateau'),
 (2595, '##x'),
 (2533, 'department'),
 (1997, 'of'),
 (3072, 'republic'),
 (1997, 'of'),
 (1996, 'the'),
 (9030, 'congo'),
 (1012, '.'),
 (7604, 'references'),
 (4696, 'category'),
 (1024, ':'),
 (9814, 'plateau'),
 (2595, '##x'),
 (2533, 'department'),
 (1006, '('),
 (3072, 'republic'),
 (1997, 'of'),
 (1996, 'the'),
 (9030, 'congo'),
 (1007, ')'),
 (4696, 'category'),
 (1024, ':'),
 (4733, 'districts'),
 (1997, 'of'),
 (1996, 'the'),
 (3072, 'republic'),
 (1997, 'of'),
 (1996, 'the'),
 (9030, 'congo'),
 (5925, 'abc'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (22038, 'xx'),
 (2595, '##x')]

In [57]:
tkz.decode(toks_ids)

'ollombo is a district in the plateaux department of republic of the congo. references category : plateaux department ( republic of the congo ) category : districts of the republic of the congo abc # # # # # # # #'

In [None]:
class WordToks:
    tkz: PreTrainedTokenizer
    s: str
    toks_ids: list[int]
    toks_strs: list[str]
    words_inds_lens: list[tuple[int, int]]
    tags_names: list[str] = ['cite_begin', 'cite_end']
    tags_dict: dict[str, str]
    max_tgt_len_fraq: float
    max_tgt_len: int
    words_inds_tgt: tuple[int, int]
    off_words_tgt: int
    n_words_tgt: int

    def __init__(self, tkz: PreTrainedTokenizer, s: str, max_tgt_len_fraq: float = 0, max_tgt_len: int = 0):
        self.tkz = tkz
        self.s = s
        self.toks_ids = self.tkz(s, add_special_tokens=False).input_ids
        self.toks_strs = self.tkz.convert_ids_to_tokens(self.toks_ids)
        self.words_inds_lens = self.calc_inds_lens()
        self.tags_dict = {tname: f'<{tname}>' for tname in self.tags_names}
        assert max_tgt_len_fraq > 0 or max_tgt_len > 0, \
            f'At least max_tgt_len_fraq (={max_tgt_len_fraq}) or max_tgt_len (={max_tgt_len}) must be positive.'
        self.max_tgt_len_fraq = max_tgt_len_fraq
        self.max_tgt_len = max_tgt_len
        self.words_inds_tgt = self.gen_words_inds()
        self.off_words_tgt, self.n_words_tgt = self.gen_words_inds()
    
    def calc_inds_lens(self) -> list[tuple[int, int]]:
        res = []
        n_toks_ids, n_toks_strs = len(self.toks_ids), len(self.toks_strs)
        assert n_toks_ids == n_toks_strs, f'n_toks_ids (={n_toks_ids}) must be equal to n_toks_strs (={n_toks_strs})'
        assert n_toks_ids > 0
        assert not self.toks_strs[0].startswith('##'), f'First token cannot start from ##. Tokens: {self.toks_strs}'
        if n_toks_ids == 0:
            return res
        off, len_ = 0, 1
        for i in range(1, n_toks_strs):
            tok_str = self.toks_strs[i]
            if not tok_str.startswith('##'):
                res.append((off, len_))
                off, len_ = i, 1
            len_ += 1
        res.append((off, len_))
        return res

    def gen_words_inds(self) -> tuple[int, int]:
        n_words = len(self.words_inds_lens)
        if self.max_tgt_len <= 0:
            max_len = int(self.max_tgt_len_fraq * n_words)
        elif self.max_tgt_len_fraq <= 0:
            max_len = self.max_tgt_len
        else:
            max_len = min(self.max_tgt_len, int(self.max_tgt_len_fraq * n_words))
        max_len = min(max_len, int(0.5 * n_words))
        max_len = max(max_len, 1)
        cite_len = np.random.randint(1, max_len + 1)
        n_rest = n_words - cite_len
        assert n_rest > 0, f'n_rest (={n_rest}) must be positive.'
        off = np.random.randint(n_rest + 1)
        return off, n_words
    
    def create_tgt(self) -> tuple[list[int], list[int]]:
        pass
        

        
