In [None]:
preprocess = True

# Prepare Dataset

In [None]:
%%capture
!pip install datasets
!pip install binpacking
!pip install sentencepiece

In [None]:
import regex as re
import string
import io

import binpacking
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def fragment_text_to_words(text):
    words = re.findall(r'\s*\S+', text)
    return words

def fragment_texts(examples):
    return [(x, fragment_text_to_words(x)) for x in (examples['text'])]

def replace_caps_fn(m):
    return '<capss>' + m.group(1).lower() + '<capse>' + m.group(3)

def replace_caps(text):
    if text[0] != ' ':
        text = ' ' + text
        add_bksp_start = True
    else:
        add_bksp_start = False

    text = re.sub(r'([^\x20\w])([^\s\p{P}\p{Nd}])', r'\1<bksp> \2', text)

    pattern = u'(\x20*\p{Lu}{2,}([\s\p{P}]*\p{Lu}{2,})*)([\s\p{P}])'
    text = re.sub(pattern, replace_caps_fn, text)

    text = re.sub(u'(\x20*[\p{Lu}])', lambda x: '<shift>' + x.group(1).lower(), text)

    if add_bksp_start:
        text = '<bksp>' + text
    return text

In [None]:
def chunk_examples(examples, length=4096, preprocess=True):
    fragmented_texts = fragment_texts(examples)
    chunks = []
    lengths = []

    small_chunks = []
    err = 0
    for (x, frag_text) in fragmented_texts:
        if len(frag_text) > length:
            start_idx = np.random.randint(0, len(frag_text) - length)
            x = ''.join(frag_text[start_idx:start_idx+length])
            cur_length = length

            chunks.append(x)
            lengths.append(length)

        else:
            small_chunks.append((x, len(frag_text)))

    bins = binpacking.to_constant_volume(small_chunks, length, weight_pos=1)
    chunks += [' '.join([y[0] for y in x]) for x in bins]
    if preprocess:
        chunks = [replace_caps(x) for x in chunks]
    lengths += [sum([y[1] for y in x]) for x in bins]

    return {"chunks": chunks, "lengths": lengths}

In [None]:
import datasets
import functools

wiki_dataset = datasets.load_dataset("wikipedia", "20220301.en",
                            split="train",
                            streaming=True)

def shuffle_chunk_dataset(ds, column_names, skip=None, preprocess=True):
    ds = ds.shuffle(seed=42)
    if skip: ds = ds.skip(skip)
    chunker = functools.partial(chunk_examples, preprocess=preprocess)
    ds = ds.map(chunker, batched=True,
                remove_columns=column_names, batch_size=10000)
    return ds

wiki_dataset = shuffle_chunk_dataset(wiki_dataset, wiki_dataset.column_names,
                                     preprocess=preprocess)

Downloading builder script:   0%|          | 0.00/35.9k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/30.4k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/16.3k [00:00<?, ?B/s]

In [None]:
def data_gen(ds, num_tokens=10000, count_filter=None):
    cur_tokens = 0
    last_print = 0
    for (i, x) in enumerate(ds):
        text = x['chunks']
        cur_tokens += x['lengths']
        last_print += x['lengths']
        if cur_tokens > num_tokens:
            break
        elif last_print > 10**6:
            last_print = 0
            print(cur_tokens)

        yield text

In [None]:
import gzip

In [None]:
with gzip.open('file.gz', 'wb') as f:
    for x in data_gen(wiki_dataset, num_tokens=7.5*10**7):
        f.write(x.encode('unicode_escape') + b'\n')

1003520
2007040
3010560
4014080
5017600
6021120
7024640
8028160
9031680
10035200
11038720
12042240
13045590
14049081
15052475
16055881
17059269
18062593
19065936
20069290
21072463
22075579
23078655
24079742
25083262
26086782
27090302
28093822
29097342
30100862
31104382
32107902
33111422
34114942
35118425
36121925
37125437
38128867
39132268
40135611
41138949
42142248
43145492
44148844
45151577
46155097
47158617
48162137
49165657
50169177
51172697
52176217
53179737
54183219
55186726
56190202
57193602
58196957
59200306
60203643
61206925
62210243
63213631
64216284
65219804
66223324
67226844
68230364
69233884
70237404
71240853
72244367
73247831
74251283


