# RAG / HyDE Example with Mistral Instruct 7b and Milvus DB

### Imports dependencies and Mistral LLM

In [1]:
import os
import torch
import torchvision
import torchaudio
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

home_dir = os.getenv("HOME")

model_name = f'{home_dir}/ext-gits/Mistral-7B-Instruct-v0.3'

tokenizer = AutoTokenizer.from_pretrained(model_name)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

try:
    # Attempt to use 8-bit quantization
    quantization_config = BitsAndBytesConfig(load_in_8bit=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=quantization_config,
        device_map={"": device},
        trust_remote_code=True
    )
    print("Model loaded with 8-bit quantization.")
except Exception as e:
    print(f"8-bit quantization failed: {e}")
    print("Falling back to Float16 precision.")
    # Load the model with Float16 precision
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map={"": device},
        trust_remote_code=True
    )



  from .autonotebook import tqdm as notebook_tqdm


Using device: mps
8-bit quantization failed: Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`
Falling back to Float16 precision.


Loading checkpoint shards: 100%|██████████| 3/3 [00:10<00:00,  3.60s/it]


### Test LLM

In [2]:

prompt = "Once upon a time"

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)

# Generate response
with torch.no_grad():
    outputs = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_length=50,
        num_return_sequences=1,
        do_sample=True,
        top_p=0.95,
        top_k=50,
        pad_token_id=tokenizer.pad_token_id
    )

# Decode and print the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(response)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Once upon a time, there was a man who was a gardener. He had the most beautiful garden in the whole country, and every day he would get up and go out into it and tend to it, watching it grow and flourish


### Set Up Vector DB

In [23]:
import pandas as pd

# Load the dataset
df = pd.read_csv('data/wiki_movie_plots_deduped.csv')

# Ensure there are no NaNs in the 'Plot' column
df = df[['Plot']].dropna()

# Find the longest plot and its length
longest_plot = df['Plot'].apply(len).idxmax()  # Find the index of the longest plot
longest_plot_text = df['Plot'].iloc[longest_plot]  # Get the longest plot text
longest_plot_length = len(longest_plot_text)  # Get the length of the longest plot text

print(f"The longest plot is at index {longest_plot} with length {longest_plot_length} characters.")
print(f"Longest Plot: {longest_plot_text[:500]}...")

The longest plot is at index 26064 with length 36773 characters.
Longest Plot: After a brief introduction to some of the main characters of the story, the beginning sees a group of Rishis, led by Vishvamitra, performing a Yajna in a forest not far from Ayodhya, the Capital of the Kingdom of Kosala. This Yajna, like several before it, is interrupted and destroyed by a group of flying demons led by Ravana's Mama(Uncle/Mother's Brother) Maricha. After seeing yet another Yajna destroyed, a despondent Vishvamitra appeals to Lord Vishnu for salvation. Vishnu appears in a spiritu...


In [24]:
# run milvis locally
# run: docker-compose up -d

# data is ..data/wiki_movie_plots_deduped.csv is from
# https://www.kaggle.com/datasets/jrobischon/wikipedia-movie-plots

from pymilvus import connections, CollectionSchema, DataType, FieldSchema, Collection, utility

# Connect to Milvus
connections.connect(alias="default", host="127.0.0.1", port="19530")

# Check Milvus status
print("Milvus connected:", connections.has_connection(alias="default"))

# Define the schema for the collection
collection_name = "wiki_movie_plots"
dim = 768  # Dimensions for the vector embeddings

# Check if the collection exists and drop it if it does
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!   WILL DROP
# if utility.has_collection(collection_name):
#     collection = Collection(collection_name)
#     collection.drop()
#     print(f"Collection '{collection_name}' dropped.")


fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),  # Primary key field
    FieldSchema(name="plot_embedding", dtype=DataType.FLOAT_VECTOR, dim=dim),
    FieldSchema(name="plot_text", dtype=DataType.VARCHAR, max_length=40000),  # To store original plot text
    FieldSchema(name="release_year", dtype=DataType.INT64),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=255),
]

# Create the schema and collection
schema = CollectionSchema(fields, description="Wikipedia Movie Plots with vector embeddings and original plot text")
collection = Collection(name=collection_name, schema=schema)

print("Milvus collection schema created successfully!")


Milvus connected: True
Collection 'wiki_movie_plots' dropped.
Milvus collection schema created successfully!


In [35]:
# doublecheck

from pymilvus import connections, Collection

# Connect to Milvus
# connections.connect(alias="default", host="localhost", port="19530")
collection.load()

# Load the collection
collection = Collection("wiki_movie_plots")

# Check how many records are in the collection
num_records = collection.num_entities
print(f"Number of records in the collection: {num_records}")

# Query a few records to inspect
if num_records > 0:
    sample_records = collection.query(expr="id >= 1", output_fields=["id", "plot_text"], limit=5)
    print("Sample records:", sample_records)

Number of records in the collection: 34886
Sample records: data: ['{\'id\': 453300591135312087, \'plot_text\': "a bartender is working at a saloon, serving drinks to customers. after he fills a stereotypically irish man\'s bucket with beer, carrie nation and her followers burst inside. they assault the irish man, pulling his hat over his eyes and then dumping the beer over his head. the group then begin wrecking the bar, smashing the fixtures, mirrors, and breaking the cash register. the bartender then sprays seltzer water in nation\'s face before a group of policemen appear and order everybody to leave."}', '{\'id\': 453300591135312088, \'plot_text\': "the moon, painted with a smiling face hangs over a park at night. a young couple walking past a fence learn on a railing and look up. the moon smiles. they embrace, and the moon\'s smile gets bigger. they then sit down on a bench by a tree. the moon\'s view is blocked, causing him to frown. in the last scene, the man fans the woman with

### Insert data

In [27]:
import pandas as pd
import re
import torch
from transformers import AutoTokenizer, AutoModel

# Load the dataset
df = pd.read_csv('data/wiki_movie_plots_deduped.csv')

# Filter the columns we are interested in
df = df[['Release Year', 'Title', 'Plot']].dropna()

# Load the pre-trained transformer model for embeddings
model_name = "sentence-transformers/all-MPNet-base-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Move model to the device (CPU or MPS for MacBooks)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)

# Function to clean movie plots
def clean_plot(plot):
    plot = re.sub(r'\[.*?\]', '', plot)  # Remove anything in square brackets
    plot = re.sub(r'\s+', ' ', plot)     # Replace multiple spaces with a single space
    plot = plot.strip()                  # Remove leading/trailing spaces
    return plot.lower()                  # Convert to lowercase

# Function to get vector embeddings from a plot
def get_batch_embeddings(batch_texts):
    inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1)  # Get sentence embedding

batch_size = 32  # Adjust according to your memory capacity

for i in range(0, len(df), batch_size):
    batch_texts = [clean_plot(text) for text in df['Plot'].iloc[i:i+batch_size].tolist()]
    batch_titles = df['Title'].iloc[i:i+batch_size].tolist()
    batch_release_year = df['Release Year'].iloc[i:i+batch_size].tolist()

    # Get embeddings
    batch_embeddings = get_batch_embeddings(batch_texts)
    
    # Move embeddings to CPU before pushing to Milvus
    batch_embeddings_cpu = batch_embeddings.cpu().numpy()

    # Prepare records to insert into Milvus
    records = [
        {
            "release_year": release_year,
            "title": title,
            "plot_embedding": embedding.tolist(),  # Convert to list for insertion
            "plot_text": text
        }
        for release_year, title, embedding, text in zip(batch_release_year, batch_titles, batch_embeddings_cpu, batch_texts)
    ]
    
    # Insert into Milvus
    collection.insert(records)

# Flush to Milvus to ensure all data is written
collection.flush()

collection.create_index(field_name="plot_embedding", index_params={"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 100}})

print(f"Inserted {len(df)} records into Milvus.")

sample_plot = "A bartender is working at a saloon, serving drinks to customers. After he fills a stereotypically Irish man's bucket with beer, Carrie Nation and her followers burst inside."
print("Plot Embedding:",  get_batch_embeddings(sample_plot))

# Takes about 23 minutes with Apple silicon

Plot Embedding: tensor([[ 7.2634e-02,  1.7146e-01, -5.8976e-04, -4.3674e-02, -9.8264e-02,
          8.6167e-02, -1.2725e-01, -9.4895e-02, -2.9519e-02,  8.2841e-02,
          4.7391e-02, -3.6577e-02,  5.8282e-02,  8.1094e-02, -1.6779e-02,
          1.1979e-01,  6.7246e-02,  7.4323e-02,  4.7637e-02, -7.6078e-02,
         -5.8983e-02,  3.1512e-02, -4.1934e-02, -7.7589e-02, -2.4669e-02,
          5.4758e-02, -3.5519e-03,  5.1198e-02, -1.4199e-01,  7.2104e-02,
          2.6151e-03, -1.6058e-02, -5.1134e-02, -4.6279e-03,  4.9275e-06,
         -1.3567e-02, -6.6775e-02, -8.0051e-02, -3.8100e-02, -1.2830e-01,
         -1.4797e-01,  6.6426e-02, -2.5168e-02,  5.1333e-02,  5.7209e-02,
          1.7733e-02, -1.2465e-02,  1.4359e-01,  3.3481e-02,  2.0943e-02,
          1.9472e-02, -3.1945e-02, -1.8737e-01, -1.0638e-01,  1.4778e-02,
         -6.3040e-03,  6.7245e-02,  1.5119e-01, -1.4442e-02, -2.4544e-01,
         -3.1666e-02,  8.5816e-02, -8.8085e-02,  6.1015e-02, -1.5713e-01,
          3.8713e-03, 

### Query

In [39]:
def search_similar_plots(plot_text, top_k=5):
    # Convert the input plot to an embedding
    embedding = get_batch_embeddings([plot_text])  # Pass as a list to get embedding
    
    # Move embedding to CPU and convert to numpy array
    embedding_cpu = embedding.cpu().numpy().squeeze()  # Squeeze to remove extra dimensions
    
    # Convert numpy array to a list of floats
    embedding_list = embedding_cpu.tolist()
    
    # Perform similarity search
    results = collection.search(
        data=[embedding_list],  # Ensure it's a list of floats
        anns_field="plot_embedding",
        param={"metric_type": "L2"},  # Use L2 (Euclidean distance) or the metric you set
        limit=top_k
    )
    
    for result in results[0]:
        # Assuming that 'id' was saved in Milvus with the record and matches the DataFrame index
        original_id = result.id
        
        # Safely handle out-of-bounds issues
        if original_id < len(df):
            plot_text = df.iloc[original_id]['Plot']
            title = df.iloc[original_id]['Title']
            release_year = df.iloc[original_id]['Release Year']
            print(f"Title: {title}, Release Year: {release_year}")
            print(f"Original Plot: {plot_text}")
            print(f"Score: {result.distance}")
        else:
            print(f"Result ID {original_id} is out of bounds for the DataFrame.")

# Load the collection before searching
collection.load()

# Example query
search_similar_plots("A group of people go on an adventure to find treasure.")



Result ID 453300591135315260 is out of bounds for the DataFrame.
Result ID 453300591135328003 is out of bounds for the DataFrame.
Result ID 453300591135325278 is out of bounds for the DataFrame.
Result ID 453300591135327129 is out of bounds for the DataFrame.
Result ID 453300591135317514 is out of bounds for the DataFrame.


# TODO notes:

[] WIP: Figure out why 'Inserted 34886 records into Milvus.' and '134164 data/wiki_movie_plots_deduped.csv' records in dataset

[] Use running code outside of development setup in this notwbook from standalone script imports