In [1]:
import torch
from transformers import AutoModel, AutoTokenizer
from nltk.tokenize import sent_tokenize
import weaviate

In [2]:
torch.set_grad_enabled(False)

# udpate to use different model if desired
MODEL_NAME = "distilbert-base-uncased"
model = AutoModel.from_pretrained(MODEL_NAME)
# model.to('cuda') # remove if working without GPUs
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# initialize nltk (for tokenizing sentences)
import nltk
nltk.download('punkt')

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\evanm\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
import requests
import json
import time

# Load data
url = 'https://raw.githubusercontent.com/evanmcfarland/A-Statistical-Approach-to-Happiness/main/data/Mary%20Shelley/mary_books_segments.json'
resp = requests.get(url)
data = json.loads(resp.text)

def text2vec(text):
    tokens_pt = tokenizer(text, padding=True, truncation=True, max_length=500, add_special_tokens = True, return_tensors="pt")
    outputs = model(**tokens_pt)
    return outputs[0].mean(0).mean(0).detach()

def vectorize_posts(posts=[]):
    post_vectors=[]
    before=time.time()
    for i, post in enumerate(posts):
        vec = text2vec(sent_tokenize(post))
        post_vectors += [vec]
        if i % 100 == 0 and i != 0:
            print("So far {} objects vectorized in {}s".format(i, time.time()-before))
    after=time.time()
    
    print("Vectorized {} items in {}s".format(len(posts), after-before))
    
    return post_vectors

# Extract the 'content' values from the dictionaries and store them in a list
content_list = [item['content'] for item in data]

# Vectorize the content list
vectorized_data = vectorize_posts(content_list)

So far 100 objects vectorized in 5.291132211685181s
So far 200 objects vectorized in 11.159983158111572s
So far 300 objects vectorized in 18.557889699935913s
So far 400 objects vectorized in 25.677875757217407s
So far 500 objects vectorized in 33.089988708496094s
So far 600 objects vectorized in 41.09516453742981s
So far 700 objects vectorized in 47.597142457962036s
So far 800 objects vectorized in 55.06011462211609s
So far 900 objects vectorized in 62.57319140434265s
So far 1000 objects vectorized in 71.91520261764526s
So far 1100 objects vectorized in 77.39976787567139s
So far 1200 objects vectorized in 83.18336153030396s
So far 1300 objects vectorized in 88.58209419250488s
So far 1400 objects vectorized in 96.38992595672607s
So far 1500 objects vectorized in 103.67958092689514s
So far 1600 objects vectorized in 112.27004408836365s
So far 1700 objects vectorized in 118.38999342918396s
So far 1800 objects vectorized in 123.95006704330444s
So far 1900 objects vectorized in 129.83985090

In [4]:
Break all this shit up. It's alreay done!

SyntaxError: invalid syntax (3714369493.py, line 1)

# Adding the Class (already done)

In [None]:
client = weaviate.Client(
    url = 'https://uncensoredgreats.weaviate.network',
    auth_client_secret=weaviate.AuthClientPassword(username="evanmcfarland.aa@gmail.com", password="Teslasby25!"),
)


# ===== add class object =====
class_obj = {
    "class": "DistilBERT_Corpus",
    "description": "Full Uncensored Greats Corpus with homeade distilbert vectors.",
    "vectorizer": "none",
    "invertedIndexConfig": {
        "bm25": {
            "b": 0.75,
            "k1": 1.2
        },
        "cleanupIntervalSeconds": 60,
        "stopwords": {
            "additions": None,
            "preset": "en",
            "removals": None
        }
    },
    "properties": [
        {
            "dataType": [
                "text"
            ],
            "description": "Greatest Writings Ever",
            "indexInverted": True,
            "name": "content",
            "tokenization": "word"
        },
        {
            "dataType": [
                "string"
            ],
            "description": "Book name",
            "indexInverted": True,
            "name": "title",
            "tokenization": "word"
        },
        {
            "dataType": [
                "string"
            ],
            "description": "Section label",
            "name": "heading",
            "tokenization": "word"
        },
        {
            "dataType": [
                "string"
            ],
            "description": "Author Name",
            "name": "author",
            "tokenization": "word"
        }
    ]
}


# add the schema
client.schema.create_class(class_obj)

# get the schema
schema = client.schema.get()

# ===== import data ===== 
# Load data 
url = 'https://raw.githubusercontent.com/evanmcfarland/A-Statistical-Approach-to-Happiness/main/data/Mary%20Shelley/mary_books_segments.json'
resp = requests.get(url)
data = json.loads(resp.text)

