# RAG / HyDE Example with Mistral Instruct 7b and Milvus DB

### Imports dependencies and Mistral LLM

In [14]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

home_dir = os.getenv("HOME")
model_name = f'{home_dir}/ext-gits/Mistral-7B-Instruct-v0.3'
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map={"": device},
    trust_remote_code=True
)

# Test LLM

prompt = "What happens in the movie Inception?"

inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)

with torch.no_grad():
    outputs = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_length=50,
        num_return_sequences=1,
        do_sample=True,
        top_p=0.95,
        top_k=50,
        pad_token_id=tokenizer.pad_token_id
    )

response = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(response)


Using device: mps


Loading checkpoint shards:  33%|███▎      | 1/3 [00:14<00:29, 14.77s/it]


RuntimeError: MPS backend out of memory (MPS allocated: 36.23 GB, other allocations: 1.88 MB, max allowed: 36.27 GB). Tried to allocate 112.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

### Check data

In [2]:
import pandas as pd
import csv

df = None
# Load the dataset
df = pd.read_csv('data/wiki_movie_plots_deduped.csv')

print(df.shape[0])

empty_rows = df[df.isnull().all(axis=1)]
print(f"Number of completely empty rows: {empty_rows.shape[0]}")

with open('data/wiki_movie_plots_deduped.csv', 'r') as file:
    reader = csv.reader(file)
    total_lines = sum(1 for row in reader)

print(f"Total lines in the file: {total_lines}")

# Ensure there are no NaNs in the 'Plot' column
df = df[['Plot']].dropna()

# Find the longest plot and its length
longest_plot = df['Plot'].apply(len).idxmax()  # Find the index of the longest plot
longest_plot_text = df['Plot'].iloc[longest_plot]  # Get the longest plot text
longest_plot_length = len(longest_plot_text)  # Get the length of the longest plot text

print(f"The longest plot is at index {longest_plot} with length {longest_plot_length} characters.")
print(f"Longest Plot: {longest_plot_text[:500]}...")

non_nan_rows_count = df.dropna().shape[0]
print("non_nan_rows_coun", non_nan_rows_count) # 34886

# need to change logic to accomodate NaNs

# Get the count of NaN values for each column
nan_counts = df.isna().sum()
nan_counts = nan_counts[nan_counts > 0]
print("nan_counts", nan_counts)


print(df.head())

row_count = df.shape[0]
row_count

df = df.drop(df.index)

row_count = df.shape[0]
row_count

34886
Number of completely empty rows: 0
Total lines in the file: 34887
The longest plot is at index 26064 with length 36773 characters.
Longest Plot: After a brief introduction to some of the main characters of the story, the beginning sees a group of Rishis, led by Vishvamitra, performing a Yajna in a forest not far from Ayodhya, the Capital of the Kingdom of Kosala. This Yajna, like several before it, is interrupted and destroyed by a group of flying demons led by Ravana's Mama(Uncle/Mother's Brother) Maricha. After seeing yet another Yajna destroyed, a despondent Vishvamitra appeals to Lord Vishnu for salvation. Vishnu appears in a spiritu...
non_nan_rows_coun 34886
nan_counts Series([], dtype: int64)
                                                Plot
0  A bartender is working at a saloon, serving dr...
1  The moon, painted with a smiling face hangs ov...
2  The film, just over a minute long, is composed...
3  Lasting just 61 seconds and consisting of two ...
4  The earliest know

0

### Load DB

In [3]:
# run milvis locally
# run: docker-compose up -d

# data is ..data/wiki_movie_plots_deduped.csv is from
# https://www.kaggle.com/datasets/jrobischon/wikipedia-movie-plots

from pymilvus import connections, CollectionSchema, DataType, FieldSchema, Collection, utility

connections.connect(alias="default", host="127.0.0.1", port="19530")

print("Milvus connected:", connections.has_connection(alias="default"))

collection_name = "wiki_movie_plots"
dim = 768  # Dimensions for the vector embeddings

# Check if the collection exists and drop it if it does
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!   WILL DROP
# if utility.has_collection(collection_name):
#     collection = Collection(collection_name)
#     collection.drop()
#     print(f"Collection '{collection_name}' dropped.")


fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),  # Primary key field
    FieldSchema(name="plot_embedding", dtype=DataType.FLOAT_VECTOR, dim=dim),
    FieldSchema(name="plot_text", dtype=DataType.VARCHAR, max_length=40000),  # To store original plot text
    FieldSchema(name="release_year", dtype=DataType.INT64),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=255),
]

# Create the schema and collection
schema = CollectionSchema(fields, description="Wikipedia Movie Plots with vector embeddings and original plot text")
collection = Collection(name=collection_name, schema=schema)

print("Milvus collection schema created successfully!")


