# Supervised Fine Tuning with GPT-Like Transformer
## Language Modelling with Dump of Wikipedia

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split

import math
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
import random
import json
import re
import time
import pickle

from matplotlib import pyplot as plt
#plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['figure.figsize'] = (4, 2)
plt.rcParams['axes.grid'] = True

%load_ext autoreload
%autoreload 2
    
from models.transformer import GPT

from tokenizers import ByteLevelBPETokenizer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'USING DEVICE: {device}')

USING DEVICE: cuda


# Hyperparameters

In [2]:
# Model configuration
sentiment_nano = {
    'features_dim': 64, 
    'num_heads': 8,
    'num_encoder_layers': 2,
    'num_decoder_layers': -1,
    #'ff_dim': 64 * 4,
    'ff_dim': 64 // 16,
    'emb_dropout_prob': 0.1,
    'attn_dropout_prob': 0.0,
    'ff_dropout_prob': 0.1,
    'attn_use_bias': False,
    'ff_use_bias': False,
}

gpt_small = {
    'vocab_size': 50_257,
    'features_dim': 384, 
    'num_heads': 6,
    #'ff_dim': 64 * 4,
    'ff_dim': 384, # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    'num_decoder_layers': 6,
    'emb_dropout_prob': 0.1,
    'attn_dropout_prob': 0.0,
    'ff_dropout_prob': 0.1,
    'attn_use_bias': False,  # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    'ff_use_bias': False,  # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    'vocab_projection_bias': True
}


steplr = {
    'type': 'StepLR',
    #'step_size': 2,
    'step_size': 800,
    #'gamma': 0.80,
    'gamma': 0.99,
}

reduce_lr_on_plateau = {
    'type': 'ReduceLROnPlateau',
    'mode': 'min',
    'factor': 0.1,
    'patience': 3,
    'cooldown': 0,
    'min_lr': 1e-7,
}

# TODO: revise the max_seq_len and context_size
hyperparameters = {
    'seed': 99999,
    'batch_size': 50,
    #'vocab_size': 50_257,
    #'max_seq_len': 256, 
    'context_size': 256,
    'split_ratio': 0.75,
    'num_epochs': 90,
    #'num_training_iters': 5_000,
    #'num_validation_iters': 1_000,
    'optimizer': {
        'learning_rate': 1e-3,
        'momentum': 0.9, # SGD
        'optimizer_betas': (0.9, 0.999), # Adam, AdamW
        'weight_decay': 1e-2, # AdamW
    },
    'clip_grad_norm': 1.0,
    'grad_accum_iter': 4,
    'learning_rate_sched_config': reduce_lr_on_plateau,
    'dataset_dir': '../data/trwiki-20231120-pages-articles/',
    'model_base_name': 'LM_GPTSmall_Wiki_TR',
    'model_config': gpt_small,
}

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False # !!!!!!!!!!!!!!!!!!!!!
    
seed_everything(hyperparameters['seed'])

# Preprocessing
* **UPDATED FOR EMOJIS**

In [4]:
#preprocess_text = lambda x: re.sub(r'[^\w\s]', '', x.lower())
preprocess_text = lambda x: x.lower()

# Tokenizer

In [5]:
from tokenizers import Tokenizer

In [6]:
tokenizer = Tokenizer.from_file('/URL/TO/YOUR/TOKENIZER')

### Template Processing

In [7]:
from tokenizers.processors import TemplateProcessing
#from tokenizers.decoders import ByteLevel as ByteLevelDecoder
#from tokenizers.pre_tokenizers import ByteLevel

#tokenizer.pre_tokenizer = ByteLevel()
#tokenizer.decoder = ByteLevelDecoder()

In [8]:
tokenizer.post_processor = TemplateProcessing(
    single='<start> $A <end>',
    #pair="<start> $A <sep> $B:1 <end>:1",
    special_tokens=[
        ('<start>', tokenizer.token_to_id('<start>')),
        ('<end>', tokenizer.token_to_id('<end>')),
    ],
)

In [9]:
output = tokenizer.encode('merhaba dünya')
print(output.tokens)

['<start>', 'Ġmer', 'haba', 'ĠdÃ¼nya', '<end>']


In [10]:
tokenizer.token_to_id('<pad>')

0

In [11]:
tokenizer.token_to_id('<start>'), tokenizer.token_to_id('<end>')

(4, 5)

In [12]:
encode = lambda s: tokenizer.encode(s).ids # encoder: take a string, output a list of integers
decode = lambda t: tokenizer.decode(t, skip_special_tokens=False) # decoder: take a list of integers, output a string

In [13]:
decode(encode('merhaba dünya'))

'<start> merhaba dünya<end>'

In [14]:
decode(encode('bugün çok güzel selam dünya'))

'<start> bugün çok güzel selam dünya<end>'

In [15]:
tokenizer.decode(tokenizer.encode('merhaba dünya').ids)

' merhaba dünya'

In [16]:
tokenizer.encode('merhaba dünya').ids

[4, 2992, 42164, 1793, 5]

