# RAG with Gemini Pro

Note: Place your Google API key in the Google API Key folder of the directory this code is in.

In [1]:
# Packages
from RAG_Functions import *
import time

## Embedding Model

In [2]:
# Packages
from sentence_transformers import SentenceTransformer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# embedding model
embedding_model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
embedding_model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

## Chat Model

In [4]:
import google.generativeai as genai
import os

# Load API key from './Google API Key/data-engineering-project.txt'
with open(os.path.expanduser('./Google/data-engineering-project.txt')) as f:
    GOOGLE_API_KEY = f.read().strip()

genai.configure(api_key=GOOGLE_API_KEY)

chat_model = genai.GenerativeModel('gemini-1.0-pro-latest')
print(chat_model)

genai.GenerativeModel(
    model_name='models/gemini-1.0-pro-latest',
    generation_config={},
    safety_settings={},
    tools=None,
    system_instruction=None,
)


## Milvus Connection

In [5]:
from pymilvus import Collection, connections
connections.connect(host='localhost', port='19530')
collection = Collection("text_embeddings")      # Get an existing collection.
# index_params = {
#     "metric_type": "COSINE",
#     "index_type": "FLAT"#,
#     #"params": {"nlist": 128}
# }
# collection.drop_index()
# collection.create_index(field_name="embedding", index_params=index_params)
# "metric_type": "L2",
#     "index_type": "IVF_FLAT",
#     "params": {"nlist": 128}
collection.load()

## Perform Chat

In [6]:
# Chat with model
input_text = input()

# Get embedding of input
input_embedding = get_mixedbread_of_query(embedding_model, input_text)

# Start timing query
start_time = time.time()

# Top5 sentences
top5_sentences, documents_cited, milvus_query_time = return_top_5_sentences(collection, input_embedding)

# End timing query
end_time = time.time()

# query time
query_time = end_time - start_time

print(top5_sentences)

['AppleServices: Apple may disclose any information we have about you (including your identity) if we determine that such disclosure is necessary in connection with any investigation or complaint regarding your use of the Site, or to identify, contact or bring legal action against someone who may be causing injury to or interference with (either intentionally or unintentionally) Apple   s rights or property, or the rights or property of visitors to or users of the Site, including Apple   s customers.', 'AppleServices: Apple may disclose any information we have about you (including your identity) if we determine that such disclosure is necessary in connection with any investigation or complaint regarding your use of the Site, or to identify, contact or bring legal action against someone who may be causing injury to or interference with (either intentionally or unintentionally) Apple   s rights or property, or the rights or property of visitors to or users of the Site, including Apple   

In [7]:
# Construct prompt
prompt_lines = ["Context That May Be Helpful (You May Disregard if Not Helpful):"] + top5_sentences + ["User Query:\n" + input_text]
prompt = "\n".join(prompt_lines)
print(prompt)

#Context:
#Document Name: <document_filename_1>
#Information: <sentence_1>
#Document Name: <document_filename_2>
#Information: <sentence_2>
#Document Name: <document_filename_3>
#Information: <sentence_3>
#Document Name: <document_filename_4>
#Information: <sentence_4>
#Document Name: <document_filename_5>
#Information: <sentence_5>
#<user_query>

Context That May Be Helpful (You May Disregard if Not Helpful):
AppleServices: Apple may disclose any information we have about you (including your identity) if we determine that such disclosure is necessary in connection with any investigation or complaint regarding your use of the Site, or to identify, contact or bring legal action against someone who may be causing injury to or interference with (either intentionally or unintentionally) Apple   s rights or property, or the rights or property of visitors to or users of the Site, including Apple   s customers.
AppleServices: Apple may disclose any information we have about you (including your identity) if we determine that such disclosure is necessary in connection with any investigation or complaint regarding your use of the Site, or to identify, contact or bring legal action against someone who may be causing injury to or interference with (either intentionally or unintentionally) Apple   s rights or property, or the rights or prope

In [8]:
# Get response
response = chat_model.generate_content(prompt)

# from IPython.display import Markdown
# import textwrap
# def to_markdown(text):
#   text = text.replace('•', '  *')
#   return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))
# to_markdown(response.text)
#print(response.text)

In [9]:
# Format response for user
response_for_user = response.text + "\nDocuments Cited: " + ', '.join(documents_cited) + "\nMilvus Query Time: " + str(round(milvus_query_time, 2)) + ' seconds'
print(response_for_user)

No, Apple does not sell your personal information.
Documents Cited: AppleServices_WebsiteTermsofService.txt, AppleServices_PrivacyPolicy.txt, AppleServices_ApplicationBasedServices.txt
Milvus Query Time: 0.38 seconds


In [10]:
#print(top5_sentences.get('sentence'))

In [11]:
# for hits in top5_sentences:
#     # Get ids
#     print(hits.ids)
    
#     # Get distances
#     print(hits.distances)
    
#     for hit in hits:
#         # Get id
#         print(hit.id)
        
#         # Get distance
#         print(hit.distance) # hit.score
        
#         # Get vector
#         #hit.vector
        
#         # Get output field
#         print(hit.get("sentence"))

In [12]:

# Tokenize
# input_ids = chat_tokenizer(input_text, return_tensors="pt").input_ids

# outputs = chat_model.generate(input_ids)
# print(chat_tokenizer.decode(outputs[0]))