# Chapter 12 Training A Transformer to Generate Text

This chapter covers

* Building a scaled-down version of the GPT-2XL model tailored to your needs
* Preparing training data for training a GPT-style Transformer
* Training a GPT-style Transformer from scratch
* Generating text using the trained GPT model

In Chapter 11, we developed the GPT-2XL model from scratch but were unable to train it due to its vast number of parameters. Training a model with 1.5 billion parameters requires supercomputing facilities and an enormous amount of data. Consequently, we loaded pre-trained weights from OpenAI into our model and then used the GPT-2XL model to generate text.

However, Learning how to train a Transformer model from scratch is crucial for several reasons. First, while this book doesn't directly cover fine-tuning a pre-rained model, understanding how to train a Transformer equips you with the skills needed for fine-tuning. Training a model involves initializing parameters randomly, whereas fine-tuning involves loading pre-trained weights and further training the model. Second, training or fine-tuning a Transformer enables you to customize the model to meet your specific needs and domain, which can significantly enhance its performance and relevance for your use case. Finally, training your own Transformer or fine-tuning an existing one provides greater control over data and privacy, which is particularly important for sensitive applications or handling proprietary data. In summary, mastering the training and fine-tuning of Transformers is essential for anyone looking to harness the power of language models for specific applications while maintaining privacy and control. 

Therefore, in this chapter, we’ll construct a scaled-down version of the GPT model with approximately five million parameters. This smaller model follows the architecture of the GPT-2XL model, with significant differences being its composition of only three decoder blocks and an embedding dimension of 256, compared to the original's 48 decoder blocks and an embedding dimension of 1600. By scaling down the GPT model to about 5 million parameters, we can train it on a regular computer. 

The generated text's style will depend on the training data. When training a model from scratch for text generation, both text length and variation are crucial. The training material must be extensive enough for the model to learn and mimic a particular writing style effectively. At the same time, if the training material lacks variation, the model may simply replicate passages from the training text. On the other hand, if the material is too long, training may require excessive computational resources. Therefore, we will use three novels by Ernest Hemingway as our training material: The Old Man and the Sea, A Farewell to Arms, and For Whom the Bell Tolls. This selection ensures that our training data has sufficient length and variation for effective learning, without being so long that training becomes impractical.

Since GPT models cannot process raw text directly, we will first tokenize the text into words. We will then create a dictionary to map each unique token to an index. Using this dictionary, we will convert the text into a long sequence of integers, ready for input into a neural network.

We will use sequences of 128 indexes as input to train the GPT model. As in Chapters 8 and 10, we will shift the input sequence by one token to the right and use it as the output. This approach forces the model to predict the next word in a sentence based on the current token and all previous tokens in the sequence.

A key challenge is determining the optimal number of epochs for training the model. Our goal is not merely to minimize the cross-entropy loss, as doing so could lead to overfitting, where the model simply replicates passages from the training text. To tackle this issue, we plan to train the model for 40 epochs. We will save the model at ten-epoch intervals and evaluate which version can generate coherent text without merely copying passages from the training material. Alternatively, one could potentially use a validation set to assess the performance of the model and decide when to stop training, as we did in Chapter 2. 

Once our GPT model is trained, we will use it to generate text autoregressively, as we did in Chapter 11. We’ll test different versions of the trained model. The model trained for 40 epochs produces very coherent text, capturing Hemingway's distinctive style. However, it may also generate text partly copied from the training material, especially if the prompt is similar to passages in the training text. The model trained for 20 epochs also generates coherent text, albeit with occasional grammatical errors, but is less likely to directly copy from the training text.

The primary goal of this chapter is not necessarily to generate the most coherent text possible, which presents significant challenges. Instead, our objective is to teach you how to build a GPT-style model from scratch, tailored to real-world applications and your specific needs. More importantly, this chapter outlines the steps involved in training a GPT model from scratch. You will learn how to select training text based on your objectives, tokenize the text and convert it to indexes, and prepare batches of training data. You will also learn how to determine the number of epochs for training. Once the model is trained, you will learn how to generate text using the model and how to avoid generating text directly copied from the training material.

