In [78]:
import os
import openai
import pandas as pd


from getpass import getpass
from langchain.llms import OpenAI
from pyspark.sql import SparkSession
from langchain.agents import create_spark_dataframe_agent

In [7]:
pd.options.display.max_columns = None
os.environ['OPENAI_API_KEY'] = openai.api_key = getpass("Insert your key:")

Insert your key: ········


In [79]:
spark = SparkSession.builder.getOrCreate()
csv_file_path = "data/hr_ibm.csv"
df = spark.read.csv(csv_file_path, header=True, inferSchema=True)
df.limit(3).toPandas()

Unnamed: 0,Age,Attrition,BusinessTravel,DailyRate,Department,DistanceFromHome,Education,EducationField,EmployeeCount,EmployeeNumber,EnvironmentSatisfaction,Gender,HourlyRate,JobInvolvement,JobLevel,JobRole,JobSatisfaction,MaritalStatus,MonthlyIncome,MonthlyRate,NumCompaniesWorked,Over18,OverTime,PercentSalaryHike,PerformanceRating,RelationshipSatisfaction,StandardHours,StockOptionLevel,TotalWorkingYears,TrainingTimesLastYear,WorkLifeBalance,YearsAtCompany,YearsInCurrentRole,YearsSinceLastPromotion,YearsWithCurrManager
0,41,Yes,Travel_Rarely,1102,Sales,1,2,Life Sciences,1,1,2,Female,94,3,2,Sales Executive,4,Single,5993,19479,8,Y,Yes,11,3,1,80,0,8,0,1,6,4,0,5
1,49,No,Travel_Frequently,279,Research & Development,8,1,Life Sciences,1,2,3,Male,61,2,2,Research Scientist,2,Married,5130,24907,1,Y,No,23,4,4,80,1,10,3,3,10,7,1,7
2,37,Yes,Travel_Rarely,1373,Research & Development,2,2,Other,1,4,4,Male,92,2,1,Laboratory Technician,3,Single,2090,2396,6,Y,Yes,15,3,2,80,0,7,3,3,0,0,0,0


In [80]:
MODEL_ID = "text-davinci-003"

model = OpenAI(temperature=0, model=MODEL_ID)

