In [1]:
# import modules
import json
import torch
import datasets
import tokenization

import numpy as np

from transformer.model import BertModel

In [2]:
# read configuration file
config = json.load(open('config.json'))
config

{'kor_vocab_length': 50000,
 'eng_vocab_length': 28998,
 'd_model': 768,
 'd_ff': 2048,
 'd_k': 64,
 'd_v': 64,
 'num_layers': 12,
 'num_heads': 8,
 'start_word': '[SOS]',
 'end_word': '[EOS]',
 'sep_word': '[SEP]',
 'cls_word': '[CLS]',
 'pad_word': '[PAD]',
 'mask_word': '[MASK]'}

In [3]:
# configure device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [4]:
# configure tokenizer
tokenizer = tokenization.FullTokenizer(
    vocab_file='vocab/eng_vocab.txt', do_lower_case=False)
tokenizer.tokenize('I love you')




['I', 'love', 'you']

In [5]:
# define sample dataset
dataset = ['Mr. Cassius crossed the highway, and stopped suddenly.',
        'Something glittered in the nearest red pool before him.',
        'Gold, surely!']
tokenized_text = [tokenizer.tokenize(t) for t in dataset]
tokenized_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokenized_text]
print(tokenized_text)
print(tokenized_ids)

[['Mr', '.', 'Cass', '##ius', 'crossed', 'the', 'highway', ',', 'and', 'stopped', 'suddenly', '.'], ['Something', 'g', '##lit', '##tered', 'in', 'the', 'nearest', 'red', 'pool', 'before', 'him', '.'], ['Gold', ',', 'surely', '!']]
[[1830, 121, 17502, 3287, 3811, 1105, 4085, 119, 1107, 2143, 2842, 121], [4264, 178, 12890, 7657, 1109, 1105, 6832, 1896, 4530, 1198, 1142, 121], [3489, 119, 9931, 108]]


In [6]:
## [text[0], text[1]] positive data
## [text[1], text[2]] positive data
## [text[1], text[0]] negative data
## [text[2], text[0]] negative data

max_pred = 5
token_length = 50

pairs = [(0, 1, True), (1, 2, True), (1, 0, False), (2, 0, False)]

input_ids, segment_ids, masked_ids, masked_poses, isNexts = [], [], [], [], []

cls_idx = tokenizer.convert_tokens_to_ids([config['cls_word']])[0]
sep_idx = tokenizer.convert_tokens_to_ids([config['sep_word']])[0]
pad_idx = tokenizer.convert_tokens_to_ids([config['pad_word']])[0]
mask_idx = tokenizer.convert_tokens_to_ids([config['mask_word']])[0]

for pair in pairs:

    tokens_a, tokens_b = tokenized_ids[pair[0]], tokenized_ids[pair[1]]
    isNext = pair[2]

    input_id = [cls_idx] + tokens_a + [sep_idx] + tokens_b + [sep_idx]
    segment_id = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
    
    n_pred = min(max_pred, max(1, int(round(len(input_id) * 0.15))))
    cand_masked_pos = [i for i, token in enumerate(input_id) if token != cls_idx and token != sep_idx]
    np.random.shuffle(cand_masked_pos)

    masked_id, masked_pos = [], []
    for pos in cand_masked_pos[:n_pred]:
        masked_pos.append(pos)
        masked_id.append(input_id[pos])
        if np.random.rand() < 0.8:
            input_id[pos] = mask_idx
        elif np.random.rand() < 0.5:
            input_id[pos] = np.random.choice(config['eng_vocab_length'])

    if max_pred > n_pred:
        n_pad = max_pred - n_pred
        masked_id.extend([0] * n_pad)
        masked_pos.extend([0] * n_pad)

    n_pad = token_length - len(input_id)
    input_id.extend([pad_idx] * n_pad)
    segment_id.extend([pad_idx] * n_pad)

    input_ids.append(input_id)
    segment_ids.append(segment_id)
    masked_ids.append(masked_id)
    masked_poses.append(masked_pos)
    isNexts.append(isNext)

print('-----------------')
for i in range(len(input_ids)):
    print(f'{input_ids[i]}')
    print(f'{segment_ids[i]}')
    print(f'{masked_ids[i]}')
    print(f'{masked_poses[i]}')
    print(f'{isNexts[i]}')