## Vocabulary size from tokenizer

In [17]:
tokenizer.get_vocab_size()

50257

## Overwrite existing vocab_size

In [18]:
hyperparameters['model_config']['vocab_size'] = tokenizer.get_vocab_size()

# LM Dataset
* NOTE: GIVING UP SOME DATA WHEN LIMITING CONTEXT!
* (Tokenizer cutts off some text files)
* Text files is passed in as Python list (useful for train/test splitting)

In [19]:
class Wiki_TR(Dataset):
    def __init__(self, root_dir, dataset_files_list, tokenizer, context_size):
        self.root_dir = root_dir
        self.dataset_files_list = dataset_files_list
        self.tokenizer = tokenizer
        self.context_size = context_size

        self.file_paths = [os.path.join(root_dir, file) for file in dataset_files_list if not 'combined' in file]
        self.total_files = len(self.file_paths)

        # NOTE: GIVING UP SOME DATA WHEN LIMITING CONTEXT!
        # (Tokenizer cutts off some text files)
        self.tokenizer.enable_padding(length=context_size)
        self.tokenizer.enable_truncation(max_length=context_size)
        
    def __len__(self):
        return self.total_files
    
    def __getitem__(self, idx):
        #sample = self.file_paths[idx]
        with open(self.file_paths[idx], 'r') as f:
            sample = f.read()

        #######################################
        '''
        _high = len(sample)-self.context_size
        
        if _high > 0:
            offset = random.randint(0, _high)
            sample = sample[offset:]
        '''
        #######################################
                
        _encoded = self.tokenizer.encode(sample)

        input_token_ids = torch.tensor(_encoded.ids[:-1], dtype=torch.long)
        attention_mask = torch.tensor(_encoded.attention_mask[:-1], dtype=torch.float32)
        target_token_ids = torch.tensor(_encoded.ids[1:], dtype=torch.long)
        
        input_token_ids, attention_mask, target_token_ids
        
        return input_token_ids, attention_mask, target_token_ids


print()




### Cached Implementation
* Continuous stream of tokens

In [20]:
'''
class Wiki_TR_Cached(Dataset):
    def __init__(self, tokenized_cached_corpus, context_size, start_token_id, end_token_id, start_offset=0, mode='single-step'):
        self.tokenized_cached_corpus = tokenized_cached_corpus
        self.context_size = context_size
        self.start_token = torch.tensor([start_token_id], dtype=torch.long)
        self.end_token = torch.tensor([end_token_id], dtype=torch.long)
        self.start_offset = start_offset
        self.mode = mode

        # TODO: implement context-step and random
        self.available_modes = [
            'single-step',
            'context-step',
            'random'
        ]

        if mode not in self.available_modes:
            raise NotImplementedError(f'Mode {mode} is not valid! Valid options are: {self.available_modes}')
    
    def generate_context(self, tokens, context_size, offset):
        #offset = torch.randint(low=0, high=len(tokens)-context_size, size=(1,))
        #print(f'offset', offset)
    
        start_token = self.start_token
        end_token = self.end_token
    
        selected_context_temp = tokens[offset:offset+context_size]
        
        #print(selected_context_temp[0], selected_context_temp[-1])

        """
        if selected_context_temp[0] == start_token and selected_context_temp[-1] == end_token:
        
            selected_context = tokens[offset:offset+context_size]
        
        elif selected_context_temp[0] == start_token and selected_context_temp[-1] != end_token:
            
            selected_context = tokens[offset:offset-1+context_size]
            selected_context = torch.cat((selected_context, end_token))
    
        elif selected_context_temp[0] != start_token and selected_context_temp[-1] == end_token:
            
            selected_context = tokens[offset+1:offset+context_size]
            selected_context = torch.cat((start_token, selected_context))
        
        elif selected_context_temp[0] != start_token and selected_context_temp[-1] != end_token:
            
            selected_context = tokens[offset+1:offset-1+context_size]
            selected_context = torch.cat((start_token, selected_context, end_token))
        """

        return selected_context_temp
        #return selected_context

    
    def __len__(self):
        return (len(self.tokenized_cached_corpus) - self.context_size)//self.context_size
    
    def __getitem__(self, idx):
        """
        #sample = self.file_paths[idx]
        with open(self.file_paths[idx], 'r') as f:
            sample = f.read()

        #######################################
        _high = len(sample)-self.context_size
        
        if _high > 0:
            offset = random.randint(0, _high)
            sample = sample[offset:]
        #######################################
                
        _encoded = self.tokenizer.encode(sample)

        input_token_ids = torch.tensor(_encoded.ids[:-1], dtype=torch.long)
        attention_mask = torch.tensor(_encoded.attention_mask[:-1], dtype=torch.float32)
        target_token_ids = torch.tensor(_encoded.ids[1:], dtype=torch.long)
        
        input_token_ids, attention_mask, target_token_ids
        """

        context = self.generate_context(self.tokenized_cached_corpus, self.context_size, (idx*self.context_size)+self.start_offset)
        #context = self.generate_context(self.tokenized_cached_corpus, self.context_size, idx+self.start_offset)

        input_token_ids = context[:-1]
        attention_mask = torch.ones_like(input_token_ids, dtype=torch.float32)
        target_token_ids = context[1:]
        
        return input_token_ids, attention_mask, target_token_ids
'''

