# Russian Poetry Generation

#### GPT-2 Pretrain & Fine-Tuning from scratch

## Setup

#### Load Taiga Proza Corpus

In [None]:
import pandas as pd

from pathlib import Path

datasets_path = Path(r'datasets')

df = pd.read_csv(datasets_path/'taiga_proza_1GB_43350.csv', sep='|', encoding='utf-8-sig')
df.head()

In [None]:
# Remove empty rows
df.isna().sum()
df.dropna(inplace=True)

In [None]:
from datasets import Dataset

# Create a Huggingface Dataset Instance
taiga_proza_russian = Dataset.from_pandas(df)
taiga_proza_russian

In [None]:
# Check first sample
taiga_proza_russian[0]['text'][:512]

#### Cleanup

In [None]:
import re
import regex

from functools import partial
from num2words import num2words

programming_quotes_pattern = r'".+"' # replace "" with «»
fake_dash_pattern = r'-{2,}' # replace -- with —
replace_number_pattern = r'[0-9]+' # replace number to word format
cyrillic_only_pattern = r'[^#А-Яа-яёй\-—\.,:;!?…|«»„“\s]' # leave only this letters
broken_manydots_pattern = r'(\.){2,}' # replace ... with …
double_punct_pattern = r'[\-—\.,:;!?…]{2,}' # if two punctuations are in place, leave first
double_space_pattern = r'\s{2,}'
garbage_punctuation_pattern = r'(?<=[А-Яа-я]\s{1,})[\-—\.,:;!?…|«»„“\s]{2,}(?=\s+[А-Яа-я])' # replace — . — with first punct
space_after_punct_pattern = r'[\.,:;!?…](?![\s«»„“])' # make sure all classic punctuations have space after them
remove_space_before_punct_pattern = r'\s+[\.,:;!?…]' # make sure there is no space before punctuation
m_dash_with_space_pattern = r'((?<!\s)—)|(—(?!\s))' # make sure m-dash with spaces around it
#fix_hyphen_pattern = r'(?<=(?:бледно|почивала|не|давай|выпукло|толстый|норд|западно|день|умница|да|пару|когда|генерал|паинька|луна|опять|буль|один|где|хухры|кой|как|киловатт|что|много|который|давным|винни|англо|жар|мать|путь|она|елки|во|тихо|баба|бутылочно|так|крепко|чуть|волей|пол|гав|ей|кто|рад|нет|пила|изжелта|поди|премьер|юго|купля|прямо|тик|кошки|нежданно|северно|крест|черным|узнал|из|бой|ну|южно|восточно|запад|северо|маленький|в|изба|мини|еле|цып|страсти|кол|дизель|зюйд|динь|сегодня|человеко|статс|ням|по|никогда|мало|на|кое|ого|экс|штабс|север|для|пепельно|раз|какой|год|постояли|вице|завтра|фигли|темно|все|миру)(\s+)?)\s?[\-—]\s?(?=(\s+)?(?:мотор|соловьиному|юбка|яблока|а|под|первых|день|своему|птичьи|тка|чем|черно|ка|единственный|русый|рыба|накрепко|буль|где|разумница|богу|мигли|иному|что|мир|палки|продажа|никак|вторых|много|капитан|русски|одинешенек|западный|нибудь|куда|час|моему|европейская|птица|во|дама|баба|восточный|так|го|тройку|восток|ледовитый|чуть|накрест|пух|неволей|гав|читальня|кто|мордасти|нет|яга|братски|видимому|розовый|вогнутый|деньской|пустому|либо|героиня|радехонек|мальски|сибирская|с|запад|мухры|маленький|третьих|дорога|то|мальчик|еле|назавтра|цып|таки|зеленый|мышки|лимона|губернатор|серый|динь|давно|ням|красный|сахалинск|прежнему|помалу|адмирал|какими|японский|другой|над|майор|негаданно|смирно|за|какой|немецки|два|постояли|завтра|ост|претолстый|министр|вест|парк|де|юг|гора))' # fix incorrect hyphen spaces around most common words
fix_hyphen_pattern = r'(?<=(?<![А-Яа-яЙЁйё])(?:бледно|почивала|не|давай|выпукло|толстый|норд|западно|день|умница|да|пару|когда|генерал|паинька|луна|опять|буль|один|где|хухры|кой|как|киловатт|что|много|который|давным|винни|англо|жар|мать|путь|она|елки|во|тихо|баба|бутылочно|так|крепко|чуть|волей|пол|гав|ей|кто|рад|нет|пила|изжелта|поди|премьер|юго|купля|прямо|тик|кошки|нежданно|северно|крест|черным|узнал|из|бой|ну|южно|восточно|запад|северо|маленький|в|изба|мини|еле|цып|страсти|кол|дизель|зюйд|динь|сегодня|человеко|статс|ням|по|никогда|мало|на|кое|ого|экс|штабс|север|для|пепельно|раз|какой|год|постояли|вице|завтра|фигли|темно|все|миру|)(\s+)?)\s?[\-—]\s?(?=(\s+)?(?:мотор|соловьиному|юбка|яблока|а|под|первых|день|своему|птичьи|тка|чем|черно|ка|единственный|русый|рыба|накрепко|буль|где|разумница|богу|мигли|иному|что|мир|палки|продажа|никак|вторых|много|капитан|русски|одинешенек|западный|нибудь|куда|час|моему|европейская|птица|во|дама|баба|восточный|так|го|тройку|восток|ледовитый|чуть|накрест|пух|неволей|гав|читальня|кто|мордасти|нет|яга|братски|видимому|розовый|вогнутый|деньской|пустому|либо|героиня|радехонек|мальски|сибирская|с|запад|мухры|маленький|третьих|дорога|то|мальчик|еле|назавтра|цып|таки|зеленый|мышки|лимона|губернатор|серый|динь|давно|ням|красный|сахалинск|прежнему|помалу|адмирал|какими|японский|другой|над|майор|негаданно|смирно|за|какой|немецки|два|постояли|завтра|ост|претолстый|министр|вест|парк|де|юг|гора)(?![А-Яа-яЙЁйё]))' # fix incorrect hyphen spaces around most common words
replace_hyphen_with_dash_pattern = r'((?<=\s)-)|(-(?=\s))' # replace hyphen with mdash where it required
replace_hyphen_with_dash_second_step_pattern = r'(\s[\-])|(\s[\-])' # where hyphen is sticked to word it's a dash

