# AI-driven Interactive Metadata 

*Amanda Birmingham, Dept. of Pediatrics, UC San Diego*

A natural-language approach to metadata investigation and cleaning using the `ChatGPT 4` LLM, the `Langchain` AI framework, and AI-based speech recognition.

## Initial set-up

To be performed outside notebook:

## Adjustable LLM settings

In [1]:
import os
g_GEMINI = "GOOGLE_API_KEY"
g_CHATGPT = os.environ['OPENAI_API_KEY']

g_chosen_llm = g_CHATGPT

In [2]:
g_use_speech = True

In [3]:
# WARNING: Increasing this number will increase the amount of information included in each 
# LLM query and thus increase the cost of the queries!
# Decreasing this number will make the LLM forget past exchanges more quickly
g_num_msgs_in_history = 4

In [4]:
g_base_prompt = f"using pandas 3 and python 3.10+ to clean data in jupyter lab."
g_ds_prompt = f"You are a data scientist {g_base_prompt}"
g_pf_prompt = f"You are a professor of data science teaching a class on {g_base_prompt}." 

## Imports

In [5]:
import os
import pandas as pd
import re

In [6]:
import ipywidgets as widgets
from IPython import get_ipython
from IPython.display import display
from traitlets import observe, link, Unicode, Bool, Any
#from itables import init_notebook_mode, show
from ipylab import JupyterFrontEnd
import time

In [7]:
import speech_recognition as speech_recog

In [8]:
from langchain_core.messages import trim_messages
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

In [9]:
import qiimp 

In [10]:
import warnings
# warnings.filterwarnings('ignore')

In [11]:
import logging
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(level=logging.WARNING, format='%(levelname)s: %(message)s')

## Model creation

In [12]:
# assert g_chosen_llm in os.environ, f"Please set the {g_chosen_llm} environment variable."

In [13]:
if g_chosen_llm == g_CHATGPT:
    from langchain_openai import ChatOpenAI
    _g_chat_model = ChatOpenAI(model="gpt-4o-mini")
elif g_chosen_llm == g_GEMINI:
    from langchain_google_genai import ChatGoogleGenerativeAI
    _g_chat_model = ChatGoogleGenerativeAI(
        model="gemini-1.5-pro",
        temperature=0,
        max_tokens=None,
        timeout=None,
        max_retries=2,
    )
else:
    raise ValueError(f"Unrecognized llm model '{g_chosen_llm}'")

## Chat creation

In [14]:
def _start_chat(model, custom_prompt):    
    # Define trimmer
    # count each message as 1 "token" (token_counter=len) and 
    #keep only the last x messages
    trimmer = trim_messages(strategy="last", max_tokens=g_num_msgs_in_history, 
                            token_counter=len)
    
    workflow = StateGraph(state_schema=MessagesState)
    
    # Define the function that calls the model
    def call_model(state: MessagesState):
        trimmed_messages = trimmer.invoke(state["messages"])
        system_prompt = custom_prompt
        messages = [SystemMessage(content=system_prompt)] + trimmed_messages
        response = model.invoke(messages)
        return {"messages": response}
    
    
    # Define the node and edge
    workflow.add_node("model", call_model)
    workflow.add_edge(START, "model")
    
    # Add simple in-memory checkpointer
    memory = MemorySaver()
    app = workflow.compile(checkpointer=memory)
    return app

In [15]:
_g_chat = _start_chat(_g_chat_model, g_ds_prompt)

In [16]:
# AB: Moved front end helpers later

## Prompt engineering

In [17]:
# ai helper prompts
g_unique_set_prefix = "for column named"
g_col_check_prefix = "check column named"
g_code_prefix = "write code to"

# AB: modified and extended
# function helper prompts
g_summarize_statement = "summarize table"
g_summarize_col_statement = "summarize column named"
g_explore_col_statement = "explore column named"

# button prompts
g_add_cell_statement = "add cell"
g_copy_last_statement = "copy it"
g_run_last_statement = "now run it"
g_revert_df_statement = "revert dataframe"

In [18]:
def _remove_known_str(a_prompt, known_str):
    return a_prompt.replace(known_str, "").strip()


def _change_first_char(a_str, upper=True):
    if upper:
        a_lambda = lambda x: x.groups()[0].upper()
    else:
        a_lambda = lambda x: x.groups()[0].lower()
        
    # affect ONLY first letter, leave all the rest alone
    # (so different that a_str.capitialize() or a_str.title())
    return re.sub('([a-zA-Z])', a_lambda, a_str, 1)   