'\nclass Wiki_TR_Cached(Dataset):\n    def __init__(self, tokenized_cached_corpus, context_size, start_token_id, end_token_id, start_offset=0, mode=\'single-step\'):\n        self.tokenized_cached_corpus = tokenized_cached_corpus\n        self.context_size = context_size\n        self.start_token = torch.tensor([start_token_id], dtype=torch.long)\n        self.end_token = torch.tensor([end_token_id], dtype=torch.long)\n        self.start_offset = start_offset\n        self.mode = mode\n\n        # TODO: implement context-step and random\n        self.available_modes = [\n            \'single-step\',\n            \'context-step\',\n            \'random\'\n        ]\n\n        if mode not in self.available_modes:\n            raise NotImplementedError(f\'Mode {mode} is not valid! Valid options are: {self.available_modes}\')\n    \n    def generate_context(self, tokens, context_size, offset):\n        #offset = torch.randint(low=0, high=len(tokens)-context_size, size=(1,))\n        #print

In [21]:
l1 = list(range(10))
l1

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [22]:
l1[:-1]

[0, 1, 2, 3, 4, 5, 6, 7, 8]

In [23]:
l1[1:]

[1, 2, 3, 4, 5, 6, 7, 8, 9]

### Full List of Text Files

In [24]:
full_text_files_list = os.listdir(hyperparameters['dataset_dir'])
print(f'original len: {len(full_text_files_list)}')

# downsample dataset for faster training times
full_text_files_list = full_text_files_list[:len(full_text_files_list)//4]

print(f'downsampled len: {len(full_text_files_list)}')

original len: 423135
downsampled len: 105783


In [25]:
full_ds = Wiki_TR(
    root_dir=hyperparameters['dataset_dir'],
    dataset_files_list=full_text_files_list, 
    tokenizer=tokenizer,
    context_size=hyperparameters['context_size']
)


"""
full_ds = Wiki_TR_Cached(
    tokenized_cached_corpus=cached_tokens,
    context_size=hyperparameters['context_size'],
    start_token_id=tokenizer.token_to_id('<start>'),
    end_token_id=tokenizer.token_to_id('<end>'),
    start_offset=0,
)
"""

len(full_ds)

105783

In [26]:
ex_input_token_ids, ex_attention_mask, ex_target_token_ids = full_ds[0]

In [27]:
print(len(ex_input_token_ids.tolist()))
decode(ex_input_token_ids.tolist())

255


"<start> Cengiz Han (doğum adıyla Temuçin,  – 18 Ağustos 1227), Moğol İmparatorluğu'nun kurucusu ve ilk Kağanı olan Moğol komutan ve [kaynağı | url = https://www.britannica.com/biography/Genghis-Khan | başlık = Genghis Khan; Mongol Ruler | erişimtarihi = 12 Eylül 2020 | tarih =  | çalışma =  | yayıncı = Encyclopædia Britannica | arşivurl = https://web.archive.org/web/20150618194658/https://www.britannica.com/biography/Genghis-Khan | arşivtarihi = 18 Haziran 2015 | ölüurl = hayır }} Hükümdarlığı döneminde gerçekleştirdiği hiçbir savaşı kaybetmeyen Cengiz Han, dünya tarihinin en büyük askeri liderlerinden birisi olarak kabul edilmektedir. \n<end><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa

In [28]:
print(len(ex_target_token_ids.tolist()))
decode(ex_target_token_ids.tolist())

255


" Cengiz Han (doğum adıyla Temuçin,  – 18 Ağustos 1227), Moğol İmparatorluğu'nun kurucusu ve ilk Kağanı olan Moğol komutan ve [kaynağı | url = https://www.britannica.com/biography/Genghis-Khan | başlık = Genghis Khan; Mongol Ruler | erişimtarihi = 12 Eylül 2020 | tarih =  | çalışma =  | yayıncı = Encyclopædia Britannica | arşivurl = https://web.archive.org/web/20150618194658/https://www.britannica.com/biography/Genghis-Khan | arşivtarihi = 18 Haziran 2015 | ölüurl = hayır }} Hükümdarlığı döneminde gerçekleştirdiği hiçbir savaşı kaybetmeyen Cengiz Han, dünya tarihinin en büyük askeri liderlerinden birisi olarak kabul edilmektedir. \n<end><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

# Train/Test Split

In [29]:
train_size = round(len(full_ds) * hyperparameters['split_ratio'])
test_size = len(full_ds) - train_size

print(f'Full: {len(full_ds)}, train: {train_size}, test: {test_size}, (combined: {train_size+test_size})')

Full: 105783, train: 79337, test: 26446, (combined: 105783)


In [30]:
l2 = list(range(20))
l2

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

In [31]:
l2[:15]

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

In [32]:
l2[15:]

[15, 16, 17, 18, 19]

In [33]:
train_dataset, test_dataset = random_split(full_ds, [train_size, test_size])

len(train_dataset), len(test_dataset)

(79337, 26446)

In [34]:
ex_input_token_ids, ex_attention_mask, ex_target_token_ids = train_dataset[0]
print(decode(ex_input_token_ids.tolist()))
print('*'*20)
print(decode(ex_target_token_ids.tolist()))

<start> Ásmundur Sveinsson (20 Mayıs 1893, Kolsstadir, Batı İzlanda - 9 Aralık 1982, Reykjavik), İzlandalı heykeltıraş. 
1915 yılında Ásmundur Reykjavík'e taşındı ve İzlanda Teknik Okulu'na yazıldı. Orada, heykeltıraş Ríkarður Jónsson yönetiminde çıraklık eğitimi aldı. 1919'da Kopenhag, oradan Stokholm'e, Güzel Sanatlar Akademisi'nde heykeltıraş Carl Milles'le altı sene eğitim görmek üzere taşındı. 
1924'te heykeltıraş Gunnfríður Jónsdóttir'le evlendi. Eşiyle boşandıktan sonra İngrid adında bir kadınla evlendi ve ondan iki çocuğu oldu. Akademi'den mezun olduktan sonra Ásmundur Paris'e eğitimine devam etmek için taşındı ve orada heykeltıraş Charles Despiau'yla çalıştı. 
Ásmundur 1929'da İzlanda'ya dönerdönmez soyut, figüratif çalışmalar yapmaya başladı. Temaları daha çok çalışan kadın ve erkeklerdi - The Blacksmith, The Washer Women ve The Water Carrier eserlerinde olduğu gibi. 
1940'larda başlangıç noktası insan ve hayvanlar olan Ásmundur'un çalışmaları 1950'lere gelindiğinde farklılık

In [35]:
ex_input_token_ids, ex_attention_mask, ex_target_token_ids = test_dataset[0]
print(decode(ex_input_token_ids.tolist()))
print('*'*20)
print(decode(ex_target_token_ids.tolist()))

<start> 1988 Supercopa de España 21 Eylül 1988 ve 29 Eylül 1988 tarihlerinde iki maç olarak oynanan İspanya La Liga şampiyonu olan takım ile Copa del Rey şampiyonu olan takımların aralarında oynadıkları süper kupa karşılaşmasıdır. 
<end><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa

# Dataloader

In [36]:
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=hyperparameters['batch_size'], 
    shuffle=True,
    #num_workers=1, 
    pin_memory=True,
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_size=hyperparameters['batch_size'], 
    shuffle=True,
    #num_workers=1, 
    pin_memory=True,
)