-----------------
[103, 1830, 121, 17502, 3287, 3811, 1105, 4085, 119, 105, 2143, 2842, 121, 104, 4264, 178, 12890, 7657, 105, 1105, 6832, 1896, 4530, 1198, 1142, 121, 104, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[2143, 1107, 1109, 1105, 0]
[10, 9, 18, 19, 0]
True
[103, 4264, 178, 105, 7657, 1109, 1105, 105, 1896, 4530, 1198, 1142, 3892, 104, 3489, 119, 9931, 108, 104, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[6832, 121, 12890, 0, 0]
[7, 12, 3, 0, 0]
True
[103, 4264, 178, 12890, 105, 1109, 1105, 6832, 5468, 105, 1198, 1142, 21089, 104, 1830, 121, 17502, 3287, 3811, 1105, 4085, 119, 1107, 2143, 2842, 121, 104, 

In [7]:
# convert ids to torch tensor
input_ids_tensor = torch.as_tensor(input_ids, dtype=torch.long).to(device)
segment_ids_tensor = torch.as_tensor(segment_ids, dtype=torch.long).to(device)
masked_ids_tensor = torch.as_tensor(masked_ids, dtype=torch.long).to(device)
masked_poses_tensor = torch.as_tensor(masked_poses, dtype=torch.long).to(device)
isNexts_tensor = torch.as_tensor(isNexts, dtype=torch.long).to(device)

In [8]:
## configure model, optimizer, criterion

pad_index = tokenizer.convert_tokens_to_ids([config['pad_word']])[0]

bert = BertModel(
    vocab_size=config['eng_vocab_length'],
    d_model=config['d_model'],
    d_ff=config['d_ff'], d_k=config['d_k'],
    d_v=config['d_v'], n_heads=config['num_heads'],
    n_layers=config['num_layers'], pad_index=pad_index,
    device=device).to(device)

optimizer = torch.optim.Adam(bert.parameters(), lr=5e-5, betas=(0.9, 0.999), weight_decay=0.01)
criterion = torch.nn.CrossEntropyLoss()
bert

BertModel(
  (tok_embed): Embedding(28998, 768)
  (pos_embed): PositionalEncoding(
    (dropout): Dropout(p=0, inplace=False)
  )
  (seg_embed): Embedding(2, 768)
  (layers): ModuleList(
    (0): EncoderLayer(
      (enc_self_attn): MultiHeadAttention(
        (WQ): Linear(in_features=768, out_features=512, bias=True)
        (WK): Linear(in_features=768, out_features=512, bias=True)
        (WV): Linear(in_features=768, out_features=512, bias=True)
        (linear): Linear(in_features=512, out_features=768, bias=True)
        (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
      (pos_ffn): PoswiseFeedForwardNet(
        (l1): Linear(in_features=768, out_features=2048, bias=True)
        (l2): Linear(in_features=2048, out_features=768, bias=True)
        (relu): GELU()
        (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
    )
    (1): EncoderLayer(
      (enc_self_attn): MultiHeadAttention(
        (WQ): Linear(in_features=768

In [9]:
## training.......
bert_step = 0
for epoch in range(50):

    optimizer.zero_grad()
    logits_lm, logits_clsf, _ = bert(input_ids_tensor, segment_ids_tensor, masked_poses_tensor)
    loss_lm = criterion(logits_lm.transpose(1, 2), masked_ids_tensor)
    loss_clsf = criterion(logits_clsf, isNexts_tensor)
    loss = loss_lm + loss_clsf
    loss.backward()
    optimizer.step()
    
    print(epoch, loss.item())

0 110.25227355957031
1 47.35847854614258
2 34.622554779052734
3 22.38869285583496
4 13.62109375
5 7.63401985168457
6 3.9623594284057617
7 2.758429527282715
8 3.287774085998535
9 3.4298293590545654
10 2.2902426719665527
11 2.5123467445373535
12 2.145576000213623
13 1.2739777565002441
14 0.8557220697402954
15 1.2299394607543945
16 1.0996782779693604
17 1.6305320262908936
18 0.510784387588501
19 0.833055853843689
20 0.6597347259521484
21 0.6508806347846985
22 0.47196927666664124
23 0.45319467782974243
24 0.6112233996391296
25 0.42565232515335083
26 0.42086994647979736
27 0.46858450770378113
28 0.3950214385986328
29 0.3883059322834015
30 0.3893374502658844
31 0.3799106180667877
32 0.3650440275669098
33 0.3586336672306061
34 0.3481409549713135
35 0.33903467655181885
36 0.33206090331077576
37 0.32673773169517517
38 0.32280218601226807
39 0.3178752362728119
40 0.3109605610370636
41 0.3043636083602905
42 0.29847440123558044
43 0.2930338382720947
44 0.2878718376159668
45 0.28281617164611816
46 

In [10]:
## check correctly trained

logits_lm, logits_clsf, _ = bert(input_ids_tensor, segment_ids_tensor, masked_poses_tensor)

print('--------')
print('--------')
print('predict')
print(torch.argmax(logits_lm, axis=2))
print('answer')
print(masked_ids_tensor)
print('--------')

print('predict')
print(torch.argmax(logits_clsf, axis=1))
print('answer')
print(isNexts_tensor)
print('--------')
print('--------')

--------
--------
predict
tensor([[ 2143,  1107,  1109,  1105,     0],
        [ 6832,   121, 12890,     0,     0],
        [  121,  1896,  7657,  4530,     0],
        [ 9931, 17502,   119,     0,     0]])
answer
tensor([[ 2143,  1107,  1109,  1105,     0],
        [ 6832,   121, 12890,     0,     0],
        [  121,  1896,  7657,  4530,     0],
        [ 9931, 17502,   119,     0,     0]])
--------
predict
tensor([1, 1, 0, 0])
answer
tensor([1, 1, 0, 0])
--------
--------