programming_quotes = partial(re.sub, programming_quotes_pattern, lambda m: f'«{m.group(0)[1:-1].strip()}»')
fake_dash = partial(re.sub, fake_dash_pattern, '—')
replace_number = partial(re.sub, replace_number_pattern, lambda m: num2words(int(m.group(0)), lang='ru'))
cyrillic_only = partial(re.sub, cyrillic_only_pattern, '')
broken_manydots = partial(re.sub, broken_manydots_pattern, '…')
double_punct = partial(re.sub, double_punct_pattern, lambda m: m.group(0)[0], flags=re.MULTILINE)
remove_double_space = partial(re.sub, double_space_pattern, ' ')
garbage_punctuation = partial(regex.sub, garbage_punctuation_pattern, lambda m: m.group(0)[0])
space_after_punct = partial(re.sub, space_after_punct_pattern, lambda m: f'{m.group(0)} ')
remove_space_before_punct = partial(re.sub, remove_space_before_punct_pattern, lambda m: f'{m.group(0).strip()}', flags=re.MULTILINE)
m_dash_with_space = partial(re.sub, m_dash_with_space_pattern, lambda m: f' {m.group(0)} ')
fix_hyphen = partial(regex.sub, fix_hyphen_pattern, '-', flags=re.IGNORECASE)
replace_hyphen_with_dash = partial(re.sub, replace_hyphen_with_dash_pattern, '—', flags=re.MULTILINE)
replace_hyphen_with_dash_second_step = partial(re.sub, replace_hyphen_with_dash_second_step_pattern, ' — ', flags=re.MULTILINE)

