Skip to content

l4rz/gpt-2-training

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

26 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Training GPT-2 on a Russian language corpus

Disclaimer: Neither me nor this repo is associated in any way with OpenAI. I did my DYOR to the best of my ability, nevertheless I might be completely wrong about anything expressed below.

TL;DR

I've trained a large GPT-2 (1.25B parameters) on a pretty diverse Russian press corpus (~4Gb), achieved a training loss of 2.42 and liked the results. Trained model is available for download.

Sample, 143k steps

Table of Contents

  1. Quick start
  2. Training environment
  3. Dataset preparation
  4. Experiments
  5. Downloads

1. Quick start

  1. clone nshepperd repo
  2. comment out the if layer == 10: line in model.py for checkpointing to work properly (to save memory)
  3. install Google SentencePiece
  4. use src/encoder_sp.py from this repo (copy to src/ directory)
  5. replace all relevant import encoder with import encoder_sp as encoder" in relevant files (encode.py and sampling scripts)
  6. train the sp tokenizer model using your dataset
  spm_train --character_coverage=1  --model_prefix=sp \
      --vocab_size=50257 --model_type=bpe \
      --user_defined_symbols '<|n|>,<|endoftext|>'
      --max_sentence_length=32768
      --input_sentence_size=10000000
      --input dataset.txt
  1. copy sp.* files to the model directory
  2. encode the dataset using trained sp tokenizer model
  mkdir /tmp/spencode
  spm_encode --model="models/1558M/sp.model" \
      --output_format=id < dataset.txt | \
      split --lines=100000 --additional-suffix=.ids \
      - /tmp/spencode/part$(printf %05d $i)
  PYTHONPATH=src ./encode.py --model_name="1558M" \
      /tmp/spencode/ dataset.npz
  rm -rf /tmp/spencode
  1. put the proper hparams.json in the model directory (since we are training the model from scratch, we do not need checkpoint files, etc)
  2. initialize the model by running sess.run(tf.global_variables_initializer()) instead of saver.restore(sess, ckpt) (sorry i haven't bothered to make a separate script)
  3. proceed with warmup and training

2. Training environment

I used the nshepperd implementation of training script. Cybertronai gradient-checkpointing was used in order to save GPU RAM.

I've employed Google SentencePiece tokenizer (it's pointless to try using the original vocab.bpe since it's not aware of Cyrillic subwords).

For multi-GPU distributed training, I've utilized Horovod framework. Using Horovod is basically as easy as wrapping your optimizer in hvd.DistributedOptimizer, as well as adding proper initialization and variable broadcast calls.

import horovod.tensorflow as hvd
...
hvd.init()
...
config.gpu_options.visible_device_list = str(hvd.local_rank())
...
bcast = hvd.broadcast_global_variables(0).
...
opt = tf.train.AdamOptimizer(decayed_lr)
opt = hvd.DistributedOptimizer(opt)

To run a Horovod cluster with two or more servers, one needs to configure ssh so the user running training script can ssh without password to the slave server. To minimize the latency the servers should be interconnected with 1Gbps or faster link.

Training with large models (>1B parameters) involves a quite long initialization, so I found it beneficial to increase the startup timeout in horovodrun:

horovodrun --start-timeout 600 \
-np 4 -H localhost:4 python3 train-1250M.py --dataset dataset.npz

Running four workers on a single machine utilized almost 200Gb of DRAM with large models.

3. Dataset preparation

Getting a large enough corpus of Russian text is quite simple, for example, there is a 568Gb one on Oscar. However corpora like this are unsuitable for training of unsupervised language models in real life because of quality. One needs a fairy clean collection of quality articles. While preparing the WebText dataset, OpenAI did a clever trick of outsourcing text cleaning to Reddit users.

I scraped a couple of Russian press sites, parsed HTML with beautifulsoup4 and saved parsed texts as well as metadata (headers, TL;DRs, timestamps) for further sorting and postprocessing in PKLs.

(Instead of scrapping, one can use the Taiga 1 dataset. Unfortunately I found it after I've already assembled my own one)

For the beginning, I got rid of texts that had a significant percentage on non-Cyrillic characters. I've also discarded texts with cruel and unusual formatting (tables, programming code) as well as repetitive ones (stock market reports, weather, sports) and too boring (official documents). Tabs, spaces and dashes were normalized. Hashtags and weird glyphs were filtered out too. Texts shorter than 1024 bytes were discarded.

Text paragraphs were separated with newlines (\n) and <|n|> tokens. Each text fragment was suffixed by <|endoftext|> token.

Overall, a lot of effort has been put into cleaning the dataset. Having a strictly monolingual dataset is a particular privilege of English; modern Russian texts always include some percent of Latin (English) proper nouns such as persons' and companies names, social media accounts, quotes, etc.

I've ended up with two datasets, ~2Gb and ~4Gb ones. These figures were much smaller than 50Gb WebText dataset, nevertheless I've considered these datasets diverse enough; moreover, they should've worked for my plan (overfit the model on a smaller dataset and then add more data).

After sentencepiece encoding, the ~2Gb dataset became a ~211M tokens one. This means that compression ratio of bytes to BPE tokens is around 9:1, or 4.5:1 in characters taking UTF-8 into account. This ratio is much higher compared to vocab.bpe used with the original GPT-2.

Because I experimented with encoding using various sentencepiece models, I found it useful to add the last digits of md5sum of sp.model to the encoded datasets, snapshots and samples file names to avoid confusion.

3. Experiments

Smaller models

Previously I tried to train GPT-2 using Russian dataset once the 117M model has been released. I only had a 1080ti at my disposal at this time so I've been training with small batch sizes. The most I was able to get was 3.00 loss after 20 epochs.

I've decided to run this experiment again, on Tesla V100s. I've settled on batch size of 24 per worker (192 combined for 8 GPUs).

The model was initialized from scratch and warmed up with LR=10-8 for 1000 steps. Initial LR was 10-4 until 10 epochs then decreasing to 1 x 10-6. Batch size was 24 (this means 192 combined for 8 GPU).

After 25 epochs and 123k steps on a 117M-sized model (first 20 epochs took approximately 150 gpu/hours), I've got training loss of 2.86. The quality of the samples was far from desired. In addition to the usual GPT-2 glitches (such as repeated words in a sentence), the text was less coherent than the English 117M model released by OpenAI.

Reducing the dataset size (from 211M to 150M tokens) and filtering out the remaining English characters did not help much.

The 117M model as of my last run, trained for 123k steps, complete with sp.vocab and dataset used in training, is available for download

Larger models

I've achieved the similar results with larger models:

Model size Duration Training loss
117M 25 epochs, 123k steps 2.86
345M 13 epochs, 507k steps 2.83
774M 20 epochs, ??? steps 2.90

It looks like I've always been hitting this 2.80 floor, something was wrong.

Decreasing max token length in vocab

I noticed that the compression ratio (bytes to tokens) of Russian text encoded to BPE is 2-3 times higher than that of the original GPT-2 vocab.bpe. Observing that a Russian text snippet translated into English varies just 10-15 percent in length, I assumed that the text complexity per 1024 tokens would be much higher for Russian, and this would lead to more perplexed model if this aspect is not addressed.

I tried to decrease the maximum length of the subword fragment by training sentencepiece model with spm_train --max_sentencepiece_length 5. The 211M tokens dataset thus became 315M tokens one. Training a model with this type of encoding basically produced far worse results, though the curve of training loss per epoch was quite steeper and the final training loss was much less compared to the original sentencepiece model (just 2.02 after 4.5 epoch and 12,000 steps). The better the tokenizer performs, the worse the indicated perplexity of the model is.

Addressing the language complexity

Russian grammar is rather complex. Word order in Russian is much more flexible compared to English. Russian uses cases, which means that every noun changes its ending, depending on what function it has in the sentence. Moreover, depending on the case, as well as singular or plural form, not only nouns are changing their endings but also adjectives, pronouns and some other parts of speech.

In order to address the complexity, I tried to increase the capacity of the model. The most logical way 2 seemed to increase the embedding size n_embd parameter that defines the size of both token embeddings (wte) and positional embeddings (wte).

I wanted to increase the value n_embd (AKA Dmodel) after 1600 as it was used in 1558M model but I learned that the number of parameters can quickly grow beyond 2B. A model with { "n_ctx": 1024, "n_embd": 2400, "n_head": 16, "n_layer": 24 } takes 6.8Gb on disk and becomes rather impractical to train on 32Gb GPU. Therefore, I settled on { "n_ctx": 1024, "n_embd": 2000, "n_head": 16, "n_layer": 24 }. Number of layers and attention heads were the same as in 345M model but n_embd was greater compared to 1558M model.

1250M model

The new model with Dmodel=2000 and 1250M parameters (approximately) was initialized and trained with the same 211M dataset.

Training log of 1250M model first training run

With 32Gb of VRAM, I've been able to use batch size of 16 per worker (128 combined for 8 GPUs). Initial LR was 10-4. The complete training log, reflecting LR changes, is available here.

From the training loss of 3.00 (~6 epochs), the samples began to demonstrate consistency and were generally better than the previous run. I've continued training for 310 wallclock hours (accumulated 1800 GPU/hours), 27 epochs and reached training loss of 2.54.

Unconditional generation samples

Conditional generation samples

I've also tested the model's ability to to perform summarization on news articles from the web. Some percentage of news articles in the dataset were salted with ТЛ;ДР: followed by article's summary or headline. Summarization behaviour was induced by Top-k random sampling with k = 2 and providing the model a text to summarize followed by ТЛ;ДР: as conditional input.

interactive_conditional_samples.py --temperature=0.8 --top_k=2 --length=100

A couple of summarization results

Update

Continuing training of this model and substituting 2Gb dataset with 4Gb one (415M tokens), after 80k more steps (+12.5 epochs, additional 1600 GPU/hours) training loss of 2.42 was achieved.

Training log of the 2nd run

Update 2

In order to push things further the 4Gb dataset (415M tokens) was augmented with 3Gb of filtered fanfics, becoming a 7Gb one (862M tokens). Training was continued for 140k more steps (+10.5 epochs, additional 2800 GPU/hours), achieving training loss of 2.64. The final LR was LR=9.75-7.

Sample, 282k steps

Training log of the 3rd run

4. Downloads

Pre-trained models

NOTE: I apologize, while packaging some of pre-trained models I forgot to include the checkpoint file. It's impractical for me to re-upload those rather huge files, so please create an checkpoint file in the model dir, e.g. for the 61k steps one:

cd model_dir
echo model_checkpoint_path: \"model-61000\" > checkpoint`
  1. 117M model trained with 2Gb dataset and sp vocab/model 1.35Gb file

  2. 1250M model trained with 2Gb dataset, 61k steps, training loss 2.54, 4.69Gb file

  3. 1250M model trained with 4Gb dataset, from 61k steps to 143k steps, training loss 2.42 4.69Gb file

  4. 1250M model trained with 7Gb dataset (press and fanfics), from 143k steps to 282k steps, training loss 2.64 4.69Gb file

Written by

l4rz

Footnotes

  1. https://github.com/TatianaShavrina/taiga_site

  2. Other approaches have been tried but ultimately failed.

About

Training GPT-2 on a Russian language corpus

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages