# RAG with FLAN-T5

In [58]:
# Packages
from RAG_Functions import *
import time
from pymilvus import Collection, connections
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

## Embedding Model

In [60]:
# embedding model
embedding_model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
embedding_model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

## Chat Model

In [62]:
# Load model directly
chat_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
chat_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

## Milvus Connection and Collection

In [63]:
connections.connect(host='localhost', port='19530')
collection = Collection("text_embeddings")      # Get an existing collection.
collection.load()

## Perform Chat

In [64]:
# Chat with model
input_text = input()

# Get embedding of input
input_embedding = get_mixedbread_of_query(embedding_model, input_text)

# Start timing query
start_time = time.time()

# Top5 sentences
top5_sentences = return_top_5_sentences(collection, input_embedding)

# End timing query
end_time = time.time()

print(top5_sentences)

['Reddit: We may share information about you with your consent or at your direction.', 'Reddit: You have choices about how to protect and limit the collection, use, and sharing of information about you when you use the Services.', 'Reddit: We may share information if we believe your actions are inconsistent with our User Agreement, rules, or other Reddit policies, or to protect the rights, property, and safety of ourselves and others.', 'Reddit: You may also provide other account information, like an email address, bio, or profile picture.', 'Reddit: We may share information about you that has been aggregated or anonymized such that it cannot reasonably be used to identify you.']


In [65]:
# Construct prompt
prompt_lines = ["Context:"] + top5_sentences + ["User Query:\n" + input_text]
prompt = "\n".join(prompt_lines)
print(prompt)

Context:
Reddit: We may share information about you with your consent or at your direction.
Reddit: You have choices about how to protect and limit the collection, use, and sharing of information about you when you use the Services.
Reddit: We may share information if we believe your actions are inconsistent with our User Agreement, rules, or other Reddit policies, or to protect the rights, property, and safety of ourselves and others.
Reddit: You may also provide other account information, like an email address, bio, or profile picture.
Reddit: We may share information about you that has been aggregated or anonymized such that it cannot reasonably be used to identify you.
User Query:
What does Reddit say about using my personal information?


In [66]:
# Tokenize
input_ids = chat_tokenizer(prompt, return_tensors="pt").input_ids

# Generate
outputs = chat_model.generate(input_ids, max_new_tokens = 100)
print(chat_tokenizer.decode(outputs[0]))

<pad> We may share information about you that has been aggregated or anonymized such that it cannot reasonably be used to identify you.</s>
