In [1]:
import os
import inspect
import re
import pandas as pd
from pandas.api.extensions import register_dataframe_accessor
from pandas.api.types import is_string_dtype, is_numeric_dtype, is_datetime64_any_dtype
from sqlalchemy import desc
from pandasql import sqldf, PandaSQLException
import openai


@pd.api.extensions.register_dataframe_accessor("querymachine")
class QueryMachineAccessor:
    """
    Pandas Dataframe accessor to add '.querymachine.description' field to dataframes,
    and manage column summaries used by QueryMachine.
    """
    def __init__(self, pandas_obj: pd.DataFrame) -> None:
        self._validate(pandas_obj)
        self._obj = pandas_obj
        self._description = '{COLUMNS_SUMMARY}'

    @staticmethod
    def _validate(obj):
        pass

    @property
    def description(self) -> str:
        return self._description.format(COLUMNS_SUMMARY=self.columns_summary())

    @description.setter
    def description(self, value: str) -> None:
        """
        Set additional description manually to inform the language engine about this table.
        Use '{COLUMNS_SUMMARY}' to include the default column summary in the description.
        By default, description is set only to this summary. To reset it, set description to None.
        """
        if value is None:
            self._description = '{COLUMNS_SUMMARY}'
        else:
            self._description = value
    
    def columns_summary(self) -> str:
        """
        Returns columns summary of the dataframe, in the "table" format containing
        column names, data types and additional info about columns.
        """
        summary_lines = ['|column name|data type|info|']
        for col_name in self._obj:
            col = self._obj[col_name]
            summary_lines.append(f'|{col_name}|{col.dtype}|{column_info(col)}|')
        return '\n'.join(summary_lines)
        
    
class QueryMachine():
    prompt_format = """
Make sure to join in tables if information from multiple tables is needed for a task.
Task: percentage of True values of column X in table Y
```
SQL query for SQLite:
SELECT (SUM(CASE WHEN X = 'True' THEN 1.0 END) / COUNT(*)) * 100 AS percentage
FROM Y
```
Task: count of rows in table T where date is equal to 11th of August 1993
```
SQL query for SQLite:
SELECT COUNT(*) AS row_count
FROM T
WHERE date(date) = date('1993-08-11')
```
Task: {QUERY}
SQL query for SQLite:
```
"""

    def __init__(self, openai_api_key='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') -> None:
        if openai_api_key:
            openai.api_key = openai_api_key
        else:
            openai.api_key = os.getenv("OPENAI_API_KEY")
        if not openai.api_key:
            raise Exception(
                "OpenAI API key is not set. Either provide it to QueryMachine(openai_api_key='...') "\
                "run openai.api_key('...'), or set it as an env variable OPENAI_API_KEY."
            )
        self.last_prompt = None
        self.last_gpt_response = None

    @staticmethod
    def dataframes_summary(env=None, ignore='^_') -> str:
        """
        Summary of all DataFrames available in the namespace, ignoring those matching the 'ignore' regex.
        """
        summary_lines = ['Tables available in the database, with their additional information, are:']
        table_count = 0
        for name, value in env.items():
            if isinstance(value, pd.DataFrame) and (not ignore or not re.match(ignore, name)):
                summary_lines += [
                    f"\n\nTable name: {name}",
                    value.querymachine.description
                ]
                table_count += 1
        if not table_count:
            return None
        return '\n'.join(summary_lines)

    def query(self, query, env=None, show_query=False):
        """
        Query all Pandas DataFrames available in the namespace with a natural language query.
        To limit the tables used in the query, set the 'env' variable to a dict of tables
        (keys are table names, and values are table objects), or set it to globals() or locals().
        To learn more, check pandasql docs.
        """
        env = env or get_outer_frame_variables()
        query = query[0].lower() + query[1:]
        prompt = self.dataframes_summary(env)
        if not prompt:
            print('No dataframes found')
            return
        prompt += QueryMachine.prompt_format.format(QUERY=query)
        response = openai.Completion.create(
            model="text-davinci-003",
            prompt=prompt,
            temperature=0,
            max_tokens=150,
            top_p=1,
            frequency_penalty=0,
            presence_penalty=0,
            stop=["\n```\n"]
        )
        sql_query = response['choices'][0]['text']
        sql_query = sql_query.replace('```', '')
        self.last_prompt = (prompt, sql_query)
        if show_query:
            print(sql_query)
        try:
            result = sqldf(sql_query, env)
        except PandaSQLException:
            result = None
            print('Unsuccessful. Try rephrasing your query, or add additional table descriptions in df.sloth.description.')
            print('You can inspect the generated prompt and GPT response in sloth.show_last_prompt().')
        return result

    def generate(self, description, columns, n_rows=10):
        """
        Generates a random dataset based on the description and a list of columns.
        """
        rows = []
        while len(rows) < n_rows:
            prompt = f'Fill the table below with {min(n_rows - len(rows) + 5, 30)} random rows about {description}\n\n'
            prompt += f"|{'|'.join(columns)}|\n"
            prompt += f"|{'|'.join(['-'*len(col) for col in columns])}|\n|"
            response = openai.Completion.create(
                model="text-davinci-003",
                prompt=prompt,
                temperature=0.3,
                max_tokens=150,
                top_p=1,
                frequency_penalty=0,
                presence_penalty=0,
            )
            response = '|' + response['choices'][0]['text']
            new_rows = [row[1:-1].split('|') for row in response.split('\n') if not re.match('^[- |]*$', row)]
            new_rows = [row for row in new_rows if len(row) == len(columns)]
            rows += new_rows
            prompt = response + prompt

        df = pd.DataFrame(rows, columns=columns).head(n_rows)
        return df

    
    def _last_prompt(self):
        if self.last_prompt:
            print(self.last_prompt[0])
            print(f'[->]\n{self.last_prompt[1]}')

    def show_last_query(self):
        """Print the SQL query generated in the last sloth.query() call."""
        if self.last_prompt:
            print(self.last_prompt[1])

