# 0. Load required modules.

In [1]:
import datasets
import spacy
from transformers import AutoTokenizer
from tqdm import tqdm_notebook
import pickle

from transformers import BertTokenizer, BertModel, BertConfig, BertForSequenceClassification
import torch
import numpy as np

# 1. Load tokenizer and dependency extraction module.

In [2]:
nlp = spacy.load('en_core_web_sm')
vocab_file = '/home/skhong/WordImportance/bert/qnli/vocab.txt'
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# 2. Load Dataset
- Preprocessed dataset used in this study.

In [26]:
dataset = datasets.Dataset.from_file("/home/skhong/.cache/huggingface/datasets/glue/qnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-0295a6411edbdafe.arrow")

## 2.1 Dataset Structure
- Since the dataset consists of two sentences, namely a premise and a hypothesis, there is an assumption that a [SEP] token is included between the two sentences. 
- Therefore, it is necessary to extract relationships between words within the sentences.

In [27]:
dataset

Dataset(features: {'question': Value(dtype='string', id=None), 'sentence': Value(dtype='string', id=None), 'label': ClassLabel(num_classes=2, names=['entailment', 'not_entailment'], names_file=None, id=None), 'idx': Value(dtype='int32', id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'tfidf_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}, num_rows: 5463)

In [29]:
len(dataset['tfidf_ids'][0])

512

In [10]:
dataset['sentence2'][0]

TypeError: 'BertTokenizer' object does not support item assignment

# 3. Random_token_select Function
- A function that randomly selects one token from tokens excluding tokens with dependencies.

In [30]:
import random

def random_value_except(lst, excluded_value):
    filtered_list = [item for item in lst if item != excluded_value]

    if filtered_list:
        random_value = random.choice(filtered_list)
        return random_value
    else:
        return None


# 4. Dataset Generater
- Compare tokenization results of two tokenizers, identify sentences with inter-token dependencies, extract token positions within those sentences, as well as positions of tokens without inter-token dependencies.

In [31]:
data_pos = []
data_set1 = []
data_set2 = []

shuffle_index = [i for i in range(len(dataset['question']))]
random.shuffle(shuffle_index)

for i in tqdm_notebook(shuffle_index):    
    y_temp = []
    text1 = dataset['sentence'][i]
    doc1 = nlp(text1)
    tokens1_1 = [d.text for d in doc1]
    tokens2_1 = tokenizer.tokenize(text1)
    
    text2 = dataset['question'][i]
    doc2 = nlp(text2)
    tokens1_2 = [d.text for d in doc2]
    tokens2_2 = tokenizer.tokenize(text2)  
    
    input_ids = torch.Tensor([dataset['input_ids'][i]]).type(torch.int32)
    token_type_ids = torch.Tensor([dataset['token_type_ids'][i]]).type(torch.int32)
    tfidf_ids = torch.Tensor([dataset['tfidf_ids'][i]]).type(torch.int32)
    # 
    
    for token in doc1:
        if (token.text in tokens2_1) and (token.head.text in tokens2_1):
            random_numbers = [ii+1 for ii in range(len(tokens1_1))]
            i_pos = tokens2_1.index(token.text) + 1
            j_pos = tokens2_1.index(token.head.text) + 1
            j_random_pos = random_value_except(random_numbers, j_pos)
        
            data_pos.append((input_ids, token_type_ids, tfidf_ids))
            data_set1.append((i_pos, j_pos))
            data_set2.append((i_pos, j_random_pos))
            
            if len(data_pos) % 20 == 0:
                with open('data_pos.pickle', 'wb') as f:
                    pickle.dump(data_pos, f, pickle.HIGHEST_PROTOCOL)
                with open('data_set1.pickle', 'wb') as f:
                    pickle.dump(data_set1, f, pickle.HIGHEST_PROTOCOL)
                with open('data_set2.pickle', 'wb') as f:
                    pickle.dump(data_set2, f, pickle.HIGHEST_PROTOCOL)
                print(len(data_pos))
                
            if len(data_pos) >= 2000:
                break
                
            break
                
    if len(data_pos) >= 2000:
        break
            
    for token in doc2:
        if (token.text in tokens2_2) and (token.head.text in tokens2_2):
            random_numbers = [ii+1 for ii in range(len(tokens1_1)+1, len(tokens1_1)+1+len(tokens1_2))]
            i_pos = tokens2_2.index(token.text) + 2 + len(tokens2_1)
            j_pos = tokens2_2.index(token.head.text) + 2 + len(tokens2_1)
            j_random_pos = random_value_except(random_numbers, j_pos)
        
            data_pos.append((input_ids, token_type_ids, tfidf_ids))
            data_set1.append((i_pos, j_pos))
            data_set2.append((i_pos, j_random_pos))
            
            if len(data_pos) % 20 == 0:
                with open('data_pos.pickle', 'wb') as f:
                    pickle.dump(data_pos, f, pickle.HIGHEST_PROTOCOL)
                with open('data_set1.pickle', 'wb') as f:
                    pickle.dump(data_set1, f, pickle.HIGHEST_PROTOCOL)
                with open('data_set2.pickle', 'wb') as f:
                    pickle.dump(data_set2, f, pickle.HIGHEST_PROTOCOL)
                print(len(data_pos))
                
            if len(data_pos) >= 2000:
                break
        break

    if len(data_pos) >= 2000:
        break
    

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i in tqdm_notebook(shuffle_index):


HBox(children=(FloatProgress(value=0.0, max=5463.0), HTML(value='')))

20
40
60
80
100
120
140
160
180
200
220
240
260
280
300
320
340
360
380
400
420
440
460
480
500
520
540
560
580
600
620
640
660
680
700
720
740
760
780
800
820
840
860
880
900
920
940
960
980
1000
1020
1040
1060
1080
1100
1120
1140
1160
1180
1200
1220
1240
1260
1280
1300
1320
1340
1360
1380
1400
1420
1440
1460
1480
1500
1520
1540
1560
1580
1600
1620
1640
1660
1680
1700
1720
1740
1760
1780
1800
1820
1840
1860
1880
1900
1920
1940
1960
1980
2000



In [32]:
import pickle
with open('data_pos.pickle', 'wb') as f:
    pickle.dump(data_pos, f, pickle.HIGHEST_PROTOCOL)
with open('data_set1.pickle', 'wb') as f:
    pickle.dump(data_set1, f, pickle.HIGHEST_PROTOCOL)
with open('data_set2.pickle', 'wb') as f:
    pickle.dump(data_set2, f, pickle.HIGHEST_PROTOCOL)