In [37]:
i, a, t = next(iter(train_dataloader))
i.shape, a.shape, t.shape

(torch.Size([50, 255]), torch.Size([50, 255]), torch.Size([50, 255]))

# Model

In [38]:
model = GPT(**hyperparameters['model_config'])

print(f'Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')

Trainable parameters: 6,989,185


In [39]:
for n, m in model.named_children():
    print(f'{n} parameters: {sum(p.numel() for p in m.parameters() if p.requires_grad):,}')

token_emb parameters: 810,256
emb_dropout_prob parameters: 0
dec_layers parameters: 5,317,632
layernorm_final parameters: 768
vocab_projection parameters: 860,529


In [40]:
_batch_size, _seq_len = 1, 20

pred = model(torch.randint(low=0, high=hyperparameters['model_config']['vocab_size'], size=(_batch_size, _seq_len)))
pred.shape

torch.Size([1, 20, 50257])

# Test input with mask

In [41]:
torch.tensor([[1.0]*20 for _ in range(_batch_size)]).shape

torch.Size([1, 20])

In [42]:
pred = model(
    x_input=torch.randint(low=0, high=hyperparameters['model_config']['vocab_size'], size=(_batch_size, _seq_len)),
    pad_mask=torch.tensor([[1.0]*20 for _ in range(_batch_size)]) # all 1.0
)
pred.shape

torch.Size([1, 20, 50257])

In [43]:
ii = torch.randn(3, 5, 2)
ii

tensor([[[-0.4326,  0.1047],
         [ 0.1812,  0.5329],
         [ 0.3142,  0.5814],
         [ 0.4007, -1.6001],
         [-1.5120, -2.6613]],

        [[-0.5496, -0.2339],
         [ 1.0727,  1.2074],
         [ 1.2624,  0.1069],
         [-0.1888,  0.0077],
         [ 1.3330, -1.8660]],

        [[ 1.5329, -0.3514],
         [ 0.3858, -1.2218],
         [-1.0494,  0.3704],
         [-0.9392, -0.0094],
         [-0.3613, -0.1364]]])

In [44]:
tt = torch.randint(low=0, high=2, size=(3, 5))
print(tt.shape)
tt

torch.Size([3, 5])


