In [1]:
# Import Module
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="3"
import math
import time
import pickle
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm

# Import PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils as torch_utils
from torch import optim
from torch.utils.data import DataLoader

# Import Custom Module
from translation.dataset import CustomDataset, PadCollate
from translation.model import Transformer
from translation.optimizer import Ralamb, WarmupLinearSchedule
from translation.rnn_model import Encoder, Decoder, Seq2Seq
from named_entity_recognition.model import NER_model
from utils import accuracy

In [2]:
parser = argparse.ArgumentParser(description='NMT argparser')
parser.add_argument('--save_path', default='./save2', 
                    type=str, help='path of data pickle file (train)')
parser.add_argument('--resume', action='store_true',
                    help='If not store, then training from scratch')
parser.add_argument('--baseline', action='store_true',
                    help='If not store, then training from Dynamic Word Embedding')
parser.add_argument('--model_setting', default='transformer', choices=['transformer', 'rnn'],
                    type=str, help='Model Selection; transformer vs rnn')
parser.add_argument('--pad_idx', default=0, type=int, help='pad index')
parser.add_argument('--bos_idx', default=1, type=int, help='index of bos token')
parser.add_argument('--eos_idx', default=2, type=int, help='index of eos token')
parser.add_argument('--unk_idx', default=3, type=int, help='index of unk token')

parser.add_argument('--min_len', type=int, default=4, help='Minimum Length of Sentences; Default is 4')
parser.add_argument('--max_len', type=int, default=300, help='Max Length of Source Sentence; Default is 150')

parser.add_argument('--num_epoch', type=int, default=10, help='Epoch count; Default is 10')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size; Default is 48')
parser.add_argument('--lr', type=float, default=5e-5, help='Learning rate; Default is 5e-5')
parser.add_argument('--lr_decay', type=float, default=0.5, help='Learning rate decay; Default is 0.5')
parser.add_argument('--lr_decay_step', type=int, default=2, help='Learning rate decay step; Default is 5')
parser.add_argument('--grad_clip', type=int, default=5, help='Set gradient clipping; Default is 5')
parser.add_argument('--w_decay', type=float, default=1e-6, help='Weight decay; Default is 1e-6')

parser.add_argument('--d_model', type=int, default=512, help='Hidden State Vector Dimension; Default is 512')
parser.add_argument('--d_embedding', type=int, default=256, help='Embedding Vector Dimension; Default is 256')
parser.add_argument('--n_head', type=int, default=8, help='Multihead Count; Default is 256')
parser.add_argument('--dim_feedforward', type=int, default=512, help='Embedding Vector Dimension; Default is 512')
parser.add_argument('--num_encoder_layer', default=8, type=int, help='number of encoder layer')
parser.add_argument('--num_decoder_layer', default=8, type=int, help='number of decoder layer')
parser.add_argument('--dropout', type=float, default=0.3, help='Dropout Ratio; Default is 0.5')

parser.add_argument('--print_freq', type=int, default=300, help='Print train loss frequency; Default is 100')
args = parser.parse_args(list())

In [3]:
args.d_model=768
args.dim_feedforward=2048

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#===================================#
#============Data Load==============#
#===================================#

print('Data Load & Setting!')
with open(os.path.join(args.save_path, 'nmt_processed.pkl'), 'rb') as f:
    data_ = pickle.load(f)
    hj_train_indices = data_['hj_train_indices']
    hj_test_indices = data_['hj_test_indices']
    kr_train_indices = data_['kr_train_indices']
    kr_test_indices = data_['kr_test_indices']
    king_train_indices = data_['king_train_indices']
    king_test_indices = data_['king_test_indices']
    hj_word2id = data_['hj_word2id']
    hj_id2word = data_['hj_id2word']
    kr_word2id = data_['kr_word2id']
    kr_id2word = data_['kr_id2word']
    src_vocab_num = len(hj_word2id.keys())
    trg_vocab_num = len(kr_word2id.keys())
    del data_

#===================================#
#========DataLoader Setting=========#
#===================================#

dataset_dict = {
    'train': CustomDataset(hj_train_indices, kr_train_indices, king_train_indices,
                        min_len=args.min_len, max_len=args.max_len),
    'valid': CustomDataset(hj_test_indices, kr_test_indices, king_test_indices,
                        min_len=args.min_len, max_len=args.max_len)
}
dataloader_dict = {
    'train': DataLoader(dataset_dict['train'], collate_fn=PadCollate(), drop_last=True,
                        batch_size=args.batch_size, shuffle=True, pin_memory=True),
    'valid': DataLoader(dataset_dict['valid'], collate_fn=PadCollate(), drop_last=True,
                        batch_size=args.batch_size, shuffle=True, pin_memory=True)
}

#====================================#
#==========DWE Results Open==========#
#====================================#

with open(os.path.join(args.save_path, 'emb_mat.pkl'), 'rb') as f:
    emb_mat = pickle.load(f)

#===================================#
#===========Model Setting===========#
#===================================#