In [None]:
# delete class "Paragraphs" THIS WILL DELETE ALL DATA IN THIS CLASS
# client.schema.delete_class("DistilBERT_Corpus")

schema = client.schema.get()
print(json.dumps(schema, indent=4))

# Doing the import

In [12]:
# Trying a different approach: 
def delay_if_needed(i):
    global last_successful_index
    if (i + 1) % 1000 == 0:
        print(f"Pausing for 1 minute to avoid rate limit")
        time.sleep(60)
        last_successful_index = i + 1  # Update the last successful index


def split_data(data, chunk_size):
    return [data[i:i + chunk_size] for i in range(0, len(data), chunk_size)]


chunk_size = 10000
data_chunks = split_data(data, chunk_size)
vectorized_data_chunks = split_data(vectorized_data, chunk_size)

for chunk_index, (data_chunk, vectorized_data_chunk) in enumerate(zip(data_chunks, vectorized_data_chunks)):
    print(f"Processing chunk {chunk_index + 1}")
    
    with client.batch as batch:
        for i, d in enumerate(data_chunk):
            current_index = i + chunk_index * chunk_size
            print(f"importing text chunk: {current_index + 1}")

            properties = {
                "title": json.dumps(d["title"], ensure_ascii=False),
                "heading": json.dumps(d["heading"], ensure_ascii=False),
                "content": json.dumps(d["content"], ensure_ascii=False),
                "author": json.dumps(d["author"], ensure_ascii=False)
            }

            client.batch.add_data_object(properties, "DistilBERT_Corpus", vector=vectorized_data_chunk[i])
            
            delay_if_needed(current_index)


Processing chunk 1
importing text chunk: 1
importing text chunk: 2
importing text chunk: 3
importing text chunk: 4
importing text chunk: 5
importing text chunk: 6
importing text chunk: 7
importing text chunk: 8
importing text chunk: 9
importing text chunk: 10
importing text chunk: 11
importing text chunk: 12
importing text chunk: 13
importing text chunk: 14
importing text chunk: 15
importing text chunk: 16
importing text chunk: 17
importing text chunk: 18
importing text chunk: 19
importing text chunk: 20
importing text chunk: 21
importing text chunk: 22
importing text chunk: 23
importing text chunk: 24
importing text chunk: 25
importing text chunk: 26
importing text chunk: 27
importing text chunk: 28
importing text chunk: 29
importing text chunk: 30
importing text chunk: 31
importing text chunk: 32
importing text chunk: 33
importing text chunk: 34
importing text chunk: 35
importing text chunk: 36
importing text chunk: 37
importing text chunk: 38
importing text chunk: 39
importing text 

ConnectionError: Batch was not added to weaviate.

[ERROR] Batch ConnectionError Exception occurred! Retrying in 2s. [1/3]
[ERROR] Batch ConnectionError Exception occurred! Retrying in 4s. [2/3]
[ERROR] Batch ConnectionError Exception occurred! Retrying in 6s. [3/3]


In [5]:
# Nietzsche done up to 16740/18830  
# Mary done 25000 / 26119  (for some resaon it skipped from 25000 after waiting for one minute, jumped to 27232)

client = weaviate.Client(
    url = 'https://uncensoredgreats.weaviate.network',
    auth_client_secret=weaviate.AuthClientPassword(username="evanmcfarland.aa@gmail.com", password="Teslasby25!")
)

In [10]:
# Batch Process with rate limit fix.
import time

# Global variable to store the last successful index
last_successful_index = 22769

def delay_if_needed(i):
    global last_successful_index
    if (i + 1) % 5000 == 0:
        print(f"Pausing for 1 minute to avoid rate limit")
        time.sleep(60)
        last_successful_index = i + 1  # Update the last successful index

# Main part of the code
with client.batch as batch:
    for i, d in enumerate(data[last_successful_index:]):
        current_index = i + last_successful_index
        print(f"importing text chunk: {current_index + 1}")

        properties = {
            "title": json.dumps(d["title"], ensure_ascii=False),
            "heading": json.dumps(d["heading"], ensure_ascii=False),
            "content": json.dumps(d["content"], ensure_ascii=False),
            "author": json.dumps(d["author"], ensure_ascii=False)
        }

        client.batch.add_data_object(properties, "DistilBERT_Corpus", vector=vectorized_data[current_index])
        
        delay_if_needed(current_index)