agent = create_spark_dataframe_agent(
    llm=model,
    df=df,
    verbose=True
)
agent.run("how many rows are there?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mThought: I need to find out how many rows are in the dataframe
Action: python_repl_ast
Action Input: df.count()[0m
Observation: [36;1m[1;3m1470[0m
Thought:[32;1m[1;3m I now know the final answer
Final Answer: There are 1470 rows in the dataframe.[0m

[1m> Finished chain.[0m


'There are 1470 rows in the dataframe.'

## Multiple datasets

In [146]:
from typing import Any, Dict, List, Optional

from langchain.llms.base import BaseLLM
from langchain.callbacks.base import BaseCallbackManager
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.spark.base import _validate_spark_connect_df, _validate_spark_df
from langchain.tools.python.tool import PythonAstREPLTool
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain


PREFIX = """
You are working with a list of spark dataframes in Python. The name of the dataframes are `{dataframe_names}`.
You should use the tools below to answer the question posed of you:"""

SUFFIX = """
This is the result of `print(df.first())`:
{dataframe_names}

Begin!
Question: {input}
{agent_scratchpad}"""


def create_multiple_spark_dataframes_agent(
    llm: BaseLLM,
    dataframes_dict: Dict[str, pyspark.sql.DataFrame],
    callback_manager: Optional[BaseCallbackManager] = None,
    prefix: str = PREFIX,
    suffix: str = SUFFIX,
    input_variables: Optional[List[str]] = None,
    verbose: bool = False,
    return_intermediate_steps: bool = False,
    max_iterations: Optional[int] = 15,
    max_execution_time: Optional[float] = None,
    early_stopping_method: str = "force",
    agent_executor_kwargs: Optional[Dict[str, Any]] = None,
    **kwargs: Dict[str, Any],
) -> AgentExecutor:
    """Construct a spark agent from an LLM and dataframe."""

    valid_dataframes = all([_validate_spark_df(df) for df in dataframes_dict.values()])
    valid_spark_connect = all([_validate_spark_connect_df(df) for df in dataframes_dict.values()])

    if not valid_dataframes and not valid_spark_connect:
        raise ValueError("Spark is not installed. run `pip install pyspark`.")

    prefix = PREFIX.format(dataframe_names="`, `".join([f"{key}" for key in dataframes_dict.keys()]))
    
    suffix = SUFFIX.format(
        dataframe_names=", ".join([f"{{{key}}}" for key in dataframes_dict.keys()]),
        input="{input}",
        agent_scratchpad="{agent_scratchpad}",
    )

    if input_variables is None:
        input_variables = list(dataframes_dict.keys()) + ["input", "agent_scratchpad"]
    
        
    tools = [PythonAstREPLTool(locals=dataframes_dict)]

    prompt = ZeroShotAgent.create_prompt(
        tools, prefix=prefix, suffix=suffix, input_variables=input_variables
    )

    promp_opts = {key: str(value.first()) for key, value in dataframes_dict.items()}
    partial_prompt = prompt.partial(**promp_opts)
    llm_chain = LLMChain(
        llm=llm,
        prompt=partial_prompt,
        callback_manager=callback_manager,
    )
    tool_names = [tool.name for tool in tools]
    agent = ZeroShotAgent(
        llm_chain=llm_chain,
        allowed_tools=tool_names,
        callback_manager=callback_manager,
        **kwargs,
    )
    return AgentExecutor.from_agent_and_tools(
        agent=agent,
        tools=tools,
        callback_manager=callback_manager,
        verbose=verbose,
        return_intermediate_steps=return_intermediate_steps,
        max_iterations=max_iterations,
        max_execution_time=max_execution_time,
        early_stopping_method=early_stopping_method,
        **(agent_executor_kwargs or {}),
    )

In [147]:
df_empresa_1 = spark.createDataFrame([(1, 'a'), (2,'b')], ['cnpj', 'classe'])
df_empresa_2 = spark.createDataFrame([(1, 'x'), (2,'y'), (3, 'z')], ['cnpj', 'value'])

my_dfs_dict = {'df_empresa_1': df_empresa_1, 'df_empresa_2': df_empresa_2}

model = OpenAI(temperature=0, model="text-davinci-003")

agent = create_multiple_spark_dataframes_agent(
    llm=model,
    dataframes_dict=my_dfs_dict,
    verbose=True
)

In [148]:
agent.run("how many rows are in df_empresa_1?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mThought: I need to find out how many rows are in the dataframe
Action: python_repl_ast
Action Input: df_empresa_1.count()[0m
Observation: [36;1m[1;3m2[0m
Thought:[32;1m[1;3m I now know the final answer
Final Answer: There are 2 rows in df_empresa_1.[0m

[1m> Finished chain.[0m


'There are 2 rows in df_empresa_1.'

In [149]:
agent.run("how many rows are in df_empresa_2?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mThought: I need to find out how many rows are in the dataframe
Action: python_repl_ast
Action Input: len(df_empresa_2.collect())[0m
Observation: [36;1m[1;3m3[0m
Thought:[32;1m[1;3m I now know the final answer
Final Answer: There are 3 rows in df_empresa_2.[0m

[1m> Finished chain.[0m


'There are 3 rows in df_empresa_2.'

In [152]:
agent.run("faça o join das duas tabelas e retorne o dataframe")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mThought: preciso juntar os dois dataframes
Action: python_repl_ast
Action Input: df_empresa_1.join(df_empresa_2, on='cnpj')[0m
Observation: [36;1m[1;3mDataFrame[cnpj: bigint, classe: string, value: string][0m
Thought:[32;1m[1;3m Agora eu tenho o dataframe final
Final Answer: O dataframe final é DataFrame[cnpj: bigint, classe: string, value: string][0m

[1m> Finished chain.[0m


'O dataframe final é DataFrame[cnpj: bigint, classe: string, value: string]'

In [33]:
print(df_empresa_1.__doc__)

A distributed collection of data grouped into named columns.

    .. versionadded:: 1.3.0

    .. versionchanged:: 3.4.0
        Supports Spark Connect.

    Examples
    --------
    A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
    and can be created using various functions in :class:`SparkSession`:

    >>> people = spark.createDataFrame([
    ...     {"deptId": 1, "age": 40, "name": "Hyukjin Kwon", "gender": "M", "salary": 50},
    ...     {"deptId": 1, "age": 50, "name": "Takuya Ueshin", "gender": "M", "salary": 100},
    ...     {"deptId": 2, "age": 60, "name": "Xinrong Meng", "gender": "F", "salary": 150},
    ...     {"deptId": 3, "age": 20, "name": "Haejoon Lee", "gender": "M", "salary": 200}
    ... ])

    Once created, it can be manipulated using the various domain-specific-language
    (DSL) functions defined in: :class:`DataFrame`, :class:`Column`.

    To select a column from the :class:`DataFrame`, use the apply method:

    >>> age_col = people

In [27]:
dir(df_empresa_1)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getitem__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_collect_as_arrow',
 '_jcols',
 '_jdf',
 '_jmap',
 '_joinAsOf',
 '_jseq',
 '_lazy_rdd',
 '_repr_html_',
 '_sc',
 '_schema',
 '_session',
 '_sort_cols',
 '_sql_ctx',
 '_support_repr_html',
 '_to_corrected_pandas_type',
 'agg',
 'alias',
 'approxQuantile',
 'cache',
 'checkpoint',
 'coalesce',
 'colRegex',
 'collect',
 'columns',
 'corr',
 'count',
 'cov',
 'createGlobalTempView',
 'createOrReplaceGlobalTempView',
 'createOrReplaceTempView',
 'createTempView',
 'crossJoin',
 'crosstab',
 'cube',
 'describe',
 'distinct',
 'drop',
 'dropDuplicates',
 'drop_duplicates',
 'drop