Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 60 additions & 49 deletions promptsource/promptsource.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import textwrap

import streamlit as st
from pygments.formatters import HtmlFormatter
from session import _get_state
from utils import get_dataset, get_dataset_confs, list_datasets, removeHyphen, renameDatasetColumn, render_features

Expand Down Expand Up @@ -30,6 +33,11 @@ def reset_template_state():
st.set_page_config(layout="wide")
st.sidebar.title("PromptSource 🌸")

#
# Adds pygments styles to the page.
#
st.markdown("<style>" + HtmlFormatter().get_style_defs(".highlight") + "</style>", unsafe_allow_html=True)

#
# Loads template data
#
Expand Down Expand Up @@ -65,18 +73,19 @@ def reset_template_state():
priority_max_templates,
state,
)
counts = template_collection.get_templates_count()


#
# Select a dataset
# Select a dataset - starts with ag_news
#
dataset_key = st.sidebar.selectbox(
"Dataset",
dataset_list,
key="dataset_select",
index=12, # AG_NEWS
help="Select the dataset to work on.",
)
st.sidebar.write("HINT: Try ag_news or boolq for examples.")

# On dataset change, clear working priority dataset
# retained in the priority list with more than priority_max_templates
Expand Down Expand Up @@ -152,7 +161,9 @@ def reset_template_state():

st.sidebar.write(example)

col1, _, col2 = st.beta_columns([18, 1, 6])
st.markdown("## Template Creator")

col1a, col1b, _, col2 = st.beta_columns([9, 9, 1, 6])

# current_templates_key and state.templates_key are keys for the templates object
current_templates_key = (dataset_key, conf_option.name if conf_option else None)
Expand All @@ -162,64 +173,65 @@ def reset_template_state():
state.templates_key = current_templates_key
reset_template_state()

with col1:
with st.beta_expander("Select Template", expanded=True):
with st.form("new_template_form"):
new_template_name = st.text_input(
"New Template Name",
key="new_template",
value="",
help="Enter name and hit enter to create a new template.",
with col1a, st.form("new_template_form"):
new_template_name = st.text_input(
"Create a New Template",
key="new_template",
value="",
help="Enter name and hit enter to create a new template.",
)
new_template_submitted = st.form_submit_button("Create")
if new_template_submitted:
if new_template_name in dataset_templates.all_template_names:
st.error(
f"A template with the name {new_template_name} already exists "
f"for dataset {state.templates_key}."
)
new_template_submitted = st.form_submit_button("Create")
if new_template_submitted:
if new_template_name in dataset_templates.all_template_names:
st.error(
f"A template with the name {new_template_name} already exists "
f"for dataset {state.templates_key}."
)
elif new_template_name == "":
st.error("Need to provide a template name.")
else:
template = Template(new_template_name, "", "")
dataset_templates.add_template(template)
reset_template_state()
state.template_name = new_template_name
# Keep the current working dataset in priority list
if priority_filter:
state.working_priority_ds = dataset_key
else:
state.new_template_name = None

dataset_templates = template_collection.get_dataset(*state.templates_key)
template_list = dataset_templates.all_template_names
if state.template_name:
index = template_list.index(state.template_name)
elif new_template_name == "":
st.error("Need to provide a template name.")
else:
index = 0
state.template_name = st.selectbox(
"", template_list, key="template_select", index=index, help="Select the template to work on."
)

if st.button("Delete Template", key="delete_template"):
dataset_templates.remove_template(state.template_name)
template = Template(new_template_name, "", "")
dataset_templates.add_template(template)
reset_template_state()
state.template_name = new_template_name
# Keep the current working dataset in priority list
if priority_filter:
state.working_priority_ds = dataset_key
else:
state.new_template_name = None

with col1b, st.beta_expander("or Select Template", expanded=True):
dataset_templates = template_collection.get_dataset(*state.templates_key)
template_list = dataset_templates.all_template_names
if state.template_name:
index = template_list.index(state.template_name)
else:
index = 0
state.template_name = st.selectbox(
"", template_list, key="template_select", index=index, help="Select the template to work on."
)

if st.button("Delete Template", key="delete_template"):
dataset_templates.remove_template(state.template_name)
reset_template_state()

col1, _, col2 = st.beta_columns([18, 1, 6])
with col1:
if state.template_name is not None:
template = dataset_templates[state.template_name]
#
# If template is selected, displays template editor
#
with st.form("edit_template_form"):
updated_template_name = st.text_area("Name", height=40, value=template.name)
state.jinja = st.text_area("Template", height=40, value=template.jinja)

state.reference = st.text_area(
updated_template_name = st.text_input("Name", value=template.name)
state.reference = st.text_input(
"Template Reference",
help="Short description of the template and/or paper reference for the template.",
value=template.reference,
)

state.jinja = st.text_area("Template", height=40, value=template.jinja)

if st.form_submit_button("Save"):
if (
updated_template_name in dataset_templates.all_template_names
Expand All @@ -244,14 +256,13 @@ def reset_template_state():
with col2:
if state.template_name is not None:
st.empty()
st.subheader("Template Output")
template = dataset_templates[state.template_name]
prompt = template.apply(example)
st.write("Prompt + X")
st.text(prompt[0])
st.text(textwrap.fill(prompt[0], width=40))
if len(prompt) > 1:
st.write("Y")
st.text(prompt[1])
st.text(textwrap.fill(prompt[1], width=40))

#
# Must sync state at end
Expand Down
Loading