def _expand_prompt(user_prompt, df_name, a_df):
    df_prompt = f"You are given the '{df_name}' dataframe with columns {list(a_df.columns)}. "
    
    if user_prompt.startswith(g_col_check_prefix):
        col_name = _remove_known_str(user_prompt, g_col_check_prefix)
        col_name = col_name.replace(" ", "_")
        user_prompt = _get_explore_col_prompt(col_name)     
    # end if starts with check prefix--which itself adds a unique set prefix
    
    if user_prompt.startswith(g_unique_set_prefix):
        new_prompt = _get_unique_set_prompt(user_prompt, a_df)
    elif user_prompt.startswith(g_code_prefix):
        new_prompt = _get_code_prompt(user_prompt)        
    else:
        new_prompt = user_prompt

    new_prompt = df_prompt + new_prompt

    return new_prompt


def _get_explore_col_prompt(col_name):
    explore_prompt = "suggest the appropriate Pandas data type for the values in this column, very briefly hypothesize about what they represent, and indicate if any look invalid or unexpected." # Please keep responses short and to the point."
    col_prompt = f"{g_unique_set_prefix} {col_name}, {explore_prompt}"
    return col_prompt


def _get_unique_set_prompt(user_prompt, a_df):
    err_msg = f"Please phrase your question as {g_unique_set_prefix} <col_name>, <question about unique values of column>"
    
    if not user_prompt.startswith(g_unique_set_prefix):
        return err_msg

    prompt_split = user_prompt.split(",")
    if len(prompt_split) < 2:
        return err_msg

    col_name = prompt_split[0].replace(g_unique_set_prefix, "").strip()
    new_prompt = f"For the column named '{col_name}' containing the set of values  {set(a_df[col_name])}, {','.join(prompt_split[1:])}"
    return new_prompt


def _get_code_prompt(user_prompt):
    err_msg = f"Please phrase your question as {g_code_prefix} <perform some operation>"
    
    if not user_prompt.startswith(g_code_prefix):
        return err_msg

    new_prompt = f"{user_prompt} Do not include any non-comment explanations, import statements, or the instantiation of the dataframe. Do not include markdown formatting in your output. Provide runnable code as output."
    return new_prompt

## State management

In [19]:
_g_last_code_out = {}
_g_last_working_df = {}

In [20]:
_g_LAST_CODE_NAME = "AI-generated code"
_g_LAST_DF_NAME = "g_working_df"   

# AB: added get-default functions
def _get_df_name(df_name):
    return df_name if df_name is not None else _g_LAST_DF_NAME


# TODO: this could be used to replace a lot of existing checks
def _get_df(a_df):
    return a_df if a_df is not None else g_working_df


def _save_state(state_dict, obj_to_save, state_name, use_last_execution_num=False):
    # this does NOT copy the input obj_to_save before saving it--that should happen outside this call, if needed
    execution_num = get_ipython().execution_count
    if use_last_execution_num:
        execution_num = execution_num - 1
    
    if execution_num in state_dict:
        warnings.warn(f"{state_name} already contains state for a cell with execution number {execution_num}, which will be overwritten.")
    state_dict[execution_num] = obj_to_save
    return state_dict


def _get_last_state(state_dict):
    last_value = None
    if state_dict is not None and len(state_dict)>0:
        last_key, last_value = next(reversed(state_dict.items()))
    return last_value


def _df_changed():
    #print("in _df_changed")
    last_working_df = _get_last_state(_g_last_working_df)

    if last_working_df is not None:
        #print(f"g_working_df cols: {g_working_df.columns}")
        #print(f"last_working_df cols: {last_working_df.columns}")
        if not last_working_df.equals(g_working_df):
            #print("are different")
            return True
    return False


def _revert_df():
    global g_working_df
    msg = f"There is no {_g_LAST_DF_NAME} state stored." 

    last_working_df = _get_last_state(_g_last_working_df)
    if last_working_df is not None:
        if _df_changed():
            store_working_df()
            g_working_df = last_working_df
            msg = f"{_g_LAST_DF_NAME} reverted to last saved state."
        else:
            msg = f"{_g_LAST_DF_NAME} has not changed since last saved state."
    return msg


def store_working_df(a_df=None, use_last_execution_num=False):
    if a_df is None:
        a_df = g_working_df.copy()
    _save_state(_g_last_working_df, a_df, _g_LAST_DF_NAME, use_last_execution_num=use_last_execution_num)


def revert_df():
    statement = _revert_df()
    return statement


# decorator
def stateful(func):
    def wrapper(*args, **kwargs):
        if _df_changed():
            store_working_df(use_last_execution_num=True)
        func(*args, **kwargs)
    return wrapper

## Front-end helpers

In [21]:
_g_front_end = JupyterFrontEnd()

In [40]:
def _set_raw():
    _g_front_end.commands.execute('notebook:change-cell-to-raw')
    