In [None]:
def cleanup_taiga_sample(samples):
    text_batch = []

    for sample in samples["text"]:

        sample = programming_quotes(sample)
        sample = fake_dash(sample)
        sample = broken_manydots(sample)
        #sample = replace_number(sample) # ignore, for now
        sample = cyrillic_only(sample)
        sample = m_dash_with_space(sample)
        sample = fix_hyphen(sample)
        sample = replace_hyphen_with_dash(sample)
        sample = remove_double_space(sample)
        sample = space_after_punct(sample)
        sample = replace_hyphen_with_dash_second_step(sample)
        sample = garbage_punctuation(sample)
        sample = remove_space_before_punct(sample)
        sample = double_punct(sample)
        
        sample = sample.strip()

        if sample != '':
            sample = sample[0].capitalize() + sample[1:]
        
        text_batch += [sample]

    return {
        "text_cleaned": text_batch
    }

In [None]:
taiga_proza_russian = taiga_proza_russian.map(
    cleanup_taiga_sample, 
    batched=True,
    remove_columns=taiga_proza_russian.column_names
)
taiga_proza_russia

In [None]:
from datasets import load_from_disk

# Save & load to disk. Will save up some RAM
taiga_proza_russian.save_to_disk(str(datasets_path/'taiga-corpus-full-russian-cleaned'))
taiga_proza_russian = load_from_disk(str(datasets_path/'taiga-corpus-full-russian-cleaned'))

## Preprocessing

#### Tokenization & arabization

In [None]:
import nltk
nltk.download('punkt_tab')

from nltk.tokenize import sent_tokenize, word_tokenize

NEWLINE_TAG = '<|newline|>' # later, will break poem to separate lines

def sentence_and_word_tokenization(sample: str) -> list[list[str]]:
    # 01. Split text into nested list, with sentences and words
    sentence_tokens = [word_tokenize(t) for t in sent_tokenize(sample)]

    # 02. Cleanup non-cyrillic symbols
    filtered_list = [[item.strip() for item in sublist if item.strip()] for sublist in [[item for item in sublist] for sublist in sentence_tokens]]

    # 03. Remove empty elements
    filtered_list = [i for i in filtered_list if i]
    return filtered_list

def arabize_nested_list(nested_list: list[list[str]]) -> list[list[str]]:
    return [[item for item in reversed(inner_list)] for inner_list in nested_list]

def apply_nl_tag(nested_list: list[list[str]]) -> list[list[str]]:
    result_list: list[list[str]] = []

    for i, inner_list in enumerate(nested_list):
        result_list.append(inner_list)
        result_list.append([NEWLINE_TAG])

    return result_list

Get original sample.

In [None]:
sample = taiga_proza_russian[0]['text_cleaned']; sample[:512]

Split it into sentences and words.

In [None]:
sentence_and_word_tokenization(sample)[:1]

Arabize words in each sentence. Keep order of sentences, but reverse words order in each sentence.

In [None]:
arabize_nested_list(sentence_and_word_tokenization(sample))[:1]

Apply special tag after each sentence. Which will indicate end of the sentence.

