# 2주차 BERT I/O만들기

https://github.com/huggingface/tokenizers

In [1]:
from tokenizers import Tokenizer
from tokenizers.models import BPE

In [2]:
tokenizer = Tokenizer(BPE())

In [3]:
filepath = "/home/long8v/torch_study/paper/file/petitions_splited_mecab/petitions_splited_mecab.txt"

In [154]:
from tokenizers.trainers import BpeTrainer

trainer = BpeTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
tokenizer.train(files=[filepath], trainer=trainer)

In [230]:
with open(filepath, 'r') as f:
    corpus = f.readlines()

In [231]:
corpus_dev = corpus[:100]

In [463]:
output = tokenizer.encode_batch(['안녕하세요?', '반갑습니다'])
output[0].tokens

['안녕', '하', '세요', '?']

In [464]:
from torch.utils.data import Dataset, DataLoader

In [465]:
import random
import torch

In [466]:
vocab = tokenizer.get_vocab()
random_token = random.choice(list(vocab))
random_ids = vocab[random_token]
random_ids, random_token

(7557, ' 빼')

In [467]:
tokenizer.get_vocab()['<mask>']

30000

In [468]:
import math

In [469]:
class BERT_input:
    def __init__(self, vocab, encoded, masking_ratio=0.15, mask='<mask>'):
        self.vocab = vocab
        self.tokens = encoded.tokens
        self.ids = encoded.ids
        self.tokens_idx = [_ for _ in range(len(self.tokens))]
        self.masking_ratio = masking_ratio
        self.get_mask_token()
        self.mask = mask
        print(self.mask_token)
        self.mask_idx = [1 if _ in self.mask_token else 0 for _ in range(len(self.tokens))]
        self.replaced_token = self.get_replaced_tokens()
        self.replaced_idx = [vocab[tok] for tok in self.replaced_token]
        
    def get_mask_token(self):
        def random_choice_with_prob(lst, p):
            k = round(len(lst) * p)
            return random.sample(lst, k=k)
        
        mask_token = random_choice_with_prob(self.tokens_idx, p=self.masking_ratio)
        print(len(mask_token))
        self.mask_token = mask_token[:]
        self.mask_mask = random_choice_with_prob(mask_token, p=0.8)
        for m in self.mask_mask:
            mask_token.remove(m)
        self.mask_replace = random_choice_with_prob(mask_token, p=0.5)
        for m in self.mask_replace:
            mask_token.remove(m)
        self.mask_remain = mask_token
    
    def get_replaced_tokens(self):      
        # 바꾸는
        return [self.replace_token(idx, token) for idx, token in enumerate(self.tokens)]
  
    def replace_token(self, idx, token):
        if idx in self.mask_mask:
            return self.mask
        elif idx in self.mask_replace:
            replaced = self.get_random_token()
            return replaced
        elif idx in self.mask_remain:
            return token
        return token

    def get_random_token(self):
        random_token = random.choice(list(self.vocab))
        return random_token

In [470]:
math.ceil(4*0.5)

2

In [471]:
output[0]

Encoding(num_tokens=4, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])

In [472]:
bt = BERT_input(tokenizer.get_vocab(), output[0])

1
[0]


In [473]:
bt.tokens, bt.ids, bt.replaced_token, bt.replaced_idx

(['안녕', '하', '세요', '?'],
 [3475, 2458, 19544, 20],
 ['<mask>', '하', '세요', '?'],
 [30000, 2458, 19544, 20])

In [474]:
class BERT_Dataset(Dataset):
    def __init__(self, corpus, tokenizer):
        self.corpus = corpus
        self.tokenizer = tokenizer
        self.vocab = tokenizer.get_vocab()
    def __len__(self, ):
        return len(self.corpus)
    
    def __getitem__(self, idx):
        sample = self.corpus[idx]
        output = tokenizer.encode(sample)
        return BERT_input(self.vocab, output, 0.5)
        

In [475]:
bd = BERT_Dataset(corpus_dev, tokenizer)
for _ in bd:
    print(_.tokens)
    print(_.ids)
    print(_.mask_idx)
    print(_.replaced_token)
    print(_.replaced_idx)
    break

9
[15, 1, 12, 3, 17, 16, 7, 11, 0]
['국민 과 ', '소통', ' 하 시 고 ', '자유', ' 롭 고 ', '행복 한 ', '나라 를 만들', ' 기 위해 ', '힘', '쓰 고 ', '계신', ' 대통령', ' 께 ', '존경', ' 과 ', '찬', '사 를 ', '올립니다 .\n']
[12278, 6024, 4626, 3348, 17495, 9887, 19397, 3776, 2594, 5552, 6034, 2892, 25545, 4836, 2679, 2101, 2883, 7543]
[1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1]
['<mask>', '<mask>', ' 하 시 고 ', '<mask>', ' 롭 고 ', '행복 한 ', '나라 를 만들', ' 기 위해 ', '힘', '쓰 고 ', '계신', '<mask>', '<mask>', '존경', ' 과 ', '<mask>', '<mask>', ' 위법']
[30000, 30000, 4626, 30000, 17495, 9887, 19397, 3776, 2594, 5552, 6034, 30000, 30000, 4836, 2679, 30000, 30000, 29047]


In [476]:
for _ in DataLoader(bd):
    print(_)
    break

9
[4, 17, 12, 6, 14, 3, 13, 0, 11]


TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class '__main__.BERT_input'>