In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataclasses import dataclass
from functools import partial
import os
from pathlib import Path
import re
import sys
from typing import Dict, List, Optional, Tuple, Any

if '..' not in sys.path: sys.path.append('..')

import numpy as np

from datasets import load_dataset
from transformers import AutoTokenizer, PreTrainedTokenizer

import torch

# from mllm.train.utils import WordToks

## Dataset loading

In [36]:
DATA_PATH = Path(os.path.expandvars('$HOME')) / 'data'
DATA_PATH = Path('Q:/data')
print(f"DATA_PATH: {DATA_PATH}")

DATA_PATH: Q:\data


In [4]:
# ds = load_dataset('wikipedia', '20200501.en',  beam_runner='DirectRunner', cache_dir=str(DATA_PATH), trust_remote_code=True)
ds = load_dataset('wikipedia', '20220301.en', cache_dir=str(DATA_PATH), trust_remote_code=True)
# ds = load_dataset('wikipedia', '20220301.en', cache_dir=str(DATA_PATH))

Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

In [5]:
list(ds.keys())

['train']

In [6]:
dst = ds['train']

In [7]:
len(dst)

6458670

In [9]:
dst[6078421]

{'title': 'Overload (novel)',
 'text': 'Overload (1979) is a novel by Arthur Hailey, concerning the electricity production industry in California and the activities of the employees and others involved with Golden State Power and Light, a fictional California public service company. The plot follows many of the issues of the day, including race relations, corporate politics, business ethics, terrorism and journalism. (Hailey would later explore (television) journalism in another novel, The Evening News.)\n\nPlot Synopsis\nThe novel is described from the point of view of vice-president of Golden State Power and Light, Nimrod "Nim" Goldman, who, despite being married, tends to be somewhat of a Lothario and has many extramarital affairs. The geographic area of service of the fictional electric utility, Golden State Power and Light, matches the actual Northern California footprint of the real-life Pacific Gas and Electric Company.\n\nGolden State Power and Light is a public utility, supply

In [10]:
dst.column_names

['title', 'text']

## Tokenization

In [8]:
pretrained_model_name = 'bert-base-uncased'
tkz = AutoTokenizer.from_pretrained(pretrained_model_name)
tkz

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [30]:
item = dst[20]
item

{'id': '569',
 'url': 'https://en.wikipedia.org/wiki/Anthropology',
 'title': 'Anthropology',
 'text': 'Anthropology is the scientific study of humanity, concerned with human behavior, human biology, cultures, societies, and linguistics, in both the present and past, including past human species. Social anthropology studies patterns of behaviour, while cultural anthropology studies cultural meaning, including norms and values. A portmanteau sociocultural anthropology is commonly used today. Linguistic anthropology studies how language influences social life. Biological or physical anthropology studies the biological development of humans.\n\nArchaeological anthropology, often termed as \'anthropology of the past\', studies human activity through investigation of physical evidence. It is considered a branch of anthropology in North America and Asia, while in Europe archaeology is viewed as a discipline in its own right or grouped under other related disciplines, such as history.\n\nEtym

In [32]:
title, text = item['title'], item['text'][:150]
print(f'{title}. Text len: {len(text)}')
inp_toks = tkz.tokenize(text)
print(inp_toks)

Anthropology. Text len: 150
['anthropology', 'is', 'the', 'scientific', 'study', 'of', 'humanity', ',', 'concerned', 'with', 'human', 'behavior', ',', 'human', 'biology', ',', 'cultures', ',', 'societies', ',', 'and', 'linguistics', ',', 'in', 'both', 'the', 'pre', '##s']


In [33]:
res1 = tkz(text, is_split_into_words=False)
print(res1)

{'input_ids': [101, 12795, 2003, 1996, 4045, 2817, 1997, 8438, 1010, 4986, 2007, 2529, 5248, 1010, 2529, 7366, 1010, 8578, 1010, 8384, 1010, 1998, 15397, 1010, 1999, 2119, 1996, 3653, 2015, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [34]:
res2 = tkz(text.split(), is_split_into_words=True)
print(res2)

{'input_ids': [101, 12795, 2003, 1996, 4045, 2817, 1997, 8438, 1010, 4986, 2007, 2529, 5248, 1010, 2529, 7366, 1010, 8578, 1010, 8384, 1010, 1998, 15397, 1010, 1999, 2119, 1996, 3653, 2015, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [35]:
res1 == res2

True

In [36]:
res3 = tkz(text.split(), add_special_tokens=False)
print(text.split())
res3


['Anthropology', 'is', 'the', 'scientific', 'study', 'of', 'humanity,', 'concerned', 'with', 'human', 'behavior,', 'human', 'biology,', 'cultures,', 'societies,', 'and', 'linguistics,', 'in', 'both', 'the', 'pres']


{'input_ids': [[12795], [2003], [1996], [4045], [2817], [1997], [8438, 1010], [4986], [2007], [2529], [5248, 1010], [2529], [7366, 1010], [8578, 1010], [8384, 1010], [1998], [15397, 1010], [1999], [2119], [1996], [3653, 2015]], 'token_type_ids': [[0], [0], [0], [0], [0], [0], [0, 0], [0], [0], [0], [0, 0], [0], [0, 0], [0, 0], [0, 0], [0], [0, 0], [0], [0], [0], [0, 0]], 'attention_mask': [[1], [1], [1], [1], [1], [1], [1, 1], [1], [1], [1], [1, 1], [1], [1, 1], [1, 1], [1, 1], [1], [1, 1], [1], [1], [1], [1, 1]]}

In [37]:
words3 = tkz.batch_decode(res3.input_ids)
words3

['anthropology',
 'is',
 'the',
 'scientific',
 'study',
 'of',
 'humanity,',
 'concerned',
 'with',
 'human',
 'behavior,',
 'human',
 'biology,',
 'cultures,',
 'societies,',
 'and',
 'linguistics,',
 'in',
 'both',
 'the',
 'pres']

In [38]:
s1 = tkz.decode(res1.input_ids, skip_special_tokens=True)
s2 = ' '.join(words3)
print(s2)
s1 == s2

anthropology is the scientific study of humanity, concerned with human behavior, human biology, cultures, societies, and linguistics, in both the pres


True

In [40]:
pat = re.compile(r'\b')
parts = pat.split(text)
''.join(parts) == text

True

In [41]:
parts_nospace = [p.lower().strip() for p in parts]
parts_nospace = filter(None, parts_nospace)
parts_nospace = list(parts_nospace)
parts_nospace

['anthropology',
 'is',
 'the',
 'scientific',
 'study',
 'of',
 'humanity',
 ',',
 'concerned',
 'with',
 'human',
 'behavior',
 ',',
 'human',
 'biology',
 ',',
 'cultures',
 ',',
 'societies',
 ',',
 'and',
 'linguistics',
 ',',
 'in',
 'both',
 'the',
 'pres']

In [42]:
s2 = ' '.join([p.lower() for p in parts_nospace])
s1 == s2

False

In [43]:
s1

'anthropology is the scientific study of humanity, concerned with human behavior, human biology, cultures, societies, and linguistics, in both the pres'

In [44]:
s2

'anthropology is the scientific study of humanity , concerned with human behavior , human biology , cultures , societies , and linguistics , in both the pres'

In [45]:
s = s1 + ' abc## ## ####xxx'
toks_ids = tkz(s, add_special_tokens=False).input_ids
toks_str = tkz.convert_ids_to_tokens(toks_ids)
assert len(toks_ids) == len(toks_str)
ids_to_str = [(tid, tis) for tid, tis in zip(toks_ids, toks_str)]
ids_to_str

[(12795, 'anthropology'),
 (2003, 'is'),
 (1996, 'the'),
 (4045, 'scientific'),
 (2817, 'study'),
 (1997, 'of'),
 (8438, 'humanity'),
 (1010, ','),
 (4986, 'concerned'),
 (2007, 'with'),
 (2529, 'human'),
 (5248, 'behavior'),
 (1010, ','),
 (2529, 'human'),
 (7366, 'biology'),
 (1010, ','),
 (8578, 'cultures'),
 (1010, ','),
 (8384, 'societies'),
 (1010, ','),
 (1998, 'and'),
 (15397, 'linguistics'),
 (1010, ','),
 (1999, 'in'),
 (2119, 'both'),
 (1996, 'the'),
 (3653, 'pre'),
 (2015, '##s'),
 (5925, 'abc'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (1001, '#'),
 (22038, 'xx'),
 (2595, '##x')]

In [46]:
tkz.decode(toks_ids)

'anthropology is the scientific study of humanity, concerned with human behavior, human biology, cultures, societies, and linguistics, in both the pres abc # # # # # # # # xxx'

In [None]:
wt = WordToks(tkz, s, max_tgt_len_freq=0.2, max_tgt_len=10)
inp_str, inp_masked_str, tgt_str = tkz.decode(wt.inp_toks), tkz.decode(wt.inp_masked_toks), tkz.decode(wt.tgt_toks)
print(inp_str)
print(inp_masked_str)
print(tgt_str)

anthropology is the scientific study of humanity, concerned with human behavior, < | cite _ begin | > human biology, cultures < | cite _ end | >, societies, and linguistics, in both the pres abc # # # # # # # # xxx
anthropology is the scientific study of humanity, concerned with human behavior, < | cite _ begin | > [MASK] [MASK] [MASK] [MASK] < | cite _ end | >, societies, and linguistics, in both the pres abc # # # # # # # # xxx
human biology, cultures


## Mapping and processing the dataset

In [6]:
tkz = AutoTokenizer.from_pretrained('bert-base-uncased', use_fast=True)
tkz

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [26]:
n = 5
ds_sub = ds['train'].select(range(n))
len(ds_sub), ds_sub.column_names

(5, ['id', 'url', 'title', 'text'])

In [27]:
def tokenize_item(tokenizer: PreTrainedTokenizer, item):
    text = item['text']
    toks = tokenizer(text, add_special_tokens=False).input_ids
    return {
        **item,
        'toks': toks,
        'toks_len': len(toks),
    }

def extract_masked_input(item: Dict, mask_token_id: int = 103, pad_token_id: int = 0, max_seq_length: int = 512) -> Dict:
    input_ids = item['toks'][:max_seq_length]

    return {
        **item,
        'input_ids': torch.tensor(input_ids, dtype=torch.long),
    }

def extract_masked_input_rnd(item: Dict, mask_token_id: int = 103, pad_token_id: int = 0, max_seq_len: int = 512) -> Dict:
    cur_len = item['toks_len']
    max_seq_len = min(cur_len, max_seq_len)
    input_ids = item['toks']
    if max_seq_len < cur_len:
        ind_off_max = cur_len - max_seq_len + 1
        ind_off_max = min(ind_off_max, 3)
        ind_off = np.random.randint(0, ind_off_max)
        input_ids = input_ids[ind_off:ind_off + max_seq_len]

    return {
        **item,
        'input_ids': torch.tensor(input_ids, dtype=torch.long),
    }



In [28]:
ds_sub1 = ds_sub.map(partial(tokenize_item, tkz))
ds_sub1['toks_len']

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (8349 > 512). Running this sequence through the model will result in indexing errors


[8349, 9058, 4568, 2322, 15590]

In [29]:
ds_sub2 = ds_sub1
ds_sub2 = ds_sub2.map(extract_masked_input, fn_kwargs={'mask_token_id': tkz.mask_token_id, 'max_seq_length': 5})
ds_sub2['input_ids']

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

[[9617, 11140, 2964, 2003, 1037],
 [19465, 2003, 1037, 11265, 10976],
 [2632, 28759, 1006, 1025, 1007],
 [1037, 1010, 2030, 1037, 1010],
 [6041, 1006, 1007, 2003, 1037]]

In [30]:
ds_sub3 = ds_sub1
ds_sub3 = ds_sub3.map(extract_masked_input_rnd, fn_kwargs={'mask_token_id': tkz.mask_token_id, 'max_seq_len': 5})
ds_sub3['input_ids']

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

[[11140, 2964, 2003, 1037, 2576],
 [2003, 1037, 11265, 10976, 24844],
 [1006, 1025, 1007, 2003, 1996],
 [1037, 1010, 2030, 1037, 1010],
 [6041, 1006, 1007, 2003, 1037]]