In [70]:
import gensim as gensim
from sqlalchemy import Column, Integer, String, ARRAY, Float
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy import create_engine

uri = f"postgresql://erik:erik@localhost:5432/erik_db"
engine = create_engine(uri)
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)

Base = declarative_base()

class Sentence(Base):
    __tablename__ = 'sentences'
    id = Column(Integer, primary_key=True)
    text = Column(String(500), nullable=False)
    vector = Column(ARRAY(Float), nullable=False)

    def __lt__(self, other):
        return True

In [19]:
# Drop the existing "sentences" table
Base.metadata.drop_all(engine)

# Recreate the "sentences" table
Base.metadata.create_all(engine)

In [10]:
import re
from nltk import sent_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer


class DataAdder:

    def __init__(self):
        self.THRESHOLD = 60

    def clean_text(self, text):

        text = self.remove_brackets(text)

        # Tokenize the text into sentences
        sentences = sent_tokenize(text)

        if len(sentences) == 0:
            return None

        if len(sentences) == 1:
            if len(sentences[0]) < self.THRESHOLD:
                return None

        # Ensure one space between each word in each sentence
        formatted_sentences = [' '.join(sentence.split()) for sentence in sentences]
        return formatted_sentences

    def remove_brackets(self, string):
        # Remove brackets and their contents (including nested brackets)
        pattern = r'\([^()]*\)|\[[^\]]*\]'
        while re.search(pattern, string):
            string = re.sub(pattern, '', string)
        return string

    def read_file(self, name):
        all_data = []
        with open(name, "r", encoding="utf-8") as f:
            raw_text = f.readlines()
            for line in raw_text:
                text = self.clean_text(line)
                if text is None:
                    continue
                all_data.extend(text)
        return all_data



da = DataAdder()
all_data = da.read_file("data.txt")
print(len(all_data))

99


In [11]:
from transformers import AutoTokenizer, AutoModel

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [75]:
import torch


with torch.no_grad():
    with Session() as session:
        for sentence in all_data:
            tokens = tokenizer.encode(sentence, add_special_tokens=True)
            input_ids = torch.tensor([tokens])
            outputs = model(input_ids)
            vector = outputs.last_hidden_state.mean(dim=1).squeeze()

            s = Sentence()
            s.text = sentence
            s.vector = vector.tolist()
            session.add(s)
        session.commit()

67
160
106
128
124
140
174
86
171
192
98
283
111
216
64
96
89
121
111
70
115
184
77
65
77
108
144
239
70
113
142
118
166
262
137
150
162
71
73
229
123
130
184
109
79
254
337
69
98
144
137
138
76
92
122
77
114
290
152
119
157
153
123
133
60
99
73
117
130
81
89
274
144
92
147
89
152
109
73
349
96
137
68
137
236
194
131
174
117
125
89
38
178
150
12
211
79
126
121


In [76]:
import heapq
import numpy as np
from scipy.spatial.distance import cosine


class FixedSizeList:
    def __init__(self, size):
        self.size = size
        self.data = []

    def push(self, score, item):
        if len(self.data) < self.size:
            heapq.heappush(self.data, (score, item))
        else:
            heapq.heappushpop(self.data, (score, item))

    def get_list(self):
        return sorted(self.data)

question = "First book?"

with torch.no_grad():
    tokens = tokenizer.encode(question, add_special_tokens=True)
    input_ids = torch.tensor([tokens])
    outputs = model(input_ids)
    target_vector = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()

# print(target_vector)
batch_size = 32
fsl = FixedSizeList(10)
with Session() as session:
    offset = 0
    batch_count = 0
    while True:
        sentences = session.query(Sentence).offset(offset).limit(batch_size).all()
        if not sentences:
            break  # No more sentences, end the loop
        for sentence in sentences:
            sentence_vector = np.array(sentence.vector)
            similarity = 1 - cosine(target_vector, sentence_vector)
            fsl.push(similarity, sentence)
        offset += batch_size
        batch_count += 1

for score, sent in fsl.get_list():
    print(score, sent.text)


0.5299220765956472 The New Testament was translated into southern Estonian in 1686 .
0.5299220765956472 The New Testament was translated into southern Estonian in 1686 .
0.5302152957964208 An Estonian grammar book to be used by priests was printed in German in 1637.
0.5302152957964208 An Estonian grammar book to be used by priests was printed in German in 1637.
0.5446677019350086 The book was a Lutheran manuscript, which never reached the reader and was destroyed immediately after publication.
0.5446677019350086 The book was a Lutheran manuscript, which never reached the reader and was destroyed immediately after publication.
0.5724760011496554 they created new words out of nothing.
0.5724760011496554 they created new words out of nothing.
0.6273926247617584 Examples are
0.6273926247617584 Examples are