def _insert_and_populate(statement=None, move_up=True):
    _g_front_end.commands.execute('notebook:insert-cell-below')
    if statement is not None:
        time.sleep(0.2)
        _g_front_end.commands.execute('notebook:replace-selection', { 'text': statement})
        _g_front_end.commands.execute('notebook:enter-edit-mode')

    if move_up:
        _g_front_end.commands.execute('notebook:move-cell-up') 
        

def _insert_and_run(statement=None, move_up=True):
    _insert_and_populate(statement, move_up)
    
    if statement is not None:
        _g_front_end.commands.execute('notebook:run-cell-and-select-next')  
        _g_front_end.commands.execute('notebook:enter-edit-mode')


# AB: new
def _set_query_to_raw(up_twice=False):
    # move the cursor up to the query cell (the one running right now)
    # and set *it* to raw
    _g_front_end.commands.execute('notebook:move-cursor-up')
    if up_twice:
        _g_front_end.commands.execute('notebook:move-cursor-up')
    _set_raw()

    # move the cursor down TWICE: once from the query cell to the
    # cell holding its output (either a raw cell or a created code cell), 
    # and once more to whatever comes next
    _g_front_end.commands.execute('notebook:move-cursor-down')
    _g_front_end.commands.execute('notebook:move-cursor-down')    


def _set_query_and_output_to_raw(output):
    # creates a (code) cell containing the (non-code) output;
    # cursor is now on that new cell
    _insert_and_populate(output)

    # change now-current (new) cell to raw format
    _set_raw()

    _set_query_to_raw()


def _set_query_to_raw_and_run_code_cell(code_str):
    # print(code_str)
    _save_state(_g_last_code_out, code_str, _g_LAST_CODE_NAME)    # This isn't working correctly
    _insert_and_run(code_str, move_up=True)
    _set_query_to_raw(up_twice=True)  

## Chat helper function creation

In [25]:
# Instantiating an empty notebook allows prompt methods to bind to the 
# variable so the can use the real contents later without needing to be 
# passed an argument
g_working_df = pd.DataFrame()

In [26]:
# AB: reordered
@stateful
def find_problem_headers(a_df=None):
    a_df = a_df if a_df is not None else g_working_df

    invalid_cols = a_df.columns[a_df.columns.str.contains(r'[^a-zA-Z0-9._ ]', regex=True)]
    if len(invalid_cols) == 0:
        out = "No invalid column headers found."
    else:
        out = invalid_cols
    print(out)


@stateful
def scrub_headers(a_df=None, lcase_headers=True):
    a_df = a_df if a_df is not None else g_working_df

    a_df.columns = a_df.columns.str.replace(r'[^a-zA-Z0-9]', '_', regex=True)  
    a_df.columns = a_df.columns.str.replace(r'__+', '_', regex=True)  
    a_df.columns = a_df.columns.str.strip('_')
    if lcase_headers:
        a_df.columns = a_df.columns.str.lower()
    print(a_df.columns)


@stateful
def find_problem_records(a_df=None):
    a_df = a_df if a_df is not None else g_working_df

    # get records with leading or trailing spaces in any field
    problem_records = a_df[a_df.apply(lambda x: x.str.contains(r'^\s|\s$', na=False).any(), axis=1)]
    if len(problem_records) == 0:
        print("No problem records found.")
    else:
        display(problem_records)
        return problem_records


@stateful
def scrub_problem_records():
    global g_working_df
    
    # Remove leading or trailing spaces from any field in the dataframe
    g_working_df = g_working_df.map(lambda x: x.strip() if isinstance(x, str) else x)
    print("Problem records scrubbed.")


# AB: modified params list, added line for input coltype, modified wording for df coltype
def _deterministic_summarize_col(col_name, col_type_names, a_df=None, max_items_shown=None):
    a_df = _get_df(a_df)
    max_items_shown = max_items_shown if max_items_shown is not None else 10
    
    lines = []
    lines.append(f"{col_name}")
    lines.append("================")
    # print(col_type_names, col_name)
    lines.append(f"{col_type_names[col_name]} type")
    a_col = a_df[col_name]
    
    summary = []
    col_uniques = a_col.unique()
    count_uniques = len(col_uniques)
    if a_col.is_unique:
        summary.append(f"All {count_uniques} values are unique.")
    else: 
        summary.append(f"There are {count_uniques} unique value(s) in {len(a_col)} total values.")
    # end if all are unique
        
    caveat = f"first {max_items_shown} " if count_uniques > max_items_shown else ""
    summary.append(f"The {caveat}unique value(s):{col_uniques[:max_items_shown]}.")
    lines.append(" ".join(summary))

    lines.append(f"The current dataframe datatype is {a_col.dtype}.")
    lines.append(" ")
    return lines

