#  aitextgen — Train a Custom GPT-2 Model + Tokenizer w/ GPU

by [Max Woolf](https://minimaxir.com)

*Last updated: May 16th, 2021 (aitextgen v0.5.2)*

Train a custom GPT-2 model **for free on a GPU using Colaboratory** using `aitextgen`!

It's recommended to only create a model from scratch if you really need to do so; otherwise, [finetuning 124M](https://colab.research.google.com/drive/15qBZx5y9rdaQSyWpsreMDnTiZ5IlN0zD?usp=sharing) may give you better results.

For more about `aitextgen`, you can visit [this GitHub repository](https://github.com/minimaxir/aitextgen) or [read the documentation](https://docs.aitextgen.io/).


To get started:

1. Copy this notebook to your Google Drive to keep it and save your changes. (File -> Save a Copy in Drive)
2. Run the cells below:


In [1]:
!pip install -q aitextgen
import os
import random
from random import *
from os import environ
import logging
logging.basicConfig(
        format="%(asctime)s — %(levelname)s — %(name)s — %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
    )

from aitextgen import aitextgen
from aitextgen.colab import mount_gdrive, copy_file_from_gdrive
from aitextgen.TokenDataset import TokenDataset, merge_datasets
from aitextgen.utils import build_gpt2_config
from aitextgen.tokenizers import train_tokenizer
from aitextgen.utils import GPT2ConfigCPU

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-io 0.21.0 requires tensorflow-io-gcs-filesystem==0.21.0, which is not installed.
tensorflow 2.6.4 requires absl-py~=0.10, but you have absl-py 1.0.0 which is incompatible.
tensorflow 2.6.4 requires numpy~=1.19.2, but you have numpy 1.21.6 which is incompatible.
tensorflow 2.6.4 requires six~=1.15.0, but you have six 1.16.0 which is incompatible.
tensorflow 2.6.4 requires tensorboard<2.7,>=2.6.0, but you have tensorboard 2.10.0 which is incompatible.
tensorflow 2.6.4 requires typing-extensions<3.11,>=3.7, but you have typing-extensions 4.1.1 which is incompatible.
tensorflow 2.6.4 requires wrapt~=1.12.1, but you have wrapt 1.14.1 which is incompatible.
tensorflow-transform 1.8.0 requires tensorflow!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,<2.9,>=1.15.5, but you have ten

In [2]:

file_name = "../input/yhaoo-answers/New folder (2)/texts_000001.csv"

items = ["01", "02", "03", "04", "05", "06", "07", "08", "09", "10","11","12","13","14","15","16","17","18","19","20","21","22","23"]

x = sample(items,  1)

file_name = "../input/yhaoo-answers/New folder (2)/texts_0000"+(x[0])+".csv"

## GPU

Colaboratory uses a Nvidia P4, an Nvidia T4, or an Nvidia P100 GPU. For finetuning GPT-2 124M, any of these GPUs will be fine, but for text generation, a T4 or a P100 is ideal since they have more VRAM.

You can verify which GPU is active by running the cell below. If you want to try for a different GPU, go to **Runtime -> Factory Reset Runtime**.

## Training the Tokenizer

Now we can train a Byte-Pair Encoding tokenizer on the dataset we just downloaded. The `train_tokenizer()` function wraps the training method for the `tokenizer` package from Huggingface.

After the training is completed, this will save one file: **aitextgen.tokenizer.json**, which is needed to rebuild the tokenizer.

In [3]:
train_tokenizer(file_name)






## Specify a Model Configuration

You can use `build_gpt2_config` to specify a model configuration. You most likely will want to adjust `max_length` (context window size) and `n_embd` (embedding size).

The config used here is the one used to build a [demo Reddit](https://github.com/minimaxir/aitextgen/blob/master/notebooks/reddit_demo.ipynb) model.

In [4]:
config = build_gpt2_config(vocab_size=50400, max_length=256, dropout=0.0, n_embd=256, n_layer=28, n_head=16,repetition_penalty=1)
config

GPT2Config {
  "activation_function": "gelu_new",
  "attn_pdrop": 0.0,
  "bos_token_id": 0,
  "embd_pdrop": 0.0,
  "eos_token_id": 0,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 256,
  "n_embd": 256,
  "n_head": 16,
  "n_inner": null,
  "n_layer": 28,
  "n_positions": 256,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.0,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.0,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "transformers_version": "4.18.0",
  "use_cache": true,
  "vocab_size": 50400
}

## Instantiating Your Custom GPT-2 Model

Pass all the information to `aitextgen()` and you're good to go!

In [5]:
config = GPT2ConfigCPU()
ai = aitextgen(config=config,
               tokenizer_file="../input/output1/aitextgen.tokenizer.json",
               to_gpu=True,
               model_folder="../input/output1/trained_model")

Generated output from it will be effectively random, for now.

In [6]:
ai.generate(5)

why are sodium?  the egg is the same as the body. the body is different from the water in the upper function. it is the eggs to loose the sun.
What do you think of the songs?  Should you try to read Yahoo! Answers? I'm a little bitch of Jesus with the rest of the songs for the subject. I'm a group that is
Aren't any BSAssets on the Web?  Many of these are probably notable but the Internet can both be removed. It's not as many sites that are on the holidays.
Are there any currently towns waiting for the best law?  Because they are also going to be available (2005) and they are striking a lot of families.
why does it simply make a walk?  At the time it has to be a maissionary you will think if you are going to be sold to pull out of your sister and you will see it.


import logging
logging.basicConfig(
        format="%(asctime)s — %(levelname)s — %(name)s — %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
    )

from aitextgen import aitextgen
from aitextgen.colab import mount_gdrive, copy_file_from_gdrive
from aitextgen.TokenDataset import TokenDataset, merge_datasets
from aitextgen.utils import build_gpt2_config
from aitextgen.tokenizers import train_tokenizer## Train GPT-2

The next cell will start the actual training of GPT-2 in aitextgen. It runs for `num_steps`, and a progress bar will appear to show training progress, current loss (the lower the better the model), and average loss (to give a sense on loss trajectory).

The model will be saved every `save_every` steps in `trained_model` by default, and when training completes. If you mounted your Google Drive, the model will _also_ be saved there in a unique folder.

The training might time out after 4ish hours; if you did not mount to Google Drive, make sure you end training and save the results so you don't lose them! (if this happens frequently, you may want to consider using [Colab Pro](https://colab.research.google.com/signup))

Important parameters for `train()`:

- **`line_by_line`**: Set this to `True` if the input text file is a single-column CSV, with one record per row. aitextgen will automatically process it optimally.
- **`from_cache`**: If you compressed your dataset locally (as noted in the previous section) and are using that cache file, set this to `True`.
- **`num_steps`**: Number of steps to train the model for.
- **`generate_every`**: Interval of steps to generate example text from the model; good for qualitatively validating training.
- **`save_every`**: Interval of steps to save the model: the model will be saved in the VM to `/trained_model`.
- **`save_gdrive`**: Set this to `True` to copy the model to a unique folder in your Google Drive, if you have mounted it in the earlier cells
- **`batch_size`**: Batch size of the model training; setting it too high will cause the GPU to go OOM. _Unlike finetuning, since you are using a small model, you can massively increase the batch size to normalize the training_.
- **`fp16`**: Enables half-precision training for faster/more memory-efficient training. Only works on a T4 or V100 GPU.

Here are other important parameters for `train()` that are useful but you likely do not need to change.

- **`learning_rate`**: Learning rate of the model training.


In [7]:

environ["TOKENIZERS_PARALLELISM"] = "false"
ai.train(file_name,
         line_by_line=True,
         from_cache=False,
         num_steps=80000 ,
         generate_every=5000,
         save_every=False,
         save_gdrive=False,
         learning_rate=1e-3,
         num_workers=2,
         batch_size=2785,
         max_grad_norm=1,
         )

  0%|          | 0/65535 [00:00<?, ?it/s]

  f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
  f"The `Callback.{hook}` hook was deprecated in v1.6 and"


  0%|          | 0/80000 [00:00<?, ?it/s]

[1m5,000 steps reached: generating sample texts.[0m
What is the legislation of the Latin's lawyer? I am wondering if you can find policies of the reasonable website as well as the lawyer is. However I am not sure if I am not
[1m10,000 steps reached: generating sample texts.[0m
Have you want to get pregnant? Which we need to start to stop bigger? i am a grained girl friend but i have multiply with myself. i have a friend of mine and i have no problem with my life. I am
[1m15,000 steps reached: generating sample texts.[0m
How do I get a diagram pill?  The diagrams are forced to see if it can make you feel bad you have a diet that you can't have. Another big fatch big and can make me feel s
[1m20,000 steps reached: generating sample texts.[0m
can you correct the PowerPower Computers & Fall Installing Firefox?  You have to know if the Meg Bridge Player is required. Anyone can help you? F
[1m25,000 steps reached: generating sample texts.[0m
Is it rather than being switching a baby

You're done! Feel free to go to the **Generate Text From The Trained Model** section to generate text based on your retrained model.


## Load a Trained Model

Running the next cell will copy the `pytorch_model.bin`, `config.json`, `aitextgen_vocab.json`, and `aitextgen_merges.json` files from the specified folder in Google Drive into the Colaboratory VM. (If no `from_folder` is specified, it assumes the two files are located at the root level of your Google Drive)

The next cell will allow you to load the retrained model + metadata necessary to generate text.

## Generate Text From The Trained Model

After you've trained the model or loaded a retrained model from checkpoint, you can now generate text.

**If you just trained a model**, you'll get much faster training performance if you reload the model; the next cell will reload the model you just trained from the `trained_model` folder.

In [8]:
ai = aitextgen(model_folder="./trained_model",
               tokenizer_file="aitextgen.tokenizer.json",
               to_gpu=True)

`generate()` without any parameters generates a single text from the loaded model to the console.

In [9]:
ai.generate(n=5,
            batch_size=5,
            prompt="what is google in simple words?",
            temperature=1.0,
            top_p=0.9)

[1mwhat is google in simple words?[0m  Speed editors is in a simple thing. Good luck
[1mwhat is google in simple words?[0m  On the internet so it will not get to the page to you. \n\nThe type of question is the question.
[1mwhat is google in simple words?[0m  I have been hoping in Atlanta. Today I need to know what to do with their emotional medication. See the link to the local creation
[1mwhat is google in simple words?[0m  Ebay
[1mwhat is google in simple words?[0m  Not all those sounds similar to how the system is and how much it are. There are different faster to get to sending or other pictures in SOOME


If you're creating an API based on your model and need to pass the generated text elsewhere, you can do `text = ai.generate_one()`

You can also pass in a `prompt` to the generate function to force the text to start with a given character sequence and generate text from there (good if you add an indicator when the text starts).

You can also generate multiple texts at a time by specifing `n`. You can pass a `batch_size` to generate multiple samples in parallel, giving a massive speedup (in Colaboratory, set a maximum of 50 for `batch_size` to avoid going OOM).

Other optional-but-helpful parameters for `ai.generate()` and friends:

*  **`max_length`**: Number of tokens to generate (default 256, you can generate up to 1024 tokens with GPT-2, but it will be _much_ slower)
* **`temperature`**: The higher the temperature, the crazier the text (default 0.7, recommended to keep between 0.7 and 1.0)
* **`top_k`**: Limits the generated guesses to the top *k* guesses (default 0 which disables the behavior; if the generated output is super crazy, you may want to set `top_k=40`)
* **`top_p`**: Nucleus sampling: limits the generated guesses to a cumulative probability. (gets good results on a dataset with `top_p=0.9`)

In [10]:
ai.generate(n=5,
            batch_size=5,
            prompt="what is google in simple terms?",
            temperature=1.0,
            top_p=0.9)

[1mwhat is google in simple terms?[0m  It's free to give you a lot of chat.
[1mwhat is google in simple terms?[0m please. Many people are trying to copy it at http://www.convention.com\n\nWhy would i wanna buy it again?????? I know I
[1mwhat is google in simple terms?[0m  I can make them in the past that you have a dialogue to be sure to get the google. I have to use your website by a portion of this website.
[1mwhat is google in simple terms?[0m  i have early 80s.\ni cannot search avoid similar to extra google.
[1mwhat is google in simple terms?[0m  Read a #2 program to solve more details.\n\nRead source!


For bulk generation, you can generate a large amount of texts to a file and sort out the samples locally on your computer. The next cell will generate `num_files` files, each with `n` texts and whatever other parameters you would pass to `generate()`. The files can then be downloaded from the Files sidebar!

You can rerun the cells as many times as you want for even more generated texts!

In [11]:
num_files = 4

for _ in range(num_files):
  ai.generate_to_file(n=1000,
                     batch_size=100,
                     prompt="what is google in simple terms?",
                     temperature=1.0,
                     top_p=0.9)

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

# LICENSE

MIT License

Copyright (c) 2020-2021 Max Woolf

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.