# 1.	How to build and train a GPT from scratch?
# 2.	Tokenize text of Hemingway novels
## 2.1. 	Tokenize the text

In [1]:
with open("files/OldManAndSea.txt","r", encoding='utf-8-sig') as f:
    text=f.read()
text=list(text)    #A
for i in range(len(text)):
    if text[i]=='"':
        if text[i+1]==' ' or text[i+1]=='\n':
            text[i]='”'    #B
        if text[i+1]!=' ' and text[i+1]!='\n':
            text[i]='“'    #C
    if text[i]=="'":
        if text[i-1]!=' ' and text[i-1]!='\n':
            text[i]='’'    #D   
text="".join(text)    #E

In [2]:
with open("files/ToWhomTheBellTolls.txt","r", encoding='utf-8-sig') as f:
    text1=f.read()    #A

with open("files/FarewellToArms.txt","r", encoding='utf-8-sig') as f:
    text2=f.read()    #B

text=text+" "+text1+" "+text2    #C

with open("files/ThreeNovels.txt","w", 
          encoding='utf-8-sig') as f:
    f.write(text)    #D
print(text[:250])

He was an old man who fished alone in a skiff in the Gulf Stream and he
had gone eighty-four days now without taking a fish.  In the first
forty days a boy had been with him.  But after forty days without a
fish the boy’s parents had told him that th


In [3]:
text=text.lower().replace("\n", " ")
chars=set(text.lower())
punctuations=[i for i in chars if i.isalpha()==False
              and i.isdigit()==False]
print(punctuations)

for x in punctuations:
    text=text.replace(f"{x}", f" {x} ")
text_tokenized=text.split()

unique_tokens=set(text_tokenized)
print(len(unique_tokens))

[')', '.', '&', ':', '(', ';', '-', '!', '“', ' ', '‘', '”', '?', ',', '’']
10599


In [4]:
from collections import Counter   

word_counts=Counter(text_tokenized)    
words=sorted(word_counts, key=word_counts.get,
                      reverse=True)     
words.append("UNK")    #A 
text_length=len(text_tokenized)
ntokens=len(words)    #B
print(f"the text contains {text_length} words")
print(f"there are {ntokens} unique tokens")  
word_to_int={v:k for k,v in enumerate(words)}    #C 
int_to_word={v:k for k,v in word_to_int.items()}    #D
print({k:v for k,v in word_to_int.items() if k in words[:10]})
print({k:v for k,v in int_to_word.items() if v in words[:10]})

the text contains 698207 words
there are 10600 unique tokens
{'.': 0, 'the': 1, ',': 2, '“': 3, '”': 4, 'and': 5, 'i': 6, 'to': 7, 'he': 8, 'it': 9}
{0: '.', 1: 'the', 2: ',', 3: '“', 4: '”', 5: 'and', 6: 'i', 7: 'to', 8: 'he', 9: 'it'}


In [5]:
print(text_tokenized[0:20])
wordidx=[word_to_int[w] for w in text_tokenized]  
print([word_to_int[w] for w in text_tokenized[0:20]])

['he', 'was', 'an', 'old', 'man', 'who', 'fished', 'alone', 'in', 'a', 'skiff', 'in', 'the', 'gulf', 'stream', 'and', 'he', 'had', 'gone', 'eighty']
[8, 16, 98, 110, 67, 85, 6052, 314, 14, 11, 1039, 14, 1, 3193, 507, 5, 8, 25, 223, 3125]


## 2.2	Create batches for training

In [6]:
import torch

seq_len=128  
xys=[]
for n in range(0, len(wordidx)-seq_len-1):
    x = wordidx[n:n+seq_len]
    y = wordidx[n+1:n+seq_len+1]
    xys.append((torch.tensor(x),(torch.tensor(y))))