In [56]:
col_type_names = _ai_get_col_types('g_working_df', g_working_df)
col_type_names

{'LibraryID': 'identifier',
 'Gender': 'categorical',
 'Year of Birth': 'numeric',
 'Ethnicity': 'categorical',
 'Country of Birth': 'categorical',
 'Weight (kg)': 'numeric',
 'Height (cm)': 'numeric',
 'Alcohol': 'categorical',
 'Passive Smoking': 'categorical',
 'Presence of pets': 'categorical',
 'FLG Genotype': 'categorical',
 'Blomia tropicalis (dust mite)': 'categorical',
 'Dermatophagoides pteronyssinus (dust mite)': 'categorical',
 'Elaeis guineensis (oil palm pollen)': 'categorical',
 'Curvularia spp. (fungus)': 'categorical',
 'Skin Prick Test (≥3+)': 'categorical',
 'Asthma Status': 'categorical',
 'AR Status': 'categorical',
 'AD Status': 'categorical',
 'Sample Type': 'categorical',
 'Sampling_Area': 'categorical',
 'Sampling_Method': 'categorical'}

In [57]:
col_name = 'Ethnicity'
lines = _deterministic_summarize_col(col_name, col_type_names, a_df=None, max_items_shown=None)
lines

['Ethnicity',
 'categorical type',
 "There are 3 unique value(s) in 94 total values. The unique value(s):['Chinese' 'Caucasian' nan].",
 'The current dataframe datatype is object.',
 ' ']

In [27]:
# AB: moved some private _ask functions up up here to allow
# earlier use, added a few
def _ask(user_prompt, df_name=None, a_df=None):
    df_name = _get_df_name(df_name)
    a_df = _get_df(a_df)
    
    new_prompt = _expand_prompt(user_prompt, df_name, a_df)    
    result = _g_chat.invoke(
        {"messages": [HumanMessage(content=new_prompt)]},
        config={"configurable": {"thread_id": "1"}},
    )
    return new_prompt, result

def _get_result_text(a_result):
    return a_result.get("messages")[-1].content


def _clean_answer(answer_str):
    answer_str = re.sub("^```python\n", "", answer_str)
    answer_str = re.sub("\n```$", "", answer_str)   
    return answer_str
    

def _get_ask_txt(ask_result):  
    output = _get_result_text(ask_result)
    output = _clean_answer(output) 
    return output

In [28]:
# AB: Placeholder for DP's ai-type-guessing function
# TODO: replace

_id_type = 'identifier'
_cat_type = 'categorical'
_num_type = 'numeric'

# def _ai_get_col_types(df_name, a_df, col_names_list=None):
#     return {"sample_name":"string", "subjectid": "string", "body_site": "string", "age": "numeric", "sex": "string"}

def _ai_get_col_type_one(col_name, a_df=None, df_name=None):
    
    col = a_df[col_name]
    _, result = _ask(f'based on the following entries: {col}, in a column named {col_name} '
            f'is this categorical or numeric or identifier. Give a one word answer '
            f'either categorical or numeric or identifier.',
            df_name=df_name, a_df=a_df)

    t = _get_result_text(result)

    col_type = _id_type
    for poss_type in [_num_type, _cat_type]:
    	if poss_type in t.lower():
    		col_type = poss_type
    		break

    return col_type


def _ai_get_col_types(df_name, a_df, col_names_list=None):

    df_name = _get_df_name(df_name)
    a_df = _get_df(a_df)
    
    # Ignore warnings generated by multiple calls to ask from the same cell
    # warnings.filterwarnings('ignore')

    col_types = {}
    for col_name in a_df.columns:

        col_type = _ai_get_col_type_one(col_name, a_df=a_df, df_name=df_name)

        col_types[col_name] = col_type

    return col_types

In [29]:
# AB: Separated from contents of earlier deterministic prompts cell for clarity 
def _get_coltypes_str(col_types, col_names_list=None):
    if col_names_list == None:
        col_names_list = list(col_types.keys())
        
    min_col_types = {k: v for k, v in col_types.items() if k in col_names_list}
    min_coltypes_str = "{" + ", ".join(f'"{key}": "{value}"' for key, value 
                                   in min_col_types.items()) + "}"
    return min_coltypes_str


def _get_ai_coltypes_str(df_name, a_df, col_names_list=None):
    col_types = _ai_get_col_types(df_name, a_df, col_names_list)
    coltypes_str = _get_coltypes_str(col_types, col_names_list)
    return coltypes_str


def _ai_summarize_col(col_name, df_name, a_df, max_items_shown=None):
    one_coltypes_str = _get_ai_coltypes_str(df_name, a_df, col_names_list=[col_name])
    result = f"summarize_col('{col_name}', {one_coltypes_str}, df_name='{df_name}', a_df={df_name}, max_items_shown={max_items_shown})"
    _set_query_to_raw_and_run_code_cell(result)
    # return result