Milvus connected: True
Collection 'wiki_movie_plots' dropped.
Milvus collection schema created successfully!


I0000 00:00:1729629804.462455 6676211 fork_posix.cc:77] Other threads are currently calling into gRPC, skipping fork() handlers
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


### Helper Functions

In [3]:
def get_batch_embeddings(batch_texts, model, tokenizer, device, remove_token_type_ids=False, mean_pooling=True):
    """
    Generalized function to get embeddings for different use cases.
    
    Parameters:
    - batch_texts: List of input texts to be embedded.
    - model: The pre-trained model used for generating embeddings.
    - tokenizer: The tokenizer corresponding to the model.
    - device: The device to run the model on ('cpu', 'cuda', 'mps').
    - remove_token_type_ids: Set to True if 'token_type_ids' should be removed from input dict (e.g., for causal LMs like GPT/Mistral).
    - mean_pooling: Set to True if you want to perform mean pooling over the output embeddings.
    
    Returns:
    - embeddings: The embeddings for the input batch.
    """
    inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    if remove_token_type_ids and 'token_type_ids' in inputs:
        del inputs['token_type_ids']
    with torch.no_grad():
        outputs = model(**inputs)
    if mean_pooling:
        return outputs.last_hidden_state.mean(dim=1)
    else:
        return outputs.logits.mean(dim=1)


import re

def clean_plot(plot):
    plot = re.sub(r'\[.*?\]', '', plot)  # Remove anything in square brackets
    plot = re.sub(r'\s+', ' ', plot)     # Replace multiple spaces with a single space
    plot = plot.strip()                  # Remove leading/trailing spaces
    return plot.lower()                  # Convert to lowercase




from pymilvus import Collection

def search_similar_plots(plot_text, *args, top_k=5, **kwargs):

    collection = Collection("wiki_movie_plots")

    embedding = get_batch_embeddings([plot_text], *args, **kwargs)
    embedding_cpu = embedding.cpu().numpy().squeeze()
    embedding_list = embedding_cpu.tolist()
    
    results = collection.search(
        data=[embedding_list],
        anns_field="plot_embedding",
        param={"metric_type": "L2"},
        limit=top_k,
        output_fields=["id"]
    )

    # wtf
    csv_file_path = 'data/wiki_movie_plots_deduped.csv'
    df = pd.read_csv(csv_file_path)
    df = df[['Release Year', 'Title', 'Plot']].dropna()
    
    response = []
    for result in results[0]:
        original_id = result.entity.get("id")
        
        if original_id < len(df):
            plot_text = df.iloc[original_id]['Plot']
            title = df.iloc[original_id]['Title']
            release_year = df.iloc[original_id]['Release Year']
            response.append({
                "title": title,
                "release_year": release_year,
                "plot_text": plot_text,
                "score": result.distance
            })
        else:
            print(f"Result ID {original_id} is out of bounds for the DataFrame.")

    return response


import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from pymilvus import Collection