tensor([[1, 0, 1, 0, 0],
        [0, 0, 1, 1, 0],
        [1, 1, 1, 0, 1]])

In [45]:
print(ii.view(-1, ii.size(-1)).shape)
print(tt.view(-1).shape)

torch.Size([15, 2])
torch.Size([15])


In [46]:
F.cross_entropy(ii.view(-1, ii.size(-1)), tt.view(-1))

tensor(0.7928)

# Training

### Without AMP 

In [47]:
PBAR_UPDATE_FREQ = 60

list_avg = lambda l: sum(l)/len(l)

In [48]:
def train_iter(dataloader, model, optimizer, criterion, scaler, epoch, clip_grad_norm, pbar_update_freq, device):
    model.train()
    
    #avg_loss = 0.0
    avg_loss = []
    count = 0

    # Accuracy
    #acc_correct = 0
    #acc_total = 0
    
    pbar = tqdm(dataloader, unit=' batch', leave=False)
    pbar.set_description(f'Epoch: {epoch}, Train')
    
    for input_token_ids_batch, attention_mask_batch, target_token_ids_batch in dataloader:

        input_token_ids_batch = input_token_ids_batch.to(device)
        attention_mask_batch = attention_mask_batch.to(device)
        target_token_ids_batch = target_token_ids_batch.to(device)

        optimizer.zero_grad(set_to_none=True)

        pred_token_ids_batch = model(
            x_input=input_token_ids_batch, 
            pad_mask=attention_mask_batch
        )

        # Combine batch and seq_len dims together to form a "longer batch"
        loss = criterion(
            pred_token_ids_batch.view(-1, pred_token_ids_batch.size(-1)),
            target_token_ids_batch.view(-1)
        )
        #loss = criterion(pred_token_ids_batch, target_token_ids_batch)

        #avg_loss += loss.item()
        avg_loss.append(loss.item())
        count += 1

        """
        _, pred_classes = torch.max(pred_token_ids_batch, 1)
        acc_correct += (pred_classes == target_token_ids_batch).sum().item()
        acc_total += target_token_ids_batch.shape[0]
        """

        if count % pbar_update_freq  == 0:
            #iter_accuracy = 100.0 * acc_correct / acc_total
            #pbar.set_postfix_str(f'Loss: {list_avg(avg_loss):.4f} Acc: {iter_accuracy:.2f}')
            pbar.set_postfix_str(f'Loss: {list_avg(avg_loss):.4f}')
            pbar.update(pbar_update_freq)
        
        loss.backward()
        #scaler.scale(loss).backward()
        
        ### CLIPPING ######
        # Prevent exploding gradients with gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
        ###################

        optimizer.step()
        #optimizer.optimizer.step()
        #optimizer.lr_step()
        #scaler.step(optimizer.optimizer)
        #scaler.update()

    pbar.close()

    #final_accuracy = 100.0 * acc_correct / acc_total
    #return list_avg(avg_loss), final_accuracy
    return list_avg(avg_loss)
    
@torch.no_grad()
def eval_iter(dataloader, model, criterion, epoch, pbar_update_freq, device):
    model.eval()
    #avg_loss = 0.0
    avg_loss = []
    count = 0

    # Accuracy
    #acc_correct = 0
    #acc_total = 0
    
    pbar = tqdm(dataloader, unit=' batch', leave=False)
    pbar.set_description(f'Epoch: {epoch}, Eval')
    
    for input_token_ids_batch, attention_mask_batch, target_token_ids_batch in dataloader:

        input_token_ids_batch = input_token_ids_batch.to(device)
        attention_mask_batch = attention_mask_batch.to(device)
        target_token_ids_batch = target_token_ids_batch.to(device)

        pred_token_ids_batch = model(
            x_input=input_token_ids_batch, 
            pad_mask=attention_mask_batch
        )

        # Combine batch and seq_len dims together to form a "longer batch"
        loss = criterion(
            pred_token_ids_batch.view(-1, pred_token_ids_batch.size(-1)),
            target_token_ids_batch.view(-1)
        )
        #loss = criterion(pred_token_ids_batch, target_token_ids_batch)
            
        #avg_loss += loss.item()
        avg_loss.append(loss.item())
        count += 1
        
        """
        _, pred_classes = torch.max(pred_token_ids_batch, 1)
        acc_correct += (pred_classes == target_token_ids_batch).sum().item()
        acc_total += target_token_ids_batch.shape[0]
        """

        if count % pbar_update_freq  == 0:
            #iter_accuracy = 100.0 * acc_correct / acc_total
            #pbar.set_postfix_str(f'Loss: {list_avg(avg_loss):.4f} Acc: {iter_accuracy:.2f}')
            pbar.set_postfix_str(f'Loss: {list_avg(avg_loss):.4f}')
            pbar.update(pbar_update_freq)

    pbar.close()
    
    #final_accuracy = 100.0 * acc_correct / acc_total
    #return list_avg(avg_loss), final_accuracy
    return list_avg(avg_loss)

### With AMP 

