In [22]:
%%writefile requirements.txt
sqlalchemy==1.4.47
snowflake-sqlalchemy
langchain==0.0.166
sqlalchemy-aurora-data-api
PyAthena[SQLAlchemy]==2.25.2
anthropic
openai
redshift-connector==2.0.910
sqlalchemy-redshift==0.8.14

Overwriting requirements.txt


In [23]:
!pip install -r requirements.txt

Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com


In [24]:
import json
import boto3

import sqlalchemy
from sqlalchemy import create_engine
from snowflake.sqlalchemy import URL

from langchain.docstore.document import Document
from langchain import PromptTemplate,SagemakerEndpoint,SQLDatabase, SQLDatabaseChain, LLMChain
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts.prompt import PromptTemplate
from langchain.chains import SQLDatabaseSequentialChain

from langchain.chains.api.prompt import API_RESPONSE_PROMPT
from langchain.chains import APIChain
from langchain.prompts.prompt import PromptTemplate
from langchain.chat_models import ChatAnthropic
from langchain.chat_models import ChatOpenAI
from langchain.chains.api import open_meteo_docs

from typing import Dict

In [25]:
CFN_STACK_NAME = "stacksnow"
stacks = boto3.client('cloudformation').list_stacks()
stack_found = CFN_STACK_NAME in [stack['StackName'] for stack in stacks['StackSummaries']]

In [26]:
from typing import List
def get_cfn_outputs(stackname: str) -> List:
    cfn = boto3.client('cloudformation')
    outputs = {}
    for output in cfn.describe_stacks(StackName=stackname)['Stacks'][0]['Outputs']:
        outputs[output['OutputKey']] = output['OutputValue']
    return outputs

def get_cfn_parameters(stackname: str) -> List:
    cfn = boto3.client('cloudformation')
    params = {}
    for param in cfn.describe_stacks(StackName=stackname)['Stacks'][0]['Parameters']:
        params[param['ParameterKey']] = param['ParameterValue']
    return params

if stack_found is True:
    outputs = get_cfn_outputs(CFN_STACK_NAME)
    params = get_cfn_parameters(CFN_STACK_NAME)
    glue_crawler_name = params['CFNCrawlerName']
    glue_database_name = params['CFNDatabaseName']
    glue_databucket_name = params['DataBucketName']
    region = outputs['Region']
    print(f"cfn outputs={outputs}\nparams={params}")
else:
    print("Recheck our cloudformation stack name")

cfn outputs={'LLMEndpointName': 'aws-genai-mda-blog-flan-t5-xxl-endpoint-bc582dc0', 'SageMakerNotebookURL': 'https://console.aws.amazon.com/sagemaker/home?region=us-east-1#/notebook-instances/openNotebook/aws-genai-mda-blog?view=classic', 'GlueCrawlerName': 'cfn-crawler-json', 'Region': 'us-east-1'}
params={'SageMakerIAMRole': 'awsGenAIMDAblogIAMRole', 'DataBucketName': 'genai-sample-bharath', 'SageMakerNotebookName': 'aws-genai-mda-blog', 'CFNCrawlerName': 'cfn-crawler-json', 'CFNTablePrefixName': 'cfn_', 'CFNDatabaseName': 'cfn_covid_lake'}


In [41]:
#LLM 
#get the llm api key
#llm variables
#Refer here for access to Anthropic API Keys https://console.anthropic.com/docs/access
anthropic_secret = "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"#<your anthropic secret id>
llm = ChatOpenAI(temperature=0, openai_api_key=anthropic_secret)
from sqlalchemy import create_engine

#Create connection to Snowflake 

account_identifier = 'AWSPARTNER'
user = 'BHARATHS'
password = 'Blippi1503'
database_name = 'MOVIELENS'
schema_name = 'PUBLIC'
table = '%'
warehouse_name = 'SUREBART'
role_name = 'ACCOUNTADMIN'
conn_string = f"snowflake://{user}:{password}@{account_identifier}/{database_name}/{schema_name}?warehouse={warehouse_name}&role={role_name}"
engine_snowflake = create_engine(conn_string)
dbsnowflake = SQLDatabase(engine_snowflake)
gdc = ['snowdb'] 

In [42]:
#Generate Dynamic prompts to populate the Glue Data Catalog
#harvest aws crawler metadata

def parse_catalog():
    #Connect to Glue catalog
    #get metadata of redshift serverless tables
    columns_str=''
    
    #define glue cient
    glue_client = boto3.client('glue')
    
    for db in gdc:
        response = glue_client.get_tables(DatabaseName =db)
        for tables in response['TableList']:
            #classification in the response for s3 and other databases is different. Set classification based on the response location
            if tables['StorageDescriptor']['Location'].startswith('s3'):  classification='s3' 
            else:  classification = tables['Parameters']['classification']
            for columns in tables['StorageDescriptor']['Columns']:
                    dbname,tblname,colname=tables['DatabaseName'],tables['Name'],columns['Name']
                    columns_str=columns_str+f'\n{classification}|{dbname}|{tblname}|{colname}'                     
    #API
    ## Append the metadata of the API to the unified glue data catalog
    columns_str=columns_str+'\n'+('api|meteo|weather|weather')
    return columns_str

