In [1]:
from transformers import BertTokenizer
import pickle
from tqdm import tqdm
import sys
import unidecode

sys.path.append("../../../src")
from application_utils.text_utils import get_token_list

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [3]:
with open('../../../downloads/sst_data/sst_trees.pickle', 'rb') as handle:
    sst_trees = pickle.load(handle)

In [4]:
def map_word_to_token_index(words, tokens):
    t=0
    token = tokens[t].replace("#", "")

    token_to_word_map = {}
    word_to_token_map = {}
    for w, word in enumerate(words):
        tmp_word = str(word)

        i = 0
        while(tmp_word):
            tmp_word = "".join(list(word)[i:])

            if tmp_word.startswith(token):   
                token_to_word_map[t] = w 
                if w not in word_to_token_map:
                    word_to_token_map[w] = []
                word_to_token_map[w].append(t)
                
                i += len(token)
                t += 1
                if t >= len(tokens):
                    break

                token = tokens[t].replace("##", "")
            else:
                i += 1

    assert(t == len(tokens))
    assert(w == len(words)-1)
    return token_to_word_map, word_to_token_map

In [5]:

index = 0
batch_size = 20

splits = {}
count = 0
for split in ["test"]:
    token_trees = []

    for index in tqdm(range(len(sst_trees[split]))):

        sentence = sst_trees[split][index][0]
        subtrees = sst_trees[split][index][2]
        sen_len = len(sentence.split())

        tokens = get_token_list(sentence, tokenizer)[1:-1]

        words = unidecode.unidecode(sentence.lower()).split()
        try:
            token_to_word_map, word_to_token_map = map_word_to_token_index(words, list(tokens))
        except:
            print(tokens)
            print(words)
            assert(False)

        filtered_subtrees = []

        for subtree in subtrees:
            if subtree["phrase"] == sentence: continue #excludes a phrase and phrase label if that phrase is the original sentence
            phrase_list = subtree["phrase"].split()
            #if len(phrase_list) == 1: continue #excludes phrases that only consist of a single word before tokenization
        
            pos = subtree["position"]
                            
            phrase_span_tokenspace = ( min(word_to_token_map[pos]), max(word_to_token_map[pos + len(phrase_list)-1]))
            first_token_index, last_token_index = phrase_span_tokenspace
    
            #if last_token_index - first_token_index == 0: continue #excludes phrases that only consist of a single token
            
            count +=1
    
            filtered_subtrees.append({"span": phrase_span_tokenspace, "label": subtree["label"], "phrase": subtree["phrase"], "position": pos })

        token_trees.append({"sentence": sentence, "tokens": tokens, "subtrees": filtered_subtrees })
    splits[split] = token_trees
    
splits["note"] = "the phrase spans need to be shifted right based on which methods index SEP and CLS. the spans mean (first token index, last token index)"

100%|██████████| 2210/2210 [00:02<00:00, 775.08it/s]


In [6]:
with open('text_data/subtree_token_pairphrase_and_greater.pickle', 'wb') as handle:
    pickle.dump(splits, handle, protocol=pickle.HIGHEST_PROTOCOL)