In [49]:
def train_iter_amp(dataloader, model, optimizer, criterion, scaler, epoch, clip_grad_norm, pbar_update_freq, grad_accum_iter, device):
    model.train()
    
    #avg_loss = 0.0
    avg_loss = []
    count = 0

    # Accuracy
    #acc_correct = 0
    #acc_total = 0
    
    pbar = tqdm(dataloader, unit=' batch', leave=False)
    pbar.set_description(f'Epoch: {epoch}, Train')
        
    for batch_idx, (input_token_ids_batch, attention_mask_batch, target_token_ids_batch) in enumerate(dataloader):

        input_token_ids_batch = input_token_ids_batch.to(device)
        attention_mask_batch = attention_mask_batch.to(device)
        target_token_ids_batch = target_token_ids_batch.to(device)

        #optimizer.zero_grad(set_to_none=True)

        with torch.autocast(device_type=device, dtype=torch.float16, enabled=True):
            pred_token_ids_batch = model(
                x_input=input_token_ids_batch, 
                pad_mask=attention_mask_batch
            )
    
            # Combine batch and seq_len dims together to form a "longer batch"
            loss = criterion(
                pred_token_ids_batch.view(-1, pred_token_ids_batch.size(-1)),
                target_token_ids_batch.view(-1)
            )
            #loss = criterion(pred_token_ids_batch, target_token_ids_batch)

        #loss.backward()
        #scaler.scale(loss).backward()
        
        ### CLIPPING ######
        # Prevent exploding gradients with gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
        ###################

        #optimizer.step()
        #optimizer.optimizer.step()
        #optimizer.lr_step()
        #scaler.step(optimizer.optimizer)
        #scaler.update()

        #optimizer.zero_grad(set_to_none=True) # set_to_none=True here can modestly improve performance
        scaler.scale(loss).backward()

        ####################
        # GRADIENT ACCUMULATION
        ####################
        if ((batch_idx + 1) % grad_accum_iter == 0) or ((batch_idx + 1) == len(dataloader)):
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True) # set_to_none=True here can modestly improve performance
        ####################
        
        #avg_loss += loss.item()
        avg_loss.append(loss.item())
        count += 1

        """
        _, pred_classes = torch.max(pred_token_ids_batch, 1)
        acc_correct += (pred_classes == target_token_ids_batch).sum().item()
        acc_total += target_token_ids_batch.shape[0]
        """

        if count % pbar_update_freq  == 0:
            #iter_accuracy = 100.0 * acc_correct / acc_total
            #pbar.set_postfix_str(f'Loss: {list_avg(avg_loss):.4f} Acc: {iter_accuracy:.2f}')
            pbar.set_postfix_str(f'Loss: {list_avg(avg_loss):.4f}')
            pbar.update(pbar_update_freq)
        
    pbar.close()

    #final_accuracy = 100.0 * acc_correct / acc_total
    #return list_avg(avg_loss), final_accuracy
    return list_avg(avg_loss)


@torch.no_grad()
def eval_iter_amp(dataloader, model, criterion, epoch, pbar_update_freq, device):
    model.eval()
    #avg_loss = 0.0
    avg_loss = []
    count = 0

    # Accuracy
    #acc_correct = 0
    #acc_total = 0
    
    pbar = tqdm(dataloader, unit=' batch', leave=False)
    pbar.set_description(f'Epoch: {epoch}, Eval')
    
    for input_token_ids_batch, attention_mask_batch, target_token_ids_batch in dataloader:

        input_token_ids_batch = input_token_ids_batch.to(device)
        attention_mask_batch = attention_mask_batch.to(device)
        target_token_ids_batch = target_token_ids_batch.to(device)

        with torch.autocast(device_type=device, dtype=torch.float16, enabled=True):
            pred_token_ids_batch = model(
                x_input=input_token_ids_batch, 
                pad_mask=attention_mask_batch
            )
    
            # Combine batch and seq_len dims together to form a "longer batch"
            loss = criterion(
                pred_token_ids_batch.view(-1, pred_token_ids_batch.size(-1)),
                target_token_ids_batch.view(-1)
            )
            #loss = criterion(pred_token_ids_batch, target_token_ids_batch)
            
        #avg_loss += loss.item()
        avg_loss.append(loss.item())
        count += 1
        
        """
        _, pred_classes = torch.max(pred_token_ids_batch, 1)
        acc_correct += (pred_classes == target_token_ids_batch).sum().item()
        acc_total += target_token_ids_batch.shape[0]
        """

        if count % pbar_update_freq  == 0:
            #iter_accuracy = 100.0 * acc_correct / acc_total
            #pbar.set_postfix_str(f'Loss: {list_avg(avg_loss):.4f} Acc: {iter_accuracy:.2f}')
            pbar.set_postfix_str(f'Loss: {list_avg(avg_loss):.4f}')
            pbar.update(pbar_update_freq)

    pbar.close()
    
    #final_accuracy = 100.0 * acc_correct / acc_total
    #return list_avg(avg_loss), final_accuracy
    return list_avg(avg_loss)