print("Build model")
if args.model_setting == 'transformer':
    model = Transformer(emb_mat, hj_word2id, src_vocab_num, trg_vocab_num, pad_idx=args.pad_idx, bos_idx=args.bos_idx, 
                eos_idx=args.eos_idx, max_len=args.max_len,
                d_model=args.d_model, d_embedding=args.d_embedding, n_head=args.n_head, 
                dim_feedforward=args.dim_feedforward, dropout=args.dropout,
                num_encoder_layer=args.num_encoder_layer, num_decoder_layer=args.num_decoder_layer,
                baseline=args.baseline, device=device)
elif args.model_setting == 'rnn':
    encoder = Encoder(src_vocab_num, args.d_embedding, args.d_model, 
                    emb_mat, hj_word2id, n_layers=args.num_encoder_layer, 
                    pad_idx=args.pad_idx, dropout=args.dropout)
    decoder = Decoder(args.d_embedding, args.d_model, trg_vocab_num, n_layers=args.num_decoder_layer, 
                    pad_idx=args.pad_idx, dropout=args.dropout)
    model = Seq2Seq(encoder, decoder, device)
else:
    raise Exception('Model error')

if args.resume:
    model_ner = NER_model(emb_mat=emb_mat, word2id=hj_word2id, pad_idx=args.pad_idx, bos_idx=args.bos_idx, eos_idx=args.eos_idx, max_len=args.max_len,
                    d_model=args.d_model, d_embedding=args.d_embedding, n_head=args.n_head,
                    dim_feedforward=args.dim_feedforward, n_layers=args.num_encoder_layer, dropout=args.dropout,
                    crf_loss=args.crf_loss, device=device)
    model_ner.load_state_dict(torch.load(os.path.join(args.save_path, 'ner_model_False.pt')))
    model.transformer_encoder.load_state_dict(model_ner.transformer_encoder.state_dict())
    for param in model.transformer_encoder.parameters():
        param.requires_grad = False
print("Total Parameters:", sum([p.nelement() for p in model.parameters()]))

Data Load & Setting!
Build model
Total Parameters: 65091008


In [4]:
model.load_state_dict(torch.load('./save2/nmt_model_False.pt'))
model = model.to(device)
model = model.eval()

In [5]:
total_trg_list = list()
total_pred_list = list()

for src, trg, king_id in tqdm(dataloader_dict['valid']):
    # Sourcen, Target sentence setting
    label_sequences = trg.to(device, non_blocking=True)
    input_sequences = src.to(device, non_blocking=True)
    king_id = king_id.to(device, non_blocking=True)

    non_pad = label_sequences != args.pad_idx
    trg_sequences_target = label_sequences[non_pad].contiguous().view(-1)

    if args.model_setting == 'transformer':
        # Target Masking
        tgt_mask = model.generate_square_subsequent_mask(label_sequences.size(1))
        tgt_mask = tgt_mask.to(device, non_blocking=True)
        tgt_mask = tgt_mask.transpose(0, 1)
    break

    # Model / Calculate loss
    with torch.no_grad():
        if args.model_setting == 'transformer':
            predicted = model(input_sequences, label_sequences, king_id, tgt_mask, non_pad)
    start = 0
    trg_list = list()
    pred_list = list()
    for end_ix in torch.where(trg_sequences_target==2)[0]:
        trg_list.append(trg_sequences_target[start:end_ix+1].tolist())
        _, pred = predicted.topk(1, 1, True, True)
        pred_list.append(pred.squeeze(1)[start:end_ix+1].tolist())
        start=end_ix+1
    total_trg_list.extend(trg_list)
    total_pred_list.extend(pred_list)

  0%|          | 0/63390 [00:00<?, ?it/s]


In [10]:
input_sequences

tensor([[   1,   13,  530,  112,   20,   13,  530,  149, 6633,   23,   24, 3203,
          296,   56,    3,    6,  273,   68,  129, 4459,  123,  517,  890,  150,
          306,  360,   41,    3,    6, 4459, 1443,  123,  929,  243,  150,  306,
          360,   41,    3,    6,   68, 3081, 3252, 2496,  567,  119, 3919,    3,
            6,   15, 1555,   89,  727,    3,    6,  267,  534,  188,  169,  148,
           37,    3,    6,  151,   56,    3,    6,    3,   87,  448,  214,   34,
           64,    3,    6,  125,   12,   51, 1061,   34,  385,  895,  154, 1860,
          420,  658,    3,    3,    2]], device='cuda:0')

In [11]:
king_id

tensor([[18]], device='cuda:0')

In [6]:
encoder_out = model.src_embedding(input_sequences, king_id).transpose(0, 1)
src_key_padding_mask = (input_sequences == model.pad_idx)
predicted = torch.LongTensor([[model.bos_idx]]).repeat(encoder_out.size(1), 1).to(device)

In [7]:
encoder_out = model.transformer_encoder(encoder_out, src_key_padding_mask=src_key_padding_mask)

