In [1]:
!pip install -q google-generativeai
!pip install -q langchain-google-genai
!pip install -q langchain
!pip show google-generativeai
!pip show langchain-google-genai
!pip show langchain


Name: google-generativeai
Version: 0.3.2
Summary: Google Generative AI High level API client library and tools.
Home-page: https://github.com/google/generative-ai-python
Author: Google LLC
Author-email: googleapis-packages@google.com
License: Apache 2.0
Location: /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages
Requires: google-ai-generativelanguage, google-api-core, google-auth, protobuf, tqdm, typing-extensions
Required-by: langchain-google-genai
Name: langchain-google-genai
Version: 0.0.9
Summary: An integration package connecting Google's genai package and LangChain
Home-page: https://github.com/langchain-ai/langchain
Author: 
Author-email: 
License: MIT
Location: /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages
Requires: google-generativeai, langchain-core
Required-by: 
Name: langchain
Version: 0.1.10
Summary: Building applications with LLMs through composability
Home-page: https://github.com/langchain-ai/langchain
Auth

In [25]:
from git import Repo
from langchain_community.document_loaders.generic import GenericLoader
from langchain_community.document_loaders.parsers import LanguageParser
from langchain_text_splitters import Language

In [26]:
# Clone
repo_path = "/Users/swaramenon/Documents/gemai"
repo = Repo.clone_from("https://github.com/redapt/pyspark-s3-parquet-example", to_path=repo_path)

In [27]:
# Load
loader = GenericLoader.from_filesystem(
    repo_path + "/pyspark-scripts",
    glob="**/*",
    suffixes=[".py"],
    exclude=["**/non-utf8-encoding.py"],
    parser=LanguageParser(language=Language.PYTHON, parser_threshold=500),
)
documents = loader.load()
len(documents)

2

In [28]:
from langchain_text_splitters import RecursiveCharacterTextSplitter

python_splitter = RecursiveCharacterTextSplitter.from_language(
    language=Language.PYTHON, chunk_size=2000, chunk_overlap=200
)
texts = python_splitter.split_documents(documents)
len(texts)

3

In [30]:
import dotenv

dotenv.load_dotenv()

True

In [31]:
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings

db = Chroma.from_documents(texts, OpenAIEmbeddings(disallowed_special=()))
retriever = db.as_retriever(
    search_type="mmr",  # Also test "similarity"
    search_kwargs={"k": 8},
)

In [34]:
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
llm= ChatGoogleGenerativeAI(model='gemini-pro', convert_system_message_to_human=True)
prompt = ChatPromptTemplate.from_messages(
    [
        ("placeholder", "{chat_history}"),
        ("user", "{input}"),
        (
            "user",
            "Given the above conversation, generate a search query to look up to get information relevant to the conversation",
        ),
    ]
)
retriever_chain = create_history_aware_retriever(llm, retriever, prompt)
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Answer the user's questions based on the below context:\n\n{context}",
        ),
        ("placeholder", "{chat_history}"),
        ("user", "{input}"),
    ]
)
document_chain = create_stuff_documents_chain(llm, prompt)

qa = create_retrieval_chain(retriever_chain, document_chain)


In [35]:
question = "explain the repo"
result = qa.invoke({"input": question})
result["answer"]

Number of requested results 20 is greater than number of elements in index 3, updating n_results = 3


"1. **How does the code load the .parquet file into Spark SQL?**\n   The code uses the `read.parquet` method of the SQLContext to load the .parquet file into a DataFrame. The DataFrame is then registered as a temporary table using the `registerTempTable` method.\n\n2. **What SQL query is used to select all rows from the temporary table?**\n   The SQL query `SELECT * FROM parquetFile` is used to select all rows from the temporary table.\n\n3. **What SQL query is used to filter the temporary table for rows where the `name` column contains the string `UNITED`?**\n   The SQL query `SELECT name FROM parquetFile WHERE name LIKE '%UNITED%'` is used to filter the temporary table for rows where the `name` column contains the string `UNITED`."

In [36]:
question = "what is the input to the read"
result = qa.invoke({"input": question})
result["answer"]

Number of requested results 20 is greater than number of elements in index 3, updating n_results = 3


'The input to the read method is a file path. In this case, the file path is "s3://jon-parquet-format/nation.plain.parquet" for the AWS EMR Spark service and "../test-data/nation.plain.parquet" for the local instance of Spark.'

In [37]:
question = "what is the output to the write"
result = qa.invoke({"input": question})
result["answer"]

Number of requested results 20 is greater than number of elements in index 3, updating n_results = 3


"```\nSuccessfully imported Spark Modules -- `SparkContext, SQLContext`\n-------------------------------------------------------------------------------\nAll Nations and Comments -- `SELECT * FROM parquetFile`\n-------------------------------------------------------------------------------\nCountry: REGION 1 Ipsum Comment: Our main region\nCountry: MIDDLE EAST Ipsum Comment: Our second main region\nCountry: ASIA Ipsum Comment: Our largest region\nCountry: EUROPE Ipsum Comment: Our EU region\nCountry: AFRICA Ipsum Comment: Our African region\nCountry: UNITED STATES Ipsum Comment: Our home country\nCountry: CANADA Ipsum Comment: Our subsidiary\nCountry: UNITED KINGDOM Ipsum Comment: Our UK office\n-------------------------------------------------------------------------------\nNations Filtered -- `SELECT name FROM parquetFile WHERE name LIKE '%UNITED%'`\n-------------------------------------------------------------------------------\nCountry: UNITED STATES\nCountry: UNITED KINGDOM\n```"

In [38]:
question = "Is there any data transformation?"
result = qa.invoke({"input": question})
result["answer"]

Number of requested results 20 is greater than number of elements in index 3, updating n_results = 3


'Yes, the data is transformed from a parquet file into a DataFrame using the `read.parquet` method. The DataFrame is then converted into a temporary table using the `registerTempTable` method. This allows SQL queries to be run against the data using the `sql` method.'