def _ai_summarize(df_name, a_df, max_items_shown=None):
    coltypes_str = _get_ai_coltypes_str(df_name, a_df)
    result = f"summarize(col_types={coltypes_str}, df_name='{df_name}', a_df={df_name}, max_items_shown={max_items_shown})"
    _set_query_to_raw_and_run_code_cell(result)



@stateful
def summarize_col(col_name, col_types=None, df_name=None, a_df=None, max_items_shown=None):
    df_name = _get_df_name(df_name)
    a_df = _get_df(a_df)
    
    if col_types is None:
        _ai_summarize_col(col_name, df_name, a_df, max_items_shown)
    else:
        result = _deterministic_summarize_col(col_name, col_types, a_df, max_items_shown)
        result_str = "\n".join(result)
        print(result_str)    


@stateful
def summarize(df_name=None, a_df=None, col_types=None, max_items_shown=None):
    df_name = _get_df_name(df_name)
    a_df = _get_df(a_df)

    if col_types is None:
        _ai_summarize(df_name, a_df, max_items_shown)
    else:
        result = [f"The dataframe '{df_name}' has {len(a_df)} rows and {len(a_df.columns)} columns.", " "]
        for curr_col_name in a_df.columns:
            curr_result = _deterministic_summarize_col(curr_col_name, col_types, a_df, max_items_shown)
            result.extend(curr_result)
        # next column
        result_str = "\n".join(result)
        print(result_str)

In [63]:
a_df = g_working_df
df_name = 'g_working_df'
col_name = 'Gender'
_ai_summarize_col(col_name, df_name, a_df, max_items_shown=10)

summarize_col('Gender', {"Gender": "categorical"}, df_name='g_working_df', a_df=g_working_df, max_items_shown=10)


In [68]:
summarize_col('Gender', None, df_name='g_working_df', a_df=g_working_df, max_items_shown=10)

In [69]:
ask('write code to add two numbers')

In [30]:
# AB: modified many of these
def _explore_col(col_name, df_name=None, a_df=None):
    full_prompt = _get_explore_col_prompt(col_name)
    return _ask(full_prompt, df_name, a_df)


@stateful
def explore_col(col_name, df_name=None, a_df=None):
    _, result = _explore_col(col_name, df_name, a_df)
    result_txt = _get_result_text(result)
    _set_query_and_output_to_raw(result_txt)



def _run_predefined_prompts(user_prompt):
    unrecognized = False
    
    # Note, this is a full prompt, a not prompt prefix, or
    # else we might erroneously catch the user trying to 
    # start a request for something else
    if user_prompt == g_summarize_statement:
        summarize()
    elif user_prompt.startswith(g_summarize_col_statement):
        col_name = _remove_known_str(user_prompt, g_summarize_col_statement)
        summarize_col(col_name)
    elif user_prompt.startswith(g_explore_col_statement):
        col_name = _remove_known_str(user_prompt, g_explore_col_statement)
        explore_col(col_name)        
    else:
        unrecognized = True
    return unrecognized
    

@stateful
# DP: modified to optionally not track the calls to get column type
def ask(user_prompt, df_name=None, a_df=None, show_prompt=False, track=True):  
    orig_prompt = user_prompt

    try:
        user_prompt = _change_first_char(user_prompt, upper=False)
        unrecognized = _run_predefined_prompts(user_prompt)
    
        if unrecognized:
            # otherwise, ask AI
            prompt, result = _ask(user_prompt, df_name, a_df)  
            output = _get_ask_txt(result)
        
            if show_prompt:
                output = prompt + "\n\n" + output
            if track:
                _save_state(_g_last_code_out, output, _g_LAST_CODE_NAME)
            #print(output)

            _set_query_and_output_to_raw(output)

    #except Exception as ex:
    #    print(f"I don't understand the prompt:\n{orig_prompt}")
    finally:
        pass

In [30]:
summarize()

## Widget creation

In [31]:
# from https://github.com/jupyter-widgets/ipywidgets/issues/2962#issuecomment-724210454
class ConfirmationButton(widgets.HBox):
    button_style = Any(default_value='')
    description = Unicode()
    disabled = Bool()
    icon = Unicode()
    layout = Any()
    style = Any()
    tooltip = Unicode()
    
    def  __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._button = widgets.Button(**kwargs)
        self._confirm_btn = widgets.Button(description='Confirm', icon='check', 
                                           button_style='success', layout=dict(width='auto'))
        self._cancel_btn = widgets.Button(description='Cancel', icon='times', 
                                          button_style='warning', layout=dict(width='auto'))
        self._button.on_click(self._on_btn_click)
        self._cancel_btn.on_click(self._on_btn_click)
        self._confirm_btn.on_click(self._on_btn_click)
        self.children = [self._button]
        for key in self._button.keys:
            if key[0]!='_':
                link((self._button,key), (self, key))
        
    def on_click(self, *args, **kwargs):
        self._confirm_btn.on_click(*args, **kwargs)
        
    def _on_btn_click(self, b):
        if b==self._button:
            self.children = [self._confirm_btn, self._cancel_btn]
        else:
            self.children = [self._button]

