In [None]:
from flask import Flask, request, jsonify, session
import openai
import psycopg2
from psycopg2.extras import RealDictCursor
from flask_session import Session  # This requires Flask-Session

app = Flask(__name__)
app.config['SECRET_KEY'] = 'your_secret_key'  # Needed for session management
app.config['SESSION_TYPE'] = 'filesystem'  # Sessions stored on the filesystem
Session(app)

# Replace with your OpenAI API key
openai.api_key = 'your_openai_api_key'

# Database connection parameters
db_params = {
    'dbname': 'postgres',
    'user': 'viewer_account',
    'password': 'empty',
    'host': '136.144.62.142',
    'port': '5432'
}

SCHEMA = '''
CREATE TABLE trades (
exchange character varying(20),
symbol character varying(20),
price double precision,
size double precision,
taker_side character varying(5),
trade_id character varying(64),
event_timestamp timestamp without time zone,
atom_timestamp bigint
);

CREATE TABLE trades_l3 (
    exchange character varying(20),
    symbol character varying(20),
    price double precision,
    size double precision,
    taker_side character varying(5),
    trade_id character varying(64),
    maker_order_id character varying(64),
    taker_order_id character varying(64),
    event_timestamp timestamp without time zone,
    atom_timestamp bigint
);

CREATE TABLE candle (
    exchange character varying(20),
    symbol character varying(20),
    start timestamp without time zone,
    "end" timestamp without time zone,
    "interval" character varying(10),
    trades integer,
    closed boolean,
    o double precision,
    h double precision,
    l double precision,
    c double precision,
    v double precision,
    event_timestamp timestamp without time zone,
    atom_timestamp bigint
);

CREATE TABLE ethereum_blocks (
    blocktimestamp timestamp without time zone,
    atomtimestamp bigint,
    number integer,
    hash character(66) NOT NULL,
    parenthash character(66),
    nonce character(18),
    sha3uncles character(66),
    logsbloom character(514),
    transactionsroot character(66),
    stateroot character(66),
    receiptsroot character(66),
    miner character(42),
    difficulty bigint,
    totaldifficulty numeric,
    extradata text,
    size bigint,
    gaslimit numeric,
    gasused numeric
);

CREATE TABLE ethereum_logs (
    atomtimestamp bigint,
    blocktimestamp timestamp without time zone NOT NULL,
    logindex integer NOT NULL,
    transactionindex integer,
    transactionhash character(66) NOT NULL,
    blockhash character(66),
    blocknumber bigint,
    address character(42),
    data text,
    topic0 text,
    topic1 text,
    topic2 text,
    topic3 text
)

CREATE TABLE ethereum_transactions (
    blocktimestamp timestamp without time zone NOT NULL,
    atomtimestamp bigint,
    blocknumber integer,
    blockhash character(66),
    hash character(66) NOT NULL,
    nonce text,
    transactionindex integer,
    fromaddr character(42),
    toaddr character(42),
    value numeric,
    gas bigint,
    gasprice bigint,
    input text,
    maxfeepergas bigint,
    maxpriorityfeepergas bigint,
    type text
)

CREATE TABLE ethereum_token_transfers (
    atomtimestamp bigint,
    blocktimestamp timestamp without time zone NOT NULL,
    tokenaddr character(42),
    fromaddr text,
    toaddr text,
    value numeric,
    transactionhash character(66) NOT NULL,
    logindex integer NOT NULL,
    blocknumber bigint,
    blockhash character(66)
)
'''

# Initialize session key for storing conversation
SESSION_KEY = 'conversation'

def execute_query(query):
    """Executes a SQL query and returns the results."""
    conn = psycopg2.connect(**db_params)
    cursor = conn.cursor(cursor_factory=RealDictCursor)
    try:
        cursor.execute(query)
        result = cursor.fetchall()
        conn.commit()
        return result
    except Exception as e:
        print(f"Database error: {e}")
        print(f"Failed SQL: {query}")  # Log the failed SQL query for debugging
        return None
    finally:
        cursor.close()
        conn.close()

