In [1]:
import os,openai,json
from dotenv import find_dotenv,load_dotenv
load_dotenv(find_dotenv())
openai.api_key = os.environ['OPENAI_API_KEY']

In [2]:
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import HumanMessage

In [3]:
DSL_TEMPLATE = """\
Given an input question, create a syntactically correct Elasticsearch query to run. \
Unless the user specifies in their question a specific number of examples they wish to obtain, \
always limit your query to at most {top_k} results. \
You can order the results by a relevant column to return the most interesting examples in the database.

Unless told to do not query for all the columns from a specific index, \
only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the mapping description. \
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which index. \
Return the query as valid json.

Use the following format:

Question: Question here
ESQuery: Elasticsearch Query formatted as json

Mapping: {mapping}

Question: {input}
ESQuery:
"""

In [4]:
top_k = 3

mapping = {
  "mappings": {
    "properties": {
      "name":    { "type" : "text" },
      "age":     { "type" : "integer" },
      "occupation":{ "type" : "text" }
    }
  }
}


In [5]:
query = '19歳以下の会社員を教えてください。'
prompt = DSL_TEMPLATE.format(
    top_k=top_k,
    mapping=json.dumps(mapping,indent=2),
    input=query
)

In [6]:
def get_os_dsl(query):
    prompt = DSL_TEMPLATE.format(
        top_k=top_k,
        mapping=json.dumps(mapping,indent=2),
        input=query
    )
    chat = ChatOpenAI()
    messages = [HumanMessage(content=prompt)]
    response = chat(messages)
    # print(response.content)
    return response.content

In [7]:
query = '19歳以下の会社員を教えてください。'
response = get_os_dsl(query)
print(response)

Question: 19歳以下の会社員を教えてください。
ESQuery: 
{
  "query": {
    "bool": {
      "must": [
        {
          "range": {
            "age": {
              "lte": 19
            }
          }
        },
        {
          "match": {
            "occupation": "会社員"
          }
        }
      ]
    }
  },
  "size": 3
}


In [8]:
query = '加藤さんの年齢の平均を教えてください。'
response = get_os_dsl(query)
print(response)

{
  "query": {
    "match": {
      "name": "加藤さん"
    }
  },
  "aggs": {
    "average_age": {
      "avg": {
        "field": "age"
      }
    }
  },
  "size": 3
}


In [9]:
import tiktoken
enc = tiktoken.encoding_for_model(model_name='gpt-3.5-turbo')
print(len(enc.encode(prompt)))

239
