# Mastering BERT Model: Building it from Scratch with Pytorch

Creating and Exploring a BERT model from its most basic form, which is building it from the ground using pytorch

Source: https://medium.com/data-and-beyond/complete-guide-to-building-bert-model-from-sratch-3e6562228891

## Preparation

- Download data
- Import packages
- Prepare datasets

In [None]:
!pip install transformers datasets tokenizers
!(if [ ! -d "./datasets" ]; then \
    echo "Downloading datasets"; \
    wget http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip; \
    unzip -qq cornell_movie_dialogs_corpus.zip; \
    rm cornell_movie_dialogs_corpus.zip; \
    mkdir datasets; \
    mv cornell\ movie-dialogs\ corpus/movie_conversations.txt ./datasets; \
    mv cornell\ movie-dialogs\ corpus/movie_lines.txt ./datasets; \
    rm -rf cornell\ movie-dialogs\ corpus; \
fi)

In [None]:
import os
from pathlib import Path
import random
from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizer
import tqdm
from torch.utils.data import DataLoader

from bert_dataset import BERTDataset
from bert_model import BERT, BERTLM
from bert_trainer import BERTTrainer

import torch

### setting seed
torch.manual_seed(42)
torch.cuda.manual_seed(42)

### setting device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

MAX_LEN = 64

### loading all data into memory
corpus_movie_conv = './datasets/movie_conversations.txt'
corpus_movie_lines = './datasets/movie_lines.txt'
with open(corpus_movie_conv, 'r', encoding='iso-8859-1') as c:
    conv = c.readlines()
with open(corpus_movie_lines, 'r', encoding='iso-8859-1') as l:
    lines = l.readlines()

### splitting text using special lines
lines_dic = {}
for line in lines:
    objects = line.split(" +++$+++ ")
    lines_dic[objects[0]] = objects[-1]

### generate question answer pairs
pairs = []
for con in conv:
    ids = eval(con.split(" +++$+++ ")[-1])
    for i in range(len(ids)):
        qa_pairs = []
        
        if i == len(ids) - 1:
            break

        first = lines_dic[ids[i]].strip()  
        second = lines_dic[ids[i+1]].strip() 

        qa_pairs.append(' '.join(first.split()[:MAX_LEN]))
        qa_pairs.append(' '.join(second.split()[:MAX_LEN]))
        pairs.append(qa_pairs)

print(pairs[20])

## Prepare modules

- Make WordPiece tokenizer
- Prepare a dataset loader
- Build a model

In [None]:
# WordPiece tokenizer

### save data as txt file
if not os.path.exists('./data'):
    os.mkdir('./data')
    text_data = []
    file_count = 0

    for sample in tqdm.tqdm([x[0] for x in pairs]):
        text_data.append(sample)

        # once we hit the 10K mark, save to file
        if len(text_data) == 10000:
            with open(f'./data/text_{file_count}.txt', 'w', encoding='utf-8') as fp:
                fp.write('\n'.join(text_data))
            text_data = []
            file_count += 1

paths = [str(x) for x in Path('./data').glob('**/*.txt')]

### training own tokenizer
tokenizer = BertWordPieceTokenizer(
    clean_text=True,
    handle_chinese_chars=False,
    strip_accents=False,
    lowercase=True
)

tokenizer.train( 
    files=paths,
    vocab_size=30_000, 
    min_frequency=5,
    limit_alphabet=1000, 
    wordpieces_prefix='##',
    special_tokens=['[PAD]', '[CLS]', '[SEP]', '[MASK]', '[UNK]']
    )

if not os.path.exists('./bert-it-1'):
    os.mkdir('./bert-it-1')
    tokenizer.save_model('./bert-it-1', 'bert-it')
tokenizer = BertTokenizer.from_pretrained('./bert-it-1/bert-it-vocab.txt', local_files_only=True)

In [None]:
train_data = BERTDataset(
   pairs, seq_len=MAX_LEN, tokenizer=tokenizer)

train_loader = DataLoader(
   train_data, batch_size=32, shuffle=True, pin_memory=True)

print(train_data[random.randrange(len(train_data))])

In [None]:
bert_model = BERT(
  vocab_size=len(tokenizer.vocab),
  d_model=768,
  n_layers=12,
  heads=12,
  dropout=0.1,
  device=device
)

## Train BERT from scratch

- Run training iteration

In [None]:
bert_lm = BERTLM(bert_model, len(tokenizer.vocab))
bert_trainer = BERTTrainer(bert_lm, train_loader, device=device)
epochs = 20

for epoch in range(epochs):
  bert_trainer.train(epoch)