In [7]:
from torch.utils.data import DataLoader

torch.manual_seed(42)
batch_size=32
loader = DataLoader(xys, batch_size=batch_size, shuffle=True)

x,y=next(iter(loader))
print(x)
print(y)
print(x.shape,y.shape)

tensor([[   3,  129,    9,  ...,   11,  251,   10],
        [   5,   41,   32,  ...,  995,   52,   23],
        [   6,   25,   11,  ...,   15,    0,   24],
        ...,
        [1254,    0,    4,  ...,   15,    0,    3],
        [  17,    8, 1388,  ...,    0,    8,   16],
        [  55,   20,  156,  ...,   74,   76,   12]])
tensor([[ 129,    9,   23,  ...,  251,   10,    1],
        [  41,   32,   34,  ...,   52,   23,    1],
        [  25,   11,   59,  ...,    0,   24,   25],
        ...,
        [   0,    4,    3,  ...,    0,    3,   93],
        [   8, 1388,    1,  ...,    8,   16, 1437],
        [  20,  156,  970,  ...,   76,   12,   29]])
torch.Size([32, 128]) torch.Size([32, 128])


# 3	Build a GPT to generate text
## 3.1	Model the causal self-attention mechanism

In [8]:
import torch
from torch import nn
import math

device="cuda" if torch.cuda.is_available() else "cpu"
class GELU(nn.Module):
    def forward(self, x):
        return 0.5*x*(1.0+torch.tanh(math.sqrt(2.0/math.pi)*\
                       (x + 0.044715 * torch.pow(x, 3.0))))

In [9]:
class Config():
    def __init__(self):
        self.n_layer = 3
        self.n_head = 4
        self.n_embd = 256
        self.vocab_size = ntokens
        self.block_size = 128 
        self.embd_pdrop = 0.1
        self.resid_pdrop = 0.1
        self.attn_pdrop = 0.1
        
# instantiate a Config() class
config=Config()

In [10]:
import torch.nn.functional as F
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        self.register_buffer("bias", torch.tril(torch.ones(\
                   config.block_size, config.block_size))
             .view(1, 1, config.block_size, config.block_size))
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size() 
        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        hs = C // self.n_head
        k = k.view(B, T, self.n_head, hs).transpose(1, 2) 
        q = q.view(B, T, self.n_head, hs).transpose(1, 2) 
        v = v.view(B, T, self.n_head, hs).transpose(1, 2) 

        att = (q @ k.transpose(-2, -1)) *\
            (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, \
                              float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v 
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y

## 3.2	Build the GPT model

In [11]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.ModuleDict(dict(
            c_fc   = nn.Linear(config.n_embd, 4 * config.n_embd),
            c_proj = nn.Linear(4 * config.n_embd, config.n_embd),
            act    = GELU(),
            dropout = nn.Dropout(config.resid_pdrop),
        ))
        m = self.mlp
        self.mlpf=lambda x:m.dropout(m.c_proj(m.act(m.c_fc(x)))) 

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlpf(self.ln_2(x))
        return x


In [12]:
class Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.block_size = config.block_size
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.embd_pdrop),
            h = nn.ModuleList([Block(config) 
                               for _ in range(config.n_layer)]),   
            ln_f = nn.LayerNorm(config.n_embd),))
        self.lm_head = nn.Linear(config.n_embd,
                                 config.vocab_size, bias=False)      
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):    
                torch.nn.init.normal_(p, mean=0.0, 
                  std=0.02/math.sqrt(2 * config.n_layer))
    def forward(self, idx, targets=None):
        b, t = idx.size()
        pos = torch.arange(0,t,dtype=torch.long).unsqueeze(0).to(device)
        tok_emb = self.transformer.wte(idx) 
        pos_emb = self.transformer.wpe(pos) 
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        return logits


In [13]:
model=Model(config)
model.to(device)
num=sum(p.numel() for p in model.transformer.parameters())
print("number of parameters: %.2fM" % (num/1e6,))
print(model)

