In [1]:
from transformers import BertTokenizer, BertModel
import torch
import numpy as np
from nltk.tokenize import sent_tokenize
from operator import itemgetter 
import pickle
import warnings
import tqdm
import random
warnings.filterwarnings("ignore")

# Preparation

In [3]:
class SentenceDatabase:
    
    def __init__(self, tokenizer : object, filename : str = None, threshold :int = 100):
        self.tokenizer = tokenizer
        self.threshold = threshold
        
        if filename:
            self.sentences = sent_tokenize(open(filename, "r").read().replace("\n",""))
            self.sentences = [i.lower() for i in self.sentences]
            self.database = self.index_words()
            
            with open('datasets/News/sentences.pickle', 'wb') as handle:
                pickle.dump(self.sentences, handle, protocol=pickle.HIGHEST_PROTOCOL)
                
            with open('datasets/News/database.pickle', 'wb') as handle:
                pickle.dump(self.database, handle, protocol=pickle.HIGHEST_PROTOCOL)
        else:
            with open('datasets/News/sentences.pickle', 'rb') as handle:
                self.sentences = pickle.load(handle)
                self.sentences = [i.lower() for i in self.sentences]
            
            with open('datasets/News/database.pickle', 'rb') as handle:
                self.database = pickle.load(handle)
        
    def index_words(self):
        database = {}
        for i, sentence in enumerate(self.sentences):
            if len(sentence)< self.threshold:
                tokenized_sentence = self.tokenizer.tokenize(sentence)
                for j, token in enumerate(tokenized_sentence):
                    if token not in database:
                        database[token] = []
                    database[token].append(i)
        return database
    
    
    # The only important function            
    def get_sentences(self, word: str):
        token = self.tokenizer.tokenize(word)[0]
        return itemgetter(*self.database[token])(self.sentences)
    
def get_all_words(tokenizer : object):
    train = open("datasets/Processed/train.txt", "r").read().split("\n")
    train = " ".join([x.split(" ")[0].lower() for x in train if x])
    words = tokenizer.tokenize(train)
    test = open("datasets/Processed/test.txt", "r").read().split("\n")
    test = " ".join([x.split(" ")[0].lower() for x in test if x])
    words.extend(tokenizer.tokenize(test))
    valid = open("datasets/Processed/validate.txt", "r").read().split("\n")
    valid = " ".join([x.split(" ")[0].lower() for x in valid if x])
    words.extend(tokenizer.tokenize(valid))
    return words

# Here we load out tokenizer. Because we are using Bert we need to use BertTokenizer. 
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [13]:
# Here we load a sentence database. You can query this database by calling its function get_sentences
# and by passing it a string. The function will return all sentences in which that string occured.
SentenceDB = SentenceDatabase(tokenizer)

In [4]:
# Here we get all the words that are contained in our train, test and validation set so we can build a dictionary.
words = get_all_words(tokenizer)

In [6]:
class Vocab:
    def __init__(self, values, max_size = -1, min_freq = 0, vocabulary = {}):
        self.vocabulary = vocabulary
        self.create_vocab(values, max_size, min_freq)
        
    def create_vocab(self, values, max_size, min_freq):
        vcounter = {}
        
        for i in values:
            if i not in vcounter:
                vcounter[i] = 0
            vcounter[i] += 1
        
        for k, v in sorted(vcounter.items(), key=lambda item: item[1], reverse=True):
            key = len(self.vocabulary)
            if max_size!=key:
                if min_freq!=0:
                    if v>min_freq:
                        self.vocabulary[k] = key
                else:
                    self.vocabulary[k] = key
        
    def tokenize(self, text):
        encoded = []
        for i in text:
            if i in self.vocabulary:
                encoded.append(self.vocabulary[i])
            else:
                encoded.append(self.vocabulary["<UNK>"])
        encoded = torch.tensor(encoded)
        return encoded

# This counts the frequencies of the words and constructs a dictionary. In short we use this class to get
# all the distinct words that appear in our train, test and valid set ordered by frequency.
vocab = Vocab(words, vocabulary = {})

In [6]:
# This class converts a list of sentences into bert embeddings.
class BertBatchEmbedding:
    def __init__(self):
        self.model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True).eval().cuda()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    def transform(self, sentences : list):
        padded_sequence = self.tokenizer.batch_encode_plus(sentences, return_tensors="pt", pad_to_max_length=True)
        with torch.no_grad():
            out = self.model(padded_sequence['input_ids'].cuda(), padded_sequence["attention_mask"].cuda())
        hidden_states = out[2]
        token_embeddings = torch.stack(hidden_states, dim=0)
        return token_embeddings.permute(1,2,0,3)
    