# Reload and train

In [None]:
!pip install sentencepiece



In [None]:
import sentencepiece as spm
import gzip
import io
import string

In [None]:
def iterate(fname):
    with gzip.open(fname, 'rb') as f:
        for x in f:
            x = x[:-1].decode('unicode_escape')
            yield(x)

In [None]:
next(iterate('file.gz'))

' students are sometimes called freshers early in the academic year; however, there are no specific names for those in other years nor for school pupils.<shift> graduate and professional students in the<shift> united<shift> states are known by their year of study, such as a "<bksp> second-<bksp> year medical student" or a "<bksp> fifth-<bksp> year doctoral candidate."<shift> law students are often referred to as "1<shift>l", "2<shift>l", or "3<shift>l" rather than "<bksp> nth-<bksp> year law students"; similarly, medical students are frequently referred to as "<bksp><shift> m1", "<bksp><shift> m2", "<bksp><shift> m3", or "<bksp><shift> m4".\t\n\n<bksp><shift> while anyone in the<capss> us<capse> who finishes studying at any educational institution by passing relevant examinations is said to graduate and to be a graduate, in the<capss> uk<capse> only degree and above level students can graduate.<shift> student itself has a wider meaning in<shift> am<shift>e, meaning any person of any ag

In [None]:
model = io.BytesIO()
spm.SentencePieceTrainer.train(sentence_iterator=iterate('file.gz'),
                            model_writer=model,
                            vocab_size=16384,
                            user_defined_symbols=['<shift>', '<capss>', '<capse>', '<bksp>'] + [c for c in string.digits + string.punctuation],
                            model_type='unigram',
                            max_sentence_length=48000,
                            character_coverage=0.99)

sp = spm.SentencePieceProcessor(model_proto=model.getvalue())

In [None]:
sp.id_to_piece(list(range(16384)))

['<unk>',
 '<s>',
 '</s>',
 '<shift>',
 '<capss>',
 '<capse>',
 '<bksp>',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '!',
 '"',
 '#',
 '$',
 '%',
 '&',
 "'",
 '(',
 ')',
 '*',
 '+',
 ',',
 '-',
 '.',
 '/',
 ':',
 ';',
 '<',
 '=',
 '>',
 '?',
 '@',
 '[',
 '\\',
 ']',
 '^',
 '_',
 '`',
 '{',
 '|',
 '}',
 '~',
 '▁',
 '▁the',
 '▁of',
 's',
 '▁and',
 '▁in',
 '▁to',
 '▁a',
 '▁is',
 '▁as',
 '▁was',
 '▁for',
 '▁by',
 '▁that',
 '▁with',
 '▁on',
 'e',
 '▁e',
 '▁from',
 '▁are',
 '▁it',
 'a',
 'ed',
 '▁at',
 '▁his',
 'i',
 '▁be',
 '▁an',
 '▁he',
 '▁which',
 '▁or',
 'ing',
 'd',
 '▁this',
 '▁were',
 't',
 '▁also',
 'o',
 '▁not',
 '▁have',
 '▁had',
 '▁has',
 'n',
 '▁but',
 '▁their',
 'th',
 '▁first',
 '▁one',
 '▁other',
 '▁its',
 '▁new',
 '▁they',
 'y',
 'or',
 'r',
 '▁can',
 '▁been',
 '▁after',
 '▁two',
 'u',
 '▁such',
 '▁who',
 '▁all',
 '▁american',
 'p',
 '▁more',
 'er',
 '▁used',
 '▁b',
 'ly',
 '▁may',
 '▁some',
 'on',
 '▁when',
 'c',
 '▁time',
 '▁into',
 'al',
 '▁most',
 'es',

In [None]:
with open('capstoken.spm', 'wb') as f:
  f.write(model.getvalue())