number of parameters: 5.12M
Model(
  (transformer): ModuleDict(
    (wte): Embedding(10600, 256)
    (wpe): Embedding(128, 256)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-2): 3 x Block(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=256, out_features=768, bias=True)
          (c_proj): Linear(in_features=256, out_features=256, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): ModuleDict(
          (c_fc): Linear(in_features=256, out_features=1024, bias=True)
          (c_proj): Linear(in_features=1024, out_features=256, bias=True)
          (act): GELU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
 

# 4	Train the GPT model to generate text
## 4.1	Train the GPT model

In [14]:
lr=0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_func = nn.CrossEntropyLoss()

In [15]:
model.train()  
for i in range(1,41):
    tloss = 0.
    for idx, (x,y) in enumerate(loader):
        x,y=x.to(device),y.to(device)
        output = model(x)
        loss=loss_func(output.view(-1,output.size(-1)),
                           y.view(-1))
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(),1)
        optimizer.step()
        tloss += loss.item()
    print(f'epoch {i} loss {tloss/(idx+1)}') 
    if i%10==0:
        torch.save(model.state_dict(),f'files/GPTe{i}.pth') 

## 4.2	A function to generate text

In [16]:
def sample(idx, weights, max_new_tokens, temperature=1.0, top_k=None):
    model.eval()
    model.load_state_dict(torch.load(weights))
    # keep track of the length of the original indexes
    original_length=len(idx[0])
    # add a fixed number of tokens to prompt
    for _ in range(max_new_tokens):
        # if the text is more than 1024 tokenx, trim it
        if idx.size(1) <= config.block_size:
            idx_cond = idx  
        else:
            idx_cond = idx[:, -config.block_size:]
        # predict the logits for the index in sequence
        logits = model(idx_cond.to(device))
        # pluck the logits at the final step; apply temperature 
        logits = logits[:, -1, :] / temperature
        # crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to get probabilities
        probs = F.softmax(logits, dim=-1)
        idx_next=torch.multinomial(probs,num_samples=1)
        idx = torch.cat((idx, idx_next.cpu()), dim=1)
    # keep only new tokens
    return idx[:, original_length:]  

In [17]:
UNK=word_to_int["UNK"]
def generate(prompt, weights, max_new_tokens, temperature=1.0,
             top_k=None):
    assert len(prompt)>0, "prompt must contain at least one token"
    text=prompt.lower().replace("\n", " ")
    for x in punctuations:
        text=text.replace(f"{x}", f" {x} ")
    text_tokenized=text.split() 
    idx=[word_to_int.get(w,UNK) for w in text_tokenized]
    idx=torch.LongTensor(idx).unsqueeze(0)
    # add a fixed number of tokens to prompt
    idx=sample(idx, weights, max_new_tokens, temperature=1.0, top_k=None)
    # convert indexes to text
    tokens=[int_to_word[i] for i in idx.squeeze().numpy()] 
    text=" ".join(tokens)
    for x in '''”).:;!?,-‘’''':
        text=text.replace(f" {x}", f"{x}") 
    for x in '''“(-‘’''':
        text=text.replace(f"{x} ", f"{x}")     
    return prompt+" "+text

# 4.3	Text generation with different versions of the trained model

In [18]:
prompt="UNK"
for i in range(10):
    torch.manual_seed(i)
    print(generate(prompt,'files/GPTe20.pth',max_new_tokens=20)[4:])
    print("-"*50)

way.” “kümmel,” i said. “it’s the way to talk about it
--------------------------------------------------
,” robert jordan said. “but do not realize how far he is ruined.” “pero
--------------------------------------------------
in the fog, robert jordan thought. and then, without looking at last, so good, he
--------------------------------------------------
pot of yellow rice and fish and the boy loved him. “no,” the boy said.
--------------------------------------------------
the line now. it’s wonderful.” “he’s crazy about the brave.”
--------------------------------------------------
candle to us. “and if the maria kisses thee again i will commence kissing thee myself. it
--------------------------------------------------
?” “do you have to for the moment.” robert jordan got up and walked away in
--------------------------------------------------
. a uniform for my father, he thought. i’ll say them later. just then he
--------------------------------------------------
and more pra

In [19]:
prompt="UNK"
for i in range(10):
    torch.manual_seed(i)
    print(generate(prompt,'files/GPTe40.pth',max_new_tokens=20)[4:])
    print("-"*50)

way.” “kümmel, and i will enjoy the killing. they must have brought me a spit
--------------------------------------------------
,” robert jordan said. “but do not tell me that he saw anything.” “not
--------------------------------------------------
in the first time he had bit the ear like that and held onto it, his neck and jaws
--------------------------------------------------
pot of yellow rice with fish. it was cold now in the head and he could not see the
--------------------------------------------------
the line of his mouth. he thought.” “the laughing hurt him.” “i can
--------------------------------------------------
candle made? that was the worst day of my life until one other day.” “don’
--------------------------------------------------
?” “do you have to for the moment.” robert jordan took the glasses and opened the
--------------------------------------------------
. that’s what they don’t marry.” i reached for her hand. “don
-----------------------------------------

In [20]:
# Answer to exercis 12.1
prompt="UNK"
torch.manual_seed(42)
print(generate(prompt,'files/GPTe10.pth',max_new_tokens=50)[4:])

. i know that the doctor was a doctor who would stay with a beard and wore a red scar across the room. “how is she in the legs?” “she could not come,” i said. “you can go too. if


In [21]:
prompt="the old man saw the shark near the"
for i in range(10):
    torch.manual_seed(i)
    print(generate(prompt,'files/GPTe40.pth',max_new_tokens=20))
    print("-"*50)

the old man saw the shark near the old man’s head with his tail out and the old man hit him squarely in the center of
--------------------------------------------------
the old man saw the shark near the boat with one hand. he had no feeling of the morning but he started to pull on it gently
--------------------------------------------------
the old man saw the shark near the old man’s head. then he went back to another man in and leaned over and dipped the
--------------------------------------------------
the old man saw the shark near the fish now, and the old man was asleep in the water as he rowed he was out of the
--------------------------------------------------
the old man saw the shark near the boat. it was a nice-boat. he saw the old man’s head and he started
--------------------------------------------------
the old man saw the shark near the boat to see him clearly and he was afraid that he was higher out of the water and the old
-------------------------------------------

In [22]:
prompt="the old man saw the shark near the"
for i in range(10):
    torch.manual_seed(i)
    print(generate(prompt,'files/GPTe20.pth',max_new_tokens=20,
                  temperature=0.9,top_k=50))
    print("-"*50)

the old man saw the shark near the boat. then he swung the great fish that was more comfortable in the sun. the old man could
--------------------------------------------------
the old man saw the shark near the boat with one hand. he wore his overcoat and carried the submachine gun muzzle down, carrying it in
--------------------------------------------------
the old man saw the shark near the boat with its long dip sharply and the old man stabbed him in the morning. he could not see
--------------------------------------------------
the old man saw the shark near the fish that was now heavy and long and grave he had taken no part in. he was still under
--------------------------------------------------
the old man saw the shark near the boat. it was a nice little light. then he rowed out and the old man was asleep over
--------------------------------------------------
the old man saw the shark near the boat to come. “old man’s shack and i’ll fill the water with him in
--------------

In [23]:
# answer to exercise 12.2
prompt="the old man saw the shark near the"
torch.manual_seed(42)
print(generate(prompt,'files/GPTe40.pth',max_new_tokens=50,
                  temperature=0.95,top_k=100))

the old man saw the shark near the old man’s head. then he went back inside the skiff and rested on the line as he leaned over the side-side and washed the flying fish in the water, noting the speed of the water against his hand. his hand was phosphorescent from
