# Semantic Search with Movie Plots

How do you find movies based on what they're about? Semantic search.

We can use movie plots and phrases to search through a movie database and pick movies based on which movies are the most similar to our search phrase. In this example, we create a way to do semantic search on movies in the Wikipedia-Movie-Plots Dataset found on [Kaggle](https://www.kaggle.com/datasets/jrobischon/wikipedia-movie-plots). We put together a system to semantically search movie plots using a vector database and the sentence-transformers library. For this example, we use [Milvus Lite](https://milvus.io/docs/milvus_lite.md) to run our vector database locally. 

We begin by installing the necessary libraries:

In [1]:
! pip install pymilvus sentence-transformers gdown milvus

Collecting pymilvus
  Downloading pymilvus-2.2.7-py3-none-any.whl (133 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m133.6/133.6 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hCollecting sentence-transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting gdown
  Using cached gdown-4.7.1-py3-none-any.whl (15 kB)
Collecting milvus
  Using cached milvus-2.2.5-py3-none-macosx_11_0_arm64.whl (23.5 MB)
Collecting pandas>=1.2.4
  Using cached pandas-2.0.0-cp310-cp310-macosx_11_0_arm64.whl (10.8 MB)
Collecting ujson>=2.0.0
  Using cached ujson-5.7.0-cp310-cp310-macosx_11_0_arm64.whl (53 kB)
Collecting protobuf>=3.20.0
  Using cached protobuf-4.22.3-cp37-abi3-macosx_10_9_universal2.whl (397 kB)
Collecting grpcio<=1.53.0,>=1.49.1
  Using cached grpcio-1.53.0-cp310

Next, we download the data and unzip it.

In [2]:
import gdown
url = 'https://drive.google.com/uc?id=11ISS45aO2ubNCGaC3Lvd3D7NT8Y7MeO8'
output = './movies.zip'
gdown.download(url, output)

import zipfile

with zipfile.ZipFile("./movies.zip","r") as zip_ref:
    zip_ref.extractall("./movies")

Downloading...
From (uriginal): https://drive.google.com/uc?id=11ISS45aO2ubNCGaC3Lvd3D7NT8Y7MeO8
From (redirected): https://drive.google.com/uc?id=11ISS45aO2ubNCGaC3Lvd3D7NT8Y7MeO8&confirm=t&uuid=44b14708-fbce-426f-99fc-1f87e415159f
To: /Users/yujiantang/Documents/workspace/bootcamp/notebooks/text/movies.zip
100%|██████████| 30.9M/30.9M [00:03<00:00, 8.38MB/s]


We need to establish some constants for our vector database.

In [3]:
COLLECTION_NAME = 'movies_db'  # Collection name
DIMENSION = 384  # Embeddings size

# Inference Arguments
BATCH_SIZE = 128

# Search Arguments
TOP_K = 3

With our constants established for consistency, we spin up an instance of Milvus to locally run a vector database, making sure that we're not duplicating any existing collection.

In [4]:
from milvus import default_server
from pymilvus import connections, utility

# (OPTIONAL) Set if you want store all related data to specific location
# Default location:
#   %APPDATA%/milvus-io/milvus-server on windows
#   ~/.milvus-io/milvus-server on linux
# default_server.set_base_dir('milvus_data')

# (OPTIONAL) if you want cleanup previous data
# default_server.cleanup()

# Start your milvus server
default_server.start()

# Now you could connect with localhost and the given port
# Port is defined by default_server.listen_port
connections.connect(host='127.0.0.1', port=default_server.listen_port)

# Check if the server is ready.
print(utility.get_server_version())

[93m[get_server_version] retry:4, cost: 0.27s, reason: <_InactiveRpcError: StatusCode.UNAVAILABLE, internal: Milvus Proxy is not ready yet. please wait>[0m
[93m[get_server_version] retry:5, cost: 0.81s, reason: <_InactiveRpcError: StatusCode.UNAVAILABLE, internal: Milvus Proxy is not ready yet. please wait>[0m


v2.2.5-dev


In [5]:
if utility.has_collection(COLLECTION_NAME):
    utility.drop_collection(COLLECTION_NAME)

Now we have an instance of a vector database spun up. Let's define our schema and create a collection.

For these movies, each object in the database needs three components: an ID, a title, and the embedding.

In [6]:
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection


# Create collection which includes the id, title, and embedding.
fields = [
    FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name='title', dtype=DataType.VARCHAR, max_length=200),  # VARCHARS need a maximum length, so for this example they are set to 200 characters
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]
schema = CollectionSchema(fields=fields)
collection = Collection(name=COLLECTION_NAME, schema=schema)

Next, we need to define the vector index. For this example, we use an IVF index on an L2 distance metric with 128 vector indices just like we do in the
[reverse image search example notebook](../vision/reverse_painting_search.ipynb).

In [8]:
index_params = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
collection.create_index(field_name="embedding", index_params=index_params)
collection.load()

With our local vector database set up, we can dive into creating vectors out of movie plots and putting them into a vector space.

For this example, we use the [MiniLM L6 v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) sentence transformer

In [9]:
import csv
from sentence_transformers import SentenceTransformer

transformer = SentenceTransformer('all-MiniLM-L6-v2')

Downloading (…)e9125/.gitattributes: 100%|██████████| 1.18k/1.18k [00:00<00:00, 784kB/s]
Downloading (…)_Pooling/config.json: 100%|██████████| 190/190 [00:00<00:00, 159kB/s]
Downloading (…)7e55de9125/README.md: 100%|██████████| 10.6k/10.6k [00:00<00:00, 6.35MB/s]
Downloading (…)55de9125/config.json: 100%|██████████| 612/612 [00:00<00:00, 437kB/s]
Downloading (…)ce_transformers.json: 100%|██████████| 116/116 [00:00<00:00, 93.9kB/s]
Downloading (…)125/data_config.json: 100%|██████████| 39.3k/39.3k [00:00<00:00, 2.43MB/s]
Downloading pytorch_model.bin: 100%|██████████| 90.9M/90.9M [00:09<00:00, 9.14MB/s]
Downloading (…)nce_bert_config.json: 100%|██████████| 53.0/53.0 [00:00<00:00, 38.1kB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 112/112 [00:00<00:00, 73.4kB/s]
Downloading (…)e9125/tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 6.21MB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 350/350 [00:00<00:00, 145kB/s]
Downloading (…)9125/train_script.py: 100%|██

With our embeddings extractor loaded, we need the movie titles and plots to embed. Taking a look at the data from the [Kaggle page](https://www.kaggle.com/datasets/jrobischon/wikipedia-movie-plots), we see that the data contains eight columns. We are only interested in the title (column 2) and the plot (column 8) so our `csv_load` function extracts just those.

The second function we write in the block below takes a tuple

In [10]:
# Extract the movie titles
def csv_load(file):
    with open(file, newline='') as f:
        reader = csv.reader(f, delimiter=',')
        for row in reader:
            if '' in (row[1], row[7]):
                continue
            yield (row[1], row[7])


# Extract embeding from text using SentenceTransformer
def embed_insert(data: tuple):
    embeds = transformer.encode(data[1]) 
    ins = [
            data[0],
            [x for x in embeds]
    ]
    collection.insert(ins)

In [11]:
data_batch = [[],[]]

for title, plot in csv_load('./movies/plots.csv'):
    data_batch[0].append(title)
    data_batch[1].append(plot)
    if len(data_batch[0]) % BATCH_SIZE == 0:
        embed_insert(data_batch)
        data_batch = [[],[]]

# Embed and insert the remainder
if len(data_batch[0]) != 0:
    embed_insert(data_batch)

# Call a flush to index any unsealed segments.
collection.flush()

In [33]:
import time

# Search for titles that closest match these phrases.
search_terms = ['We do not talk about fight club.', 'Boxing with a Russian.']

# Search the database based on input text
def embed_search(data):
    embeds = transformer.encode(data) 
    return [x for x in embeds]

search_data = embed_search(search_terms)

start = time.time()
res = collection.search(
    data=search_data,  # Embeded search value
    anns_field="embedding",  # Search across embeddings
    param={"metric_type": "L2",
            "params": {"nprobe": 10}},
    limit = TOP_K,  # Limit to top_k results per search
    output_fields=['title']  # Include title field in result
)
end = time.time()

for hits_i, hits in enumerate(res):
    print('Title:', search_terms[hits_i])
    print('Search Time:', end-start)
    print('Results:')
    for hit in hits:
        print( hit.entity.get('title'), '----', hit.distance)
    print()

Title: We do not talk about fight club.
Search Time: 0.004420757293701172
Results:
Fight Club – Members Only ---- 1.2392218112945557
Boxer ---- 1.398276925086975
Battle Creek Brawl ---- 1.400315761566162

Title: Boxing with a Russian.
Search Time: 0.004420757293701172
Results:
Never Say Die ---- 0.8928807377815247
Rocky IV ---- 0.9470371007919312
Shadowboxing ---- 0.9799186587333679



In [None]:
# cleanup
default_server.stop()