def find_group_column(data):
    """
    Dynamically identifies a potential column for grouping data based on the criterion
    of having a moderate number of unique values and being non-numeric, now considering 'time' or 'date' in variable names.
    """
    # Adjust the potential_group_columns list comprehension to exclude numeric columns
    potential_group_columns = [col for col in data[0].keys() if not is_numeric(data[0][col]) and not any(substring in col for substring in ['time', 'date'])]
    
    for col in potential_group_columns:
        unique_values = {item[col] for item in data}
        if 1 < len(unique_values) <= 10:  # Example threshold, adjust based on your data
            return col, True
    return None, False

def is_numeric(value):
    """Determines if a value is numeric."""
    try:
        float(value)
        return True
    except ValueError:
        return False
    
def rearrange_columns_for_datetime(data):
    """
    Reorders the columns in each row of the query result, ensuring that date or time related columns are first.
    Assumes that 'data' is a list of dictionaries, where each dictionary represents a row.
    """
    if not data:
        return data

    reordered_data = []
    for row in data:
        date_time_columns = {key: value for key, value in row.items() if any(substring in key.lower() for substring in ['time', 'date'])}
        other_columns = {key: value for key, value in row.items() if key not in date_time_columns}
        reordered_row = {**date_time_columns, **other_columns}  # Merge dictionaries, with date/time columns first
        reordered_data.append(reordered_row)

    return reordered_data

def determine_chart_type(data):
    """
    Enhanced to determine the most appropriate chart type based on data types,
    including dynamic grouping for multi-line charts. Now more robustly identifies time/date related columns.
    Enhancements to ensure date/time columns are prioritized. Data is preprocessed to reorder columns.
    """
    if not data:
        return ['text'], None  # New case for empty data
    
    # Preprocess the data to reorder columns, prioritizing date/time columns
    data = rearrange_columns_for_datetime(data)

    first_row = data[0]
    numeric_columns = [col for col, val in first_row.items() if is_numeric(val)]
    # Adjusted to include any column containing 'time' or 'date'
    categorical_columns = [col for col in first_row.keys() if col not in numeric_columns and not any(substring in col for substring in ['time', 'date'])]

    if any(substring in col for col in first_row for substring in ['time', 'date']):
        group_column, is_suitable_for_multi_line = find_group_column(data)
        if is_suitable_for_multi_line:
            return ['multi-line'], group_column
        else:
            return ['line'], None
    elif len(numeric_columns) > 1:
        return ['scatter'], None
    elif len(categorical_columns) >= 1 and len(numeric_columns) == 1:
        return ['bar', 'pie'], None
    else:
        return ['text'], None  # Fallback for unexpected data shapes