# Code copied from pandasql
def get_outer_frame_variables():
    """ Get a dict of local and global variables of the first outer frame from another file. """
    cur_filename = inspect.getframeinfo(inspect.currentframe()).filename
    outer_frame = next(f
                       for f in inspect.getouterframes(inspect.currentframe())
                       if f.filename != cur_filename)
    variables = {}
    variables.update(outer_frame.frame.f_globals)
    variables.update(outer_frame.frame.f_locals)
    return variables

def column_info(col):
    """Info about a specific column, different depending on its type"""
    if is_string_dtype(col) or col.dtype == 'category':
        unique = col.unique().tolist()
        summary = 'unique values: ' + ', '.join(map(str, unique[:30]))
        if len(unique) > 30:
            summary += '...'
    elif col.dtype == 'bool':
        summary = f"values: 0, 1"
    elif is_numeric_dtype(col):
        summary = f"min={col.min()}, max={col.max()}"
    elif is_datetime64_any_dtype(col):
        summary = f"first={col.min()}, last={col.max()}"
    else:
        summary = ''
    return 

## Titanic Dataset

In [2]:
import pandas as pd
import seaborn as sns

# Main dataset to show qm capabilities
titanic = sns.load_dataset('titanic')
titanic.head()



Unnamed: 0,survived,pclass,sex,age,sibsp,parch,fare,embarked,class,who,adult_male,deck,embark_town,alive,alone
0,0,3,male,22.0,1,0,7.25,S,Third,man,True,,Southampton,no,False
1,1,1,female,38.0,1,0,71.2833,C,First,woman,False,C,Cherbourg,yes,False
2,1,3,female,26.0,0,0,7.925,S,Third,woman,False,,Southampton,yes,True
3,1,1,female,35.0,1,0,53.1,S,First,woman,False,C,Southampton,yes,False
4,0,3,male,35.0,0,0,8.05,S,Third,man,True,,Southampton,no,True


In [3]:
qm = QueryMachine()

In [4]:
qm.query("titanic에서 살아남은 남성의 나이의 평균을 알려줘", show_query=True)

SELECT AVG(age) AS avg_age
FROM titanic
WHERE survived = 1 AND sex = 'male'



Unnamed: 0,avg_age
0,27.276022


In [5]:
qm.query("성별에 따른 타이타닉 생존 비율을 알려줘", show_query=True)

SELECT sex, (SUM(CASE WHEN survived = 1 THEN 1.0 END) / COUNT(*)) * 100 AS survival_rate
FROM titanic
GROUP BY sex



Unnamed: 0,sex,survival_rate
0,female,74.203822
1,male,18.890815


## Iris Dataset

In [6]:
iris = sns.load_dataset('iris')
iris.head()

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species
0,5.1,3.5,1.4,0.2,setosa
1,4.9,3.0,1.4,0.2,setosa
2,4.7,3.2,1.3,0.2,setosa
3,4.6,3.1,1.5,0.2,setosa
4,5.0,3.6,1.4,0.2,setosa


In [7]:
qm2 = QueryMachine()

In [8]:
qm2.query("sepal_length와 sepal_width의 평균이 가장 높은 종을 알려줘", show_query=True)

SELECT species, AVG(sepal_length) AS avg_sepal_length, AVG(sepal_width) AS avg_sepal_width
FROM iris
GROUP BY species
ORDER BY avg_sepal_length DESC, avg_sepal_width DESC
LIMIT 1



Unnamed: 0,species,avg_sepal_length,avg_sepal_width
0,virginica,6.588,2.974


In [9]:
qm2.query("sepal_length가 큰 상위 5개의 row를 알려줘", show_query=True)

SELECT *
FROM iris
ORDER BY sepal_length DESC
LIMIT 5



Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species
0,7.9,3.8,6.4,2.0,virginica
1,7.7,3.8,6.7,2.2,virginica
2,7.7,2.6,6.9,2.3,virginica
3,7.7,2.8,6.7,2.0,virginica
4,7.7,3.0,6.1,2.3,virginica
