In [49]:
import pandas as pd
import tensorflow_hub as hub
import openai 
import numpy as np
import altair as alt
from openai import OpenAI
from sql_metadata import Parser
import os
from IPython.display import display, Markdown

# Read the MIT DW queries 

In [157]:
df_q_original = pd.read_json('data/original_queries.json')

In [150]:
df_q_original

Unnamed: 0,questions,sql
0,How many organisation units in human ressource...,SELECT count(*) from (select hr_org_unit_key f...
1,How many organisation units in human ressource...,SELECT count(*) from (select hr_org_unit_key f...
2,Count the number of courses offered by each de...,"SELECT department_id, COUNT(*) AS course_count..."
3,"What are the past, upcoming, and, current cour...","select distinct subject_title, effective_term_..."
4,Which faculty has had the most location changes?,select max(cnt) from (select count(*) as cnt f...
5,What is the average number of courses a studen...,select avg(cnt) from (select count(*) as cnt f...
6,Which students are beeing employed by mit?,select distinct empl.full_name from mit_studen...
7,what is the shortest course available this sem...,"select distinct subject_title, term_end_date-t..."
8,Show buildings and their adress.,"SELECT fb.building_name, fba.* FROM wareuser.f..."
9,How many students do we have?,select count(*) from mit_student_directory;


### Use the one that is reformatted and slightly cleaned.

In [208]:
df_q = pd.read_json('data/queries.json')

In [209]:
df_q

Unnamed: 0,questions,sql
0,How many organisation units in human resources...,SELECT COUNT(*) FROM (SELECT hr_org_unit_key F...
1,How many organisation units in human resources...,SELECT COUNT(*) FROM (SELECT hr_org_unit_key F...
2,Count the number of courses offered by each de...,"SELECT department_id, COUNT(*) AS course_count..."
3,"What are the past, upcoming, and current courses?","SELECT DISTINCT subject_title, effective_term_..."
4,Which faculty has had the most location changes?,SELECT MAX(cnt) FROM (SELECT COUNT(*) AS cnt F...
5,What is the average number of courses a studen...,SELECT AVG(cnt) FROM (SELECT COUNT(*) AS cnt F...
6,Which students are employed by mit?,SELECT DISTINCT empl.full_name FROM mit_studen...
7,What is the shortest course available this sem...,"SELECT DISTINCT subject_title, term_end_date -..."
8,Show buildings and their address.,"SELECT fb.building_name, fba.* FROM wareuser.f..."
9,How many students do we have?,SELECT COUNT(*) FROM mit_student_directory;


## Rename the _questions_ column as _question_


In [210]:
df_q = df_q.rename(columns={"questions":"question"})

## Identify the tables used by each query 

In [211]:
tables = []
for q in df_q['sql']: 
    t = Parser(q).tables
    tables.append([v.upper() for v in t])
df_q['tables'] = tables


## Identify the unique tables referred by the corpus 

In [212]:
unique_tables = set()
for  l in tables: 
    for t in l:
        unique_tables.add(t.upper())

In [214]:
unique_tables

{'ACADEMIC_TERMS',
 'ACADEMIC_TERMS_ALL',
 'CIS_COURSE_CATALOG',
 'COURSE_CATALOG_SUBJECT_OFFERED',
 'EMPLOYEE_DIRECTORY',
 'FCLT_BUILDING',
 'FCLT_ROOMS',
 'HR_ORG_UNIT',
 'HR_ORG_UNIT_NEW',
 'IAP_SUBJECT_CATEGORY',
 'IAP_SUBJECT_DETAIL',
 'IAP_SUBJECT_PERSON',
 'IAP_SUBJECT_SESSION',
 'LIBRARY_MATERIAL_STATUS',
 'MIT_STUDENT_DIRECTORY',
 'SPACE_UNIT',
 'STUDENT_DEPARTMENT',
 'SUBJECT_ENROLLABLE',
 'SUBJECT_OFFERED',
 'SUBJECT_SUMMARY',
 'TIP_DETAIL',
 'TIP_MATERIAL',
 'TIP_MATERIAL_STATUS',
 'TIP_SUBJECT_OFFERED',
 'TOP_LEVEL_DOMAIN',
 'WAREUSER.CIS_COURSE_CATALOG',
 'WAREUSER.FAC_BUILDING',
 'WAREUSER.FAC_BUILDING_ADDRESS',
 'WAREUSER.FCLT_BUILDING_ADDRESS_HIST',
 'WAREUSER.SE_PERSON',
 'ZIP_CANADA',
 'ZIP_USA'}

In [215]:
len(unique_tables)

32

In [216]:
prefix='data/views/'

## Go thrugh the `unique_tables` list and save the respective schemas as proper CSV files 

In [217]:
for t in unique_tables: 
    schema_file = prefix + 'schema/' + t + '.csv'
    if os.path.isfile(schema_file): 
        s = pd.read_csv(schema_file, header=None)
        s = s.drop(1) # delete the "header" line 
        column_names = [h.strip() for h in s[0][0].split(';')]
        s_new = pd.DataFrame(list(s[0][1:].apply(lambda row: [c.strip() for c in row.split(';')])), columns=column_names)
        new_schema_file = prefix + 'schema_new/' + t + '.csv'
        s_new.to_csv(new_schema_file, index=False)

# Treat ChatGPT as an all-knowing oracle
## Prompt
_Your job is to write SQL queries that answer a user's question using the tables in the MIT Data Warehouse. 
The MIT Data Warehouse is a central data source that combines data from various administrative systems at 
MIT, containing information about students, faculty, and personnel. You can find more about the MIT data 
warehouse tables at https://web.mit.edu/warehouse/metadata/tables/all_tables.html._

_How many organisation units in human resources are relabeled?_

In [218]:
key = os.environ['OPENAI_API_KEY']

In [219]:
client = OpenAI(api_key=key)

## Populate the prompt and call the Open AI API

In [None]:
# TODO: do this in batch 
gpt_sql_nocontext = []
for q in df_q['question']: 
    response = client.chat.completions.create(
      model="gpt-4",
      messages = [
          {
          "role": "system",
          "content": "Your job is to write SQL queries that answer a user's question using the tables in the MIT Data Warehouse. \
          The MIT Data Warehouse is a central data source that combines data from various administrative systems at MIT, containing \
          information about students, faculty, and personnel. You can find more about the MIT data warehouse tables at \
          https://web.mit.edu/warehouse/metadata/tables/all_tables.html. \
          Reply with only the answer in SQL and include no linebreaks, newlines, escape characters or other commentary."
          # Provide your answer in JSON form."
          # Reply with only the answer in JSON form and include no other commentary"
        },
          {"role":"user", 
           "content":q}]
          )
    gpt_sql_nocontext.append(response.choices[0].message.content)

In [164]:
df_q['gpt_sql_nocontext']=gpt_sql_nocontext

In [168]:
df_q

Unnamed: 0,question,sql,tables,gpt_sql_nocontext
0,How many organisation units in human resources...,SELECT COUNT(*) FROM (SELECT hr_org_unit_key F...,"[HR_ORG_UNIT_NEW, HR_ORG_UNIT]",SELECT COUNT(*) FROM HR_ORGANIZATION_UNITS WHE...
1,How many organisation units in human resources...,SELECT COUNT(*) FROM (SELECT hr_org_unit_key F...,"[HR_ORG_UNIT_NEW, HR_ORG_UNIT]",SELECT COUNT(*) FROM human_resources WHERE new...
2,Count the number of courses offered by each de...,"SELECT department_id, COUNT(*) AS course_count...",[WAREUSER.CIS_COURSE_CATALOG],"""SELECT department_id, COUNT(*) FROM Courses G..."
3,"What are the past, upcoming, and current courses?","SELECT DISTINCT subject_title, effective_term_...","[CIS_COURSE_CATALOG, ACADEMIC_TERMS]",SELECT * FROM courses WHERE course_start_date ...
4,Which faculty has had the most location changes?,SELECT MAX(cnt) FROM (SELECT COUNT(*) AS cnt F...,[WAREUSER.FCLT_BUILDING_ADDRESS_HIST],"SELECT faculty_id, COUNT(DISTINCT location_id)..."
5,What is the average number of courses a studen...,SELECT AVG(cnt) FROM (SELECT COUNT(*) AS cnt F...,"[MIT_STUDENT_DIRECTORY, SUBJECT_ENROLLABLE]",SELECT AVG(course_count) FROM (SELECT student_...
6,Which students are employed by mit?,SELECT DISTINCT empl.full_name FROM mit_studen...,"[MIT_STUDENT_DIRECTORY, EMPLOYEE_DIRECTORY]","SELECT s.student_id, s.first_name, s.last_name..."
7,What is the shortest course available this sem...,"SELECT DISTINCT subject_title, term_end_date -...","[ACADEMIC_TERMS, CIS_COURSE_CATALOG]","SELECT MIN(course_length), course_name FROM co..."
8,Show buildings and their address.,"SELECT fb.building_name, fba.* FROM wareuser.f...","[WAREUSER.FAC_BUILDING, WAREUSER.FAC_BUILDING_...","SELECT building_number, building_name, street_..."
9,How many students do we have?,SELECT COUNT(*) FROM mit_student_directory;,[MIT_STUDENT_DIRECTORY],SELECT COUNT(*) FROM students;


# Give ChatGPT the "database" schema as context and elicit its answer


## Build the schema prompt 

In [137]:
def table_to_sql_create(t): 
    schema_file=prefix + 'schema_new/' + t + '.csv'
    p = None
    if os.path.isfile(schema_file):
        df_t = pd.read_csv(schema_file)
        p = "CREATE TABLE "+ t+ "(\n"  
        p = p + '\n'.join(list(df_t[['COLUMN_NAME', 'DATA_TYPE']].apply(lambda row: ' '.join(row.astype(str)), axis=1)))
        p = p + ')\n'
    return p
    

In [175]:
prompt = "** all the tables in the database ** \n\n"
cnt = 0 
for t in unique_tables:
    p = table_to_sql_create(t)
    if not p: 
        print(t)
    if p: 
        cnt = cnt + 1
        prompt = prompt + p
        prompt = prompt + "\n\n"


WAREUSER.FAC_BUILDING
WAREUSER.FCLT_BUILDING_ADDRESS_HIST
WAREUSER.FAC_BUILDING_ADDRESS
WAREUSER.CIS_COURSE_CATALOG
WAREUSER.SE_PERSON


### We seem to be missing some of the tables (above) 

In [205]:
## The "schema" prompt, which we append to the system prompt below 
print(prompt)

** all the tables in the database ** 

CREATE TABLE ACADEMIC_TERMS(
ACADEMIC_TERMS_KEY VARCHAR2
TERM_CODE VARCHAR2
TERM_DESCRIPTION VARCHAR2
TERM_SELECTOR VARCHAR2
TERM_START_DATE DATE
TERM_END_DATE DATE
ACADEMIC_YEAR VARCHAR2
ACADEMIC_YEAR_DESC VARCHAR2
IS_CURRENT_TERM VARCHAR2
IS_REGULAR_TERM VARCHAR2
TERM_STATUS_INDICATOR VARCHAR2
TERM_STATUS VARCHAR2
FINANCIAL_AID_YEAR VARCHAR2
DEGREE_YEAR VARCHAR2
LAST_DAY_OF_FINAL_EXAM DATE
PRE_REGISTRATION_START_DAY DATE
REGISTRATION_DAY DATE
FIRST_DAY_OF_CLASSES DATE
LAST_DAY_OF_CLASSES DATE
ADD_DATE DATE
DROP_DATE DATE
GRADUATE_AWARD_START_DATE DATE
GRADUATE_AWARD_END_DATE DATE
WAREHOUSE_LOAD_DATE DATE)


CREATE TABLE EMPLOYEE_DIRECTORY(
MIT_ID VARCHAR2
LAST_NAME VARCHAR2
FIRST_NAME VARCHAR2
MIDDLE_NAME VARCHAR2
FULL_NAME VARCHAR2
DIRECTORY_FULL_NAME VARCHAR2
OFFICE_LOCATION VARCHAR2
OFFICE_PHONE VARCHAR2
DIRECTORY_TITLE VARCHAR2
PRIMARY_TITLE VARCHAR2
DEPARTMENT_NUMBER VARCHAR2
DEPARTMENT_NAME VARCHAR2
KRB_NAME VARCHAR2
KRB_NAME_UPPERCASE VAR

## We're now ready to make the api call 

In [143]:
gpt_sql_global_schema = []
for q in df_q['question']: 
    response = client.chat.completions.create(
      model="gpt-4",
      messages = [
          {
          "role": "system",
          "content": "Your job is to write SQL queries that answer a user's question using the tables in the MIT Data Warehouse. \
          The MIT Data Warehouse is a central data source that combines data from various administrative systems at MIT, containing \
          information about students, faculty, and personnel. The schemas of the all the tables that you might need \
          for answering the user questions are below. There are 27 tables in the database.  Reply with only the answer \
          in SQL and include no linebreaks, newlines, escape characters or other commentary.\n\n" + prompt 
          # Provide your answer in JSON form."
          # Reply with only the answer in JSON form and include no other commentary"
        },
          {"role":"user", 
           "content":q}]
          )
    gpt_sql_global_schema.append(response.choices[0].message.content)

In [200]:
df_q['gpt_sql_global_schema']=gpt_sql_global_schema

In [202]:
df_q.to_csv('data/gpt_sql_v0.csv', index=False)

In [277]:
df_q = pd.read_csv('data/gpt_sql_v0.csv')

In [223]:
df_q

Unnamed: 0,question,sql,tables_referred,gpt_sql_nocontext,gpt_sql_global_schema
0,How many organisation units in human resources...,SELECT COUNT(*) FROM (SELECT hr_org_unit_key F...,"['HR_ORG_UNIT_NEW', 'HR_ORG_UNIT']",SELECT COUNT(*) FROM HR_ORGANIZATION_UNITS WHE...,SELECT COUNT(*) FROM HR_ORG_UNIT_NEW WHERE HR_...
1,How many organisation units in human resources...,SELECT COUNT(*) FROM (SELECT hr_org_unit_key F...,"['HR_ORG_UNIT_NEW', 'HR_ORG_UNIT']",SELECT COUNT(*) FROM human_resources WHERE new...,SELECT COUNT(*) FROM HR_ORG_UNIT_NEW WHERE HR_...
2,Count the number of courses offered by each de...,"SELECT department_id, COUNT(*) AS course_count...",['WAREUSER.CIS_COURSE_CATALOG'],"""SELECT department_id, COUNT(*) FROM Courses G...","SELECT DEPARTMENT_NAME, COUNT(*) FROM COURSE_C..."
3,"What are the past, upcoming, and current courses?","SELECT DISTINCT subject_title, effective_term_...","['CIS_COURSE_CATALOG', 'ACADEMIC_TERMS']",SELECT * FROM courses WHERE course_start_date ...,"SELECT SUBJECT_ID, SUBJECT_TITLE, TERM_CODE FR..."
4,Which faculty has had the most location changes?,SELECT MAX(cnt) FROM (SELECT COUNT(*) AS cnt F...,['WAREUSER.FCLT_BUILDING_ADDRESS_HIST'],"SELECT faculty_id, COUNT(DISTINCT location_id)...","SELECT EMPLOYEE_DIRECTORY.FULL_NAME, COUNT(FCL..."
5,What is the average number of courses a studen...,SELECT AVG(cnt) FROM (SELECT COUNT(*) AS cnt F...,"['MIT_STUDENT_DIRECTORY', 'SUBJECT_ENROLLABLE']",SELECT AVG(course_count) FROM (SELECT student_...,SELECT AVG(NUM_ENROLLED_STUDENTS) FROM SUBJECT...
6,Which students are employed by mit?,SELECT DISTINCT empl.full_name FROM mit_studen...,"['MIT_STUDENT_DIRECTORY', 'EMPLOYEE_DIRECTORY']","SELECT s.student_id, s.first_name, s.last_name...","SELECT FIRST_NAME, LAST_NAME FROM MIT_STUDENT_..."
7,What is the shortest course available this sem...,"SELECT DISTINCT subject_title, term_end_date -...","['ACADEMIC_TERMS', 'CIS_COURSE_CATALOG']","SELECT MIN(course_length), course_name FROM co...","SELECT SUBJECT_TITLE, MIN(TERM_DURATION) AS MI..."
8,Show buildings and their address.,"SELECT fb.building_name, fba.* FROM wareuser.f...","['WAREUSER.FAC_BUILDING', 'WAREUSER.FAC_BUILDI...","SELECT building_number, building_name, street_...","SELECT BUILDING_NUMBER, BUILDING_NAME, LATITUD..."
9,How many students do we have?,SELECT COUNT(*) FROM mit_student_directory;,['MIT_STUDENT_DIRECTORY'],SELECT COUNT(*) FROM students;,SELECT COUNT(*) FROM MIT_STUDENT_DIRECTORY


In [291]:
tables = []
cnt = 0 
for q in df_q['sql']: 
    print(q)
    cnt = cnt + 1
    try: 
        t = Parser(q).tables
        tables.append([v.upper() for v in t])
    except:
        tables.append(None)
df_q['tables_referred'] = tables


SELECT COUNT(*) FROM (SELECT hr_org_unit_key FROM hr_org_unit_new INTERSECT SELECT hr_org_unit_key FROM hr_org_unit);
SELECT COUNT(*) FROM (SELECT hr_org_unit_key FROM hr_org_unit_new MINUS SELECT hr_org_unit_key FROM hr_org_unit);
SELECT department_id, COUNT(*) AS course_count FROM wareuser.cis_course_catalog GROUP BY department_id;
SELECT DISTINCT subject_title, effective_term_code, CASE WHEN term_start_date > SYSDATE THEN 'Upcoming' WHEN term_start_date <= SYSDATE AND term_end_date > SYSDATE THEN 'Current' ELSE 'Past' END AS term_status FROM cis_course_catalog, academic_terms WHERE effective_term_code = term_code;
SELECT MAX(cnt) FROM (SELECT COUNT(*) AS cnt FROM WAREUSER.FCLT_BUILDING_ADDRESS_HIST GROUP BY fclt_building_key);
SELECT AVG(cnt) FROM (SELECT COUNT(*) AS cnt FROM mit_student_directory, subject_enrollable WHERE offer_dept_code = department GROUP BY full_name);
SELECT DISTINCT empl.full_name FROM mit_student_directory mit, employee_directory empl WHERE empl.full_name = mi

In [325]:
accuracy = []
for t0, t1 in zip(df_q['tables_referred'], df_q['gpt_sql_global_schema_retrieved']):
    a = None 
    k='WAREUSER'
    f = False
    for t in t0:
        # ignore the sql queries referring to the tables with the 'WAREUSER' 
        # prefix, for which we didn't find the schemas in the schema directory 
        if t.find(k) > -1: 
            f = True
            break
    if t0 and t1 and not f: 
        a = len(set(t0).intersection(set(t1)))/len(t0)
    accuracy.append(a)
df_q['gpt_sql_global_schema_accuracy']=accuracy 