In [32]:
def _add_cell(a_button=None):
    #_g_front_end.commands.execute('notebook:insert-cell-below')
    #_g_front_end.commands.execute('notebook:move-cell-up')  
    _insert_and_run(None, move_up=True)
    
def _revert_df_from_button(a_button=None, move_up=True):
    statement = revert_df()
    _insert_and_run(statement, move_up=move_up) 


def _copy_or_run_suggestion(a_button, move_up, copy_only=True):
    statement = _get_last_state(_g_last_code_out)
    statement = statement if statement else f'There is no {_g_LAST_CODE_NAME} stored.' 
    if copy_only:
        _insert_and_populate(statement, move_up=move_up)
    else:
        _insert_and_run(statement, move_up=move_up)  

@stateful
def copy_suggestion(a_button=None, move_up=True):
    _copy_or_run_suggestion(a_button, move_up, copy_only=True)
    _set_query_to_raw(up_twice=False)


@stateful
def run_suggestion(a_button=None, move_up=True):
    _copy_or_run_suggestion(a_button, move_up, copy_only=False)
    _set_query_to_raw(up_twice=True)

In [33]:
def _add_first_button(a_widget, a_button):
    curr_buttons = list(a_widget.children)
    curr_buttons.insert(0, a_button)
    return tuple(curr_buttons)  

In [34]:
_g_add_cell_button = widgets.Button(
    description="Add Cell",
    button_style="primary",  # full blue
    tooltip="Add an empty code cell",
    icon="plus"
)
_g_add_cell_button.on_click(_add_cell)


_g_copy_suggestion_button = widgets.Button(
    description="Copy Suggestion",
    tooltip="Copy AI-generated code to a new cell",
    icon="copy"
)
_g_copy_suggestion_button.style.button_color = 'lightgreen'
_g_copy_suggestion_button.on_click(copy_suggestion)


_g_run_suggestion_button = widgets.Button(
    description="Run Suggestion",
    button_style="info",  # light blue
    tooltip="Run last AI-generated code in a new cell",
    icon="run"
)
_g_run_suggestion_button.on_click(run_suggestion)

_g_undo_button = ConfirmationButton(
    description='Revert Df', 
    tooltip="Revert dataframe to last stored state",
    button_style="warning"  # red
)
_g_undo_button.on_click(_revert_df_from_button)


g_buttons = widgets.HBox([_g_add_cell_button, _g_copy_suggestion_button, _g_run_suggestion_button, _g_undo_button])

In [35]:
_g_unrecognized_msg = "I didn't catch that."

if g_use_speech:
    _g_speech_recognizer = speech_recog.Recognizer()
    
    # lifted from https://youtu.be/2kSPbH4jWME
    _g_record_button = widgets.Button(
        description="Record",
        disabled=False,
        button_style="success",  # full green
        icon="microphone"
    )
    
    def _record_audio(a_button):
        a_button.description = "Recording"
        with speech_recog.Microphone() as source: 
            # listen for 10 seconds for speech to start, 
            # listen for 10 seconds after speech pauses for it to restart
            audio = _g_speech_recognizer.listen(source, 10, 10)
        
        try:
            txt = _g_speech_recognizer.recognize_google(audio)
        except speech_recog.UnknownValueError:
            txt = f"print('{_g_unrecognized_msg}')"
    
        a_button.description = "Record"

        if txt == g_add_cell_statement:
            _add_cell()
        elif txt == g_run_last_statement:
            run_suggestion()
        elif txt == g_revert_df_statement:
            _revert_df()
        else:
            statement = f"ask('{txt}')"
            _insert_and_run(statement)


    _g_record_button.on_click(_record_audio)
    g_buttons.children = _add_first_button(g_buttons, _g_record_button)
# end if use speech

## Dataframe helpers

In [36]:
g_TAB_SEP = "tab"
g_COMMA_SEP = "comma"

#pd.set_option('display.max_rows', None)
#pd.set_option('display.max_columns', None)