def generate_charts(data, chart_types, group_column=None):
    """Generates charts from data based on the determined chart types, including multi-line charts with a dynamically determined grouping column."""
    charts = {}
    for chart_type in chart_types:
        chart_data = []
        if chart_type in ['multi-line', 'line']:
            # Ensure time or date columns are used for the x-axis
            for item in data:
                # Identify time/date columns and numeric columns
                time_date_keys = sorted([key for key in item.keys() if any(substring in key.lower() for substring in ['time', 'date'])],
                                        key=lambda x: ('time' in x.lower(), 'date' in x.lower()))
                numeric_keys = [key for key in item.keys() if is_numeric(item[key]) and key not in time_date_keys]

                # Assuming the first time/date column for the x-axis and the first numeric column for the y-axis
                if time_date_keys and numeric_keys:
                    x_axis_key = time_date_keys[0]
                    y_axis_key = numeric_keys[0]  # Use the first available numeric column for the y-axis
                    chart_data_entry = {
                        'x': item[x_axis_key],
                        'y': item[y_axis_key]
                    }
                    if chart_type == 'multi-line' and group_column:
                        # If there's a group_column specified, use it to group data for multi-line charts
                        chart_data_entry['group'] = item[group_column]
                    chart_data.append(chart_data_entry)
                else:
                    # If suitable time/date or numeric columns are not found, print a warning
                    print("Warning: Insufficient data for line/multi-line chart. A time/date and a numeric column are required.")
                    continue  # Skip to the next item

            # For multi-line charts, additional logic to handle grouping and plotting multiple lines may be implemented here
            # Assume chart_data has been populated as shown in the previous snippet
            if chart_type == 'multi-line':
                # Group data by the specified group_column for plotting multiple lines
                grouped_data = {}
                for entry in chart_data:
                    group_key = entry['group']  # The group_column value for this entry
                    if group_key not in grouped_data:
                        grouped_data[group_key] = []
                    grouped_data[group_key].append(entry)

                # Now, grouped_data contains separate lists of data points for each group
                # This data structure is ready for plotting multiple lines within the same chart
                # Here, you would iterate over grouped_data and plot each group as a distinct line
                # For the purpose of this script, we will simply organize the data for return or further processing
                organized_chart_data = []
                for group, entries in grouped_data.items():
                    # For each group, you could further process or format the entries as needed for your charting tool
                    # Here, we append the group data directly, assuming downstream processing will handle the specifics
                    organized_chart_data.append({
                        'group': group,
                        'data': entries  # This contains all the x, y (and potentially other) values for this group
                    })

                # Replace chart_data with organized_chart_data for multi-line chart preparation
                chart_data = organized_chart_data
        elif chart_type == 'scatter':
            # Prioritize time or date columns for the x-axis if available, otherwise use two numeric columns
            for item in data:
                # Identify time/date columns and numeric columns
                time_date_keys = sorted([key for key in item.keys() if any(substring in key.lower() for substring in ['time', 'date'])],
                                        key=lambda x: ('time' in x.lower(), 'date' in x.lower()))
                numeric_keys = [key for key in item.keys() if is_numeric(item[key])]

                # Decide on the x and y axis keys based on available columns
                if time_date_keys:
                    x_axis_key = time_date_keys[0]  # Use the first time/date column for the x-axis
                    numeric_keys_except_time_date = [key for key in numeric_keys if key not in time_date_keys]
                    if numeric_keys_except_time_date:
                        y_axis_key = numeric_keys_except_time_date[0]  # Use the first non-time/date numeric column for the y-axis
                    else:
                        print("Warning: No suitable numeric column available for scatter chart y-axis.")
                        continue  # Skip to the next data item
                elif len(numeric_keys) >= 2:
                    x_axis_key = numeric_keys[0]  # Use the first two numeric columns for the x and y axes
                    y_axis_key = numeric_keys[1]
                else:
                    print("Warning: Insufficient data for scatter chart. At least one time/date and one numeric column, or two numeric columns are required.")
                    continue  # Skip to the next data item

                # Prepare the chart data entry
                chart_data_entry = {
                    'x': item[x_axis_key],
                    'y': item[y_axis_key],
                    'label': item.get(group_column) if group_column else None  # Include group label if applicable
                }
                chart_data.append(chart_data_entry)
        elif chart_type == 'bar':
            # Identify the first categorical column as the x-axis and the first numeric column as the y-axis
            for item in data:
                categorical_keys = [key for key in item.keys() if not is_numeric(item[key]) and not any(substring in key.lower() for substring in ['time', 'date'])]
                numeric_keys = [key for key in item.keys() if is_numeric(item[key])]

                if categorical_keys and numeric_keys:
                    x_axis_key = categorical_keys[0]  # Use the first categorical column for the x-axis
                    y_axis_key = numeric_keys[0]  # Use the first numeric column for the y-axis
                    chart_data_entry = {
                        'category': item[x_axis_key],
                        'value': item[y_axis_key]
                    }
                    chart_data.append(chart_data_entry)
                else:
                    print("Warning: Insufficient data for bar chart. At least one categorical and one numeric column are required.")
        elif chart_type == 'pie':
            # Assuming there's a clear distinction between categorical (for labels) and numeric columns (for values)
            categorical_keys = [key for key in data[0].keys() if not is_numeric(data[0][key]) and not any(substring in key.lower() for substring in ['time', 'date'])]
            numeric_keys = [key for key in data[0].keys() if is_numeric(data[0][key])]

            if not categorical_keys or not numeric_keys:
                print("Warning: Insufficient data for pie chart. A categorical and a numeric column are required.")
                # Handling or notification for insufficient data
            else:
                x_axis_key = categorical_keys[0]  # The label
                y_axis_key = numeric_keys[0]  # The value

                chart_data = [{'label': item[x_axis_key], 'value': item[y_axis_key]} for item in data]

        charts[chart_type] = chart_data

    return charts

