In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.insert(0, '../')

In [3]:
from consonant.model.tokenization import NGRAMTokenizer

In [4]:
sentence = ["내가 너 엄청 좋아해?!이 기호는 불가능$^* 잘릴 예정인 텍스트", "너도 나 좋아하니?"]
tokenizer = NGRAMTokenizer(3)

print("Num Head Vocab:", len(tokenizer.head2id))
print("Num Midtail Vocab:", len(tokenizer.midtail2id))

result = tokenizer.encode(sentence, max_char_length=30, return_attention_mask=True) #, return_tensors='pt')
head_ids = result['head_ids']
midtail_ids = result['midtail_ids']
attention_masks = result['attention_masks']

print('\n Encoding Example :', sentence[0] )
print("=========================")

print("Head Consonant ID")
print("0: [PAD], 1: [CLS], 2: [SEP] \n")
print(head_ids[0])

print()
print("Midtail Consonant ID")
print(midtail_ids[0])

print()
print("Attention Mask")
print(attention_masks[0])

print('\n Decoding Example')
print("=========================")
print("Unknown consonant replaced to @\n")

result = tokenizer.decode_sent(head_ids[0], midtail_ids[0])
print(result)



Num Head Vocab: 17579
Num Midtail Vocab: 589

 Encoding Example : 내가 너 엄청 좋아해?!이 기호는 불가능$^* 잘릴 예정인 텍스트
Head Consonant ID
0: [PAD], 1: [CLS], 2: [SEP] 

[    1   244  6269  4744   237  6105   492 12717 14218   515 13340 12824
 16985  2153  3175 12178   210  5394 17137  6101   374  9658  4975  6249
  4221  4215  4078   509 13159     2]

Midtail Consonant ID
[-100   29    1    0  113    0  129  134    0  252    1   29    0    0
  561    0  561  225  509    0  373    1  526    0    0    0    0    9
  569 -100]

Attention Mask
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]

 Decoding Example
Unknown consonant replaced to @

내가 너 엄청 좋아해?!이 기호는 불가능 잘릴


### Ignore midtail ooutput for non-korean character   
During the decoding, if the model doesn't predict empty midtail consonant,  
it might break the clean decoding.   
Therefore, in case of non-korean head consonant, ignore the midtail output and decode with the head.  

In [6]:
midtail_ids[0][3] = 999
print()
print("Perturbed Midtail Consonant ID")
print('->', midtail_ids[0])

result = tokenizer.decode_sent(head_ids[0], midtail_ids[0])
print("=========================")
print(result) # You can see that space remains with space


Perturbed Midtail Consonant ID
-> [-100   29    1  999  113    0  129  134    0  252    1   29    0    0
  561    0  561  225  509    0  373    1  526    0    0    0    0    9
  569 -100]
내가 너 엄청 좋아해?!이 기호는 불가능 잘릴