importing text chunk: 22770
importing text chunk: 22771
importing text chunk: 22772
importing text chunk: 22773
importing text chunk: 22774
importing text chunk: 22775
importing text chunk: 22776
importing text chunk: 22777
importing text chunk: 22778
importing text chunk: 22779
importing text chunk: 22780
importing text chunk: 22781
importing text chunk: 22782
importing text chunk: 22783
importing text chunk: 22784
importing text chunk: 22785
importing text chunk: 22786
importing text chunk: 22787
importing text chunk: 22788
importing text chunk: 22789
importing text chunk: 22790
importing text chunk: 22791
importing text chunk: 22792
importing text chunk: 22793
importing text chunk: 22794
importing text chunk: 22795
importing text chunk: 22796
importing text chunk: 22797
importing text chunk: 22798
importing text chunk: 22799
importing text chunk: 22800
importing text chunk: 22801
importing text chunk: 22802
importing text chunk: 22803
importing text chunk: 22804
importing text chunk

ConnectionError: Batch was not added to weaviate.

[ERROR] Batch SSLError Exception occurred! Retrying in 2s. [1/3]
[ERROR] Batch SSLError Exception occurred! Retrying in 4s. [2/3]
[ERROR] Batch SSLError Exception occurred! Retrying in 6s. [3/3]


In [7]:
i

16739

In [None]:
import torch
from transformers import AutoModel, AutoTokenizer

torch.set_grad_enabled(False)
MODEL_NAME = "distilbert-base-uncased"
model = AutoModel.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def text2vec(text):
    tokens_pt = tokenizer(text, padding=True, truncation=True, max_length=500, add_special_tokens=True, return_tensors="pt")
    outputs = model(**tokens_pt)
    return outputs[0].mean(0).mean(0).detach()

def vectorize_query(query):
    return text2vec(query).numpy()

from weaviate import Client

client = weaviate.Client(
    url = 'https://uncensoredgreats.weaviate.network',
    auth_client_secret=weaviate.AuthClientPassword(username="evanmcfarland.aa@gmail.com", password="Teslasby25!"),
)

# def get_relevant_results(vector, n_results=5):
#     query = {
#         "vector": vector.tolist(),
#         "n": n_results,
#         "certainty": 0.7
#     }
#     response = client.query.get("DistilBERT_Corpus", query)
#     return response['result']


# def search(query, limit=3):
#     vectorized_input = vectorize_query(query)
#     relevant_results = get_relevant_results(vectorized_input, limit)

#     for i, result in enumerate(relevant_results):
#         print(f"Result {i + 1}:")
#         print(f"Title: {result['title']}")
#         print(f"Heading: {result['heading']}")
#         print(f"Content: {result['content']}")
#         print(f"Author: {result['author']}")
#         print("\n")

# search("What is the meaning of life?")


In [None]:

# def search(query="", limit=3):
#     before = time.time()
#     vec = text2vec(query)
#     vec_took = time.time() - before

#     before = time.time()
#     near_vec = {"vector": vec.tolist()}
#     res = client \
#         .query.get("DistilBERT_Corpus", ["content", "_additional {certainty}"]) \
#         .with_near_vector(near_vec) \
#         .with_limit(limit) \
#         .do()
#     search_took = time.time() - before

#     print("\nQuery \"{}\" with {} results took {:.3f}s ({:.3f}s to vectorize and {:.3f}s to search)" \
#           .format(query, limit, vec_took+search_took, vec_took, search_took))
#     for post in res["data"]["Get"]["DistilBERT_Corpus"]:
#         print("{:.4f}: {}".format(post["_additional"]["certainty"], post["content"]))
#         print('---')


def search(query="", limit=3):
    before = time.time()
    vec = text2vec(query)
    vec_took = time.time() - before

    before = time.time()
    near_vec = {"vector": vec.tolist()}
    res = client \
        .query.get("DistilBERT_Corpus", ["title", "heading", "author", "content", "_additional {certainty}"]) \
        .with_near_vector(near_vec) \
        .with_limit(limit) \
        .do()
    search_took = time.time() - before

    print("\nQuery \"{}\" with {} results took {:.3f}s ({:.3f}s to vectorize and {:.3f}s to search)" \
          .format(query, limit, vec_took+search_took, vec_took, search_took))
    for post in res["data"]["Get"]["DistilBERT_Corpus"]:
        print("Title: {}\nHeading: {}\nAuthor: {}\nCertainty: {:.4f}\nContent: {}"
              .format(post["title"], post["heading"], post["author"], post["_additional"]["certainty"], post["content"]))
        print('---')


In [None]:
search("Hello wtf are you working?", 10)