In [1]:
# !pip3 install transformers
# !pip3 install nltk
# !pip3 install torch
# !pip3 install weaviate-client==3.2.2

## Import the BERT transformer model and pytorch

We are using the `bert-base-uncased` model in this example, but any model will work. Feel free to adjust accordingly.

## Initialize Weaviate Client
This assumes you have Weaviate running locally on `:8080`. Adjust URL accordingly. You could also enter the WCS URL here, for example, if you are running a WCS cloud instance instead of running Weaviate locally.

In [None]:
import torch
from transformers import AutoModel, AutoTokenizer
from nltk.tokenize import sent_tokenize
import weaviate

torch.set_grad_enabled(False)

# udpated to use different model if desired
MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
model = AutoModel.from_pretrained(MODEL_NAME)
model.to('cuda') # remove if working without GPUs
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# initialize nltk (for tokenizing sentences)
import nltk
nltk.download('punkt')

# initialize weaviate client for importing and searching
client = weaviate.Client("http://64.71.146.93:8080")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[nltk_data] Downloading package punkt to
[nltk_data]     /Users/stefanbogdan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## Load dataset from disk
Create some helper functions to create the dataset (20-newsgroup text posts) from disk. These methods are specific to the structure of your dataset, adjust accordingly.

In [None]:
import os
import random

def get_post_filenames(limit_objects=100):
    file_names = []
    i=0
    for root, dirs, files in os.walk("./data/20news-bydate-test"):
        for filename in files:
            path = os.path.join(root, filename)
            file_names += [path]
        
    random.shuffle(file_names)
    limit_objects = min(len(file_names), limit_objects)
      
    file_names = file_names[:limit_objects]

    return file_names

def read_posts(filenames=[]):
    posts = []
    for filename in filenames:
        f = open(filename, encoding="utf-8", errors='ignore')
        post = f.read()
        
        # strip the headers (the first occurrence of two newlines)
        post = post[post.find('\n\n'):]
        
        # remove posts with less than 10 words to remove some of the noise
        if len(post.split(' ')) < 10:
               continue
        
        post = post.replace('\n', ' ').replace('\t', ' ')
        if len(post) > 1000:
            post = post[:1000]
        posts += [post]

    return posts       


## Vectorize Dataset using BERT

The following is a helper function to vectorize all posts (using our BERT transformer) which are entered as an array. The return array contains all the vectors in the same order. BERT is optimized to run on GPUs, if you're using CPUs this might take a while. 

In [None]:
import time

def text2vec(text):
    tokens_pt = tokenizer(text, padding=True, truncation=True, max_length=500, add_special_tokens = True, return_tensors="pt")
    outputs = model(**tokens_pt)
    tokens_pt.to('cuda') # remove if working without GPUs
    return outputs[0].mean(0).mean(0).detach()

def vectorize_posts(posts=[]):
    post_vectors=[]
    before=time.time()
    for i, post in enumerate(posts):
        vec=text2vec(sent_tokenize(post))
        post_vectors += [vec]
        if i % 25 == 0 and i != 0:
            print("So far {} objects vectorized in {}s".format(i, time.time()-before))
    after=time.time()
    
    print("Vectorized {} items in {}s".format(len(posts), after-before))
    
    return post_vectors

### Run everything we have so far

It is now time to run the functions we defined before. Let's load 50 random posts from disk, then vectorize them using BERT.

## Initialize Weaviate

Now that we have vectors we can import both the posts and the vectors into Weaviate, so we can then search through them.

### Init a simple schema
Our schema is very simple, we just have one object class, the "Post". A post class has just a single property, which we call "content" and is of type "text".

Each class in schema creates one index, so by running the below we tell weaviate to create one brand new vector index waiting for us to import data.

In [None]:
def init_weaviate_schema(client):
    # a simple schema containing just a single class for our posts
    schema = {
        "classes": [{
                "class": "Post",
                "vectorizer": "none", # explicitly tell Weaviate not to vectorize anything, we are providing the vectors ourselves through our BERT model
                "properties": [{
                    "name": "content",
                    "dataType": ["text"],
                }]
        }]
    }

    # cleanup from previous runs
    client.schema.delete_all()

    client.schema.create(schema)

In [None]:
def import_posts_with_vectors(posts, vectors, client):
    if len(posts) != len(vectors):
        raise Exception("len of posts ({}) and vectors ({}) does not match".format(len(posts), len(vectors)))
        
    for i, post in enumerate(posts):
        try:
            client.data_object.create(
                data_object={"content": post},
                class_name='Post',
                vector=vectors[i]
            )
        except:
            print(res)

In [None]:
def search(query="", limit=3):
    before = time.time()
    vec = text2vec(query)
    vec_took = time.time() - before

    before = time.time()
    near_vec = {"vector": vec}
    res = client \
        .query.get("Post", ["content", "_additional {certainty}"]) \
        .with_near_vector(near_vec) \
        .with_limit(limit) \
        .do()
    search_took = time.time() - before

    print("\nQuery \"{}\" with {} results took {:.3f}s ({:.3f}s to vectorize and {:.3f}s to search)" \
          .format(query, limit, vec_took+search_took, vec_took, search_took))
    for post in res["data"]["Get"]["Post"]:
        print("{:.4f}: {}".format(post["_additional"]["certainty"], post["content"]))
        print('---')

In [None]:
init_weaviate_schema(client)
posts = read_posts(get_post_filenames(100))
vectors = vectorize_posts(posts)
import_posts_with_vectors(posts, vectors, client)

So far 25 objects vectorized in 17.370927810668945s
So far 50 objects vectorized in 31.07236075401306s
So far 75 objects vectorized in 49.37031865119934s
Vectorized 100 items in 64.40030074119568s


In [None]:
search("the best camera lens", 1)
search("motorcycle trip", 1)
search("which software do i need to view jpeg files", 1)
search("windows vs mac", 1)


Query "the best camera lens" with 1 results took 0.058s (0.034s to vectorize and 0.024s to search)
0.8271:   Hi,  I have a HP LaserJet III for sale.  It's been printed for less than 1500 pages according the self test report.  I am asking $1000 for it.  If interested, please e-mail.  Thanks! 
---

Query "motorcycle trip" with 1 results took 0.035s (0.025s to vectorize and 0.011s to search)
0.8202:   Hi,  I have a HP LaserJet III for sale.  It's been printed for less than 1500 pages according the self test report.  I am asking $1000 for it.  If interested, please e-mail.  Thanks! 
---

Query "which software do i need to view jpeg files" with 1 results took 0.055s (0.044s to vectorize and 0.011s to search)
0.9191:   Is anybody using David Rapier's Hebrew Quiz software?  And can tell me how to *space* when typing in the Hebrew?  (space bar doesn't work, for me anyway...)  Email please; thanks.  Ken --  miner@kuhub.cc.ukans.edu opinions are my own      
---

Query "windows vs mac" with 1 r