In [1]:
import torch
from torch import nn
from torch.nn import functional
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, DataCollatorWithPadding
import torch.optim as optim

from dataset.create_dataset import create_data_loader
from layers.model import Transformer, AutoregressiveWrapper
from transformers import BertTokenizer

import numpy as np

from test_model.test_model import TestModel
from tqdm import tqdm

import wandb

import time
from torch.profiler import profile, record_function, ProfilerActivity


In [2]:
CONFIG = {
    "architecture": "Transformer", # Wandb only
    "dataset": "wikitext-103-raw-v1", # Wandb only
    "batch_size": 10,
    "accumulation": 6,
    "embedding_size": 768,
    "max_sequence_length": 256,
    "number_of_layers": 12,
    "number_of_heads": 12,
    "additional_feed_forward_layers": 0,
    "extention_factor": 4,
    "attention_activation": "softmax",
    "dropout_rate": 0.1,
    'train_size': 2**20,
    'test_size': 128,
    'model_path': None, # "savepoints/likely-durian-120"
}
CONFIG["lr"] = 0.001 / np.sqrt(CONFIG["batch_size"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

cuda


In [3]:
def test_model(pipeline, model, loss_function):
    model.eval()
    total_loss = 0

    for batch in test_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        model_output, target = pipeline(input_ids, attention_mask)

        loss = loss_function(model_output.transpose(1, 2), target)

        total_loss += float(loss)

    total_loss /= len(test_dataloader)# * CONFIG["batch_size"]

    return total_loss


def train(CONFIG, pipeline, model, optimizer, loss_function, model_tester, wandb):
    train_config = {
        "test_every": 1024 // CONFIG["batch_size"],
        "log_traing_metrics_every": 64 // CONFIG["batch_size"],
    }

    train_time = 0
    test_time = 0
    last_moment = time.time()

    model.train()

    batch_num = 0
    train_losses = []
    
    for batch in tqdm(train_dataloader, desc="Training Progress"):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        model_output, target = pipeline(input_ids, attention_mask)
        loss = loss_function(model_output.transpose(1, 2), target)
        train_losses.append(float(loss))
        
        loss = loss / CONFIG['accumulation']
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        
        model_output.detach()
        loss.detach()

        if ((batch_num + 1) % CONFIG['accumulation'] == 0) or (batch_num + 1 == len(train_dataloader)):
            optimizer.step()
            optimizer.zero_grad()

        batch_num += 1

        if batch_num % train_config["log_traing_metrics_every"] == 0:
            train_time += time.time() - last_moment
            last_moment = time.time()

            datapoints_processed_total = batch_num * CONFIG["batch_size"]
            wandb.log({
                "train_loss": sum(train_losses[-train_config["log_traing_metrics_every"]:]) / train_config["log_traing_metrics_every"],
                "datapoints_processed_total": datapoints_processed_total,
                "train_time": train_time,
            })

        if batch_num % train_config["test_every"] == 0:
            train_time += time.time() - last_moment
            last_moment = time.time()

            metrics = model_tester.test_model(pipeline, test_dataloader)
            test_loss = metrics['loss']
            bleu = metrics['bleu']
            #bert_f1 = metrics['bert_f1']
            rouge1 = metrics['rouge1']
            rouge2 = metrics['rouge2']
            rougeL = metrics['rougeL']

            test_time += time.time() - last_moment
            last_moment = time.time()

            datapoints_processed_total = batch_num * CONFIG["batch_size"]

            wandb.log({
                "test_loss": test_loss,
                "bleu": bleu,
                #"bert_f1": bert_f1,
                "rouge1": rouge1,
                "rouge2": rouge2,
                "rougeL": rougeL,
                "datapoints_processed_total": datapoints_processed_total,
                "test_time": test_time,
            })

In [4]:
def create_model(CONFIG, model_path=None):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    number_of_tokens = tokenizer.vocab_size
    
    model = Transformer(
        embedding_size=CONFIG["embedding_size"],
        number_of_tokens=number_of_tokens,
        number_of_heads=CONFIG["number_of_heads"],
        number_of_layers=CONFIG["number_of_layers"],
        extention_factor=CONFIG["extention_factor"],
        additional_feed_forward_layers=CONFIG["additional_feed_forward_layers"],
        attention_activation=CONFIG["attention_activation"],
        dropout_rate=CONFIG["dropout_rate"],
        max_sequence_length=CONFIG["max_sequence_length"]
    ).to(device)
    if model_path:
        model.load_state_dict(torch.load(model_path))
        
    pipeline = AutoregressiveWrapper(model).to(device)
    loss_function = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=CONFIG["lr"])
    model_tester = TestModel(tokenizer, model)

    return pipeline, model, optimizer, loss_function, model_tester

In [None]:
for i in range(1):
    train_dataloader, test_dataloader, _ = create_data_loader(batch_size=CONFIG["batch_size"],
                                    max_sequence_size=CONFIG["max_sequence_length"],
                                    train_size=CONFIG['train_size'], test_size=CONFIG['test_size'])

    wandb_run = wandb.init(
        # set the wandb project where this run will be logged
        project="transformer",
        tags=["long_training_testing"],
        
        # track hyperparameters and run metadata
        config=CONFIG
    )
    
    load_path = CONFIG['model_path']

    pipeline, model, optimizer, loss_function, model_tester = create_model(CONFIG, load_path)
    num_parameters, num_trainable_parameters, memory_allocated = pipeline.count_parameters() 
    print('number of parameters =', num_parameters)
    print('number of trainable parameters =', num_trainable_parameters)
    print('memory allocated in GB =', memory_allocated)
    prof = train(CONFIG, pipeline, model, optimizer, loss_function, model_tester, wandb)


Found cached dataset wikitext (C:/Users/skoro/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at C:\Users\skoro\.cache\huggingface\datasets\wikitext\wikitext-103-raw-v1\1.0.0\a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126\cache-be33ed5b4b4522e7.arrow
Loading cached processed dataset at C:\Users\skoro\.cache\huggingface\datasets\wikitext\wikitext-103-raw-v1\1.0.0\a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126\cache-26a133c3c002dbd8.arrow
[34m[1mwandb[0m: Currently logged in as: [33mskorodumov-work[0m ([33m8667[0m). Use [1m`wandb login --relogin`[0m to force relogin


number of parameters = 131968314
number of trainable parameters = 131968314
memory allocated in GB = 0.4916202798485756


Training Progress:   2%|▍                      | 2121/104858 [23:40<18:22:18,  1.55it/s]

In [6]:
PATH = "savepoints/" + wandb_run.name
torch.save(model.state_dict(), PATH)

In [None]:
wandb.finish()

In [5]:
pipeline, model, optimizer, loss_function, model_tester = create_model(CONFIG, CONFIG['model_path'])

In [6]:
train_dataloader, test_dataloader, _ = create_data_loader(batch_size=CONFIG["batch_size"],
                                    max_sequence_size=CONFIG["max_sequence_length"],
                                    train_size=CONFIG['train_size'], test_size=CONFIG['test_size'])

for batch in train_dataloader:
    break

Found cached dataset wikitext (C:/Users/skoro/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at C:\Users\skoro\.cache\huggingface\datasets\wikitext\wikitext-103-raw-v1\1.0.0\a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126\cache-be33ed5b4b4522e7.arrow
Loading cached processed dataset at C:\Users\skoro\.cache\huggingface\datasets\wikitext\wikitext-103-raw-v1\1.0.0\a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126\cache-26a133c3c002dbd8.arrow


In [27]:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    i = 0
    for batch in train_dataloader:
        i += 1

        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        
        print(input_ids.size())

        with record_function("model_inference"):
            model_output, target = pipeline(input_ids, attention_mask)

        if i == 1:
            break

torch.Size([9, 256])


In [28]:
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=-1))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        24.75%      31.094ms        98.50%     123.752ms     123.752ms       4.887ms         2.44%     199.077ms     199.077ms             1  
                                           aten::linear         7.32%       9.202ms        27.97%      35.146ms      74.938us       4.364ms         2.18%     118.161ms     251.942us           469  
         

In [10]:
from random import randint

def predict_next(pipeline, input_text, num_predicted_tokens, tokenizer):
    input_tokens = tokenizer.encode(input_text, return_tensors="pt")
    input_tokens = input_tokens[:, :-1].to(device)
    pipeline.eval()
    
    for i in range(num_predicted_tokens):
        mask = torch.ones_like(input_tokens)

        with torch.no_grad():
            probabilities = pipeline.next_token_probabilities(input_tokens, mask)
        
        answer = probabilities.argsort(dim=-1)[:, -randint(1, 2)].unsqueeze(0)
        input_tokens = torch.cat((input_tokens, answer), dim=1)
        
    return tokenizer.decode(input_tokens[0])

In [11]:
input_text = """
london is the capital of
"""
num_predicted_tokens = 50
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

answer = predict_next(pipeline, input_text, num_predicted_tokens, tokenizer)

In [12]:
answer

"[CLS] london is the capital of the country, and is the home to a family of local residents. [SEP] are a family, but only the home is a family. [SEP] are the family's home, and the home is a small room, which is a small room. the"

In [None]:
"""
Performance 
first approach
batch_size, accumulation, num DP, time
4           4             2**9    40s
8           1             2**9    37s
6           4             2**9    34s
4           6             2**9    39s
5           5             2**9    37s
8           2             2**9    34s
10          2             2**9    33s
8           3             2**9    33s
11          1             2**9    34s
14          1             2**9    34s
second approach
batch_size, accumulation, num DP, time
14          1             2**9    32s
8           3             2**9    34s
8           4             2**9    33s
10          4             2**9    33s
"""

In [None]:
"""
GPU comparison
1. GOOGLE COLAB
-100 CU are given for 10 dollars
-V100, 32 GB, 5 CU per hour, 20 hours, 706 DP, 14120 DP total, 1412 DP per dollar
-A100, 40 GB, 15 CU per hour, 6.5 hours, 2179 DP, 14163 DP total, 1416 DP per dollar

2. VAST.AI
-RTX 4090, 24 GB, 0.4 dollars per hour, 1720 DP, 4300 DP per dollar, 177% while training on 4 GPUs
-RTX 3080, 10 GB, 0.117 dollars per hour, 1022 DP, 8717 DP per dollar, 177% while training on 4 GPUs
-RTX 3090, 24 GB, 0.23 dollars per hour, 1071 DP, 4652 DP per dollar, 177% while training on 4 GPUs
-A5000, 24 GB, 0.22 dollars per hour, 1061 DP, 4686 DP per dollar, 389% while training on 4 GPUs
"""

In [None]:
"""
TODO
optimize
-understand time wasting
-choose batch size


beam search

new dataset
-find new dataset for big training
-make some analysis
-avg, mean, each percentile sequence length
-find info in the net

model parameters
-max sequence length
-estimate batch size
-choose parameters for n GB VRAM
-choose model parameters for 5-6 GB model

filter model
-

new conversation dataset 
-analyse
-understand how to train and test conversation model
-try to train

metrics
-create metrics for conversation

multiple GPUs
-write code to run model on several gpus
-try to run model on vast.ai
-also model must be saved per epoch

site
-create django project with model which is placed on google colab
-create 


"""

In [None]:
"""
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        24.75%      31.094ms        98.50%     123.752ms     123.752ms       4.887ms         2.44%     199.077ms     199.077ms             1  
                                           aten::linear         7.32%       9.202ms        27.97%      35.146ms      74.938us       4.364ms         2.18%     118.161ms     251.942us           469  
                                            aten::addmm        11.76%      14.781ms        13.69%      17.198ms      36.670us     106.835ms        53.31%     108.917ms     232.232us           469  
                                           aten::matmul         6.70%       8.419ms        19.93%      25.042ms      86.951us       3.189ms         1.59%      20.345ms      70.642us           288  
                                      aten::masked_fill         2.56%       3.218ms         9.28%      11.657ms      80.951us       1.305ms         0.65%      14.171ms      98.410us           144  
                                              aten::bmm         4.73%       5.937ms         4.73%       5.937ms      20.615us      11.271ms         5.62%      11.271ms      39.135us           288  
                                              aten::sub         2.30%       2.887ms         2.30%       2.887ms      10.024us       9.934ms         4.96%       9.934ms      34.493us           288  
                                           aten::expand         6.88%       8.645ms         7.23%       9.082ms       6.149us       4.384ms         2.19%       6.691ms       4.530us          1477  
                                     aten::masked_fill_         1.66%       2.085ms         2.25%       2.826ms      19.625us       5.498ms         2.74%       6.153ms      42.729us           144  
                                             aten::tril         1.07%       1.345ms         1.07%       1.345ms       9.340us       5.948ms         2.97%       5.948ms      41.306us           144  
                                            aten::clone         1.15%       1.445ms         3.35%       4.213ms      29.055us     647.000us         0.32%       5.442ms      37.531us           145  
                                            aten::copy_         1.63%       2.051ms         1.63%       2.051ms      13.765us       5.083ms         2.54%       5.083ms      34.114us           149  
                                             aten::rsub         0.75%     946.000us         1.84%       2.308ms      16.028us     428.000us         0.21%       4.786ms      33.236us           144  
                                          aten::softmax         0.57%     716.000us         1.82%       2.291ms      15.910us     425.000us         0.21%       4.265ms      29.618us           144  
                                         aten::_softmax         1.25%       1.575ms         1.25%       1.575ms      10.938us       3.840ms         1.92%       3.840ms      26.667us           144  
                                              aten::mul         0.94%       1.185ms         0.94%       1.185ms       8.229us       3.594ms         1.79%       3.594ms      24.958us           144  
                                                aten::t         2.69%       3.379ms         4.81%       6.037ms      12.872us       1.402ms         0.70%       3.477ms       7.414us           469  
                                       aten::as_strided         0.99%       1.239ms         0.99%       1.239ms       0.548us       3.476ms         1.73%       3.476ms       1.537us          2262  
                                              aten::div         1.54%       1.941ms         1.54%       1.941ms      13.479us       3.397ms         1.69%       3.397ms      23.590us           144  
                                        aten::ones_like         1.04%       1.309ms         2.61%       3.282ms      22.792us     650.000us         0.32%       3.219ms      22.354us           144  
                                              aten::add         0.25%     318.000us         0.25%     318.000us      12.720us       2.853ms         1.42%       2.853ms     114.120us            25  
                                          aten::reshape         3.00%       3.767ms         4.33%       5.436ms       9.421us       1.915ms         0.96%       2.799ms       4.851us           577  
                                       aten::layer_norm         0.11%     138.000us         1.00%       1.255ms      50.200us      71.000us         0.04%       2.734ms     109.360us            25  
                                        aten::transpose         2.55%       3.202ms         3.02%       3.799ms       6.197us       1.797ms         0.90%       2.708ms       4.418us           613  
                                aten::native_layer_norm         0.72%     907.000us         0.89%       1.117ms      44.680us       2.478ms         1.24%       2.663ms     106.520us            25  
                                            aten::fill_         0.54%     680.000us         0.54%     680.000us       4.722us       1.923ms         0.96%       1.923ms      13.354us           144  
                                       aten::empty_like         1.30%       1.635ms         2.07%       2.596ms       8.294us       1.130ms         0.56%       1.596ms       5.099us           313  
                                          aten::dropout         0.06%      75.000us         0.47%     587.000us      48.917us      34.000us         0.02%       1.520ms     126.667us            12  
                                   aten::native_dropout         0.26%     324.000us         0.41%     512.000us      42.667us       1.381ms         0.69%       1.486ms     123.833us            12  
                                             aten::view         2.18%       2.743ms         2.18%       2.743ms       2.768us       1.482ms         0.74%       1.482ms       1.495us           991  
                                               aten::eq         1.44%       1.809ms         1.44%       1.809ms      12.562us       1.294ms         0.65%       1.294ms       8.986us           144  
                                               aten::to         0.02%      21.000us         0.39%     491.000us      98.200us      14.000us         0.01%       1.170ms     234.000us             5  
                                         aten::_to_copy         0.04%      51.000us         0.37%     470.000us     117.500us      18.000us         0.01%       1.156ms     289.000us             4  
                                              aten::cat         0.29%     368.000us         0.32%     401.000us      28.643us     895.000us         0.45%     910.000us      65.000us            14  
                                   aten::_reshape_alias         1.29%       1.623ms         1.29%       1.623ms       2.818us     867.000us         0.43%     867.000us       1.505us           576  
                                        aten::unsqueeze         0.89%       1.119ms         1.05%       1.315ms       9.132us     426.000us         0.21%     639.000us       4.438us           144  
                                     aten::_unsafe_view         0.66%     831.000us         0.66%     831.000us       2.875us     440.000us         0.22%     440.000us       1.522us           289  
                                            aten::empty         0.48%     598.000us         0.48%     598.000us       2.694us     319.000us         0.16%     319.000us       1.437us           222  
                                    aten::empty_strided         0.46%     576.000us         0.46%     576.000us       3.349us     266.000us         0.13%     266.000us       1.547us           172  
enumerate(DataLoader)#_SingleProcessDataLoaderIter._...         0.88%       1.108ms         1.12%       1.410ms       1.410ms      45.000us         0.02%     188.000us     188.000us             1  
                                        aten::embedding         0.03%      32.000us         0.10%     121.000us     121.000us       6.000us         0.00%     123.000us     123.000us             1  
                                     aten::index_select         0.02%      23.000us         0.02%      29.000us      29.000us      93.000us         0.05%      95.000us      95.000us             1  
                                           aten::select         0.08%     104.000us         0.09%     107.000us       5.944us      48.000us         0.02%      78.000us       4.333us            18  
                                            aten::slice         0.06%      75.000us         0.06%      81.000us       8.100us      29.000us         0.01%      44.000us       4.400us            10  
                                            aten::stack         0.03%      38.000us         0.10%     122.000us      61.000us       9.000us         0.00%      32.000us      16.000us             2  
                                           aten::narrow         0.01%      14.000us         0.03%      33.000us      16.500us       7.000us         0.00%      15.000us       7.500us             2  
                                          aten::detach_         0.01%       7.000us         0.01%      10.000us       5.000us       7.000us         0.00%       9.000us       4.500us             2  
                                             aten::item         0.01%      11.000us         0.01%      14.000us      14.000us       4.000us         0.00%       5.000us       5.000us             1  
                                           aten::detach         0.00%       4.000us         0.01%       7.000us       7.000us       4.000us         0.00%       5.000us       5.000us             1  
                                       aten::lift_fresh         0.00%       3.000us         0.00%       3.000us       1.500us       3.000us         0.00%       3.000us       1.500us             2  
                                                detach_         0.00%       3.000us         0.00%       3.000us       1.500us       2.000us         0.00%       2.000us       1.000us             2  
                                          aten::random_         0.02%      19.000us         0.02%      19.000us      19.000us       1.000us         0.00%       1.000us       1.000us             1  
                              aten::_local_scalar_dense         0.00%       3.000us         0.00%       3.000us       3.000us       1.000us         0.00%       1.000us       1.000us             1  
                                          aten::resize_         0.00%       5.000us         0.00%       5.000us       5.000us       1.000us         0.00%       1.000us       1.000us             1  
                                                 detach         0.00%       3.000us         0.00%       3.000us       3.000us       1.000us         0.00%       1.000us       1.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 125.636ms
Self CUDA time total: 200.421ms
"""