bertbatch = BertBatchEmbedding()

# Creation of embeddings

In [7]:
def get_vector(search, tokenized_sentences, embeddings):
    counter = 0
    vector = None
    for i, tokenized_sentence in enumerate(tokenized_sentences):
        for j, token in enumerate(tokenized_sentence):
            if token==search:
                if vector==None:
                    vector = torch.flatten(embeddings[i,:][j,:])
                else:
                    vector += torch.flatten(embeddings[i,:][j,:])
                counter += 1

        if counter==100:
            break
    return vector/counter


def write(token_embeddings, missing):
    with open('datasets/embeddings/embeddings.pickle', 'wb') as handle:
        pickle.dump(token_embeddings, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    with open('datasets/embeddings/missing.pickle', 'wb') as handle:
        pickle.dump(missing, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [3]:
with open('datasets/embeddings/embeddings.pickle', 'rb') as handle:
    token_embeddings = pickle.load(handle)
    
with open('datasets/embeddings/missing.pickle', 'rb') as handle:
    missing = pickle.load(handle)

In [33]:
len(token_embeddings["the"])

9984

In [11]:
import time

counter = 0
for word in tqdm.tqdm(vocab.vocabulary):
    if word not in token_embeddings:
        try:
            sentences = SentenceDB.get_sentences(word)
            if len(sentences)>100:
                random.shuffle(list(sentences))
                sentences = sentences[:100]
            if sentences!=[]:
                search = tokenizer.encode(word)[1]
                embeddings = bertbatch.transform(sentences)
                tokenized_sentences = tokenizer.batch_encode_plus(sentences)
                vector = get_vector(search, tokenized_sentences["input_ids"], embeddings)
                if counter%1000==0:
                    write(token_embeddings, missing)
                token_embeddings[word] = vector.cpu()
        except KeyboardInterrupt:
            print("Interupted")
            break
        except:
            missing.append(word)
            print(word)
        counter+=1
write(token_embeddings, missing)

 27%|██▋       | 4334/15843 [00:00<00:00, 29634.72it/s]

-
'
"
–
/
prelate
dioceses
ecclesiastical


 40%|████      | 6354/15843 [00:00<00:00, 14219.29it/s]

pietro
theologians
amalgamation


 56%|█████▌    | 8841/15843 [1:02:19<3:56:51,  2.03s/it]

utc


 61%|██████▏   | 9711/15843 [1:31:51<3:18:44,  1.94s/it]

nottinghamshire


 63%|██████▎   | 10025/15843 [1:41:53<3:25:11,  2.12s/it]

vincenzo


 69%|██████▉   | 10926/15843 [2:11:05<2:08:29,  1.57s/it]

html


 72%|███████▏  | 11329/15843 [2:23:16<1:57:27,  1.56s/it]

thence


 75%|███████▌  | 11959/15843 [2:42:03<2:23:48,  2.22s/it]

juris


 77%|███████▋  | 12176/15843 [2:48:23<2:35:54,  2.55s/it]

hz


 77%|███████▋  | 12234/15843 [2:50:14<1:34:09,  1.57s/it]

fledged


 79%|███████▊  | 12452/15843 [2:56:30<1:15:47,  1.34s/it]

genoa


 79%|███████▉  | 12515/15843 [2:58:18<2:09:31,  2.34s/it]

abbess


 79%|███████▉  | 12530/15843 [2:58:35<47:31,  1.16it/s]  

berber


 81%|████████▏ | 12885/15843 [3:08:59<1:10:53,  1.44s/it]

gunners


 82%|████████▏ | 12969/15843 [3:10:52<49:12,  1.03s/it]  

shrewsbury


 83%|████████▎ | 13152/15843 [3:15:50<52:12,  1.16s/it]  

successively


 83%|████████▎ | 13171/15843 [3:16:15<50:55,  1.14s/it]  

johnstone


 83%|████████▎ | 13191/15843 [3:16:51<1:00:46,  1.37s/it]

magdalene


 84%|████████▍ | 13272/15843 [3:19:19<50:03,  1.17s/it]  

1658


 85%|████████▍ | 13460/15843 [3:24:29<1:14:11,  1.87s/it]

gmbh


 85%|████████▌ | 13531/15843 [3:26:40<48:26,  1.26s/it]  

1570


 86%|████████▌ | 13557/15843 [3:27:26<1:34:15,  2.47s/it]

mongol


 86%|████████▌ | 13621/15843 [3:29:15<19:55,  1.86it/s]  

aachen


 87%|████████▋ | 13808/15843 [3:33:54<41:22,  1.22s/it]  

antonia


 88%|████████▊ | 14021/15843 [3:39:47<45:49,  1.51s/it]  

saxons


 91%|█████████ | 14354/15843 [3:48:39<51:54,  2.09s/it]  

1757


 91%|█████████ | 14356/15843 [3:48:40<39:24,  1.59s/it]

1764


 91%|█████████ | 14362/15843 [3:48:47<29:44,  1.21s/it]

ptolemy


 92%|█████████▏| 14507/15843 [3:52:38<1:06:31,  2.99s/it]

tq


 92%|█████████▏| 14518/15843 [3:52:58<45:15,  2.05s/it]  

administratively


 94%|█████████▍| 14934/15843 [4:05:20<33:29,  2.21s/it]  

rochdale


 95%|█████████▌| 15091/15843 [4:09:31<15:20,  1.22s/it]

terminates


 96%|█████████▋| 15252/15843 [4:14:30<17:48,  1.81s/it]

computed


 97%|█████████▋| 15302/15843 [4:16:09<12:25,  1.38s/it]

צ


 98%|█████████▊| 15486/15843 [4:21:14<13:51,  2.33s/it]

vertices


 98%|█████████▊| 15599/15843 [4:24:21<06:03,  1.49s/it]

cyrillic


 99%|█████████▉| 15662/15843 [4:26:17<02:29,  1.21it/s]

amalgamated
navarre


100%|██████████| 15843/15843 [4:31:15<00:00,  1.03s/it]


In [11]:
with open('datasets/embeddings.pickle', 'rb') as handle:
    token_embeddings = pickle.load(handle)

In [8]:
voc = {
    "<PAD>":0,
    "<UNK>":1,
}
vocab = Vocab(words, vocabulary = voc)

In [16]:
vocab.vocabulary

{'<PAD>': 0,
 '<UNK>': 1,
 'the': 2,
 ',': 3,
 '.': 4,
 'to': 5,
 'of': 6,
 'and': 7,
 'in': 8,
 'a': 9,
 'that': 10,
 '’': 11,
 's': 12,
 'is': 13,
 '“': 14,
 '”': 15,
 '-': 16,
 'for': 17,
 'it': 18,
 'on': 19,
 'was': 20,
 'he': 21,
 'with': 22,
 'as': 23,
 'this': 24,
 'be': 25,
 'by': 26,
 'not': 27,
 'his': 28,
 'have': 29,
 'are': 30,
 '##s': 31,
 ':': 32,
 'has': 33,
 'i': 34,
 'from': 35,
 'at': 36,
 'they': 37,
 'who': 38,
 "'": 39,
 'an': 40,
 'said': 41,
 '"': 42,
 'but': 43,
 'we': 44,
 'you': 45,
 '?': 46,
 't': 47,
 'will': 48,
 'trump': 49,
 'or': 50,
 'had': 51,
 ')': 52,
 '(': 53,
 'all': 54,
 'were': 55,
 'about': 56,
 'their': 57,
 'what': 58,
 'one': 59,
 'which': 60,
 'been': 61,
 'no': 62,
 'our': 63,
 'out': 64,
 'there': 65,
 'if': 66,
 '—': 67,
 'would': 68,
 'so': 69,
 'people': 70,
 'do': 71,
 'also': 72,
 'she': 73,
 'when': 74,
 'after': 75,
 'up': 76,
 'her': 77,
 'more': 78,
 'should': 79,
 'us': 80,
 'can': 81,
 'president': 82,
 'him': 83,
 'church': 8

In [13]:
def generate_embedding_matrix(vocab, token_embeddings):
    lenght = len(vocab.vocabulary)
    embeddings = np.random.normal(0, 1, (lenght, len(token_embeddings["the"])))
    for i in vocab.vocabulary:
        if i in token_embeddings:
            values = token_embeddings[i]
            if len(values)!=0:
                embeddings[vocab.vocabulary[i],:] = values
    return embeddings

In [14]:
weights = generate_embedding_matrix(vocab, token_embeddings)

In [16]:
with open('datasets/weights.pickle', 'wb') as handle:
    pickle.dump(weights, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [10]:
with open('datasets/vocab.pickle', 'wb') as handle:
    pickle.dump(vocab.vocabulary, handle, protocol=pickle.HIGHEST_PROTOCOL)