def load_df(fp, sep_name=g_TAB_SEP, dtype="str", override=False):
    global g_working_df
    if not override:
        proposed_sep = None
        if fp.endswith(".csv") and sep_name==g_TAB_SEP:
            proposed_sep = g_COMMA_SEP
        elif (fp.endswith(".txt") or fp.endswith(".tsv")) and sep_name==g_COMMA_SEP:
            proposed_sep = g_TAB_SEP
    
        if proposed_sep is not None:
            msg = (f"Are you sure this file shouldn't be loaded with a {proposed_sep}?\n"
                   f"If it should be, rerun `load_df` with the `sep_name` parameter "
                   f"set to {proposed_sep}.\n"
                   f"If not, you can run `load_df` with the `override` parameter "
                   f"set to True.")
            print(msg)
            return

    real_sep = None
    if sep_name == g_TAB_SEP:
        real_sep = "\t"
    elif sep_name == g_COMMA_SEP:
        real_sep = ","
    else:
        msg = (f"'{sep_name}' is an unrecognized separator type.  Please choose one of the "
               f" following recognized separators: {[g_TAB_SEP, g_COMMA_SEP]}.")
        print(msg)
        return 

    loaded_df = pd.read_csv(fp, sep=real_sep, dtype=dtype)

    if not override:
        if len(g_working_df) > 0:
            msg = ("This load will overwrite the current contents of g_working_df.\n"
                   "If you don't want to load these contents, copy g_working_df to another "
                   "dataframe variable before running this.\n"
                   "If you really don't care, rerun `load_df` with the `override` "
                   "parameter set to True.")
            print(msg)
            return
            
    g_working_df = loaded_df
    store_working_df()
    return g_working_df
    #display_df()


@stateful
def display_df():
    if g_working_df.shape[0] > 300 or g_working_df.shape[1] > 100:
        print("Dataframe is too large for interactive display; this is a partial visualization.")
        display(g_working_df)
    else:
        show(
            g_working_df,
            layout={"top1": "searchPanes"},
            searchPanes={"layout": "columns-3", "cascadePanes": True, "columns": [0]},  # not sure how to use this columns setting
            lengthMenu=[100, 200, 300],
            buttons=[
                "colvis",
                {"extend": "csvHtml5", "title": "metadata"},  # TODO: will want to autogenerate this name
                {"extend": "excelHtml5", "title": "metadata"},
            ],
            fixedColumns={"start": 1},
            scrollX=True,
            scrollY="200px", scrollCollapse=True, paging=False,
        )


# conda install dtale -c conda-forge
#import dtale
#d = dtale.show(g_working_df, host='localhost', hide_drop_rows=True, hide_header_editor=True, allow_cell_edits=False, hide_column_menus=True)
#d

## Interactive investigation

To talk to the AI, either type your request within the function `ask()` or, if voice control is enabled, clicking the record button (which will call the `ask` function with your spoken input).

Special statements:
* `check column named <column name>`: asks the AI to draw summarized conclusions about the column and its contents.
    * Example: `ask('check column named time initiate breast')`
* `write code to <description of action>`: asks the AI to limit its responses to code and comments only.
    * Example: `ask('write code to replace within one hour of birth with less one hour')`
* `now run it`: asks the AI to run the last code it wrote

In [37]:
external_metadata_fp = "/Users/abirmingham/Desktop/trpca/15613_PRJNA971252.csv"
qiita_metadata_fp = "/Users/abirmingham/Desktop/trpca/15613_20240711-144128.txt"
study_config_fp = '/Users/abirmingham/Desktop/trpca/trpca_study.yml'

In [37]:
external_metadata_fp = "./proof_of_concept_nb/15612_expanded_sample_info_10112024_PRJNA277905.csv"
qiita_metadata_fp = "./proof_of_concept_nb/15612_20240714-052306.txt"
study_config_fp = './proof_of_concept_nb/trpca_study.yml'

In [38]:
load_df(external_metadata_fp, sep_name="comma")

