In [None]:
from flask import Flask, request, jsonify, session
import openai
import psycopg2
from psycopg2.extras import RealDictCursor
from flask_session import Session  # This requires Flask-Session

app = Flask(__name__)
app.config['SECRET_KEY'] = 'your_secret_key'  # Needed for session management
app.config['SESSION_TYPE'] = 'filesystem'  # Sessions stored on the filesystem
Session(app)

# Replace with your OpenAI API key
openai.api_key = 'your_openai_api_key'

# Database connection parameters
db_params = {
    'dbname': 'public',
    'user': 'postgres',
    'password': 'abcdabcd',
    'host': '136.144.56.123',
    'port': '5432'
}

SCHEMA = '''
CREATE TABLE trades (
exchange character varying(20),
symbol character varying(20),
price double precision,
size double precision,
taker_side character varying(5),
trade_id character varying(64),
event_timestamp timestamp without time zone,
atom_timestamp bigint
);

CREATE TABLE trades_l3 (
    exchange character varying(20),
    symbol character varying(20),
    price double precision,
    size double precision,
    taker_side character varying(5),
    trade_id character varying(64),
    maker_order_id character varying(64),
    taker_order_id character varying(64),
    event_timestamp timestamp without time zone,
    atom_timestamp bigint
);

CREATE TABLE candle (
    exchange character varying(20),
    symbol character varying(20),
    start timestamp without time zone,
    "end" timestamp without time zone,
    "interval" character varying(10),
    trades integer,
    closed boolean,
    o double precision,
    h double precision,
    l double precision,
    c double precision,
    v double precision,
    event_timestamp timestamp without time zone,
    atom_timestamp bigint
);

CREATE TABLE ethereum_blocks (
    blocktimestamp timestamp without time zone,
    atomtimestamp bigint,
    number integer,
    hash character(66) NOT NULL,
    parenthash character(66),
    nonce character(18),
    sha3uncles character(66),
    logsbloom character(514),
    transactionsroot character(66),
    stateroot character(66),
    receiptsroot character(66),
    miner character(42),
    difficulty bigint,
    totaldifficulty numeric,
    extradata text,
    size bigint,
    gaslimit numeric,
    gasused numeric
);

CREATE TABLE ethereum_logs (
    atomtimestamp bigint,
    blocktimestamp timestamp without time zone NOT NULL,
    logindex integer NOT NULL,
    transactionindex integer,
    transactionhash character(66) NOT NULL,
    blockhash character(66),
    blocknumber bigint,
    address character(42),
    data text,
    topic0 text,
    topic1 text,
    topic2 text,
    topic3 text
)

CREATE TABLE ethereum_transactions (
    blocktimestamp timestamp without time zone NOT NULL,
    atomtimestamp bigint,
    blocknumber integer,
    blockhash character(66),
    hash character(66) NOT NULL,
    nonce text,
    transactionindex integer,
    fromaddr character(42),
    toaddr character(42),
    value numeric,
    gas bigint,
    gasprice bigint,
    input text,
    maxfeepergas bigint,
    maxpriorityfeepergas bigint,
    type text
)

CREATE TABLE ethereum_token_transfers (
    atomtimestamp bigint,
    blocktimestamp timestamp without time zone NOT NULL,
    tokenaddr character(42),
    fromaddr text,
    toaddr text,
    value numeric,
    transactionhash character(66) NOT NULL,
    logindex integer NOT NULL,
    blocknumber bigint,
    blockhash character(66)
)
'''

# Initialize session key for storing conversation
SESSION_KEY = 'conversation'

def execute_query(query):
    """Executes a SQL query and returns the results."""
    conn = psycopg2.connect(**db_params)
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    try:
        cursor.execute(query)
        result = cursor.fetchall()
        conn.commit()
        return result
    except Exception as e:
        print(f"Database error: {e}")
        return None
    finally:
        cursor.close()
        conn.close()

def determine_chart_type(data):
    """Determines the most appropriate chart type based on data types."""
    if any('timestamp' in column for column in data[0].keys()):
        return ['line']
    elif len(data[0]) == 2:
        return ['scatter']
    else:
        return ['bar', 'pie']  # Include 'pie' for categorical data

def generate_charts(data, chart_types):
    """Prepares data for charts to be rendered by Recharts on the frontend."""
    charts = {}
    for chart_type in chart_types:
        chart_data = []
        if chart_type in ['line', 'scatter', 'bar']:
            for item in data:
                entry = {key: value for key, value in item.items()}
                chart_data.append(entry)
        elif chart_type == 'pie':
            labels = [str(item[list(item.keys())[0]]) for item in data]
            sizes = [item[list(item.keys())[1]] for item in data]
            chart_data = [{"name": label, "value": size} for label, size in zip(labels, sizes)]

        charts[chart_type] = chart_data
    
    return charts

def generate_summary(data):
    """Generates a summary for given data using OpenAI's API."""
    prompt = "Summarize this data in natural language: " + str(data)
    try:
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",  # Adjust according to the latest available model
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.7
        )
        summary = response.choices[0].message['content']
        return summary
    except Exception as e:
        print(f"Error generating summary: {e}")
        return "Error generating summary."

@app.route('/query', methods=['POST'])
def handle_query():
    user_question = request.json.get('question')
    conversation_history = session.get(SESSION_KEY, [])
    
    try:
        # Include conversation history in the request to OpenAI
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",  # Adjust according to the latest available model
            messages=conversation_history + [
                {"role": "system", "content": f"""You are a program which translates natural language into read-only SQL commands. \
                                               Use the following table schema: {SCHEMA}. You only output SQL queries. Your \
                                               queries are designed to be used as timeseries charts from Apache Superset. \
                                               Trading pairs are in the form "<base>.<quote>", where <base> and <quote> are \
                                               uppercase. Exchanges "binance" and "coinbase" use the trades_l3 table for their \
                                               trades, all other exchanges use the trades table. All exchange names are \
                                               lowercase. Input:"""},
                {"role": "user", "content": f"Translate this natural language question to SQL: \"{user_question}\""}
            ],
            temperature=0.7
        )
        sql_query = response.choices[0].message['content'].strip()
        conversation_history.append({"role": "assistant", "content": sql_query})
        session[SESSION_KEY] = conversation_history  # Update session with the new history
    except Exception as e:
        return jsonify({"error": f"Error generating SQL query: {e}"}), 500

    # Execute the SQL query
    query_result = execute_query(sql_query)
    if query_result is None:
        return jsonify({"error": "Failed to execute SQL query"}), 500

    # Generate a chart if applicable and requested
    charts = {}
    summary = ""
    if 'chart' in user_question.lower():
        # Determine the most appropriate chart type based on the query result
        chart_types = determine_chart_type(query_result)
        charts = generate_charts(query_result, chart_types)
        summary = generate_summary(query_result)

        # Adjust the response to include both charts if they exist
        response_data = {}
        for chart_type, chart_base64 in charts.items():
            response_data[chart_type + "_chart"] = chart_base64
        response_data["summary"] = summary
        return jsonify(response_data)
    else:
        # If not generating a chart, return the query result directly
        summary = generate_summary(query_result)
        return jsonify({"result": query_result, "summary": summary})

if __name__ == '__main__':
    app.run(debug=True)
