In [1]:
# ! pip install pymilvus==2.3.1

In [2]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from dotenv import load_dotenv
import os
from pymilvus import connections, utility
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema
from sentence_transformers import SentenceTransformer

### Read news_articles JSON

In [3]:
data = pd.read_json("News_Category_Dataset_v3.json", lines=True, nrows=200) # loading 200 news articles
data = data[['short_description']].copy()
data.reset_index(inplace=True)
data

Unnamed: 0,index,short_description
0,0,Health experts said it is too early to predict...
1,1,He was subdued by passengers and crew when he ...
2,2,"""Until you have a dog you don't understand wha..."
3,3,"""Accidentally put grown-up toothpaste on my to..."
4,4,Amy Cooper accused investment firm Franklin Te...
...,...,...
195,195,Concerned your chest pain might be heart-relat...
196,196,"The Senate GOP leader cited ""candidate quality..."
197,197,"Bryant's widow testified that she ""lives in fe..."
198,198,"Young Thug, along with rapper Gunna, are one o..."


### Vectorization using Sentence transformers

In [4]:
class TextVectorizer:
    '''
    sentence transformers to extract sentence embeddings
    '''
    def vectorize(self, x: pd.Series, dataset: str = "train"):
        x = x.copy()
        model = SentenceTransformer('bert-base-nli-mean-tokens')
        sen_embeddings = model.encode(x)
        return sen_embeddings

In [5]:
vectorizer = TextVectorizer()

In [6]:
# getting max length of article descriptions to be used for VARCHAR while defining schema
max_desc_len = max([len(s) for s in data['short_description']])
max_desc_len

245

In [7]:
# Reading milvus URI & API token from secrets.env
load_dotenv('secrets.env')
uri = os.environ.get("URI")
token = os.environ.get("TOKEN")

In [8]:
# connecting to db
connections.connect("default", uri=uri, token=token)
print(f"Connected to DB")

Connected to DB


In [9]:
collection_name = os.environ.get("COLLECTION_NAME")
check_collection = utility.has_collection(collection_name)

In [10]:
if check_collection:
    drop_result = utility.drop_collection(collection_name)
    print("Droped Existing collection")

Droped Existing collection


In [11]:
# Creating collection schema
dim = 768 # embeddings dim
article_id = FieldSchema(name="article_id", dtype=DataType.INT64, is_primary=True, description="primary id") # primary key
article_embed_field = FieldSchema(name="article_embed", dtype=DataType.FLOAT_VECTOR, dim=dim) # description embeddings
article_desc = FieldSchema(name="article_desc", dtype=DataType.VARCHAR, max_length=(max_desc_len + 50), # using max_desc_len to specify VARCHAR len 
                           is_primary=False, description="short description of the article") # short description of article
schema = CollectionSchema(fields=[article_id, article_embed_field, article_desc], 
                          auto_id=False, description="collection of news articles")
print(f"Creating the collection")
collection = Collection(name=collection_name, schema=schema)
print(f"Schema: {schema}")
print("Success!")

Creating the collection
Schema: {'auto_id': False, 'description': 'collection of news articles', 'fields': [{'name': 'article_id', 'description': 'primary id', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'article_embed', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 768}}, {'name': 'article_desc', 'description': 'short description of the article', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 295}}]}
Success!


In [12]:
# Preparing data to load
article_id = []
article_desc = []
article_embed = []
for i in tqdm(range(len(data))):
    article_id.append(data['index'][i])
    article_desc.append(data['short_description'][i])
    article_embed.append(vectorizer.vectorize([data['short_description'][i]])[0])
docs = [article_id, article_embed, article_desc]

100%|██████████| 200/200 [03:25<00:00,  1.03s/it]


In [13]:
# insert documents into collection
ins_resp = collection.insert(docs)
ins_resp # insert result

(insert count: 200, delete count: 0, upsert count: 0, timestamp: 444532761385500674, success count: 200, err count: 0)

In [14]:
# creating index on embeddings field (article_embed)
# metric type: L2 (euclidean dist). supported: [L2 IP]
index_params = {"index_type": "AUTOINDEX", "metric_type": "L2", "params": {}} 
collection.create_index(field_name='article_embed', index_params=index_params)

alloc_timestamp unimplemented, ignore it


Status(code=0, message=)