def generate_summary(data):
    """Generates a summary for given data using OpenAI's API."""
    prompt = "Summarize this data in natural language: " + str(data)
    try:
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",  # Adjust according to the latest available model
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.7
        )
        summary = response.choices[0].message['content']
        return summary
    except Exception as e:
        print(f"Error generating summary: {e}")
        return "Error generating summary."

@app.route('/query', methods=['POST'])
def handle_query():
    user_question = request.json.get('question')
    conversation_history = session.get(SESSION_KEY, [])
    
    try:
        # Include conversation history in the request to OpenAI
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo",  # Adjust according to the latest available model
            messages=conversation_history + [
                {"role": "system", "content": f"""You are a program which translates natural language into read-only SQL commands. \
                                               Use the following table schema: {SCHEMA}. You only output SQL queries. Your \
                                               queries are designed to be used as timeseries charts from Apache Superset. \
                                               Trading pairs are in the form "<base>.<quote>", where <base> and <quote> are \
                                               uppercase. Exchanges "binance" and "coinbase" use the trades_l3 table for their \
                                               trades, all other exchanges use the trades table. All exchange names are \
                                               lowercase. Input:"""},
                {"role": "user", "content": f"Translate this natural language question to SQL: \"{user_question}\""}
            ],
            temperature=0
        )
        sql_query = response.choices[0].message['content'].strip()

        # Trim Markdown code block syntax accurately and ensure "sql" prefix is properly removed
        if sql_query.startswith("```sql"):
            sql_query = sql_query[6:]  # Remove starting "```sql"
        elif sql_query.startswith("sql"):
            sql_query = sql_query[3:]  # Remove starting "sql" if it's directly at the beginning
        
        sql_query = sql_query.strip(" `\n")  # Trim spaces, backticks, and newlines from both ends

        conversation_history.append({"role": "assistant", "content": sql_query})
        session[SESSION_KEY] = conversation_history  # Update session with the new history
    except Exception as e:
        return jsonify({"error": f"Error generating SQL query: {e}"}), 500

    # Execute the SQL query
    query_result = execute_query(sql_query)
    if query_result is None:
        return jsonify({"error": "Failed to execute SQL query"}), 500

    # Generate a chart if applicable and requested
    charts = {}
    summary = ""
    if 'chart' in user_question.lower():
        # Determine the most appropriate chart type based on the query result
        chart_types = determine_chart_type(query_result)
        charts = generate_charts(query_result, chart_types)
        summary = generate_summary(query_result)

        # Adjust the response to include both charts if they exist
        response_data = {}
        for chart_type, chart_base64 in charts.items():
            response_data[chart_type + "_chart"] = chart_base64
        response_data["summary"] = summary
        return jsonify(response_data)
    else:
        # If not generating a chart, return the query result directly
        summary = generate_summary(query_result)
        return jsonify({"result": query_result, "summary": summary})

if __name__ == '__main__':
    app.run(debug=True)