# Optimizer, Scheduler & Loss

In [50]:
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=hyperparameters['optimizer']['learning_rate'],
    betas=hyperparameters['optimizer']['optimizer_betas'],
    #weight_decay=hyperparameters['optimizer']['weight_decay']
)

# PAD ID IGNORE!!
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id('<pad>'))

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer=optimizer,
    mode=hyperparameters['learning_rate_sched_config']['mode'],
    factor=hyperparameters['learning_rate_sched_config']['factor'],
    patience=hyperparameters['learning_rate_sched_config']['patience'],
    cooldown=hyperparameters['learning_rate_sched_config']['cooldown'],
    min_lr=hyperparameters['learning_rate_sched_config']['min_lr'],
    verbose=True,
)

# Grad Scaler (FP16) For Automatic Mixed Precision (AMP)

In [51]:
scaler = torch.cuda.amp.GradScaler()

# Save/Load Functions

In [52]:
def save_model(model, optimizer, root_folder, file_name, hyperparameter_dict, metrics_dict, last_epoch, verbose=False):
    os.makedirs(root_folder, exist_ok=True)
    model_full_path = os.path.join(root_folder, file_name+'.pt')
    
    torch.save({
        'hyperparameters': hyperparameter_dict,
        'metrics': metrics_dict,
        'model_state_dict': model.state_dict(),
        #'optimizer_state_dict': optimizer.optimizer.state_dict()
        'optimizer_state_dict': optimizer.state_dict(),
        'saved_time_unix': time.time(),
        'saved_time_asctime': time.asctime(),
        'last_epoch': last_epoch,
    }, model_full_path)
    
    if verbose:
        print(f'Model: {file_name} is saved successfully')
    
    
def load_model(model, optimizer, root_folder, file_name):
    model_full_path = os.path.join(root_folder, file_name+'.pt')
    #checkpoint = torch.load(model_full_path, map_location='cpu')
    checkpoint = torch.load(model_full_path)
    
    model.load_state_dict(checkpoint['model_state_dict'], strict=True)
    
    #if optimizer is not None:
    #    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    print(f'Model: {file_name} is loaded successfully')
    
    return checkpoint

# Language Modelling Functions

In [53]:
def generate_multinomial_sampling(model, idx, max_new_tokens, max_seq_len, temp=1.0, topk=None):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -max_seq_len:]

        logits = model(x_input=idx_cond, x_cross=None, pad_mask=None)
        
        """
        if topk is not None:
            logits = logits.topk(topk, dim=1).values
        """
        
        logits = logits[:, -1, :]
                   
        if topk is not None:
            _values, _indices = F.softmax(logits/temp, dim=-1).topk(topk, dim=1)
            #probs = F.softmax(logits/temp, dim=-1).topk(topk, dim=-1).values
            probs = _values
        else:
            probs = F.softmax(logits/temp, dim=-1)
        
        # sample from the distribution
        _idx_next = torch.multinomial(probs, num_samples=1) 
        #print(_idx_next)
        #print(_indices)
        idx_next = _indices[:, _idx_next[0]]
        #print(f'idx_next: {idx_next}')
        idx = torch.cat((idx, idx_next), dim=1)
    return idx


def generate_text(model, device, n_tokens, temp, context=None, topk=None, remove_newlines=True):
    """
    Assumes that model is already on "device"
    """
    model.eval()
    
    if context is None:
        #context = torch.zeros((1, 1), dtype=torch.long, device=device)
        context = torch.tensor([encode('')], dtype=torch.long, device=device)
    else:
        context = torch.tensor([encode(context)], dtype=torch.long, device=device)

    _genereted = generate_multinomial_sampling(
            model,
            context, 
            max_new_tokens=n_tokens, 
            max_seq_len=hyperparameters['context_size'], 
            temp=temp,
            topk=topk
        )
    
    generated = decode(
        _genereted[0].tolist()
    )
    
    if remove_newlines:
        generated = generated.replace('\n', '')
        
    print(generated)

### Start Epoch

In [54]:
START_EPOCH = 1

### Load Pre-Trained Model (Optional)

In [55]:
#checkpoint = load_model(model, optimizer, './saved_models', f"{hyperparameters['model_base_name']}_best")
#START_EPOCH = checkpoint['last_epoch'] + 1

# Start Training

In [56]:
metrics = {
    'best_state_dict': None,
    'best_epoch': -1,
    'best_val_loss': float('inf'), # TODO: val_acc can be also used
    'best_val_acc': 0.0,
}