Unnamed: 0,LibraryID,Gender,Year of Birth,Ethnicity,Country of Birth,Weight (kg),Height (cm),Alcohol,Passive Smoking,Presence of pets,...,Dermatophagoides pteronyssinus (dust mite),Elaeis guineensis (oil palm pollen),Curvularia spp. (fungus),Skin Prick Test (≥3+),Asthma Status,AR Status,AD Status,Sample Type,Sampling_Area,Sampling_Method
0,WBE005,M,1986,Chinese,Indonesia,Not collected,Not collected,Non-drinker,No,No,...,0,0,0,No,Control,Control,Control,Control,Antecubital fossa,tape
1,WBE006,M,1986,Chinese,Indonesia,Not collected,Not collected,Non-drinker,No,No,...,0,0,0,No,Control,Control,Control,Control,Antecubital fossa,tape
2,WBE007,F,1987,Chinese,Singapore,47,Not collected,Non-drinker,No,No,...,0,0,0,No,Indeterminate,Indeterminate,Control,Control,Antecubital fossa,tape
3,WBE008,F,1987,Chinese,Singapore,47,Not collected,Non-drinker,No,No,...,0,0,0,No,Indeterminate,Indeterminate,Control,Control,Antecubital fossa,tape
4,WBE017,M,1988,Chinese,China,63,178,Occasionally,Yes,No,...,0,0,0,No,Control,Indeterminate,Control,Control,Antecubital fossa,tape
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
89,WOS013,M,1963,Caucasian,Poland,82,189,Occasionally,No,No,...,Indeterminate,Indeterminate,Indeterminate,Indeterminate,Control,Control,Control,Control,(Right) Antecubital Fossa,tape
90,WOS015,M,1963,Caucasian,Poland,82,189,Occasionally,No,No,...,Indeterminate,Indeterminate,Indeterminate,Indeterminate,Control,Control,Control,Control,(Right) Antecubital Fossa,cup scrub
91,WOS018,,,,,,,,,,...,,,,,,,,Gloves,,tape
92,WOS019,,,,,,,,,,...,,,,,,,,Bench Top,,tape


## TEMPORARY: Examples of code that preserves AI results while preventing their (non-deterministic) rerun

In [None]:
# example: summarize table using recognized phrase # DP: not working
ask('summarize table')

In [None]:
# example: summarize table using direct function call
summarize()

In [None]:
# example: summarize a column using recognized phrase
ask("summarize column named LibraryID")

In [None]:
# example: summarize a column using direct function call
summarize_col("Gender")

In [None]:
# example: explore a column using recognized phrase
ask("explore column named Gender")

In [None]:
# example: explore a column using direct function call
explore_col("Gender")

In [None]:
# example: ask arbitrary question
ask("Describe how best to check an id column for uniqueness.")

In [None]:
# example: write arbitrary code
ask("write code to change name of Gender column to gender")

In [None]:
# example: run suggested code
run_suggestion()

In [None]:
# example: write arbitrary code #2
ask("write code to change name of LibraryID column to sample_name")

In [None]:
# example: copy but do not run suggestion
copy_suggestion()

In [None]:
#display(g_buttons)

In [None]:
#print(g_working_df[g_working_df['sampletype_shorthand'].isna()])

In [None]:
#qiimp_cols = qiimp.get_reserved_cols(g_working_df, {})

In [None]:
#qiimp.find_common_col_names(g_working_df.columns, qiimp_cols, [qiimp.HOSTTYPE_SHORTHAND_KEY, #qiimp.SAMPLETYPE_SHORTHAND_KEY], [qiimp.HOSTTYPE_SHORTHAND_KEY, qiimp.SAMPLETYPE_SHORTHAND_KEY])

In [None]:
#ask("Write code to rename the 'sex' and 'sample_name' columns with the prefix '_external' at the end.")

In [None]:
#summarize()

In [None]:
#g_working_df['sex'] = g_working_df['sex_external'].apply(qiimp.standardize_input_sex)

In [None]:
#g_working_df['host_life_stage'] = g_working_df['host_age'].apply(qiimp.set_life_stage_from_age_yrs)

In [None]:
# I know these values for this particular study.
# Note it had only healthy participants
#g_working_df["has_skin_disorder"] = False
#g_working_df["lesional_status"] = "not applicable"  
#g_working_df["ethnicity"] = "not provided"

In [None]:
#ext_metadata_df = g_working_df.copy()

In [None]:
#load_df(qiita_metadata_fp, sep_name="tab", override=True)

In [None]:
#summarize()

In [None]:
#qiimp.find_common_df_cols(g_working_df, ext_metadata_df)

In [None]:
#g_working_df.rename(columns={'host_age': 'host_age_external'}, inplace=True)

In [None]:
#qiimp.find_common_df_cols(g_working_df, ext_metadata_df)

In [None]:
#merged_df = qiimp.merge_many_to_one_metadata(g_working_df, ext_metadata_df, "sample_alias", "sample_name_external", "qiita", "external", "outer")
#merged_df

In [None]:
#merged_df[merged_df["sample_name"].isna()]

In [None]:
#merged_df[merged_df["sample_alias"].isna()]

In [None]:
#g_working_df = merged_df.copy()

In [None]:
# extended_df, validation_msgs = qiimp.get_extended_metadata_from_df_and_yaml(g_working_df, study_config_fp)

In [None]:
#extended_df[extended_df['qc_note'].notnull() & (extended_df['qc_note'] != '')]

In [None]:
# validation_msgs

In [None]:
#qiimp.write_metadata_results(extended_df, validation_msgs, '/Users/abirmingham/Desktop/trpca', "15613_merged_metadata_standardized", suppress_empty_fails=True)