In [8]:
for _ in tqdm(range(model.max_len)):
    tgt_mask = model.generate_square_subsequent_mask(predicted.size(1))
    tgt_mask = tgt_mask.to(device, non_blocking=True).transpose(0, 1)

    tgt_key_padding_mask = (predicted == model.pad_idx)
    decoder_out = model.trg_embedding(predicted).transpose(0, 1)

    for i in range(len(model.decoders)):            
        decoder_out = model.decoders[i](decoder_out, encoder_out, tgt_mask=tgt_mask, 
                                        memory_key_padding_mask=src_key_padding_mask,
                                        tgt_key_padding_mask=tgt_key_padding_mask)

    decoder_out = F.gelu(model.trg_output_linear(decoder_out[-1]))
    y_pred = model.trg_output_linear2(decoder_out)
    y_pred_id = y_pred.max(dim=1)[1].unsqueeze(1)

    predicted = torch.cat([predicted, y_pred_id], dim=1)

  1%|▏         | 4/300 [00:00<00:09, 30.69it/s]

0
1
2
3
4
5
6
7


  4%|▍         | 12/300 [00:00<00:09, 29.32it/s]

8
9
10
11
12


  6%|▌         | 18/300 [00:00<00:10, 27.09it/s]

13
14
15
16
17
18


 10%|█         | 30/300 [00:00<00:07, 37.53it/s]

19
20
21
22
23
24
25
26
27
28
29
30
31
32


 15%|█▍        | 44/300 [00:01<00:05, 48.85it/s]

33
34
35
36
37
38
39
40
41
42
43
44
45
46


 19%|█▉        | 58/300 [00:01<00:04, 57.40it/s]

47
48
49
50
51
52
53
54
55
56
57
58
59
60


 22%|██▏       | 65/300 [00:01<00:04, 56.21it/s]

61
62
63
64
65
66
67


 24%|██▎       | 71/300 [00:01<00:06, 37.79it/s]

68
69
70
71


 25%|██▌       | 76/300 [00:01<00:08, 27.56it/s]

72
73
74
75


 27%|██▋       | 80/300 [00:02<00:09, 23.14it/s]

76
77
78
79


 28%|██▊       | 84/300 [00:02<00:10, 20.87it/s]

80
81
82
83


 31%|███       | 92/300 [00:02<00:08, 25.07it/s]

84
85
86
87
88
89
90
91


 32%|███▏      | 96/300 [00:02<00:07, 27.57it/s]

92
93
94
95
96
97
98


 35%|███▍      | 104/300 [00:03<00:06, 28.27it/s]

99
100
101
102
103
104


 37%|███▋      | 112/300 [00:03<00:06, 31.07it/s]

105
106
107
108
109
110
111
112


 40%|████      | 120/300 [00:03<00:05, 33.01it/s]

113
114
115
116
117
118
119
120


 45%|████▍     | 134/300 [00:03<00:03, 44.25it/s]

121
122
123
124
125
126
127
128
129
130
131
132
133
134


 49%|████▉     | 148/300 [00:03<00:02, 53.58it/s]

135
136
137
138
139
140
141
142
143
144
145
146
147
148


 54%|█████▎    | 161/300 [00:04<00:02, 56.08it/s]

149
150
151
152
153
154
155
156
157
158
159
160


 58%|█████▊    | 175/300 [00:04<00:02, 60.63it/s]

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178


 61%|██████    | 182/300 [00:04<00:03, 33.11it/s]

179
180
181
182


 62%|██████▏   | 187/300 [00:05<00:04, 25.12it/s]

183
184
185
186


 64%|██████▎   | 191/300 [00:05<00:05, 21.51it/s]

187
188
189
190


 65%|██████▌   | 195/300 [00:05<00:04, 22.86it/s]

191
192
193
194
195
196
197


 66%|██████▋   | 199/300 [00:06<00:07, 13.86it/s]

198
199


 70%|██████▉   | 209/300 [00:06<00:05, 16.10it/s]

200
201
202
203
204
205
206
207
208
209
210
211
212
213


 74%|███████▍  | 222/300 [00:06<00:03, 25.21it/s]

214
215
216
217
218
219
220
221
222
223
224
225


 79%|███████▊  | 236/300 [00:06<00:01, 36.83it/s]

226
227
228
229
230
231
232
233
234
235
236
237
238
239


 81%|████████▏ | 244/300 [00:07<00:01, 43.14it/s]

240
241
242
243
244
245
246


 84%|████████▎ | 251/300 [00:07<00:01, 30.36it/s]

247
248
249
250
251
252
253
254


 85%|████████▌ | 256/300 [00:07<00:01, 23.87it/s]

255
256
257
258


 88%|████████▊ | 264/300 [00:08<00:01, 22.56it/s]

259
260
261
262
263
264
265
266


 89%|████████▉ | 268/300 [00:08<00:02, 15.72it/s]

267
268
269
270
271
272
273


 92%|█████████▏| 275/300 [00:08<00:01, 16.96it/s]

274
275


 95%|█████████▌| 285/300 [00:09<00:01, 13.98it/s]

276
277
278
279
280
281
282
283
284
285
286
287
288


100%|██████████| 300/300 [00:09<00:00, 30.60it/s]

289
290
291
292
293
294
295
296
297
298
299





In [9]:
predicted

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1