# Train a GPT model from scratch

In [None]:
# in this notebook we'll only get one of the files (the Oscar one) for the sake of simplicity and performance
#!wget -c https://cdn-datasets.huggingface.co/EsperBERTo/data/oscar.eo.txt

## Create dataset

In [None]:
%%time
from transformers import LineByLineTextDataset

paths = ['oscar.eo.txt']

dataset = LineByLineTextDataset(
    tokenizer=fs_tokenizer,
    file_path=paths[0],
    block_size=128,
)

## Train a tokenizer

In [None]:
%%time 

from tokenizers import ByteLevelBPETokenizer

# Initialize a tokenizer
tokenizer = ByteLevelBPETokenizer()

# Customize training
tokenizer.train(files=paths, vocab_size=52_000, min_frequency=2, special_tokens=[
    "<s>",
    "<pad>",
    "</s>",
    "<unk>",
])

In [None]:
#!mkdir GPT2
tokenizer.save_model("./GPT2")

## Train a model

### Load Tokenizer

In [None]:
from transformers import GPT2TokenizerFast

fs_tokenizer = GPT2TokenizerFast.from_pretrained("./GPT2", max_len=512)

fs_tokenizer.pad_token = fs_tokenizer.eos_token

### Initiate model

In [None]:
from transformers import GPT2LMHeadModel, GPT2Config

config = GPT2Config()
config.pad_token_id = config.eos_token_id
model = GPT2LMHeadModel(config)

### Define Trainer

In [None]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=fs_tokenizer, mlm=False
)

In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./GPT2",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=64,
    save_steps=10_000,
    save_total_limit=2,
    prediction_loss_only=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)

### Display logs in wandb

In [None]:
import wandb
wandb.login()

In [None]:
%env WANDB_PROJECT=GPT2_train_demo

### Train model

In [None]:
%%time
trainer.train()

In [None]:
trainer.save_model("./GPT2")

In [None]:
from transformers import pipeline

text_gen = pipeline(
    "text-generation",
    model="./GPT2",
    tokenizer="./GPT2"
)

In [None]:
text_gen("ABC")