glue_catalog = parse_catalog()

#display a few lines from the catalog
print('\n'.join(glue_catalog.splitlines()[-10:]) )


snowflake|snowdb|movielens_public_movies|movieid
snowflake|snowdb|movielens_public_movies|genres
snowflake|snowdb|movielens_public_movies|title
api|meteo|weather|weather


In [43]:
#Function 1 'Infer Channel'
#define a function that infers the channel/database/table and sets the database for querying
def identify_channel(query):
    #Prompt 1 'Infer Channel'
    ##set prompt template. It instructs the llm on how to evaluate and respond to the llm. It is referred to as dynamic since glue data catalog is first getting generated and appended to the prompt.
    prompt_template = """
     From the table below, find the database (in column database) which will contain the data (in corresponding column_names) to answer the question 
     {query} \n
     """+glue_catalog +""" 
     Give your answer as database == 
     Also,give your answer as database.table == 
     """
    ##define prompt 1
    PROMPT_channel = PromptTemplate( template=prompt_template, input_variables=["query"]  )

    # define llm chain
    llm_chain = LLMChain(prompt=PROMPT_channel, llm=llm)
    #run the query and save to generated texts
    generated_texts = llm_chain.run(query)
    print(generated_texts)

    #set the best channel from where the query can be answered
    if 'snowflake' in generated_texts: 
            channel='db'
            db=dbsnowflake 
            print("SET database to snowflake")  
    elif 'redshift'  in generated_texts: 
            channel='db'
            db=dbredshift
            print("SET database to redshift")
    elif 's3' in generated_texts: 
            channel='db'
            db=dbathena
            print("SET database to athena")
    elif 'rdsmysql' in generated_texts: 
            channel='db'
            db=dbrds
            print("SET database to rds")    
    elif 'api' in generated_texts: 
            channel='api'
            print("SET database to weather api")        
    else: raise Exception("User question cannot be answered by any of the channels mentioned in the catalog")
    print("Step complete. Channel is: ", channel)
    
    return channel, db

#Function 2 'Run Query'
#define a function that infers the channel/database/table and sets the database for querying
def run_query(query):

    channel, db = identify_channel(query) #call the identify channel function first

    ##Prompt 2 'Run Query'
    #after determining the data channel, run the Langchain SQL Database chain to convert 'text to sql' and run the query against the source data channel. 
    #provide rules for running the SQL queries in default template--> table info.

    _DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.

    Do not append 'Query:' to SQLQuery.
    
    Display SQLResult after the query is run in plain english that users can understand. 

    Provide answer in simple english statement.
 
    Only use the following tables:

    {table_info}
    If someone asks for the sales, they really mean the tickit.sales table.
    If someone asks for the sales date, they really mean the column tickit.sales.saletime.

    Question: {input}"""

    PROMPT_sql = PromptTemplate(
        input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
    )

    
    if channel=='db':
        db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT_sql, verbose=True, return_intermediate_steps=False)
        response=db_chain.run(query)
    elif channel=='api':
        chain_api = APIChain.from_llm_and_api_docs(llm, open_meteo_docs.OPEN_METEO_DOCS, verbose=True)
        response=chain_api.run(query)
    else: raise Exception("Unlisted channel. Check your unified catalog")
    return response

In [57]:
#snowflake - Finance and Investments
# query = """Which stock performed the best and the worst in May of 2013?"""
query = """"which movie title contain word Toy"""


#Response from Langchain
response =  run_query(query)
print("----------------------------------------------------------------------")
print(f'SQL and response from user query {query}  \n  {response}')

The database that will contain the data to answer the question "which movie title contains the word Toy" is snowflake.

Answer:
database == snowflake
database.table == snowflake.movielens_public_movies
SET database to snowflake
Step complete. Channel is:  db


[1m> Entering new SQLDatabaseChain chain...[0m
"which movie title contain word Toy
SQLQuery:[32;1m[1;3mSELECT title
FROM movies
WHERE title LIKE '%Toy%'
[0m
SQLResult: [33;1m[1;3m[('Toy Story (1995)',), ('Babes in Toyland (1961)',), ('Toys (1992)',), ('Babes in Toyland (1934)',), ('Toy Story 2 (1999)',), ('Toy, The (1982)',), ('Toy Soldiers (1991)',), ('Toy Story 3 (2010)',)][0m
Answer:[32;1m[1;3mThe movies with titles that contain the word "Toy" are "Toy Story (1995)", "Babes in Toyland (1961)", "Toys (1992)", "Babes in Toyland (1934)", "Toy Story 2 (1999)", "Toy, The (1982)", "Toy Soldiers (1991)", and "Toy Story 3 (2010)".[0m
[1m> Finished chain.[0m
-----------------------------------------------------------------