```python
'<|newline|>'

In [None]:
apply_nl_tag(arabize_nested_list(sentence_and_word_tokenization(sample)))

#### Split into syllables & add stresses

In [None]:
import itertools

from rusyllab import rusyllab
from stressrnn import StressRNN

In [None]:
stress_rnn = StressRNN()

In [None]:
def cumulative_sum(numbers):
    return list(itertools.accumulate(numbers))

# Lookup table of regular cyrillic vowels with stressed variants
vowels_mappings = {
"А'": "А́", "Е'": "Е́", "Ё'": "Ё́", "И'": "И́", "Й'": "Й́", "О'": "О́", "У'": "У́", "Ы'": "Ы́", "Э'": "Э́", "Ю'": "Ю́", "Я'": "Я́",
"а'": "а́", "е'": "е́", "ё'": "ё́", "и'": "и́", "й'": "й́", "о'": "о́", "у'": "у́", "ы'": "ы́", "э'": "э́", "ю'": "ю́", "я'": "я́",   
"А": "А́", "Е": "Е́", "Ё": "Ё́", "И": "И́", "Й": "Й́", "О": "О́", "У": "У́", "Ы": "Ы́", "Э": "Э́", "Ю": "Ю́", "Я": "Я́",   
"а": "а́", "е": "е́", "ё": "ё́", "и": "и́", "й": "й́", "о": "о́", "у": "у́", "ы": "ы́", "э": "э́", "ю": "ю́", "я": "я́",
}

vowels = "аеёийоуыэюя"

In [None]:
def split_into_syllables_add_stresses (tokens: list[str]) -> list[list[str]]:
    """
    Preprocess tokens into syllables with stresses. ['Парадокс'] -> ['Па', 'ра', 'до́кс'].
    
    Parameters
    ----------
    tokens : list
        list of tokens, where each token is a single word. For example, ['Глава', 'первая', 'Михаил', 'Иванович', ...]
    
    Returns
    -------
    string
        return string, where each word separated by | symbol.
    """

    # 01. We will return preprocessed proza as nested list
    preprocessed_tokens: list[list[str]] = []

    # 02. Iterate each word in sentence
    for token in tokens:

        # IMPORTANT
        # Clean empty spaces
        token = token.strip().replace(' ', '')

        # 03. Keep track of:
        # - Each vowel in the word
        # - Index of stressed vowel
        # - Stressed character
        vowel_indexes = {} # {1: 'е', 4: 'а', 5: 'я'}
        stress_vowel_index = None # 1
        stress_vowel_char = None # 'e'

        # 02. Check each character in token
        #     If it's a vowel store it and get it's index
        for i, char in enumerate(token):
            if char.lower() in vowels:
                vowel_indexes[i] = char

        # 03. Find stress in token
        #     Result will be same token but with ' char after vowel
        #     "Глава" -> "Глава'" -- stress ' added after last а
        stressed_token = stress_rnn.put_stress(
            token,
            stress_symbol="'",
            accuracy_threshold=0.5,
            replace_similar_symbols=True
        )

        # If stress is not found, just add word as it is
        if "'" not in stressed_token:
            syllables: list[str] = rusyllab.split_word(token)
            preprocessed_tokens.append(syllables)
            continue

        # 04. Find index of stressed vowel
        #     Store it in variables
        #     Potentially error
        for k, v in vowel_indexes.items():
            if stressed_token[k + 1] == "'":
                stress_vowel_index = k
                stress_vowel_char = v

        # If stress_vowel_index is None, just add as it is
        if stress_vowel_index == None:
            syllables: list[str] = rusyllab.split_word(token)
            preprocessed_tokens.append(syllables)
            continue

        # 05. Break word into syllables
        syllables: list[str] = rusyllab.split_word(token) # 'Парадокс' -> ['Па', 'ра', 'докс']

        # 06. Count len of each syllable part
        #     ['Па', 'ра', 'докс'] -> [len('Па'), len('ра'), len('докс')] -> [2, 2, 4]
        syllables_lens: list[int] = [len(s) for s in syllables]
        
        # 06. Count cumulative sum of each syllable lens
        #     If [i, k, j] it's a len of a syllable, then [i, i + k, i + k + j]
        sum_syllables_lens: list[int] = cumulative_sum(syllables_lens)

        # 07. Store stressed syllables here
        #     We will find what syllable to stress next
        stressed_syllables: list[str] = []

        # Check if stress char is already replaced
        char_is_replased = False

        # 08. Replace regular vowel in splitted word with stress vowel
        #     ['Па', 'ра', 'докс'] -> ['Па', 'ра', 'до́кс']
        for i, (s, sum) in enumerate(zip(syllables, sum_syllables_lens)):

            # 08.1. Add syllable to the array
            stressed_syllables.append(s)

            # 08.2. If index of stressed vowel is bigger than cumulative sum of syllable
            #       Then continue to the next loop iteration. The stressed vowel is not here
            if stress_vowel_index + 1 > sum:
                continue
            else:
                if not char_is_replased:
                    # 08.3. Overwise replace regular vowel with stressed one
                    ss = s.replace(stress_vowel_char, vowels_mappings[stress_vowel_char])
    
                    # 08.4. Update syllable in array
                    stressed_syllables[i] = ss
    
                    char_is_replased = True

        # Reverse syllables ['Па', 'ра', 'до́кс'] -> ['до́кс', 'ра', 'Па']
        reversed_stressed_syllables = stressed_syllables[::-1]
        
        # Add stressed syllables to the array ['до́кс', 'ра', 'Па']
        preprocessed_tokens.append(reversed_stressed_syllables)

    return preprocessed_tokens

Add stresses to text.

In [None]:
split_into_syllables_add_stresses(apply_nl_tag(arabize_nested_list(sentence_and_word_tokenization(sample)))[0])

#### Tokenize & arabize whole dataset

In [None]:
def preprocess_sample(sample):
    tokens = sentence_and_word_tokenization(sample['text_cleaned'])
    tokens = arabize_nested_list(tokens)
    tokens = apply_nl_tag(tokens)
    return {"tokens": tokens}

Apply to whole dataset.

In [None]:
taiga_proza_preprocessed = taiga_proza_russian.map(
    preprocess_sample, batched=False
)

taiga_proza_preprocessed

Check results

In [None]:
taiga_proza_preprocessed[0]['tokens']

Save & load back from. This will reduce RAM usage on later steps.

In [None]:
taiga_proza_preprocessed.save_to_disk(str(datasets_path/'taiga-corpus-full-russian-cleaned-preprocessed'))
taiga_proza_preprocessed = load_from_disk(str(datasets_path/'taiga-corpus-full-russian-cleaned-quarter-preprocessed'))

#### Break tokens into syllables & add stresses to whole dataset

In [None]:
def syllabize_and_stress_words (tokens: list[str]) -> str:
    # 01. Split each token into syllables, add stresses
    #     ['Парадокс'] -> ['до́кс', 'ра', 'Па']
    syllabized = split_into_syllables_add_stresses(tokens)

    # 02. Join each word with | symbol
    joined_sentence = ' | '.join([' '.join(syls) for syls in syllabized])

    return joined_sentence

def syllabize_and_stress_sample(sample):
    syllables = ' '.join([syllabize_and_stress_words(tokens) if tokens != [NEWLINE_TAG] else ''.join(tokens) for tokens in sample['tokens']])
    return {"syllables": syllables}

In [None]:
syllabize_and_stress_sample(taiga_proza_preprocessed[0])['syllables']

Preprocess whole dataset

In [None]:
taiga_proza_stressed = taiga_proza_preprocessed.map(
    syllabize_and_stress_sample, batched=False
)
taiga_proza_stressed

Add quatrain tag after each 4th newline tag. This will indicates a quatrain. Later it will help model to write poetry with quatrains.

In [None]:
import itertools
from functools import partial

QUATRAIN_TAG = '<|quatrain|>'

apply_quatrain_tag = partial(re.sub, '(<\|newline\|>)', lambda m, c = itertools.count(start=1): m.group() if next(c) % 4 else NEWLINE_TAG + QUATRAIN_TAG)

apply_quatrain_tag(taiga_proza_stressed[0]['syllables'])

In [None]:
def apply_quatrain_sample(sample):
    return {"quatrained": apply_quatrain_tag(sample['syllables'])}

taiga_proza_stressed = taiga_proza_stressed.map(
    apply_quatrain_sample, batched=False
)
taiga_proza_stressed

In [None]:
taiga_proza_stressed.save_to_disk(str(datasets_path/'taiga-corpus-full-russian-cleaned-quarter-preprocessed-stressed'))

## Create Tokenizer From Scratch

https://huggingface.co/learn/nlp-course/chapter6/8?fw=pt

[add_special_tokens.example. Add special tokens to already pretrained model or tokenizer](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.add_special_tokens.example)

#### Create BPE Tokenizer

GPT-2 compatible.

In [None]:
from tokenizers import (
    decoders,
    models,
    normalizers,
    pre_tokenizers,
    processors,
    trainers,
    Tokenizer,
)

tokenizer = Tokenizer(models.BPE())

In [None]:
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)

In [None]:
tokenizer.pre_tokenizer.pre_tokenize_str(taiga_proza_stressed[0]['quatrained'])

In [None]:
def get_training_corpus(num_batches = 10):
    return (
        taiga_proza_stressed[i : i + num_batches]["quatrained"]
        for i in range(0, len(taiga_proza_stressed), num_batches)
    )

training_corpus = get_training_corpus()
training_corpus

#### Setup BPE Tokenizer trainer

https://huggingface.co/docs/tokenizers/api/trainers#tokenizers.trainers.BpeTrainer

Include stressed and regular vowels as special tokens. To treat them as single letter.

In [None]:
trainer = trainers.BpeTrainer(
    vocab_size=52000,
    min_frequency=2,
    special_tokens=[
        "<|endoftext|>",
        "А́", "Е́", "Ё́", "И́", "Й́", "О́", "У́", "Ы́", "Э́", "Ю́", "Я́",
        "а́", "е́", "ё́", "и́", "й́", "о́", "у́", "ы́", "э́", "ю́", "я́",
        "А", "Е", "Ё", "И", "Й", "О", "У", "Ы", "Э", "Ю", "Я",
        "а", "е", "ё", "и", "й", "о", "у", "ы", "э", "ю", "я",
        NEWLINE_TAG,
        QUATRAIN_TAG,
        "-", "—", ".", ",", ":", ";", "!", "?", "…", "|",
        #"«", "»", # not working as expected
        "„", "“"
    ],
    show_progress=True
)

Train BPE Tokenizer

In [None]:
tokenizer.train_from_iterator(training_corpus, trainer=trainer)

See results

In [None]:
encoding = tokenizer.encode(taiga_proza_stressed[0]['quatrained'][:512])

print(encoding.tokens)

Init ByteLevel decoder

In [None]:
from tokenizers.decoders import ByteLevel

decoder = ByteLevel()

In [None]:
print(decoder.decode(tokenizer.encode(taiga_proza_stressed[0]['quatrained'][:512]).tokens))

Check special stressed characters

In [None]:
print(tokenizer.encode("а́").tokens)

In [None]:
print(decoder.decode(['а́']))

In [None]:
print(decoder.decode(['ĠÐĵÐ»']))

Apply byte-level postprocessing

In [None]:
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)

In [None]:
tokenizer.decoder = decoders.ByteLevel()

Check how it works. Show special tokens, because they are important part of the training.

In [None]:
print(tokenizer.decode(tokenizer.encode(taiga_proza_stressed[0]['quatrained'][:512]).ids, skip_special_tokens=False))

In [None]:
tokenizer.encode('« »').ids

In [None]:
tokenizer.decode([102, 82, 357], skip_special_tokens=False)

Wrap pretrained tokenizer into GPT2 Tokenizer

In [None]:
from transformers import GPT2TokenizerFast

wrapped_tokenizer = GPT2TokenizerFast(tokenizer_object=tokenizer)
wrapped_tokenizer

In [None]:
wrapped_tokenizer.save_pretrained("tokenizers/poetry-generator-punctuation-quatrained")

## Train/valid/test splits

#### Split the data, remove unnecessary colums

In [None]:
from datasets import DatasetDict

# Setup split seed for reproducibility
SPLIT_SEED=42

# Remove colums, rename syllables to text
taiga_proza_stressed = taiga_proza_stressed.remove_columns(['text_cleaned', 'tokens', 'syllables'])
#taiga_proza_stressed = taiga_proza_stressed.rename_column("text", "raw")
taiga_proza_stressed = taiga_proza_stressed.rename_column("quatrained", "text")

raw_datasets = DatasetDict({
    "train": taiga_proza_stressed
})

# 02. Create train/test split
train_test_dataset = raw_datasets['train'].train_test_split(test_size=0.2, seed=SPLIT_SEED)

# 03. Create test/validation split
test_valid_dataset = train_test_dataset['test'].train_test_split(test_size=0.5, seed=SPLIT_SEED)

raw_datasets = DatasetDict({
    "train": train_test_dataset['train'],
    "test": test_valid_dataset['test'],
    "valid": test_valid_dataset['train'],
})
raw_datasets

In [None]:
raw_datasets['train'][0]['text']

Save & load back from disk. To reduce RAM usage later, when tokenization will be applied.

In [None]:
raw_datasets.save_to_disk(str(datasets_path/'taiga-corpus-full-russian-cleaned-quarter-preprocessed-stressed_train_test_valid'))
raw_datasets = load_from_disk(str(datasets_path/'taiga-corpus-full-russian-cleaned-quarter-preprocessed-stressed_train_test_valid'))

## Tokenization

#### Load pretrained tokenizer

In [None]:
from transformers import AutoTokenizer

context_length = 512 # we will train it on 4090 RTX
tokenizer = AutoTokenizer.from_pretrained("tokenizers/poetry-generator-punctuation-quatrained")

outputs = tokenizer(
    raw_datasets["train"][:2]["text"],
    truncation=True,
    max_length=context_length,
    return_overflowing_tokens=True,
    return_length=True,
)

print(f"Input IDs length: {len(outputs['input_ids'])}")
print(f"Input chunk lengths: {(outputs['length'])}")
print(f"Chunk mapping: {outputs['overflow_to_sample_mapping']}")

#### Tokenize whole dataset

In [None]:
def tokenize(element):
    outputs = tokenizer(
        element["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}


tokenized_datasets = raw_datasets.map(
    tokenize, batched=True, remove_columns=raw_datasets["train"].column_names
)
tokenized_datasets

#### Save to disk

In [None]:
tokenized_datasets.save_to_disk(str(datasets_path/'tokenized/taiga-corpus-full-russian-cleaned-quarter-preprocessed-stressed_train_test_valid'))

<hr/>

## Pretrain GPT-2 on Preprocessed Taiga Proza

#### See pretraining script

<hr/>

## Metrics. Perplexity

https://huggingface.co/docs/transformers/perplexity#example-calculating-perplexity-with-gpt-2-in--transformers

#### Do evaluation

Init model & tokenizer

In [None]:
from transformers import GPT2LMHeadModel, AutoTokenizer, GenerationConfig

# Put here correct paths to the models and tokenizer of course
tokenizer = AutoTokenizer.from_pretrained("tokenizers/poetry-generator-lemmatized-arabized-syllabized-stressed-ver2")
model = GPT2LMHeadModel.from_pretrained("models/taiga_proza_5GB_216750_syllables_stresses_arabized_ver2")

model.to(device) # important!

Get tensors. Take 500 samples, for faster inference.

In [None]:
encodings = tokenizer(' '.join(raw_datasets['test'][:500]["text"]), return_tensors="pt")

Compute perplexity.

In [None]:
import torch
from tqdm import tqdm

max_length = model.config.n_positions
stride = 512
seq_len = encodings.input_ids.size(1)

nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
    end_loc = min(begin_loc + max_length, seq_len)
    trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)

        # loss is calculated using CrossEntropyLoss which averages over valid labels
        # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
        # to the left by 1.
        neg_log_likelihood = outputs.loss

    nlls.append(neg_log_likelihood)

    prev_end_loc = end_loc
    if end_loc == seq_len:
        break

ppl = torch.exp(torch.stack(nlls).mean())
ppl

## Fine-Tune GPT-2 on Russian Poetry Corpus

#### See corresponding notebook