def insert_embeddings_to_milvus(
    csv_file_path, 
    collection, 
    embeddings_model_name, 
    batch_size=32, 
    device=None
):
    
    df = pd.read_csv(csv_file_path)
    df = df[['Release Year', 'Title', 'Plot']].dropna()

    tokenizer = AutoTokenizer.from_pretrained(embeddings_model_name)
    model = AutoModel.from_pretrained(embeddings_model_name)

    device = device or torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model.to(device)

    for i in range(0, len(df), batch_size):
        batch_texts = [clean_plot(text) for text in df['Plot'].iloc[i:i+batch_size].tolist()]
        batch_titles = df['Title'].iloc[i:i+batch_size].tolist()
        batch_release_year = df['Release Year'].iloc[i:i+batch_size].tolist()
        batch_ids = df.index[i:i+batch_size].tolist() 
        batch_embeddings = get_batch_embeddings(batch_texts, model, tokenizer, device, mean_pooling=True)
        batch_embeddings_cpu = batch_embeddings.cpu().numpy()

        # Prepare records
        records = [
            {
                "id": id_value,
                "release_year": release_year,
                "title": title,
                "plot_embedding": embedding.tolist(),  # Convert to list for insertion
                "plot_text": text
            }
            for id_value, release_year, title, embedding, text in zip(batch_ids, batch_release_year, batch_titles, batch_embeddings_cpu, batch_texts)
        ]
        collection.insert(records)

    # Flush and create index
    collection.flush()
    collection.create_index(field_name="plot_embedding", index_params={"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 100}})
    collection.load()

    # Clear memory
    torch.mps.empty_cache()


### Insert data

In [6]:
from pymilvus import Collection



# csv_file_path = 'data/wiki_movie_plots_deduped.csv'
# embeddings_model_name = "sentence-transformers/bert-base-nli-mean-tokens"
# collection = Collection("wiki_movie_plots")

# insert_embeddings_to_milvus(csv_file_path, collection, embeddings_model_name)

print('done')

done


### Doublecheck ingestion row count

In [7]:
## Run this to connect to db if already performed data insert
## ie: you run out or memory and restart the kernel and need to reconnect to db

from pymilvus import connections, Collection

# Connect to Milvus
connections.connect(alias="default", host="localhost", port="19530")
collection = Collection("wiki_movie_plots")
collection.load()

num_records = collection.num_entities
print(f"Number of records in the collection: {num_records}")

if num_records > 0:
    sample_records = collection.query(expr="id >= 1", output_fields=["id", "plot_text"], limit=5)
    print("Sample records:", sample_records)

Number of records in the collection: 34886
Sample records: data: ['{\'id\': 1, \'plot_text\': "the moon, painted with a smiling face hangs over a park at night. a young couple walking past a fence learn on a railing and look up. the moon smiles. they embrace, and the moon\'s smile gets bigger. they then sit down on a bench by a tree. the moon\'s view is blocked, causing him to frown. in the last scene, the man fans the woman with his hat because the moon has left the sky and is perched over her shoulder to see everything better."}', "{'id': 2, 'plot_text': 'the film, just over a minute long, is composed of two shots. in the first, a girl sits at the base of an altar or tomb, her face hidden from the camera. at the center of the altar, a viewing portal displays the portraits of three u.s. presidents—abraham lincoln, james a. garfield, and william mckinley—each victims of assassination. in the second shot, which runs just over eight seconds long, an assassin kneels feet of lady justice.'

# RAG function

### Load models

In [8]:

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
import os, torch

home_dir = os.getenv("HOME")

rag_model_name = f'{home_dir}/ext-gits/Mistral-7B-Instruct-v0.3'
rag_tokenizer = AutoTokenizer.from_pretrained(rag_model_name)
rag_device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
rag_model = AutoModelForCausalLM.from_pretrained(
    rag_model_name,
    torch_dtype=torch.float16,
    device_map={"": device},
    trust_remote_code=True
    )

if rag_tokenizer.pad_token is None:
    rag_tokenizer.add_special_tokens({'pad_token':rag_tokenizer.eos_token})

rag_model.to(device)


db_model_name = "sentence-transformers/bert-base-nli-mean-tokens"
db_tokenizer = AutoTokenizer.from_pretrained(db_model_name)
db_model = AutoModel.from_pretrained(db_model_name)

db_device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
db_model.to(device)

print('got loaded')


Loading checkpoint shards: 100%|██████████| 3/3 [00:18<00:00,  6.23s/it]


got loaded


### Run Query

In [13]:
def plots_query(query, *args, **kwargs):
    search_results = search_similar_plots(query, top_k=2, *args, **kwargs)
    context_texts = [res['plot_text'] for res in search_results]
    return query + "\n\n".join(context_texts)

def generate_response(combined_texts, *args, max_new_tokens=50):
    model, tokenizer, device = args

    new_prompt = f'Givent the following descriptions {combined_texts}, create a new plotline'

    inputs = tokenizer(new_prompt, return_tensors="pt", padding=True).to(device)

    print('')
    print('')
    print('outcome:')
    print('')
    print('')

    with torch.no_grad():
        outputs = model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            # max_length=500,
            num_return_sequences=1,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_p=0.95,
            top_k=50,
            pad_token_id=tokenizer.pad_token_id
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response


combined_query = plots_query(
    "a flight turns dangerous when the passengers find a ton of snakes on board",
    db_model, 
    db_tokenizer, 
    db_device, 
    mean_pooling=True,
    remove_token_type_ids=True
    )

outcome = generate_response(
    combined_query,
    rag_model, 
    rag_tokenizer, 
    rag_device, 
    )

print(outcome)





outcome:


As Pettis describes developments in the streamroom to Mancini, the scene is shown of three men and three women meeting in the steamroom of a luxury hotel as part of an online dating promotion, then being locked in together. When they discover that they have been locked in, they react badly: Frank (Quinn Duffy) becomes abusive to Jessie (Eve Mauro), and is killed in her defense by openly neurotic Margaret (Cordelia Reynolds). Jessie is killed with a nail gun by an unseen assailant when she pokes her head through the small window in the steamroom door; Christopher (Patrick Muldoon) is injured in the hand with a nail as the window is boarded over from outside. Margaret becomes agitated and commits suicide. Grant (Eric Roberts) is bludgeoned by Catherine (Megan Brown) after he accuses her and Christopher of being allies of the perpetrators and repeatedly holds her head underwater.
Mancini's call to Pettis' psychiatrist finally brings staff from the local state psychiatric hosp

## TODOS:

[] Refactor rag implementation

## 

[] Adjust mistral params

##

[] Create HyDE
