In [1]:
import openai
import pandas as pd
import json
from hana_ml.dataframe import ConnectionContext
from dotenv import load_dotenv
import os

In [2]:
load_dotenv()

# Connection constants
dbHost = os.getenv('DB_HOST')
dbPort = 443
dbUser = os.getenv('DB_USER')
dbPwd = os.getenv('DB_PASS')
dbTable = "COUNTRIES_BY_GDP"

In [3]:
# Connect to OpenAI API
openai.organization = os.getenv('OPENAI_ORGANIZATION')
openai.api_key = os.getenv('OPENAI_API_KEY')

In [4]:
# Connect to HANA Cloud
cc=ConnectionContext(dbHost, dbPort, dbUser, dbPwd, encrypt=True, sslValidateCertificate=False)

In [5]:
table = cc.table(dbTable)
df = table.collect()
df['GDP'] = df['GDP'].astype(str)

data_gdp = {}

for column in df:
    data_gdp[column] = df[column].values.tolist()

In [15]:
def apply_model(query):
    prompt = """Please regard the following table: {}

    The table name is COUNTRIES_BY_GDP. Use ' as the quote character. Quote column aliases with ". Write a SQL query to answer the following question: {}""".format(json.dumps(data_gdp), query)

    request = openai.Completion.create(
        model="text-davinci-003",
        prompt=prompt,
        temperature=0.9,
        max_tokens=3500
    )
    sql_query = request.choices[0].text
    print("===> {}: {}\n".format(query, sql_query))    

    df = cc.sql(sql_query)
    return df.collect()


In [17]:
# print(apply_model("What is the GDP of the United States in 2022?"))
# print(apply_model("What is the total GDP of Europe in 2022?"))
# print(apply_model("What is the average GDP of all countries in 2022?"))
# print(apply_model("What is the GDP of Italy in 2021?"))
# print(apply_model("How many distinct countries are there in Asia?"))
# print(apply_model("Total European GDP in 2021?"))
# print(apply_model("What countries have a GDP higher than 10 million in 2022? (do not use thousands separators)"))
# print(apply_model("How much has the total GDP in 2022 increased compared to the total GDP in 2021?"))
print(apply_model("What is the percentual increase of the US 2022 GDP vs 2021?"))


===> What is the percentual increase of the US 2022 GDP vs 2021?: 

SELECT round(((25035164/22996100) - 1) * 100,2) as "% Increase US GDP"
FROM COUNTRIES_BY_GDP
WHERE "COUNTRY" = 'United States' AND "YEAR" = 2022;

  % Increase US GDP
0              8.87
