In [5]:
# Import Module
import os
import time
import pickle
import argparse
import math
import warnings
import numpy as np
import pandas as pd
import sentencepiece as spm

from tqdm import tqdm

# Import PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils as torch_utils
from torch import optim
from torch.utils.data import DataLoader


# Import Custom Module
from dataset2 import HanjaDataset, PadCollate
from module import Encoder

In [3]:
# Setting
warnings.simplefilter("ignore", UserWarning)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data Loading
print('Data loading and data spliting...')
with open('/home/jin/joseon_backup/preprocessed_data.pkl', 'rb') as f:
    data = pickle.load(f)
    src_word2id = data['hanja_word2id']
    src_vocab = [k for k in src_word2id.keys()]
    trg_word2id = data['korean_word2id']
    trg_vocab = [k for k in trg_word2id.keys()]
    train_src_list = data['train_hanja_indices']
    train_trg_list = data['train_korean_indices']
    train_add_hanja = data['train_additional_hanja_indices']
    valid_src_list = data['valid_hanja_indices']
    valid_trg_list = data['valid_korean_indices']
    valid_add_hanja = data['valid_additional_hanja_indices']

    src_vocab_num = len(src_vocab)
    trg_vocab_num = len(trg_vocab)

    del data
print('Done!')

Data loading and data spliting...
Done!


In [6]:
h_dataset = HanjaDataset(train_src_list, train_add_hanja, pad_idx=0, mask_idx=4, 
                         min_len=4, src_max_len=150)
h_loader = DataLoader(h_dataset, collate_fn=PadCollate(), drop_last=True,
                batch_size=4)

In [14]:
encoder = Encoder(src_vocab_num, 256, 256, n_layers=6, 
                  pad_idx=0, dropout=0.3, embedding_dropout=0.2)

In [7]:
for src, trg in h_loader:
    break

In [19]:
src.size()

torch.Size([4, 49])

In [17]:
output, hidden = encoder(src.transpose(0, 1))

In [18]:
output.size()

torch.Size([49, 4, 256])

In [21]:
masked_position = trg != 0

In [23]:
output.transpose(0,1)[masked_position].size()

torch.Size([15, 256])

In [26]:
t = output.transpose(0,1)[masked_position]

In [27]:
q = nn.Linear(256, 32000)

In [28]:
qq = q(output.transpose(0,1)[masked_position])

In [30]:
qq.size()

torch.Size([15, 32000])

In [32]:
trg[masked_position].size()

torch.Size([15])

In [33]:
trg[masked_position]

tensor([  126, 25009,     5,  4672,  2700,    13,     5,     7,   896,   537,
            5,  1253, 12238,     5,  2059])

In [12]:
masked_position = trg != 0

In [13]:
masked_position

tensor([[ True,  True, False, False, False, False, False,  True, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False,  True, False, False, False, False, False,
         False,  True, False, False, False,  True, False, False, False],
        [False, False, False, False, False,  True, False,  True, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False],
        [ True, False,  True,  True,  True, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False,

In [9]:
trg

tensor([[  126, 25009,     0,     0,     0,     0,     0,     5,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,  4672,     0,     0,     0,     0,     0,
             0,  2700,     0,     0,     0,    13,     0,     0,     0],
        [    0,     0,     0,     0,     0,     5,     0,     7,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  896,     0,   537,     5,  1253,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,