In [57]:
def start_training(start_epoch=1):
    print(f'Start trainin from epoch: {start_epoch}')
    
    model.to(device)
    criterion.to(device)
    
    for epoch in range(start_epoch, hyperparameters['num_epochs']+1):
        
        #train_loss = train_iter(
        train_loss = train_iter_amp(
            train_dataloader, 
            model, 
            optimizer, 
            criterion, 
            scaler, 
            epoch,
            hyperparameters['clip_grad_norm'], 
            PBAR_UPDATE_FREQ, 
            hyperparameters['grad_accum_iter'],
            device
        )
        
        #val_loss = eval_iter(
        val_loss = eval_iter_amp(
            test_dataloader, 
            model, 
            criterion, 
            epoch, 
            PBAR_UPDATE_FREQ, 
            device
        )

        """
        generate_text(
            model=model, 
            device=device,
            n_tokens=150, 
            temp=0.65, 
            context='Bu film ',
            topk=150,
            remove_newlines=False
        )
        
        print('*'*10)
        
        generate_text(
            model=model, 
            device=device,
            n_tokens=150, 
            temp=0.85, 
            context='Bu film ',
            topk=150,
            remove_newlines=False
        )
        """
        
        print(f'Epoch: {epoch}, [LOSS] train: {train_loss:.4f}, val: {val_loss:.4f}')

        
        save_model(
            model, 
            optimizer, 
            './saved_models', 
            f"{hyperparameters['model_base_name']}_checkpoint",
            hyperparameters, 
            metrics,
            epoch, 
            verbose=False
        )

        
        if val_loss < metrics['best_val_loss']:
            metrics['best_state_dict'] = model.state_dict().copy()
            metrics['best_epoch'] = epoch
            metrics['best_val_loss'] = val_loss
    
            save_model(
                model, 
                optimizer, 
                './saved_models', 
                f"{hyperparameters['model_base_name']}_best", 
                hyperparameters, 
                metrics,
                epoch,
                verbose=True
            )
        

        """
        if val_acc > metrics['best_val_acc']:
            metrics['best_state_dict'] = model.state_dict().copy()
            metrics['best_epoch'] = epoch
            metrics['best_val_acc'] = val_acc
    
            save_model(
                model, 
                optimizer, 
                './saved_models', 
                'LM_Sentiment_best', 
                hyperparameters, 
                metrics,
                epoch,
                verbose=True
            )

        """
        
        # LR Scheduling
        lr_scheduler.step(val_loss)

### Start Training

In [58]:
start_training(start_epoch=START_EPOCH)

Start trainin from epoch: 1


                                                                                                                        

Epoch: 1, [LOSS] train: 8.4243, val: 8.0453
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 2, [LOSS] train: 7.8181, val: 7.5801
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 3, [LOSS] train: 7.4129, val: 7.2548
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 4, [LOSS] train: 7.0994, val: 7.0219
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 5, [LOSS] train: 6.9249, val: 6.9030
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 6, [LOSS] train: 6.8123, val: 6.9253


                                                                                                                        

Epoch: 7, [LOSS] train: 6.7219, val: 6.6735
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 8, [LOSS] train: 6.5870, val: 6.5685
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 9, [LOSS] train: 6.4952, val: 6.5022
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 10, [LOSS] train: 6.4146, val: 6.4094
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 11, [LOSS] train: 6.4183, val: 6.3712
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 12, [LOSS] train: 6.3092, val: 6.3045
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 13, [LOSS] train: 6.2360, val: 6.2228
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 14, [LOSS] train: 6.2069, val: 6.1640
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 15, [LOSS] train: 6.1501, val: 6.1459
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 16, [LOSS] train: 6.1133, val: 6.1361
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 17, [LOSS] train: 6.0448, val: 6.0326
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 18, [LOSS] train: 5.9938, val: 5.9833
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 19, [LOSS] train: 5.9363, val: 5.9551
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 20, [LOSS] train: 5.9134, val: 5.9577


                                                                                                                        

Epoch: 21, [LOSS] train: 5.8519, val: 5.8590
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 22, [LOSS] train: 5.8044, val: 5.8398
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 23, [LOSS] train: 5.7889, val: 5.7794
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 24, [LOSS] train: 5.7278, val: 5.7478
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 25, [LOSS] train: 5.8036, val: 5.7444
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 26, [LOSS] train: 5.6796, val: 5.6826
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


                                                                                                                        

Epoch: 27, [LOSS] train: 5.8086, val: 5.7112


                                                                                                                        

Epoch: 28, [LOSS] train: 5.6327, val: 5.6471
Model: LM_GPTSmall_Wiki_TR_best is saved successfully


Epoch: 29, Train:  23%|█████████▉                                  | 360/1587 [01:31<05:11,  3.94 batch/s, Loss: 5.5870]

KeyboardInterrupt: 

# Revert to Best Checkpoint (Opitional)

In [None]:
#model.load_state_dict(metrics['best_state_dict'])
#print(f'Loaded Epoch: {metrics["best_epoch"]}, with val acc: {metrics["best_val_loss"]:.2f}')

# Last Save (Optional)

In [None]:
"""
save_model(
    model, 
    optimizer, 
    './saved_models', 
    f"{hyperparameters['model_base_name']}_latest",
    hyperparameters, 
    metrics,
)
"""

In [None]:
"""
cp = load_model(
    model, 
    optimizer, 
    './saved_models', 
    f"{hyperparameters['model_base_name']}_latest",
)
"""