<a target="_blank" href="https://colab.research.google.com/github/mHemaAP/S17/blob/main/bert_transformer_train.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [1]:
!git clone https://github.com/mHemaAP/S17.git

Cloning into 'S17'...
remote: Enumerating objects: 316, done.[K
remote: Counting objects: 100% (2/2), done.[K
remote: Compressing objects: 100% (2/2), done.[K
remote: Total 316 (delta 0), reused 0 (delta 0), pack-reused 314[K
Receiving objects: 100% (316/316), 16.19 MiB | 16.20 MiB/s, done.
Resolving deltas: 100% (8/8), done.


In [2]:
%cd S17
%ls

/content/S17
[0m[01;34mBERT_data[0m/                    names.tsv           [01;34mtransformer_model[0m/
bert_transformer_train.ipynb  [01;34mpizza_steak_sushi[0m/  values.tsv
[01;34mGPT_data[0m/                     README.md           vit_transformer_train.ipynb
gpt_transformer_train.ipynb   [01;34msuper_repo[0m/         vocab.txt


In [3]:
import torch
import random
import numpy as np
from collections import Counter
from os.path import exists
import torch.optim as optim
import torch.nn as nn

In [4]:
from transformer_model.common_model import Transformer
from transformer_model.datamodules.bert_datamodule import SentencesDataset, create_sentences_and_vocab
from transformer_model.models.bert.bert_train import bert_train

Number of patches (N) with image height (H=224), width (W=224) and patch size (P=16): 196
Input shape (single 2D image): (224, 224, 3)
Output shape (single 2D image flattened into patches): (196, 768)


In [5]:
print('Initializing data for BERT...')
batch_size = 1024
seq_len = 20
embed_size = 128
inner_ff_size = embed_size * 4
n_heads = 8
n_code = 8
n_vocab = 40000
dropout = 0.1
# n_workers = 12

optim_kwargs = {'lr':1e-4, 'weight_decay':1e-4, 'betas':(.9,.999)}

#1) Configure text
print('Configuring Text...')
sentence_path = 'BERT_data/training.txt'
vocab_path = "vocab.txt"

sentences, vocab = create_sentences_and_vocab(sentence_path, vocab_path)
print('Creating Dataset...')
dataset = SentencesDataset(sentences, vocab, seq_len)
kwargs = {'shuffle':True,  'drop_last':True, 'pin_memory':True, 'batch_size':batch_size}
data_loader = torch.utils.data.DataLoader(dataset, **kwargs)

Initializing data for BERT...
Configuring Text...
tokenizing sentences...
creating/loading vocab...
Creating Dataset...


In [7]:
print('Initializing BERT Transformer model...')
bert_model = Transformer(n_code=n_code, n_heads=n_heads, embed_size=embed_size,
                    inner_ff_size=inner_ff_size, n_embeddings=len(dataset.vocab),
                    seq_len=seq_len, dropout=dropout, algorithm="BERT")
bert_model = bert_model.cuda()

print('Initializing Optimizer and Loss functions...')
optimizer = optim.Adam(bert_model.parameters(), **optim_kwargs)
loss_model = nn.CrossEntropyLoss(ignore_index=dataset.IGNORE_IDX)

bert_model = bert_train(bert_model, optimizer, data_loader, loss_model)

print('Saving Embeddings...')
N = 3000
np.savetxt('values.tsv',
           np.round(bert_model.embeddings.weight.detach().cpu().numpy()[0:N], 2),
                    delimiter='\t', fmt='%1.2f')
s = [dataset.rvocab[i] for i in range(N)]
open('names.tsv', 'w+').write('\n'.join(s) )

print('Training end')

Initializing BERT Transformer model...
Initializing Optimizer and Loss functions...
Training BERT...
it: 0  | loss 10.2  | Δw: 1.254
it: 10  | loss 9.47  | Δw: 0.573
it: 20  | loss 9.27  | Δw: 0.38
it: 30  | loss 9.08  | Δw: 0.306
it: 40  | loss 8.95  | Δw: 0.248
it: 50  | loss 8.78  | Δw: 0.226
it: 60  | loss 8.63  | Δw: 0.197
it: 70  | loss 8.49  | Δw: 0.196
it: 80  | loss 8.32  | Δw: 0.184
it: 90  | loss 8.19  | Δw: 0.178
it: 100  | loss 8.07  | Δw: 0.174
it: 110  | loss 7.9  | Δw: 0.162
it: 120  | loss 7.75  | Δw: 0.156
it: 130  | loss 7.64  | Δw: 0.157
it: 140  | loss 7.48  | Δw: 0.148
it: 150  | loss 7.49  | Δw: 0.143
it: 160  | loss 7.34  | Δw: 0.147
it: 170  | loss 7.23  | Δw: 0.143
it: 180  | loss 7.14  | Δw: 0.139
it: 190  | loss 7.0  | Δw: 0.138
it: 200  | loss 6.9  | Δw: 0.135
it: 210  | loss 6.85  | Δw: 0.135
it: 220  | loss 6.77  | Δw: 0.13
it: 230  | loss 6.69  | Δw: 0.133
it: 240  | loss 6.72  | Δw: 0.136
it: 250  | loss 6.7  | Δw: 0.131
it: 260  | loss 6.62  | Δw: 0.13